repo_stream/
walk.rs

1//! Depth-first MST traversal
2
3use crate::drive::MaybeProcessedBlock;
4use crate::mst::Node;
5use ipld_core::cid::Cid;
6use std::collections::HashMap;
7use std::error::Error;
8
9/// Errors that can happen while walking
10#[derive(Debug, thiserror::Error)]
11pub enum Trip<E: Error> {
12    #[error("empty mst nodes are not allowed")]
13    NodeEmpty,
14    #[error("Failed to decode commit block: {0}")]
15    BadCommit(Box<dyn std::error::Error>),
16    #[error("Action node error: {0}")]
17    RkeyError(#[from] RkeyError),
18    #[error("Process failed: {0}")]
19    ProcessFailed(E),
20    #[error("Encountered an rkey out of order while walking the MST")]
21    RkeyOutOfOrder,
22}
23
24/// Errors from invalid Rkeys
25#[derive(Debug, thiserror::Error)]
26pub enum RkeyError {
27    #[error("Failed to compute an rkey due to invalid prefix_len")]
28    EntryPrefixOutOfbounds,
29    #[error("RKey was not utf-8")]
30    EntryRkeyNotUtf8(#[from] std::string::FromUtf8Error),
31}
32
33/// Walker outputs
34#[derive(Debug)]
35pub enum Step<T> {
36    /// We need a CID but it's not in the block store
37    ///
38    /// Give the needed CID to the driver so it can load blocks until it's found
39    Rest(Cid),
40    /// Reached the end of the MST! yay!
41    Finish,
42    /// A record was found!
43    Step { rkey: String, data: T },
44}
45
46#[derive(Debug, Clone, PartialEq)]
47enum Need {
48    Node(Cid),
49    Record { rkey: String, cid: Cid },
50}
51
52fn push_from_node(stack: &mut Vec<Need>, node: &Node) -> Result<(), RkeyError> {
53    let mut entries = Vec::with_capacity(node.entries.len());
54
55    let mut prefix = vec![];
56    for entry in &node.entries {
57        let mut rkey = vec![];
58        let pre_checked = prefix
59            .get(..entry.prefix_len)
60            .ok_or(RkeyError::EntryPrefixOutOfbounds)?;
61        rkey.extend_from_slice(pre_checked);
62        rkey.extend_from_slice(&entry.keysuffix);
63        prefix = rkey.clone();
64
65        entries.push(Need::Record {
66            rkey: String::from_utf8(rkey)?,
67            cid: entry.value,
68        });
69        if let Some(ref tree) = entry.tree {
70            entries.push(Need::Node(*tree));
71        }
72    }
73
74    entries.reverse();
75    stack.append(&mut entries);
76
77    if let Some(tree) = node.left {
78        stack.push(Need::Node(tree));
79    }
80    Ok(())
81}
82
83/// Traverser of an atproto MST
84///
85/// Walks the tree from left-to-right in depth-first order
86#[derive(Debug)]
87pub struct Walker {
88    stack: Vec<Need>,
89    prev: String,
90}
91
92impl Walker {
93    pub fn new(tree_root_cid: Cid) -> Self {
94        Self {
95            stack: vec![Need::Node(tree_root_cid)],
96            prev: "".to_string(),
97        }
98    }
99
100    /// Advance through nodes until we find a record or can't go further
101    pub fn step<T: Clone, E: Error>(
102        &mut self,
103        blocks: &mut HashMap<Cid, MaybeProcessedBlock<T, E>>,
104        process: impl Fn(&[u8]) -> Result<T, E>,
105    ) -> Result<Step<T>, Trip<E>> {
106        loop {
107            let Some(mut need) = self.stack.last() else {
108                log::trace!("tried to walk but we're actually done.");
109                return Ok(Step::Finish);
110            };
111
112            match &mut need {
113                Need::Node(cid) => {
114                    log::trace!("need node {cid:?}");
115                    let Some(block) = blocks.remove(cid) else {
116                        log::trace!("node not found, resting");
117                        return Ok(Step::Rest(*cid));
118                    };
119
120                    let MaybeProcessedBlock::Raw(data) = block else {
121                        return Err(Trip::BadCommit("failed commit fingerprint".into()));
122                    };
123                    let node = serde_ipld_dagcbor::from_slice::<Node>(&data)
124                        .map_err(|e| Trip::BadCommit(e.into()))?;
125
126                    // found node, make sure we remember
127                    self.stack.pop();
128
129                    // queue up work on the found node next
130                    push_from_node(&mut self.stack, &node)?;
131                }
132                Need::Record { rkey, cid } => {
133                    log::trace!("need record {cid:?}");
134                    let Some(data) = blocks.get_mut(cid) else {
135                        log::trace!("record block not found, resting");
136                        return Ok(Step::Rest(*cid));
137                    };
138                    let rkey = rkey.clone();
139                    let data = match data {
140                        MaybeProcessedBlock::Raw(data) => process(data),
141                        MaybeProcessedBlock::Processed(Ok(t)) => Ok(t.clone()),
142                        bad => {
143                            // big hack to pull the error out -- this corrupts
144                            // a block, so we should not continue trying to work
145                            let mut steal = MaybeProcessedBlock::Raw(vec![]);
146                            std::mem::swap(&mut steal, bad);
147                            let MaybeProcessedBlock::Processed(Err(e)) = steal else {
148                                unreachable!();
149                            };
150                            return Err(Trip::ProcessFailed(e));
151                        }
152                    };
153
154                    // found node, make sure we remember
155                    self.stack.pop();
156
157                    log::trace!("emitting a block as a step. depth={}", self.stack.len());
158                    let data = data.map_err(Trip::ProcessFailed)?;
159
160                    // rkeys *must* be in order or else the tree is invalid (or
161                    // we have a bug)
162                    if rkey <= self.prev {
163                        return Err(Trip::RkeyOutOfOrder);
164                    }
165                    self.prev = rkey.clone();
166
167                    return Ok(Step::Step { rkey, data });
168                }
169            }
170        }
171    }
172}
173
174#[cfg(test)]
175mod test {
176    use super::*;
177    // use crate::mst::Entry;
178
179    fn cid1() -> Cid {
180        "bafyreihixenvk3ahqbytas4hk4a26w43bh6eo3w6usjqtxkpzsvi655a3m"
181            .parse()
182            .unwrap()
183    }
184    //     fn cid2() -> Cid {
185    //         "QmY7Yh4UquoXHLPFo2XbhXkhBvFoPwmQUSa92pxnxjQuPU"
186    //             .parse()
187    //             .unwrap()
188    //     }
189    //     fn cid3() -> Cid {
190    //         "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
191    //             .parse()
192    //             .unwrap()
193    //     }
194    //     fn cid4() -> Cid {
195    //         "QmbWqxBEKC3P8tqsKc98xmWNzrzDtRLMiMPL8wBuTGsMnR"
196    //             .parse()
197    //             .unwrap()
198    //     }
199    //     fn cid5() -> Cid {
200    //         "QmSnuWmxptJZdLJpKRarxBMS2Ju2oANVrgbr2xWbie9b2D"
201    //             .parse()
202    //             .unwrap()
203    //     }
204    //     fn cid6() -> Cid {
205    //         "QmdmQXB2mzChmMeKY47C43LxUdg1NDJ5MWcKMKxDu7RgQm"
206    //             .parse()
207    //             .unwrap()
208    //     }
209    //     fn cid7() -> Cid {
210    //         "bafybeiaysi4s6lnjev27ln5icwm6tueaw2vdykrtjkwiphwekaywqhcjze"
211    //             .parse()
212    //             .unwrap()
213    //     }
214    //     fn cid8() -> Cid {
215    //         "bafyreif3tfdpr5n4jdrbielmcapwvbpcthepfkwq2vwonmlhirbjmotedi"
216    //             .parse()
217    //             .unwrap()
218    //     }
219    //     fn cid9() -> Cid {
220    //         "bafyreicnokmhmrnlp2wjhyk2haep4tqxiptwfrp2rrs7rzq7uk766chqvq"
221    //             .parse()
222    //             .unwrap()
223    //     }
224
225    #[test]
226    fn test_next_from_node_empty() {
227        let node = Node {
228            left: None,
229            entries: vec![],
230        };
231        let mut stack = vec![];
232        push_from_node(&mut stack, &node).unwrap();
233        assert_eq!(stack.last(), None);
234    }
235
236    #[test]
237    fn test_needs_from_node_just_left() {
238        let node = Node {
239            left: Some(cid1()),
240            entries: vec![],
241        };
242        let mut stack = vec![];
243        push_from_node(&mut stack, &node).unwrap();
244        assert_eq!(stack.last(), Some(Need::Node(cid1())).as_ref());
245    }
246
247    //     #[test]
248    //     fn test_needs_from_node_just_one_record() {
249    //         let node = Node {
250    //             left: None,
251    //             entries: vec![Entry {
252    //                 keysuffix: "asdf".into(),
253    //                 prefix_len: 0,
254    //                 value: cid1(),
255    //                 tree: None,
256    //             }],
257    //         };
258    //         assert_eq!(
259    //             needs_from_node(node).unwrap(),
260    //             vec![Need::Record {
261    //                 rkey: "asdf".into(),
262    //                 cid: cid1(),
263    //             },]
264    //         );
265    //     }
266
267    //     #[test]
268    //     fn test_needs_from_node_two_records() {
269    //         let node = Node {
270    //             left: None,
271    //             entries: vec![
272    //                 Entry {
273    //                     keysuffix: "asdf".into(),
274    //                     prefix_len: 0,
275    //                     value: cid1(),
276    //                     tree: None,
277    //                 },
278    //                 Entry {
279    //                     keysuffix: "gh".into(),
280    //                     prefix_len: 2,
281    //                     value: cid2(),
282    //                     tree: None,
283    //                 },
284    //             ],
285    //         };
286    //         assert_eq!(
287    //             needs_from_node(node).unwrap(),
288    //             vec![
289    //                 Need::Record {
290    //                     rkey: "asdf".into(),
291    //                     cid: cid1(),
292    //                 },
293    //                 Need::Record {
294    //                     rkey: "asgh".into(),
295    //                     cid: cid2(),
296    //                 },
297    //             ]
298    //         );
299    //     }
300
301    //     #[test]
302    //     fn test_needs_from_node_with_both() {
303    //         let node = Node {
304    //             left: None,
305    //             entries: vec![Entry {
306    //                 keysuffix: "asdf".into(),
307    //                 prefix_len: 0,
308    //                 value: cid1(),
309    //                 tree: Some(cid2()),
310    //             }],
311    //         };
312    //         assert_eq!(
313    //             needs_from_node(node).unwrap(),
314    //             vec![
315    //                 Need::Record {
316    //                     rkey: "asdf".into(),
317    //                     cid: cid1(),
318    //                 },
319    //                 Need::Node(cid2()),
320    //             ]
321    //         );
322    //     }
323
324    //     #[test]
325    //     fn test_needs_from_node_left_and_record() {
326    //         let node = Node {
327    //             left: Some(cid1()),
328    //             entries: vec![Entry {
329    //                 keysuffix: "asdf".into(),
330    //                 prefix_len: 0,
331    //                 value: cid2(),
332    //                 tree: None,
333    //             }],
334    //         };
335    //         assert_eq!(
336    //             needs_from_node(node).unwrap(),
337    //             vec![
338    //                 Need::Node(cid1()),
339    //                 Need::Record {
340    //                     rkey: "asdf".into(),
341    //                     cid: cid2(),
342    //                 },
343    //             ]
344    //         );
345    //     }
346
347    //     #[test]
348    //     fn test_needs_from_full_node() {
349    //         let node = Node {
350    //             left: Some(cid1()),
351    //             entries: vec![
352    //                 Entry {
353    //                     keysuffix: "asdf".into(),
354    //                     prefix_len: 0,
355    //                     value: cid2(),
356    //                     tree: Some(cid3()),
357    //                 },
358    //                 Entry {
359    //                     keysuffix: "ghi".into(),
360    //                     prefix_len: 1,
361    //                     value: cid4(),
362    //                     tree: Some(cid5()),
363    //                 },
364    //                 Entry {
365    //                     keysuffix: "jkl".into(),
366    //                     prefix_len: 2,
367    //                     value: cid6(),
368    //                     tree: Some(cid7()),
369    //                 },
370    //                 Entry {
371    //                     keysuffix: "mno".into(),
372    //                     prefix_len: 4,
373    //                     value: cid8(),
374    //                     tree: Some(cid9()),
375    //                 },
376    //             ],
377    //         };
378    //         assert_eq!(
379    //             needs_from_node(node).unwrap(),
380    //             vec![
381    //                 Need::Node(cid1()),
382    //                 Need::Record {
383    //                     rkey: "asdf".into(),
384    //                     cid: cid2(),
385    //                 },
386    //                 Need::Node(cid3()),
387    //                 Need::Record {
388    //                     rkey: "aghi".into(),
389    //                     cid: cid4(),
390    //                 },
391    //                 Need::Node(cid5()),
392    //                 Need::Record {
393    //                     rkey: "agjkl".into(),
394    //                     cid: cid6(),
395    //                 },
396    //                 Need::Node(cid7()),
397    //                 Need::Record {
398    //                     rkey: "agjkmno".into(),
399    //                     cid: cid8(),
400    //                 },
401    //                 Need::Node(cid9()),
402    //             ]
403    //         );
404    //     }
405}