aranya_runtime/storage/
memory.rs

1use alloc::{boxed::Box, collections::BTreeMap, string::String, sync::Arc, vec::Vec};
2use core::ops::{Bound, Deref};
3
4use buggy::{Bug, BugExt, bug};
5use vec1::Vec1;
6
7use crate::{
8    Address, Checkpoint, Command, CommandId, Fact, FactIndex, FactPerspective, GraphId, Keys,
9    Location, Perspective, PolicyId, Prior, Priority, Query, QueryMut, Revertable, Segment,
10    Storage, StorageError, StorageProvider,
11};
12
13#[derive(Debug)]
14pub struct MemCommand {
15    priority: Priority,
16    id: CommandId,
17    parent: Prior<Address>,
18    policy: Option<Box<[u8]>>,
19    data: Box<[u8]>,
20    max_cut: usize,
21}
22
23impl MemCommand {
24    fn from_cmd<C: Command>(command: &C, max_cut: usize) -> Self {
25        let policy = command.policy().map(Box::from);
26
27        MemCommand {
28            priority: command.priority(),
29            id: command.id(),
30            parent: command.parent(),
31            policy,
32            data: command.bytes().into(),
33            max_cut,
34        }
35    }
36}
37
38impl Command for MemCommand {
39    fn priority(&self) -> Priority {
40        self.priority.clone()
41    }
42
43    fn id(&self) -> CommandId {
44        self.id
45    }
46
47    fn parent(&self) -> Prior<Address> {
48        self.parent
49    }
50
51    fn policy(&self) -> Option<&[u8]> {
52        self.policy.as_deref()
53    }
54
55    fn bytes(&self) -> &[u8] {
56        &self.data
57    }
58
59    fn max_cut(&self) -> Result<usize, Bug> {
60        Ok(self.max_cut)
61    }
62}
63
64#[derive(Default)]
65pub struct MemStorageProvider {
66    storage: BTreeMap<GraphId, MemStorage>,
67}
68
69impl MemStorageProvider {
70    pub const fn new() -> MemStorageProvider {
71        MemStorageProvider {
72            storage: BTreeMap::new(),
73        }
74    }
75}
76
77impl StorageProvider for MemStorageProvider {
78    type Perspective = MemPerspective;
79    type Storage = MemStorage;
80    type Segment = MemSegment;
81
82    fn new_perspective(&mut self, policy_id: PolicyId) -> Self::Perspective {
83        MemPerspective::new_unrooted(policy_id)
84    }
85
86    fn new_storage(
87        &mut self,
88        update: Self::Perspective,
89    ) -> Result<(GraphId, &mut Self::Storage), StorageError> {
90        use alloc::collections::btree_map::Entry;
91
92        if update.commands.is_empty() {
93            return Err(StorageError::EmptyPerspective);
94        }
95        let graph_id = GraphId::from(update.commands[0].command.id.into_id());
96        let entry = match self.storage.entry(graph_id) {
97            Entry::Vacant(v) => v,
98            Entry::Occupied(_) => return Err(StorageError::StorageExists),
99        };
100
101        let mut storage = MemStorage::new();
102        let segment = storage.write(update)?;
103        storage.commit(segment)?;
104        Ok((graph_id, entry.insert(storage)))
105    }
106
107    fn get_storage(&mut self, graph: GraphId) -> Result<&mut Self::Storage, StorageError> {
108        self.storage
109            .get_mut(&graph)
110            .ok_or(StorageError::NoSuchStorage)
111    }
112
113    fn remove_storage(&mut self, graph: GraphId) -> Result<(), StorageError> {
114        self.storage
115            .remove(&graph)
116            .ok_or(StorageError::NoSuchStorage)?;
117
118        Ok(())
119    }
120
121    fn list_graph_ids(
122        &mut self,
123    ) -> Result<impl Iterator<Item = Result<GraphId, StorageError>>, StorageError> {
124        Ok(self.storage.keys().copied().map(Ok))
125    }
126}
127
128type FactMap = BTreeMap<Keys, Option<Box<[u8]>>>;
129type NamedFactMap = BTreeMap<String, FactMap>;
130
131pub struct MemStorage {
132    segments: Vec<MemSegment>,
133    commands: BTreeMap<CommandId, Location>,
134    head: Option<Location>,
135}
136
137impl MemStorage {
138    fn new() -> Self {
139        Self {
140            segments: Vec::new(),
141            commands: BTreeMap::new(),
142            head: None,
143        }
144    }
145
146    fn new_segment(
147        &mut self,
148        prior: Prior<Location>,
149        policy: PolicyId,
150        mut commands: Vec1<CommandData>,
151        facts: MemFactIndex,
152        max_cut: usize,
153    ) -> Result<MemSegment, StorageError> {
154        let index = self.segments.len();
155
156        for (i, command) in commands.iter_mut().enumerate() {
157            command.command.max_cut = max_cut.checked_add(i).assume("must not overflow")?;
158        }
159
160        let segment = MemSegmentInner {
161            prior,
162            index,
163            policy,
164            commands,
165            facts,
166        };
167
168        let cell = MemSegment::from(segment);
169        self.segments.push(cell.clone());
170
171        Ok(cell)
172    }
173}
174
175impl Drop for MemStorage {
176    // Ensure the segments are dropped high to low, which helps avoid a stack
177    // overflow on dropping really long Arc chains.
178    fn drop(&mut self) {
179        while self.segments.pop().is_some() {}
180    }
181}
182
183impl Storage for MemStorage {
184    type Perspective = MemPerspective;
185    type Segment = MemSegment;
186    type FactIndex = MemFactIndex;
187    type FactPerspective = MemFactPerspective;
188
189    fn get_command_id(&self, location: Location) -> Result<CommandId, StorageError> {
190        let segment = self.get_segment(location)?;
191        let command = segment
192            .get_command(location)
193            .ok_or(StorageError::CommandOutOfBounds(location))?;
194        Ok(command.id())
195    }
196
197    fn get_linear_perspective(
198        &self,
199        parent: Location,
200    ) -> Result<Option<Self::Perspective>, StorageError> {
201        let segment = self.get_segment(parent)?;
202        let command = segment
203            .get_command(parent)
204            .ok_or(StorageError::CommandOutOfBounds(parent))?;
205        let parent_addr = command.address()?;
206
207        let policy = segment.policy;
208        let prior_facts: FactPerspectivePrior = if parent == segment.head_location() {
209            segment.facts.clone().into()
210        } else {
211            let mut facts = MemFactPerspective::new(segment.facts.prior.clone().into());
212            for data in &segment.commands[..=parent.command] {
213                facts.apply_updates(&data.updates);
214            }
215            if facts.map.is_empty() {
216                facts.prior
217            } else {
218                facts.into()
219            }
220        };
221        let prior = Prior::Single(parent);
222        let parents = Prior::Single(parent_addr);
223
224        let max_cut = self
225            .get_segment(parent)?
226            .get_command(parent)
227            .assume("location must exist")?
228            .max_cut()?
229            .checked_add(1)
230            .assume("must not overflow")?;
231        let perspective = MemPerspective::new(prior, parents, policy, prior_facts, max_cut);
232
233        Ok(Some(perspective))
234    }
235
236    fn get_fact_perspective(
237        &self,
238        location: Location,
239    ) -> Result<Self::FactPerspective, StorageError> {
240        let segment = self.get_segment(location)?;
241
242        if location == segment.head_location()
243            || segment.commands.iter().all(|cmd| cmd.updates.is_empty())
244        {
245            return Ok(MemFactPerspective::new(segment.facts.clone().into()));
246        }
247
248        let mut facts = MemFactPerspective::new(segment.facts.prior.clone().into());
249        for data in &segment.commands[..=location.command] {
250            facts.apply_updates(&data.updates);
251        }
252
253        Ok(facts)
254    }
255
256    fn new_merge_perspective(
257        &self,
258        left: Location,
259        right: Location,
260        _last_common_ancestor: (Location, usize),
261        policy_id: PolicyId,
262        braid: MemFactIndex,
263    ) -> Result<Option<Self::Perspective>, StorageError> {
264        // TODO(jdygert): ensure braid belongs to this storage.
265        // TODO(jdygert): ensure braid ends at given command?
266
267        let left_segment = self.get_segment(left)?;
268        let left_policy_id = left_segment.policy;
269        let right_segment = self.get_segment(right)?;
270        let right_policy_id = right_segment.policy;
271
272        if (policy_id != left_policy_id) && (policy_id != right_policy_id) {
273            return Err(StorageError::PolicyMismatch);
274        }
275
276        let prior = Prior::Merge(left, right);
277
278        let left_command = left_segment
279            .get_command(left)
280            .ok_or(StorageError::CommandOutOfBounds(left))?;
281        let right_command = right_segment
282            .get_command(right)
283            .ok_or(StorageError::CommandOutOfBounds(right))?;
284        let parents = Prior::Merge(left_command.address()?, right_command.address()?);
285
286        let left_distance = left_command.max_cut()?;
287        let right_distance = right_command.max_cut()?;
288        let max_cut = left_distance
289            .max(right_distance)
290            .checked_add(1)
291            .assume("must not overflow")?;
292
293        let perspective = MemPerspective::new(prior, parents, policy_id, braid.into(), max_cut);
294
295        Ok(Some(perspective))
296    }
297
298    fn get_segment(&self, location: Location) -> Result<MemSegment, StorageError> {
299        self.segments
300            .get(location.segment)
301            .ok_or(StorageError::SegmentOutOfBounds(location))
302            .cloned()
303    }
304
305    fn get_head(&self) -> Result<Location, StorageError> {
306        Ok(self.head.assume("storage has head after init")?)
307    }
308
309    fn write(&mut self, update: Self::Perspective) -> Result<Self::Segment, StorageError> {
310        let facts = self.write_facts(update.facts)?;
311
312        let commands: Vec1<CommandData> = update
313            .commands
314            .try_into()
315            .map_err(|_| StorageError::EmptyPerspective)?;
316
317        let segment_index = self.segments.len();
318
319        // Add the commands to the segment
320        for (command_index, data) in commands.iter().enumerate() {
321            let new_location = Location::new(segment_index, command_index);
322            self.commands.insert(data.command.id(), new_location);
323        }
324
325        let segment =
326            self.new_segment(update.prior, update.policy, commands, facts, update.max_cut)?;
327
328        Ok(segment)
329    }
330
331    fn write_facts(
332        &mut self,
333        facts: Self::FactPerspective,
334    ) -> Result<Self::FactIndex, StorageError> {
335        let prior = match facts.prior {
336            FactPerspectivePrior::None => None,
337            FactPerspectivePrior::FactPerspective(prior) => Some(self.write_facts(*prior)?),
338            FactPerspectivePrior::FactIndex(prior) => Some(prior),
339        };
340        if facts.map.is_empty() {
341            if let Some(prior) = prior {
342                return Ok(prior);
343            }
344        }
345        Ok(MemFactIndex(Arc::new(MemFactsInner {
346            map: facts.map,
347            prior,
348        })))
349    }
350
351    fn commit(&mut self, segment: Self::Segment) -> Result<(), StorageError> {
352        // TODO(jdygert): ensure segment belongs to self?
353
354        if let Some(head) = self.head {
355            if !self.is_ancestor(head, &segment)? {
356                return Err(StorageError::HeadNotAncestor);
357            }
358        }
359
360        self.head = Some(segment.head_location());
361        Ok(())
362    }
363}
364
365#[derive(Clone, Debug)]
366pub struct MemFactIndex(Arc<MemFactsInner>);
367
368impl Deref for MemFactIndex {
369    type Target = MemFactsInner;
370    fn deref(&self) -> &Self::Target {
371        self.0.deref()
372    }
373}
374
375impl MemFactIndex {
376    #[cfg(all(test, feature = "graphviz"))]
377    fn name(&self) -> String {
378        format!("\"{:p}\"", Arc::as_ptr(&self.0))
379    }
380}
381
382#[derive(Debug)]
383pub struct MemFactsInner {
384    map: NamedFactMap,
385    prior: Option<MemFactIndex>,
386}
387
388pub(crate) fn find_prefixes<'m, 'p: 'm>(
389    map: &'m FactMap,
390    prefix: &'p [Box<[u8]>],
391) -> impl Iterator<Item = (&'m Keys, Option<&'m [u8]>)> + 'm {
392    map.range::<[Box<[u8]>], _>((Bound::Included(prefix), Bound::Unbounded))
393        .take_while(|(k, _)| k.starts_with(prefix))
394        .map(|(k, v)| (k, v.as_deref()))
395}
396
397impl FactIndex for MemFactIndex {}
398impl Query for MemFactIndex {
399    fn query(&self, name: &str, keys: &[Box<[u8]>]) -> Result<Option<Box<[u8]>>, StorageError> {
400        let mut prior = Some(self.deref());
401        while let Some(facts) = prior {
402            if let Some(slot) = facts.map.get(name).and_then(|m| m.get(keys)) {
403                return Ok(slot.as_ref().cloned());
404            }
405            prior = facts.prior.as_deref();
406        }
407        Ok(None)
408    }
409
410    type QueryIterator = Box<dyn Iterator<Item = Result<Fact, StorageError>>>;
411    fn query_prefix(
412        &self,
413        name: &str,
414        prefix: &[Box<[u8]>],
415    ) -> Result<Self::QueryIterator, StorageError> {
416        Ok(Box::from(
417            self.query_prefix_inner(name, prefix)
418                .into_iter()
419                // remove deleted facts
420                .filter_map(|(key, value)| Some(Ok(Fact { key, value: value? }))),
421        ))
422    }
423}
424
425impl MemFactIndex {
426    fn query_prefix_inner(&self, name: &str, prefix: &[Box<[u8]>]) -> FactMap {
427        let mut matches = BTreeMap::new();
428
429        let mut prior = Some(self.deref());
430        // walk backwards along fact indices
431        while let Some(facts) = prior {
432            if let Some(map) = facts.map.get(name) {
433                for (k, v) in find_prefixes(map, prefix) {
434                    // don't override, if we've already found the fact (including deletions)
435                    if !matches.contains_key(k) {
436                        matches.insert(k.clone(), v.map(Into::into));
437                    }
438                }
439            }
440            prior = facts.prior.as_deref();
441        }
442
443        matches
444    }
445}
446
447#[derive(Debug)]
448struct CommandData {
449    command: MemCommand,
450    updates: Vec<Update>,
451}
452
453#[derive(Debug)]
454pub struct MemSegmentInner {
455    index: usize,
456    prior: Prior<Location>,
457    policy: PolicyId,
458    commands: Vec1<CommandData>,
459    facts: MemFactIndex,
460}
461
462#[derive(Clone, Debug)]
463pub struct MemSegment(Arc<MemSegmentInner>);
464
465impl Deref for MemSegment {
466    type Target = MemSegmentInner;
467
468    fn deref(&self) -> &Self::Target {
469        self.0.deref()
470    }
471}
472
473impl From<MemSegmentInner> for MemSegment {
474    fn from(segment: MemSegmentInner) -> Self {
475        MemSegment(Arc::new(segment))
476    }
477}
478
479impl Segment for MemSegment {
480    type FactIndex = MemFactIndex;
481    type Command<'a> = &'a MemCommand;
482
483    fn head(&self) -> Result<&MemCommand, StorageError> {
484        Ok(&self.commands.last().command)
485    }
486
487    fn first(&self) -> &MemCommand {
488        &self.commands.first().command
489    }
490
491    fn head_location(&self) -> Location {
492        Location {
493            segment: self.index,
494            command: self
495                .commands
496                .len()
497                .checked_sub(1)
498                .expect("commands.len() must be > 0"),
499        }
500    }
501
502    fn first_location(&self) -> Location {
503        Location {
504            segment: self.index,
505            command: 0,
506        }
507    }
508
509    fn contains(&self, location: Location) -> bool {
510        location.segment == self.index
511    }
512
513    fn policy(&self) -> PolicyId {
514        self.policy
515    }
516
517    fn prior(&self) -> Prior<Location> {
518        self.prior
519    }
520
521    fn get_command(&self, location: Location) -> Option<&MemCommand> {
522        if location.segment != self.index {
523            return None;
524        }
525
526        self.commands.get(location.command).map(|d| &d.command)
527    }
528
529    fn get_from(&self, location: Location) -> Vec<&MemCommand> {
530        if location.segment != self.index {
531            return Vec::new();
532        }
533
534        self.commands[location.command..]
535            .iter()
536            .map(|d| &d.command)
537            .collect()
538    }
539
540    fn get_from_max_cut(&self, max_cut: usize) -> Result<Option<Location>, StorageError> {
541        for (i, command) in self.commands.iter().enumerate() {
542            if command.command.max_cut == max_cut {
543                return Ok(Some(Location {
544                    segment: self.index,
545                    command: i,
546                }));
547            }
548        }
549        Ok(None)
550    }
551
552    fn longest_max_cut(&self) -> Result<usize, StorageError> {
553        Ok(self.commands.last().command.max_cut)
554    }
555
556    fn shortest_max_cut(&self) -> usize {
557        self.commands[0].command.max_cut
558    }
559
560    fn skip_list(&self) -> &[(Location, usize)] {
561        &[]
562    }
563
564    fn facts(&self) -> Result<Self::FactIndex, StorageError> {
565        Ok(self.facts.clone())
566    }
567}
568
569type Update = (String, Keys, Option<Box<[u8]>>);
570
571#[derive(Debug)]
572pub struct MemPerspective {
573    prior: Prior<Location>,
574    parents: Prior<Address>,
575    policy: PolicyId,
576    facts: MemFactPerspective,
577    commands: Vec<CommandData>,
578    current_updates: Vec<Update>,
579    max_cut: usize,
580}
581
582#[derive(Debug)]
583enum FactPerspectivePrior {
584    None,
585    FactPerspective(Box<MemFactPerspective>),
586    FactIndex(MemFactIndex),
587}
588
589impl From<MemFactIndex> for FactPerspectivePrior {
590    fn from(value: MemFactIndex) -> Self {
591        Self::FactIndex(value)
592    }
593}
594
595impl From<Option<MemFactIndex>> for FactPerspectivePrior {
596    fn from(value: Option<MemFactIndex>) -> Self {
597        value.map_or(Self::None, Self::FactIndex)
598    }
599}
600
601impl From<MemFactPerspective> for FactPerspectivePrior {
602    fn from(value: MemFactPerspective) -> Self {
603        Self::FactPerspective(Box::new(value))
604    }
605}
606
607#[derive(Debug)]
608pub struct MemFactPerspective {
609    map: NamedFactMap,
610    prior: FactPerspectivePrior,
611}
612
613impl MemFactPerspective {
614    fn new(prior_facts: FactPerspectivePrior) -> MemFactPerspective {
615        Self {
616            map: NamedFactMap::new(),
617            prior: prior_facts,
618        }
619    }
620
621    fn clear(&mut self) {
622        self.map.clear();
623    }
624
625    fn apply_updates(&mut self, updates: &[Update]) {
626        for (name, key, value) in updates {
627            self.map
628                .entry(name.clone())
629                .or_default()
630                .insert(key.clone(), value.clone());
631        }
632    }
633}
634
635impl MemPerspective {
636    fn new(
637        prior: Prior<Location>,
638        parents: Prior<Address>,
639        policy: PolicyId,
640        prior_facts: FactPerspectivePrior,
641        max_cut: usize,
642    ) -> Self {
643        Self {
644            prior,
645            parents,
646            policy,
647            facts: MemFactPerspective::new(prior_facts),
648            commands: Vec::new(),
649            current_updates: Vec::new(),
650            max_cut,
651        }
652    }
653
654    fn new_unrooted(policy: PolicyId) -> Self {
655        Self {
656            prior: Prior::None,
657            parents: Prior::None,
658            policy,
659            facts: MemFactPerspective::new(FactPerspectivePrior::None),
660            commands: Vec::new(),
661            current_updates: Vec::new(),
662            max_cut: 0,
663        }
664    }
665}
666
667impl Revertable for MemPerspective {
668    fn checkpoint(&self) -> Checkpoint {
669        Checkpoint {
670            index: self.commands.len(),
671        }
672    }
673
674    fn revert(&mut self, checkpoint: Checkpoint) -> Result<(), Bug> {
675        if checkpoint.index == self.commands.len() {
676            return Ok(());
677        }
678
679        if checkpoint.index > self.commands.len() {
680            bug!(
681                "A checkpoint's index should always be less than or equal to the length of a perspective's command history!"
682            );
683        }
684
685        self.commands.truncate(checkpoint.index);
686        self.facts.clear();
687        self.current_updates.clear();
688        for data in &self.commands {
689            self.facts.apply_updates(&data.updates);
690        }
691
692        Ok(())
693    }
694}
695
696impl Perspective for MemPerspective {
697    fn add_command(&mut self, command: &impl Command) -> Result<usize, StorageError> {
698        if command.parent() != self.head_address()? {
699            return Err(StorageError::PerspectiveHeadMismatch);
700        }
701
702        let entry = CommandData {
703            command: MemCommand::from_cmd(command, self.head_address()?.next_max_cut()?),
704            updates: core::mem::take(&mut self.current_updates),
705        };
706        self.commands.push(entry);
707        Ok(self.commands.len())
708    }
709
710    fn policy(&self) -> PolicyId {
711        self.policy
712    }
713
714    fn includes(&self, id: CommandId) -> bool {
715        self.commands.iter().any(|cmd| cmd.command.id == id)
716    }
717
718    fn head_address(&self) -> Result<Prior<Address>, Bug> {
719        Ok(if let Some(last) = self.commands.last() {
720            Prior::Single(last.command.address()?)
721        } else {
722            self.parents
723        })
724    }
725}
726
727impl FactPerspective for MemPerspective {}
728
729impl Query for MemPerspective {
730    fn query(&self, name: &str, keys: &[Box<[u8]>]) -> Result<Option<Box<[u8]>>, StorageError> {
731        self.facts.query(name, keys)
732    }
733
734    type QueryIterator = <MemFactPerspective as Query>::QueryIterator;
735    fn query_prefix(
736        &self,
737        name: &str,
738        prefix: &[Box<[u8]>],
739    ) -> Result<Self::QueryIterator, StorageError> {
740        self.facts.query_prefix(name, prefix)
741    }
742}
743
744impl QueryMut for MemPerspective {
745    fn insert(&mut self, name: String, keys: Keys, value: Box<[u8]>) {
746        self.facts.insert(name.clone(), keys.clone(), value.clone());
747        self.current_updates.push((name, keys, Some(value)));
748    }
749
750    fn delete(&mut self, name: String, keys: Keys) {
751        self.facts.delete(name.clone(), keys.clone());
752        self.current_updates.push((name, keys, None));
753    }
754}
755
756impl MemFactPerspective {
757    fn query_prefix_inner(&self, name: &str, prefix: &[Box<[u8]>]) -> FactMap {
758        let map = self.map.get(name);
759        let mut matches = match &self.prior {
760            FactPerspectivePrior::None => BTreeMap::new(),
761            FactPerspectivePrior::FactPerspective(fp) => fp.query_prefix_inner(name, prefix),
762            FactPerspectivePrior::FactIndex(fi) => fi.query_prefix_inner(name, prefix),
763        };
764        if let Some(map) = map {
765            for (k, v) in find_prefixes(map, prefix) {
766                // overwrite "earlier" facts
767                matches.insert(k.clone(), v.map(Into::into));
768            }
769        }
770        matches
771    }
772}
773
774impl FactPerspective for MemFactPerspective {}
775
776impl Query for MemFactPerspective {
777    fn query(&self, name: &str, keys: &[Box<[u8]>]) -> Result<Option<Box<[u8]>>, StorageError> {
778        if let Some(wrapped) = self.map.get(name).and_then(|m| m.get(keys)) {
779            return Ok(wrapped.as_deref().map(Box::from));
780        }
781        match &self.prior {
782            FactPerspectivePrior::None => Ok(None),
783            FactPerspectivePrior::FactPerspective(prior) => prior.query(name, keys),
784            FactPerspectivePrior::FactIndex(prior) => prior.query(name, keys),
785        }
786    }
787
788    type QueryIterator = Box<dyn Iterator<Item = Result<Fact, StorageError>>>;
789    fn query_prefix(
790        &self,
791        name: &str,
792        prefix: &[Box<[u8]>],
793    ) -> Result<Self::QueryIterator, StorageError> {
794        Ok(Box::from(
795            self.query_prefix_inner(name, prefix)
796                .into_iter()
797                // remove deleted facts
798                .filter_map(|(key, value)| Some(Ok(Fact { key, value: value? }))),
799        ))
800    }
801}
802
803impl QueryMut for MemFactPerspective {
804    fn insert(&mut self, name: String, keys: Keys, value: Box<[u8]>) {
805        self.map.entry(name).or_default().insert(keys, Some(value));
806    }
807
808    fn delete(&mut self, name: String, keys: Keys) {
809        self.map.entry(name).or_default().insert(keys, None);
810    }
811}
812
813#[cfg(all(test, feature = "graphviz"))]
814pub mod graphviz {
815    #![allow(clippy::unwrap_used)]
816
817    use std::{fs::File, io::BufWriter};
818
819    use dot_writer::{Attributes, DotWriter, Style};
820
821    #[allow(clippy::wildcard_imports)]
822    use super::*;
823
824    fn loc(location: impl Into<Location>) -> String {
825        let location = location.into();
826        format!("\"{}:{}\"", location.segment, location.command)
827    }
828
829    fn get_seq(p: &MemFactIndex) -> &str {
830        if let Some(Some(seq)) = p.map.get("seq").and_then(|m| m.get(&Keys::default())) {
831            std::str::from_utf8(seq).unwrap()
832        } else {
833            ""
834        }
835    }
836
837    fn dotwrite(storage: &MemStorage, out: &mut DotWriter<'_>) {
838        let mut graph = out.digraph();
839        graph
840            .graph_attributes()
841            .set("compound", "true", false)
842            .set("rankdir", "RL", false)
843            .set_style(Style::Filled)
844            .set("color", "grey", false);
845        graph
846            .node_attributes()
847            .set("shape", "square", false)
848            .set_style(Style::Filled)
849            .set("color", "lightgrey", false);
850
851        let mut seen_facts = std::collections::HashMap::new();
852        let mut external_facts = Vec::new();
853
854        for segment in &storage.segments {
855            let mut cluster = graph.cluster();
856            match segment.prior {
857                Prior::None => {
858                    cluster.graph_attributes().set("color", "green", false);
859                }
860                Prior::Single(..) => {}
861                Prior::Merge(..) => {
862                    cluster.graph_attributes().set("color", "crimson", false);
863                }
864            }
865
866            // Draw commands and edges between commands within the segment.
867            for (i, cmd) in segment.commands.iter().enumerate() {
868                {
869                    let mut node = cluster.node_named(loc((segment.index, i)));
870                    node.set_label(&cmd.command.id.short_b58());
871                    match cmd.command.parent {
872                        Prior::None => {
873                            node.set("shape", "house", false);
874                        }
875                        Prior::Single(..) => {}
876                        Prior::Merge(..) => {
877                            node.set("shape", "hexagon", false);
878                        }
879                    };
880                }
881                if i > 0 {
882                    let previous = i.checked_sub(1).expect("i must be > 0");
883                    cluster.edge(loc((segment.index, i)), loc((segment.index, previous)));
884                }
885            }
886
887            // Draw edges to previous segments.
888            let first = loc(segment.first_location());
889            for p in segment.prior() {
890                cluster.edge(&first, loc(p));
891            }
892
893            // Draw fact index for this segment.
894            let curr = segment.facts.name();
895            cluster
896                .node_named(curr.clone())
897                .set_label(get_seq(&segment.facts))
898                .set("shape", "cylinder", false)
899                .set("color", "black", false)
900                .set("style", "solid", false);
901            cluster
902                .edge(loc(segment.head_location()), &curr)
903                .attributes()
904                .set("color", "red", false);
905
906            seen_facts.insert(curr, segment.facts.clone());
907            // Make sure prior facts of fact index will get processed later.
908            let mut node = &segment.facts;
909            while let Some(prior) = &node.prior {
910                node = prior;
911                let name = node.name();
912                if seen_facts.insert(name, node.clone()).is_some() {
913                    break;
914                }
915                external_facts.push(node.clone());
916            }
917        }
918
919        graph
920            .node_attributes()
921            .set("shape", "cylinder", false)
922            .set("color", "black", false)
923            .set("style", "solid", false);
924
925        for fact in external_facts {
926            // Draw nodes for fact indices not directly associated with a segment.
927            graph.node_named(fact.name()).set_label(get_seq(&fact));
928
929            // Draw edge to prior facts.
930            if let Some(prior) = &fact.prior {
931                graph
932                    .edge(fact.name(), prior.name())
933                    .attributes()
934                    .set("color", "blue", false);
935            }
936        }
937
938        // Draw edges to prior facts for fact indices in segments.
939        for segment in &storage.segments {
940            if let Some(prior) = &segment.facts.prior {
941                graph
942                    .edge(segment.facts.name(), prior.name())
943                    .attributes()
944                    .set("color", "blue", false);
945            }
946        }
947
948        // Draw HEAD indicator.
949        graph.node_named("HEAD").set("shape", "none", false);
950        graph.edge("HEAD", loc(storage.get_head().unwrap()));
951    }
952
953    pub fn dot(storage: &MemStorage, name: &str) {
954        std::fs::create_dir_all(".ignore").unwrap();
955        dotwrite(
956            storage,
957            &mut DotWriter::from(&mut BufWriter::new(
958                File::create(format!(".ignore/{name}.dot")).unwrap(),
959            )),
960        );
961    }
962}
963
964#[cfg(test)]
965mod test {
966    use super::*;
967    use crate::testing::dsl::{StorageBackend, test_suite};
968
969    #[test]
970    fn test_query_prefix() {
971        let mut graph = MemStorage::new();
972        let mut fp = MemFactPerspective::new(FactPerspectivePrior::None);
973
974        let name = "x";
975
976        let keys: &[&[&str]] = &[
977            &["aa", "xy", "123"],
978            &["aa", "xz", "123"],
979            &["bb", "ccc"],
980            &["bc", ""],
981        ];
982        let keys: Vec<Keys> = keys
983            .iter()
984            .map(|ks| ks.iter().map(|k| k.as_bytes()).collect())
985            .collect();
986
987        for ks in &keys {
988            fp.insert(
989                name.into(),
990                ks.clone(),
991                format!("{ks:?}").into_bytes().into(),
992            );
993        }
994        let facts = graph.write_facts(fp).unwrap();
995
996        let prefixes: &[&[&str]] = &[
997            &["aa", "xy", "12"],
998            &["aa", "xy"],
999            &["aa", "xz"],
1000            &["aa", "x"],
1001            &["bb", ""],
1002            &["bb", "ccc"],
1003            &["bc", ""],
1004            &["bc", "", ""],
1005        ];
1006
1007        for prefix in prefixes {
1008            let prefix: Keys = prefix.iter().map(|k| k.as_bytes()).collect();
1009            let found: Vec<_> = facts.query_prefix(name, &prefix).unwrap().collect();
1010            let mut expected: Vec<_> = keys.iter().filter(|k| k.starts_with(&prefix)).collect();
1011            expected.sort();
1012            assert_eq!(found.len(), expected.len());
1013            for (a, b) in std::iter::zip(found, expected) {
1014                let a = a.unwrap();
1015                assert_eq!(&a.key, b);
1016                assert_eq!(a.value.as_ref(), format!("{b:?}").as_bytes());
1017            }
1018        }
1019    }
1020
1021    struct MemBackend;
1022    impl StorageBackend for MemBackend {
1023        type StorageProvider = MemStorageProvider;
1024
1025        fn provider(&mut self, _client_id: u64) -> Self::StorageProvider {
1026            MemStorageProvider::new()
1027        }
1028    }
1029    test_suite!(|| MemBackend);
1030}