graphannis_core/graph/storage/
linear.rs

1use super::{
2    EdgeContainer, GraphStatistic, GraphStorage, deserialize_gs_field,
3    legacy::LinearGraphStorageV1, load_statistics_from_location, save_statistics_to_toml,
4    serialize_gs_field,
5};
6use crate::{
7    annostorage::{
8        AnnotationStorage, EdgeAnnotationStorage, NodeAnnotationStorage, inmemory::AnnoStorageImpl,
9    },
10    dfs::CycleSafeDFS,
11    errors::Result,
12    types::{Edge, NodeID, NumValue},
13};
14use rustc_hash::FxHashSet;
15use serde::{Deserialize, Serialize};
16use std::{clone::Clone, collections::HashMap, path::Path};
17
18#[derive(Serialize, Deserialize, Clone)]
19pub(crate) struct RelativePosition<PosT> {
20    pub root: NodeID,
21    pub pos: PosT,
22}
23
24#[derive(Serialize, Deserialize, Clone)]
25pub struct LinearGraphStorage<PosT: NumValue> {
26    node_to_pos: HashMap<NodeID, RelativePosition<PosT>>,
27    node_chains: HashMap<NodeID, Vec<NodeID>>,
28    annos: AnnoStorageImpl<Edge>,
29    stats: Option<GraphStatistic>,
30}
31
32impl<PosT> LinearGraphStorage<PosT>
33where
34    PosT: NumValue,
35{
36    pub fn new() -> LinearGraphStorage<PosT> {
37        LinearGraphStorage {
38            node_to_pos: HashMap::default(),
39            node_chains: HashMap::default(),
40            annos: AnnoStorageImpl::new(),
41            stats: None,
42        }
43    }
44
45    pub fn clear(&mut self) -> Result<()> {
46        self.node_to_pos.clear();
47        self.node_chains.clear();
48        self.annos.clear()?;
49        self.stats = None;
50        Ok(())
51    }
52
53    fn copy_edge_annos_for_node(
54        &mut self,
55        source_node: NodeID,
56        orig: &dyn GraphStorage,
57    ) -> Result<()> {
58        // Iterate over the outgoing edges of this node to add the edge
59        // annotations
60        let out_edges = orig.get_outgoing_edges(source_node);
61        for target in out_edges {
62            let target = target?;
63            let e = Edge {
64                source: source_node,
65                target,
66            };
67            let edge_annos = orig.get_anno_storage().get_annotations_for_item(&e)?;
68            for a in edge_annos {
69                self.annos.insert(e.clone(), a)?;
70            }
71        }
72        Ok(())
73    }
74}
75
76impl<PosT> Default for LinearGraphStorage<PosT>
77where
78    PosT: NumValue,
79{
80    fn default() -> Self {
81        LinearGraphStorage::new()
82    }
83}
84
85impl<PosT: 'static> EdgeContainer for LinearGraphStorage<PosT>
86where
87    PosT: NumValue,
88{
89    fn get_outgoing_edges<'a>(
90        &'a self,
91        node: NodeID,
92    ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
93        if let Some(pos) = self.node_to_pos.get(&node) {
94            // find the next node in the chain
95            if let Some(chain) = self.node_chains.get(&pos.root) {
96                let next_pos = pos.pos.clone() + PosT::one();
97                if let Some(next_pos) = next_pos.to_usize()
98                    && next_pos < chain.len()
99                {
100                    return Box::from(std::iter::once(Ok(chain[next_pos])));
101                }
102            }
103        }
104        Box::from(std::iter::empty())
105    }
106
107    fn get_ingoing_edges<'a>(
108        &'a self,
109        node: NodeID,
110    ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
111        if let Some(pos) = self.node_to_pos.get(&node) {
112            // find the previous node in the chain
113            if let Some(chain) = self.node_chains.get(&pos.root)
114                && let Some(pos) = pos.pos.to_usize()
115                && let Some(previous_pos) = pos.checked_sub(1)
116            {
117                return Box::from(std::iter::once(Ok(chain[previous_pos])));
118            }
119        }
120        Box::from(std::iter::empty())
121    }
122
123    fn has_ingoing_edges(&self, node: NodeID) -> Result<bool> {
124        let result = self
125            .node_to_pos
126            .get(&node)
127            .map(|pos| !pos.pos.is_zero())
128            .unwrap_or(false);
129        Ok(result)
130    }
131
132    fn source_nodes<'a>(&'a self) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
133        // use the node chains to find source nodes, but always skip the last element
134        // because the last element is only a target node, not a source node
135        let it = self
136            .node_chains
137            .iter()
138            .flat_map(|(_root, chain)| chain.iter().rev().skip(1))
139            .cloned()
140            .map(Ok);
141
142        Box::new(it)
143    }
144
145    fn root_nodes<'a>(&'a self) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
146        let it = self.node_chains.keys().copied().map(Ok);
147        Box::new(it)
148    }
149
150    fn get_statistics(&self) -> Option<&GraphStatistic> {
151        self.stats.as_ref()
152    }
153}
154
155impl<PosT: 'static> GraphStorage for LinearGraphStorage<PosT>
156where
157    for<'de> PosT: NumValue + Deserialize<'de> + Serialize,
158{
159    fn get_anno_storage(&self) -> &dyn EdgeAnnotationStorage {
160        &self.annos
161    }
162
163    fn serialization_id(&self) -> String {
164        format!("LinearO{}V1", std::mem::size_of::<PosT>() * 8)
165    }
166
167    fn load_from(location: &Path) -> Result<Self>
168    where
169        for<'de> Self: std::marker::Sized + Deserialize<'de>,
170    {
171        let legacy_path = location.join("component.bin");
172        let mut result: Self = if legacy_path.is_file() {
173            let component: LinearGraphStorageV1<PosT> =
174                deserialize_gs_field(location, "component")?;
175            Self {
176                node_to_pos: component.node_to_pos,
177                node_chains: component.node_chains,
178                annos: component.annos,
179                stats: component.stats.map(GraphStatistic::from),
180            }
181        } else {
182            let stats = load_statistics_from_location(location)?;
183            Self {
184                node_to_pos: deserialize_gs_field(location, "node_to_pos")?,
185                node_chains: deserialize_gs_field(location, "node_chains")?,
186                annos: deserialize_gs_field(location, "annos")?,
187                stats,
188            }
189        };
190
191        result.annos.after_deserialization();
192        Ok(result)
193    }
194
195    fn save_to(&self, location: &Path) -> Result<()> {
196        serialize_gs_field(&self.node_to_pos, "node_to_pos", location)?;
197        serialize_gs_field(&self.node_chains, "node_chains", location)?;
198        serialize_gs_field(&self.annos, "annos", location)?;
199        save_statistics_to_toml(location, self.stats.as_ref())?;
200        Ok(())
201    }
202
203    fn find_connected<'a>(
204        &'a self,
205        source: NodeID,
206        min_distance: usize,
207        max_distance: std::ops::Bound<usize>,
208    ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
209        if let Some(start_pos) = self.node_to_pos.get(&source)
210            && let Some(chain) = self.node_chains.get(&start_pos.root)
211            && let Some(offset) = start_pos.pos.to_usize()
212            && let Some(min_distance) = offset.checked_add(min_distance)
213            && min_distance < chain.len()
214        {
215            let max_distance = match max_distance {
216                std::ops::Bound::Unbounded => {
217                    return Box::new(chain[min_distance..].iter().map(|n| Ok(*n)));
218                }
219                std::ops::Bound::Included(max_distance) => offset + max_distance + 1,
220                std::ops::Bound::Excluded(max_distance) => offset + max_distance,
221            };
222            // clip to chain length
223            let max_distance = std::cmp::min(chain.len(), max_distance);
224            if min_distance < max_distance {
225                return Box::new(chain[min_distance..max_distance].iter().map(|n| Ok(*n)));
226            }
227        }
228        Box::new(std::iter::empty())
229    }
230
231    fn find_connected_inverse<'a>(
232        &'a self,
233        source: NodeID,
234        min_distance: usize,
235        max_distance: std::ops::Bound<usize>,
236    ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
237        if let Some(start_pos) = self.node_to_pos.get(&source)
238            && let Some(chain) = self.node_chains.get(&start_pos.root)
239            && let Some(offset) = start_pos.pos.to_usize()
240        {
241            let max_distance = match max_distance {
242                std::ops::Bound::Unbounded => 0,
243                std::ops::Bound::Included(max_distance) => offset.saturating_sub(max_distance),
244                std::ops::Bound::Excluded(max_distance) => offset.saturating_sub(max_distance + 1),
245            };
246
247            if let Some(min_distance) = offset.checked_sub(min_distance) {
248                if min_distance < chain.len() && max_distance <= min_distance {
249                    // return all entries in the chain between min_distance..max_distance (inclusive)
250                    return Box::new(chain[max_distance..=min_distance].iter().map(|n| Ok(*n)));
251                } else if max_distance < chain.len() {
252                    // return all entries in the chain between min_distance..max_distance
253                    return Box::new(chain[max_distance..chain.len()].iter().map(|n| Ok(*n)));
254                }
255            }
256        }
257        Box::new(std::iter::empty())
258    }
259
260    fn distance(&self, source: NodeID, target: NodeID) -> Result<Option<usize>> {
261        if source == target {
262            return Ok(Some(0));
263        }
264
265        if let (Some(source_pos), Some(target_pos)) =
266            (self.node_to_pos.get(&source), self.node_to_pos.get(&target))
267            && source_pos.root == target_pos.root
268            && source_pos.pos <= target_pos.pos
269        {
270            let diff = target_pos.pos.clone() - source_pos.pos.clone();
271            if let Some(diff) = diff.to_usize() {
272                return Ok(Some(diff));
273            }
274        }
275        Ok(None)
276    }
277
278    fn is_connected(
279        &self,
280        source: NodeID,
281        target: NodeID,
282        min_distance: usize,
283        max_distance: std::ops::Bound<usize>,
284    ) -> Result<bool> {
285        if let (Some(source_pos), Some(target_pos)) =
286            (self.node_to_pos.get(&source), self.node_to_pos.get(&target))
287            && source_pos.root == target_pos.root
288            && source_pos.pos <= target_pos.pos
289        {
290            let diff = target_pos.pos.clone() - source_pos.pos.clone();
291            if let Some(diff) = diff.to_usize() {
292                match max_distance {
293                    std::ops::Bound::Unbounded => {
294                        return Ok(diff >= min_distance);
295                    }
296                    std::ops::Bound::Included(max_distance) => {
297                        return Ok(diff >= min_distance && diff <= max_distance);
298                    }
299                    std::ops::Bound::Excluded(max_distance) => {
300                        return Ok(diff >= min_distance && diff < max_distance);
301                    }
302                }
303            }
304        }
305
306        Ok(false)
307    }
308
309    fn copy(
310        &mut self,
311        _node_annos: &dyn NodeAnnotationStorage,
312        orig: &dyn GraphStorage,
313    ) -> Result<()> {
314        self.clear()?;
315
316        // find all roots of the component
317        let roots: Result<FxHashSet<NodeID>> = orig.root_nodes().collect();
318        let roots = roots?;
319
320        for root_node in &roots {
321            self.copy_edge_annos_for_node(*root_node, orig)?;
322
323            // iterate over all edges beginning from the root
324            let mut chain: Vec<NodeID> = vec![*root_node];
325            let pos: RelativePosition<PosT> = RelativePosition {
326                root: *root_node,
327                pos: PosT::zero(),
328            };
329            self.node_to_pos.insert(*root_node, pos);
330
331            let dfs = CycleSafeDFS::new(orig.as_edgecontainer(), *root_node, 1, usize::MAX);
332            for step in dfs {
333                let step = step?;
334
335                self.copy_edge_annos_for_node(step.node, orig)?;
336
337                if let Some(pos) = PosT::from_usize(chain.len()) {
338                    let pos: RelativePosition<PosT> = RelativePosition {
339                        root: *root_node,
340                        pos,
341                    };
342                    self.node_to_pos.insert(step.node, pos);
343                }
344                chain.push(step.node);
345            }
346            chain.shrink_to_fit();
347            self.node_chains.insert(*root_node, chain);
348        }
349
350        self.node_chains.shrink_to_fit();
351        self.node_to_pos.shrink_to_fit();
352
353        self.stats = orig.get_statistics().cloned();
354        self.annos.calculate_statistics()?;
355
356        Ok(())
357    }
358
359    fn inverse_has_same_cost(&self) -> bool {
360        true
361    }
362
363    fn as_edgecontainer(&self) -> &dyn EdgeContainer {
364        self
365    }
366}
367
368#[cfg(test)]
369mod tests;