graphannis_core/graph/storage/
linear.rs1use super::{
2 deserialize_gs_field, legacy::LinearGraphStorageV1, load_statistics_from_location,
3 save_statistics_to_toml, serialize_gs_field, EdgeContainer, GraphStatistic, GraphStorage,
4};
5use crate::{
6 annostorage::{
7 inmemory::AnnoStorageImpl, AnnotationStorage, EdgeAnnotationStorage, NodeAnnotationStorage,
8 },
9 dfs::CycleSafeDFS,
10 errors::Result,
11 types::{Edge, NodeID, NumValue},
12};
13use rustc_hash::FxHashSet;
14use serde::{Deserialize, Serialize};
15use std::{clone::Clone, collections::HashMap, path::Path};
16
17#[derive(Serialize, Deserialize, Clone)]
18pub(crate) struct RelativePosition<PosT> {
19 pub root: NodeID,
20 pub pos: PosT,
21}
22
23#[derive(Serialize, Deserialize, Clone)]
24pub struct LinearGraphStorage<PosT: NumValue> {
25 node_to_pos: HashMap<NodeID, RelativePosition<PosT>>,
26 node_chains: HashMap<NodeID, Vec<NodeID>>,
27 annos: AnnoStorageImpl<Edge>,
28 stats: Option<GraphStatistic>,
29}
30
31impl<PosT> LinearGraphStorage<PosT>
32where
33 PosT: NumValue,
34{
35 pub fn new() -> LinearGraphStorage<PosT> {
36 LinearGraphStorage {
37 node_to_pos: HashMap::default(),
38 node_chains: HashMap::default(),
39 annos: AnnoStorageImpl::new(),
40 stats: None,
41 }
42 }
43
44 pub fn clear(&mut self) -> Result<()> {
45 self.node_to_pos.clear();
46 self.node_chains.clear();
47 self.annos.clear()?;
48 self.stats = None;
49 Ok(())
50 }
51
52 fn copy_edge_annos_for_node(
53 &mut self,
54 source_node: NodeID,
55 orig: &dyn GraphStorage,
56 ) -> Result<()> {
57 let out_edges = orig.get_outgoing_edges(source_node);
60 for target in out_edges {
61 let target = target?;
62 let e = Edge {
63 source: source_node,
64 target,
65 };
66 let edge_annos = orig.get_anno_storage().get_annotations_for_item(&e)?;
67 for a in edge_annos {
68 self.annos.insert(e.clone(), a)?;
69 }
70 }
71 Ok(())
72 }
73}
74
75impl<PosT> Default for LinearGraphStorage<PosT>
76where
77 PosT: NumValue,
78{
79 fn default() -> Self {
80 LinearGraphStorage::new()
81 }
82}
83
84impl<PosT: 'static> EdgeContainer for LinearGraphStorage<PosT>
85where
86 PosT: NumValue,
87{
88 fn get_outgoing_edges<'a>(
89 &'a self,
90 node: NodeID,
91 ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
92 if let Some(pos) = self.node_to_pos.get(&node) {
93 if let Some(chain) = self.node_chains.get(&pos.root) {
95 let next_pos = pos.pos.clone() + PosT::one();
96 if let Some(next_pos) = next_pos.to_usize() {
97 if next_pos < chain.len() {
98 return Box::from(std::iter::once(Ok(chain[next_pos])));
99 }
100 }
101 }
102 }
103 Box::from(std::iter::empty())
104 }
105
106 fn get_ingoing_edges<'a>(
107 &'a self,
108 node: NodeID,
109 ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
110 if let Some(pos) = self.node_to_pos.get(&node) {
111 if let Some(chain) = self.node_chains.get(&pos.root) {
113 if let Some(pos) = pos.pos.to_usize() {
114 if let Some(previous_pos) = pos.checked_sub(1) {
115 return Box::from(std::iter::once(Ok(chain[previous_pos])));
116 }
117 }
118 }
119 }
120 Box::from(std::iter::empty())
121 }
122
123 fn has_ingoing_edges(&self, node: NodeID) -> Result<bool> {
124 let result = self
125 .node_to_pos
126 .get(&node)
127 .map(|pos| !pos.pos.is_zero())
128 .unwrap_or(false);
129 Ok(result)
130 }
131
132 fn source_nodes<'a>(&'a self) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
133 let it = self
136 .node_chains
137 .iter()
138 .flat_map(|(_root, chain)| chain.iter().rev().skip(1))
139 .cloned()
140 .map(Ok);
141
142 Box::new(it)
143 }
144
145 fn root_nodes<'a>(&'a self) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
146 let it = self.node_chains.keys().copied().map(Ok);
147 Box::new(it)
148 }
149
150 fn get_statistics(&self) -> Option<&GraphStatistic> {
151 self.stats.as_ref()
152 }
153}
154
155impl<PosT: 'static> GraphStorage for LinearGraphStorage<PosT>
156where
157 for<'de> PosT: NumValue + Deserialize<'de> + Serialize,
158{
159 fn get_anno_storage(&self) -> &dyn EdgeAnnotationStorage {
160 &self.annos
161 }
162
163 fn serialization_id(&self) -> String {
164 format!("LinearO{}V1", std::mem::size_of::<PosT>() * 8)
165 }
166
167 fn load_from(location: &Path) -> Result<Self>
168 where
169 for<'de> Self: std::marker::Sized + Deserialize<'de>,
170 {
171 let legacy_path = location.join("component.bin");
172 let mut result: Self = if legacy_path.is_file() {
173 let component: LinearGraphStorageV1<PosT> =
174 deserialize_gs_field(location, "component")?;
175 Self {
176 node_to_pos: component.node_to_pos,
177 node_chains: component.node_chains,
178 annos: component.annos,
179 stats: component.stats.map(GraphStatistic::from),
180 }
181 } else {
182 let stats = load_statistics_from_location(location)?;
183 Self {
184 node_to_pos: deserialize_gs_field(location, "node_to_pos")?,
185 node_chains: deserialize_gs_field(location, "node_chains")?,
186 annos: deserialize_gs_field(location, "annos")?,
187 stats,
188 }
189 };
190
191 result.annos.after_deserialization();
192 Ok(result)
193 }
194
195 fn save_to(&self, location: &Path) -> Result<()> {
196 serialize_gs_field(&self.node_to_pos, "node_to_pos", location)?;
197 serialize_gs_field(&self.node_chains, "node_chains", location)?;
198 serialize_gs_field(&self.annos, "annos", location)?;
199 save_statistics_to_toml(location, self.stats.as_ref())?;
200 Ok(())
201 }
202
203 fn find_connected<'a>(
204 &'a self,
205 source: NodeID,
206 min_distance: usize,
207 max_distance: std::ops::Bound<usize>,
208 ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
209 if let Some(start_pos) = self.node_to_pos.get(&source) {
210 if let Some(chain) = self.node_chains.get(&start_pos.root) {
211 if let Some(offset) = start_pos.pos.to_usize() {
212 if let Some(min_distance) = offset.checked_add(min_distance) {
213 if min_distance < chain.len() {
214 let max_distance = match max_distance {
215 std::ops::Bound::Unbounded => {
216 return Box::new(chain[min_distance..].iter().map(|n| Ok(*n)));
217 }
218 std::ops::Bound::Included(max_distance) => {
219 offset + max_distance + 1
220 }
221 std::ops::Bound::Excluded(max_distance) => offset + max_distance,
222 };
223 let max_distance = std::cmp::min(chain.len(), max_distance);
225 if min_distance < max_distance {
226 return Box::new(
227 chain[min_distance..max_distance].iter().map(|n| Ok(*n)),
228 );
229 }
230 }
231 }
232 }
233 }
234 }
235 Box::new(std::iter::empty())
236 }
237
238 fn find_connected_inverse<'a>(
239 &'a self,
240 source: NodeID,
241 min_distance: usize,
242 max_distance: std::ops::Bound<usize>,
243 ) -> Box<dyn Iterator<Item = Result<NodeID>> + 'a> {
244 if let Some(start_pos) = self.node_to_pos.get(&source) {
245 if let Some(chain) = self.node_chains.get(&start_pos.root) {
246 if let Some(offset) = start_pos.pos.to_usize() {
247 let max_distance = match max_distance {
248 std::ops::Bound::Unbounded => 0,
249 std::ops::Bound::Included(max_distance) => {
250 offset.saturating_sub(max_distance)
251 }
252 std::ops::Bound::Excluded(max_distance) => {
253 offset.saturating_sub(max_distance + 1)
254 }
255 };
256
257 if let Some(min_distance) = offset.checked_sub(min_distance) {
258 if min_distance < chain.len() && max_distance <= min_distance {
259 return Box::new(
261 chain[max_distance..=min_distance].iter().map(|n| Ok(*n)),
262 );
263 } else if max_distance < chain.len() {
264 return Box::new(
266 chain[max_distance..chain.len()].iter().map(|n| Ok(*n)),
267 );
268 }
269 }
270 }
271 }
272 }
273 Box::new(std::iter::empty())
274 }
275
276 fn distance(&self, source: NodeID, target: NodeID) -> Result<Option<usize>> {
277 if source == target {
278 return Ok(Some(0));
279 }
280
281 if let (Some(source_pos), Some(target_pos)) =
282 (self.node_to_pos.get(&source), self.node_to_pos.get(&target))
283 {
284 if source_pos.root == target_pos.root && source_pos.pos <= target_pos.pos {
285 let diff = target_pos.pos.clone() - source_pos.pos.clone();
286 if let Some(diff) = diff.to_usize() {
287 return Ok(Some(diff));
288 }
289 }
290 }
291 Ok(None)
292 }
293
294 fn is_connected(
295 &self,
296 source: NodeID,
297 target: NodeID,
298 min_distance: usize,
299 max_distance: std::ops::Bound<usize>,
300 ) -> Result<bool> {
301 if let (Some(source_pos), Some(target_pos)) =
302 (self.node_to_pos.get(&source), self.node_to_pos.get(&target))
303 {
304 if source_pos.root == target_pos.root && source_pos.pos <= target_pos.pos {
305 let diff = target_pos.pos.clone() - source_pos.pos.clone();
306 if let Some(diff) = diff.to_usize() {
307 match max_distance {
308 std::ops::Bound::Unbounded => {
309 return Ok(diff >= min_distance);
310 }
311 std::ops::Bound::Included(max_distance) => {
312 return Ok(diff >= min_distance && diff <= max_distance);
313 }
314 std::ops::Bound::Excluded(max_distance) => {
315 return Ok(diff >= min_distance && diff < max_distance);
316 }
317 }
318 }
319 }
320 }
321
322 Ok(false)
323 }
324
325 fn copy(
326 &mut self,
327 _node_annos: &dyn NodeAnnotationStorage,
328 orig: &dyn GraphStorage,
329 ) -> Result<()> {
330 self.clear()?;
331
332 let roots: Result<FxHashSet<NodeID>> = orig.root_nodes().collect();
334 let roots = roots?;
335
336 for root_node in &roots {
337 self.copy_edge_annos_for_node(*root_node, orig)?;
338
339 let mut chain: Vec<NodeID> = vec![*root_node];
341 let pos: RelativePosition<PosT> = RelativePosition {
342 root: *root_node,
343 pos: PosT::zero(),
344 };
345 self.node_to_pos.insert(*root_node, pos);
346
347 let dfs = CycleSafeDFS::new(orig.as_edgecontainer(), *root_node, 1, usize::MAX);
348 for step in dfs {
349 let step = step?;
350
351 self.copy_edge_annos_for_node(step.node, orig)?;
352
353 if let Some(pos) = PosT::from_usize(chain.len()) {
354 let pos: RelativePosition<PosT> = RelativePosition {
355 root: *root_node,
356 pos,
357 };
358 self.node_to_pos.insert(step.node, pos);
359 }
360 chain.push(step.node);
361 }
362 chain.shrink_to_fit();
363 self.node_chains.insert(*root_node, chain);
364 }
365
366 self.node_chains.shrink_to_fit();
367 self.node_to_pos.shrink_to_fit();
368
369 self.stats = orig.get_statistics().cloned();
370 self.annos.calculate_statistics()?;
371
372 Ok(())
373 }
374
375 fn inverse_has_same_cost(&self) -> bool {
376 true
377 }
378
379 fn as_edgecontainer(&self) -> &dyn EdgeContainer {
380 self
381 }
382}
383
384#[cfg(test)]
385mod tests;