graphannis_core/graph/storage/
linear.rs1use super::{
2 EdgeContainer, GraphStatistic, GraphStorage, deserialize_gs_field,
3 legacy::LinearGraphStorageV1, load_statistics_from_location, save_statistics_to_toml,
4 serialize_gs_field,
5};
6use crate::{
7 annostorage::{
8 AnnotationStorage, EdgeAnnotationStorage, NodeAnnotationStorage, inmemory::AnnoStorageImpl,
9 },
10 dfs::CycleSafeDFS,
11 errors::Result,
12 types::{Edge, NodeID, NumValue},
13};
14use rustc_hash::FxHashSet;
15use serde::{Deserialize, Serialize};
16use std::{clone::Clone, collections::HashMap, path::Path};
17
18#[derive(Serialize, Deserialize, Clone)]
19pub(crate) struct RelativePosition<PosT> {
20 pub root: NodeID,
21 pub pos: PosT,
22}
23
24#[derive(Serialize, Deserialize, Clone)]
25pub struct LinearGraphStorage<PosT: NumValue> {
26 node_to_pos: HashMap<NodeID, RelativePosition<PosT>>,
27 node_chains: HashMap<NodeID, Vec<NodeID>>,
28 annos: AnnoStorageImpl<Edge>,
29 stats: Option<GraphStatistic>,
30}
31
32impl<PosT> LinearGraphStorage<PosT>
33where
34 PosT: NumValue,
35{
36 pub fn new() -> LinearGraphStorage<PosT> {
37 LinearGraphStorage {
38 node_to_pos: HashMap::default(),
39 node_chains: HashMap::default(),
40 annos: AnnoStorageImpl::new(),
41 stats: None,
42 }
43 }
44
45 pub fn clear(&mut self) -> Result<()> {
46 self.node_to_pos.clear();
47 self.node_chains.clear();
48 self.annos.clear()?;
49 self.stats = None;
50 Ok(())
51 }
52
53 fn copy_edge_annos_for_node(
54 &mut self,
55 source_node: NodeID,
56 orig: &dyn GraphStorage,
57 ) -> Result<()> {
58 let out_edges = orig.get_outgoing_edges(source_node);
61 for target in out_edges {
62 let target = target?;
63 let e = Edge {
64 source: source_node,
65 target,
66 };
67 let edge_annos = orig.get_anno_storage().get_annotations_for_item(&e)?;
68 for a in edge_annos {
69 self.annos.insert(e.clone(), a)?;
70 }
71 }
72 Ok(())
73 }
74}
75
76impl<PosT> Default for LinearGraphStorage<PosT>
77where
78 PosT: NumValue,
79{
80 fn default() -> Self {
81 LinearGraphStorage::new()
82 }
83}
84
85impl<PosT: 'static> EdgeContainer for LinearGraphStorage<PosT>
86where
87 PosT: NumValue,
88{
89 fn get_outgoing_edges<'a>(
90 &'a self,
91 node: NodeID,
92 ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
93 if let Some(pos) = self.node_to_pos.get(&node) {
94 if let Some(chain) = self.node_chains.get(&pos.root) {
96 let next_pos = pos.pos.clone() + PosT::one();
97 if let Some(next_pos) = next_pos.to_usize()
98 && next_pos < chain.len()
99 {
100 return Box::from(std::iter::once(Ok(chain[next_pos])));
101 }
102 }
103 }
104 Box::from(std::iter::empty())
105 }
106
107 fn get_ingoing_edges<'a>(
108 &'a self,
109 node: NodeID,
110 ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
111 if let Some(pos) = self.node_to_pos.get(&node) {
112 if let Some(chain) = self.node_chains.get(&pos.root)
114 && let Some(pos) = pos.pos.to_usize()
115 && let Some(previous_pos) = pos.checked_sub(1)
116 {
117 return Box::from(std::iter::once(Ok(chain[previous_pos])));
118 }
119 }
120 Box::from(std::iter::empty())
121 }
122
123 fn has_ingoing_edges(&self, node: NodeID) -> Result<bool> {
124 let result = self
125 .node_to_pos
126 .get(&node)
127 .map(|pos| !pos.pos.is_zero())
128 .unwrap_or(false);
129 Ok(result)
130 }
131
132 fn source_nodes<'a>(&'a self) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
133 let it = self
136 .node_chains
137 .iter()
138 .flat_map(|(_root, chain)| chain.iter().rev().skip(1))
139 .cloned()
140 .map(Ok);
141
142 Box::new(it)
143 }
144
145 fn root_nodes<'a>(&'a self) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
146 let it = self.node_chains.keys().copied().map(Ok);
147 Box::new(it)
148 }
149
150 fn get_statistics(&self) -> Option<&GraphStatistic> {
151 self.stats.as_ref()
152 }
153}
154
155impl<PosT: 'static> GraphStorage for LinearGraphStorage<PosT>
156where
157 for<'de> PosT: NumValue + Deserialize<'de> + Serialize,
158{
159 fn get_anno_storage(&self) -> &dyn EdgeAnnotationStorage {
160 &self.annos
161 }
162
163 fn serialization_id(&self) -> String {
164 format!("LinearO{}V1", std::mem::size_of::<PosT>() * 8)
165 }
166
167 fn load_from(location: &Path) -> Result<Self>
168 where
169 for<'de> Self: std::marker::Sized + Deserialize<'de>,
170 {
171 let legacy_path = location.join("component.bin");
172 let mut result: Self = if legacy_path.is_file() {
173 let component: LinearGraphStorageV1<PosT> =
174 deserialize_gs_field(location, "component")?;
175 Self {
176 node_to_pos: component.node_to_pos,
177 node_chains: component.node_chains,
178 annos: component.annos,
179 stats: component.stats.map(GraphStatistic::from),
180 }
181 } else {
182 let stats = load_statistics_from_location(location)?;
183 Self {
184 node_to_pos: deserialize_gs_field(location, "node_to_pos")?,
185 node_chains: deserialize_gs_field(location, "node_chains")?,
186 annos: deserialize_gs_field(location, "annos")?,
187 stats,
188 }
189 };
190
191 result.annos.after_deserialization();
192 Ok(result)
193 }
194
195 fn save_to(&self, location: &Path) -> Result<()> {
196 serialize_gs_field(&self.node_to_pos, "node_to_pos", location)?;
197 serialize_gs_field(&self.node_chains, "node_chains", location)?;
198 serialize_gs_field(&self.annos, "annos", location)?;
199 save_statistics_to_toml(location, self.stats.as_ref())?;
200 Ok(())
201 }
202
203 fn find_connected<'a>(
204 &'a self,
205 source: NodeID,
206 min_distance: usize,
207 max_distance: std::ops::Bound<usize>,
208 ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
209 if let Some(start_pos) = self.node_to_pos.get(&source)
210 && let Some(chain) = self.node_chains.get(&start_pos.root)
211 && let Some(offset) = start_pos.pos.to_usize()
212 && let Some(min_distance) = offset.checked_add(min_distance)
213 && min_distance < chain.len()
214 {
215 let max_distance = match max_distance {
216 std::ops::Bound::Unbounded => {
217 return Box::new(chain[min_distance..].iter().map(|n| Ok(*n)));
218 }
219 std::ops::Bound::Included(max_distance) => offset + max_distance + 1,
220 std::ops::Bound::Excluded(max_distance) => offset + max_distance,
221 };
222 let max_distance = std::cmp::min(chain.len(), max_distance);
224 if min_distance < max_distance {
225 return Box::new(chain[min_distance..max_distance].iter().map(|n| Ok(*n)));
226 }
227 }
228 Box::new(std::iter::empty())
229 }
230
231 fn find_connected_inverse<'a>(
232 &'a self,
233 source: NodeID,
234 min_distance: usize,
235 max_distance: std::ops::Bound<usize>,
236 ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
237 if let Some(start_pos) = self.node_to_pos.get(&source)
238 && let Some(chain) = self.node_chains.get(&start_pos.root)
239 && let Some(offset) = start_pos.pos.to_usize()
240 {
241 let max_distance = match max_distance {
242 std::ops::Bound::Unbounded => 0,
243 std::ops::Bound::Included(max_distance) => offset.saturating_sub(max_distance),
244 std::ops::Bound::Excluded(max_distance) => offset.saturating_sub(max_distance + 1),
245 };
246
247 if let Some(min_distance) = offset.checked_sub(min_distance) {
248 if min_distance < chain.len() && max_distance <= min_distance {
249 return Box::new(chain[max_distance..=min_distance].iter().map(|n| Ok(*n)));
251 } else if max_distance < chain.len() {
252 return Box::new(chain[max_distance..chain.len()].iter().map(|n| Ok(*n)));
254 }
255 }
256 }
257 Box::new(std::iter::empty())
258 }
259
260 fn distance(&self, source: NodeID, target: NodeID) -> Result<Option<usize>> {
261 if source == target {
262 return Ok(Some(0));
263 }
264
265 if let (Some(source_pos), Some(target_pos)) =
266 (self.node_to_pos.get(&source), self.node_to_pos.get(&target))
267 && source_pos.root == target_pos.root
268 && source_pos.pos <= target_pos.pos
269 {
270 let diff = target_pos.pos.clone() - source_pos.pos.clone();
271 if let Some(diff) = diff.to_usize() {
272 return Ok(Some(diff));
273 }
274 }
275 Ok(None)
276 }
277
278 fn is_connected(
279 &self,
280 source: NodeID,
281 target: NodeID,
282 min_distance: usize,
283 max_distance: std::ops::Bound<usize>,
284 ) -> Result<bool> {
285 if let (Some(source_pos), Some(target_pos)) =
286 (self.node_to_pos.get(&source), self.node_to_pos.get(&target))
287 && source_pos.root == target_pos.root
288 && source_pos.pos <= target_pos.pos
289 {
290 let diff = target_pos.pos.clone() - source_pos.pos.clone();
291 if let Some(diff) = diff.to_usize() {
292 match max_distance {
293 std::ops::Bound::Unbounded => {
294 return Ok(diff >= min_distance);
295 }
296 std::ops::Bound::Included(max_distance) => {
297 return Ok(diff >= min_distance && diff <= max_distance);
298 }
299 std::ops::Bound::Excluded(max_distance) => {
300 return Ok(diff >= min_distance && diff < max_distance);
301 }
302 }
303 }
304 }
305
306 Ok(false)
307 }
308
309 fn copy(
310 &mut self,
311 _node_annos: &dyn NodeAnnotationStorage,
312 orig: &dyn GraphStorage,
313 ) -> Result<()> {
314 self.clear()?;
315
316 let roots: Result<FxHashSet<NodeID>> = orig.root_nodes().collect();
318 let roots = roots?;
319
320 for root_node in &roots {
321 self.copy_edge_annos_for_node(*root_node, orig)?;
322
323 let mut chain: Vec<NodeID> = vec![*root_node];
325 let pos: RelativePosition<PosT> = RelativePosition {
326 root: *root_node,
327 pos: PosT::zero(),
328 };
329 self.node_to_pos.insert(*root_node, pos);
330
331 let dfs = CycleSafeDFS::new(orig.as_edgecontainer(), *root_node, 1, usize::MAX);
332 for step in dfs {
333 let step = step?;
334
335 self.copy_edge_annos_for_node(step.node, orig)?;
336
337 if let Some(pos) = PosT::from_usize(chain.len()) {
338 let pos: RelativePosition<PosT> = RelativePosition {
339 root: *root_node,
340 pos,
341 };
342 self.node_to_pos.insert(step.node, pos);
343 }
344 chain.push(step.node);
345 }
346 chain.shrink_to_fit();
347 self.node_chains.insert(*root_node, chain);
348 }
349
350 self.node_chains.shrink_to_fit();
351 self.node_to_pos.shrink_to_fit();
352
353 self.stats = orig.get_statistics().cloned();
354 self.annos.calculate_statistics()?;
355
356 Ok(())
357 }
358
359 fn inverse_has_same_cost(&self) -> bool {
360 true
361 }
362
363 fn as_edgecontainer(&self) -> &dyn EdgeContainer {
364 self
365 }
366}
367
368#[cfg(test)]
369mod tests;