graphannis_core/graph/storage/
linear.rs

1use super::{
2    deserialize_gs_field, legacy::LinearGraphStorageV1, load_statistics_from_location,
3    save_statistics_to_toml, serialize_gs_field, EdgeContainer, GraphStatistic, GraphStorage,
4};
5use crate::{
6    annostorage::{
7        inmemory::AnnoStorageImpl, AnnotationStorage, EdgeAnnotationStorage, NodeAnnotationStorage,
8    },
9    dfs::CycleSafeDFS,
10    errors::Result,
11    types::{Edge, NodeID, NumValue},
12};
13use rustc_hash::FxHashSet;
14use serde::{Deserialize, Serialize};
15use std::{clone::Clone, collections::HashMap, path::Path};
16
17#[derive(Serialize, Deserialize, Clone)]
18pub(crate) struct RelativePosition<PosT> {
19    pub root: NodeID,
20    pub pos: PosT,
21}
22
23#[derive(Serialize, Deserialize, Clone)]
24pub struct LinearGraphStorage<PosT: NumValue> {
25    node_to_pos: HashMap<NodeID, RelativePosition<PosT>>,
26    node_chains: HashMap<NodeID, Vec<NodeID>>,
27    annos: AnnoStorageImpl<Edge>,
28    stats: Option<GraphStatistic>,
29}
30
31impl<PosT> LinearGraphStorage<PosT>
32where
33    PosT: NumValue,
34{
35    pub fn new() -> LinearGraphStorage<PosT> {
36        LinearGraphStorage {
37            node_to_pos: HashMap::default(),
38            node_chains: HashMap::default(),
39            annos: AnnoStorageImpl::new(),
40            stats: None,
41        }
42    }
43
44    pub fn clear(&mut self) -> Result<()> {
45        self.node_to_pos.clear();
46        self.node_chains.clear();
47        self.annos.clear()?;
48        self.stats = None;
49        Ok(())
50    }
51
52    fn copy_edge_annos_for_node(
53        &mut self,
54        source_node: NodeID,
55        orig: &dyn GraphStorage,
56    ) -> Result<()> {
57        // Iterate over the outgoing edges of this node to add the edge
58        // annotations
59        let out_edges = orig.get_outgoing_edges(source_node);
60        for target in out_edges {
61            let target = target?;
62            let e = Edge {
63                source: source_node,
64                target,
65            };
66            let edge_annos = orig.get_anno_storage().get_annotations_for_item(&e)?;
67            for a in edge_annos {
68                self.annos.insert(e.clone(), a)?;
69            }
70        }
71        Ok(())
72    }
73}
74
75impl<PosT> Default for LinearGraphStorage<PosT>
76where
77    PosT: NumValue,
78{
79    fn default() -> Self {
80        LinearGraphStorage::new()
81    }
82}
83
84impl<PosT: 'static> EdgeContainer for LinearGraphStorage<PosT>
85where
86    PosT: NumValue,
87{
88    fn get_outgoing_edges<'a>(
89        &'a self,
90        node: NodeID,
91    ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
92        if let Some(pos) = self.node_to_pos.get(&node) {
93            // find the next node in the chain
94            if let Some(chain) = self.node_chains.get(&pos.root) {
95                let next_pos = pos.pos.clone() + PosT::one();
96                if let Some(next_pos) = next_pos.to_usize() {
97                    if next_pos < chain.len() {
98                        return Box::from(std::iter::once(Ok(chain[next_pos])));
99                    }
100                }
101            }
102        }
103        Box::from(std::iter::empty())
104    }
105
106    fn get_ingoing_edges<'a>(
107        &'a self,
108        node: NodeID,
109    ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
110        if let Some(pos) = self.node_to_pos.get(&node) {
111            // find the previous node in the chain
112            if let Some(chain) = self.node_chains.get(&pos.root) {
113                if let Some(pos) = pos.pos.to_usize() {
114                    if let Some(previous_pos) = pos.checked_sub(1) {
115                        return Box::from(std::iter::once(Ok(chain[previous_pos])));
116                    }
117                }
118            }
119        }
120        Box::from(std::iter::empty())
121    }
122
123    fn has_ingoing_edges(&self, node: NodeID) -> Result<bool> {
124        let result = self
125            .node_to_pos
126            .get(&node)
127            .map(|pos| !pos.pos.is_zero())
128            .unwrap_or(false);
129        Ok(result)
130    }
131
132    fn source_nodes<'a>(&'a self) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
133        // use the node chains to find source nodes, but always skip the last element
134        // because the last element is only a target node, not a source node
135        let it = self
136            .node_chains
137            .iter()
138            .flat_map(|(_root, chain)| chain.iter().rev().skip(1))
139            .cloned()
140            .map(Ok);
141
142        Box::new(it)
143    }
144
145    fn root_nodes<'a>(&'a self) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
146        let it = self.node_chains.keys().copied().map(Ok);
147        Box::new(it)
148    }
149
150    fn get_statistics(&self) -> Option<&GraphStatistic> {
151        self.stats.as_ref()
152    }
153}
154
155impl<PosT: 'static> GraphStorage for LinearGraphStorage<PosT>
156where
157    for<'de> PosT: NumValue + Deserialize<'de> + Serialize,
158{
159    fn get_anno_storage(&self) -> &dyn EdgeAnnotationStorage {
160        &self.annos
161    }
162
163    fn serialization_id(&self) -> String {
164        format!("LinearO{}V1", std::mem::size_of::<PosT>() * 8)
165    }
166
167    fn load_from(location: &Path) -> Result<Self>
168    where
169        for<'de> Self: std::marker::Sized + Deserialize<'de>,
170    {
171        let legacy_path = location.join("component.bin");
172        let mut result: Self = if legacy_path.is_file() {
173            let component: LinearGraphStorageV1<PosT> =
174                deserialize_gs_field(location, "component")?;
175            Self {
176                node_to_pos: component.node_to_pos,
177                node_chains: component.node_chains,
178                annos: component.annos,
179                stats: component.stats.map(GraphStatistic::from),
180            }
181        } else {
182            let stats = load_statistics_from_location(location)?;
183            Self {
184                node_to_pos: deserialize_gs_field(location, "node_to_pos")?,
185                node_chains: deserialize_gs_field(location, "node_chains")?,
186                annos: deserialize_gs_field(location, "annos")?,
187                stats,
188            }
189        };
190
191        result.annos.after_deserialization();
192        Ok(result)
193    }
194
195    fn save_to(&self, location: &Path) -> Result<()> {
196        serialize_gs_field(&self.node_to_pos, "node_to_pos", location)?;
197        serialize_gs_field(&self.node_chains, "node_chains", location)?;
198        serialize_gs_field(&self.annos, "annos", location)?;
199        save_statistics_to_toml(location, self.stats.as_ref())?;
200        Ok(())
201    }
202
203    fn find_connected<'a>(
204        &'a self,
205        source: NodeID,
206        min_distance: usize,
207        max_distance: std::ops::Bound<usize>,
208    ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
209        if let Some(start_pos) = self.node_to_pos.get(&source) {
210            if let Some(chain) = self.node_chains.get(&start_pos.root) {
211                if let Some(offset) = start_pos.pos.to_usize() {
212                    if let Some(min_distance) = offset.checked_add(min_distance) {
213                        if min_distance < chain.len() {
214                            let max_distance = match max_distance {
215                                std::ops::Bound::Unbounded => {
216                                    return Box::new(chain[min_distance..].iter().map(|n| Ok(*n)));
217                                }
218                                std::ops::Bound::Included(max_distance) => {
219                                    offset + max_distance + 1
220                                }
221                                std::ops::Bound::Excluded(max_distance) => offset + max_distance,
222                            };
223                            // clip to chain length
224                            let max_distance = std::cmp::min(chain.len(), max_distance);
225                            if min_distance < max_distance {
226                                return Box::new(
227                                    chain[min_distance..max_distance].iter().map(|n| Ok(*n)),
228                                );
229                            }
230                        }
231                    }
232                }
233            }
234        }
235        Box::new(std::iter::empty())
236    }
237
238    fn find_connected_inverse<'a>(
239        &'a self,
240        source: NodeID,
241        min_distance: usize,
242        max_distance: std::ops::Bound<usize>,
243    ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
244        if let Some(start_pos) = self.node_to_pos.get(&source) {
245            if let Some(chain) = self.node_chains.get(&start_pos.root) {
246                if let Some(offset) = start_pos.pos.to_usize() {
247                    let max_distance = match max_distance {
248                        std::ops::Bound::Unbounded => 0,
249                        std::ops::Bound::Included(max_distance) => {
250                            offset.saturating_sub(max_distance)
251                        }
252                        std::ops::Bound::Excluded(max_distance) => {
253                            offset.saturating_sub(max_distance + 1)
254                        }
255                    };
256
257                    if let Some(min_distance) = offset.checked_sub(min_distance) {
258                        if min_distance < chain.len() && max_distance <= min_distance {
259                            // return all entries in the chain between min_distance..max_distance (inclusive)
260                            return Box::new(
261                                chain[max_distance..=min_distance].iter().map(|n| Ok(*n)),
262                            );
263                        } else if max_distance < chain.len() {
264                            // return all entries in the chain between min_distance..max_distance
265                            return Box::new(
266                                chain[max_distance..chain.len()].iter().map(|n| Ok(*n)),
267                            );
268                        }
269                    }
270                }
271            }
272        }
273        Box::new(std::iter::empty())
274    }
275
276    fn distance(&self, source: NodeID, target: NodeID) -> Result<Option<usize>> {
277        if source == target {
278            return Ok(Some(0));
279        }
280
281        if let (Some(source_pos), Some(target_pos)) =
282            (self.node_to_pos.get(&source), self.node_to_pos.get(&target))
283        {
284            if source_pos.root == target_pos.root && source_pos.pos <= target_pos.pos {
285                let diff = target_pos.pos.clone() - source_pos.pos.clone();
286                if let Some(diff) = diff.to_usize() {
287                    return Ok(Some(diff));
288                }
289            }
290        }
291        Ok(None)
292    }
293
294    fn is_connected(
295        &self,
296        source: NodeID,
297        target: NodeID,
298        min_distance: usize,
299        max_distance: std::ops::Bound<usize>,
300    ) -> Result<bool> {
301        if let (Some(source_pos), Some(target_pos)) =
302            (self.node_to_pos.get(&source), self.node_to_pos.get(&target))
303        {
304            if source_pos.root == target_pos.root && source_pos.pos <= target_pos.pos {
305                let diff = target_pos.pos.clone() - source_pos.pos.clone();
306                if let Some(diff) = diff.to_usize() {
307                    match max_distance {
308                        std::ops::Bound::Unbounded => {
309                            return Ok(diff >= min_distance);
310                        }
311                        std::ops::Bound::Included(max_distance) => {
312                            return Ok(diff >= min_distance && diff <= max_distance);
313                        }
314                        std::ops::Bound::Excluded(max_distance) => {
315                            return Ok(diff >= min_distance && diff < max_distance);
316                        }
317                    }
318                }
319            }
320        }
321
322        Ok(false)
323    }
324
325    fn copy(
326        &mut self,
327        _node_annos: &dyn NodeAnnotationStorage,
328        orig: &dyn GraphStorage,
329    ) -> Result<()> {
330        self.clear()?;
331
332        // find all roots of the component
333        let roots: Result<FxHashSet<NodeID>> = orig.root_nodes().collect();
334        let roots = roots?;
335
336        for root_node in &roots {
337            self.copy_edge_annos_for_node(*root_node, orig)?;
338
339            // iterate over all edges beginning from the root
340            let mut chain: Vec<NodeID> = vec![*root_node];
341            let pos: RelativePosition<PosT> = RelativePosition {
342                root: *root_node,
343                pos: PosT::zero(),
344            };
345            self.node_to_pos.insert(*root_node, pos);
346
347            let dfs = CycleSafeDFS::new(orig.as_edgecontainer(), *root_node, 1, usize::MAX);
348            for step in dfs {
349                let step = step?;
350
351                self.copy_edge_annos_for_node(step.node, orig)?;
352
353                if let Some(pos) = PosT::from_usize(chain.len()) {
354                    let pos: RelativePosition<PosT> = RelativePosition {
355                        root: *root_node,
356                        pos,
357                    };
358                    self.node_to_pos.insert(step.node, pos);
359                }
360                chain.push(step.node);
361            }
362            chain.shrink_to_fit();
363            self.node_chains.insert(*root_node, chain);
364        }
365
366        self.node_chains.shrink_to_fit();
367        self.node_to_pos.shrink_to_fit();
368
369        self.stats = orig.get_statistics().cloned();
370        self.annos.calculate_statistics()?;
371
372        Ok(())
373    }
374
375    fn inverse_has_same_cost(&self) -> bool {
376        true
377    }
378
379    fn as_edgecontainer(&self) -> &dyn EdgeContainer {
380        self
381    }
382}
383
384#[cfg(test)]
385mod tests;