repo_stream/walk.rs
1//! Depth-first MST traversal
2
3use crate::drive::MaybeProcessedBlock;
4use crate::mst::Node;
5use ipld_core::cid::Cid;
6use std::collections::HashMap;
7use std::error::Error;
8
9/// Errors that can happen while walking
10#[derive(Debug, thiserror::Error)]
11pub enum Trip<E: Error> {
12 #[error("empty mst nodes are not allowed")]
13 NodeEmpty,
14 #[error("Failed to decode commit block: {0}")]
15 BadCommit(Box<dyn std::error::Error>),
16 #[error("Action node error: {0}")]
17 RkeyError(#[from] RkeyError),
18 #[error("Process failed: {0}")]
19 ProcessFailed(E),
20 #[error("Encountered an rkey out of order while walking the MST")]
21 RkeyOutOfOrder,
22}
23
24/// Errors from invalid Rkeys
25#[derive(Debug, thiserror::Error)]
26pub enum RkeyError {
27 #[error("Failed to compute an rkey due to invalid prefix_len")]
28 EntryPrefixOutOfbounds,
29 #[error("RKey was not utf-8")]
30 EntryRkeyNotUtf8(#[from] std::string::FromUtf8Error),
31}
32
33/// Walker outputs
34#[derive(Debug)]
35pub enum Step<T> {
36 /// We need a CID but it's not in the block store
37 ///
38 /// Give the needed CID to the driver so it can load blocks until it's found
39 Rest(Cid),
40 /// Reached the end of the MST! yay!
41 Finish,
42 /// A record was found!
43 Step { rkey: String, data: T },
44}
45
46#[derive(Debug, Clone, PartialEq)]
47enum Need {
48 Node(Cid),
49 Record { rkey: String, cid: Cid },
50}
51
52fn push_from_node(stack: &mut Vec<Need>, node: &Node) -> Result<(), RkeyError> {
53 let mut entries = Vec::with_capacity(node.entries.len());
54
55 let mut prefix = vec![];
56 for entry in &node.entries {
57 let mut rkey = vec![];
58 let pre_checked = prefix
59 .get(..entry.prefix_len)
60 .ok_or(RkeyError::EntryPrefixOutOfbounds)?;
61 rkey.extend_from_slice(pre_checked);
62 rkey.extend_from_slice(&entry.keysuffix);
63 prefix = rkey.clone();
64
65 entries.push(Need::Record {
66 rkey: String::from_utf8(rkey)?,
67 cid: entry.value,
68 });
69 if let Some(ref tree) = entry.tree {
70 entries.push(Need::Node(*tree));
71 }
72 }
73
74 entries.reverse();
75 stack.append(&mut entries);
76
77 if let Some(tree) = node.left {
78 stack.push(Need::Node(tree));
79 }
80 Ok(())
81}
82
83/// Traverser of an atproto MST
84///
85/// Walks the tree from left-to-right in depth-first order
86#[derive(Debug)]
87pub struct Walker {
88 stack: Vec<Need>,
89 prev: String,
90}
91
92impl Walker {
93 pub fn new(tree_root_cid: Cid) -> Self {
94 Self {
95 stack: vec![Need::Node(tree_root_cid)],
96 prev: "".to_string(),
97 }
98 }
99
100 /// Advance through nodes until we find a record or can't go further
101 pub fn step<T: Clone, E: Error>(
102 &mut self,
103 blocks: &mut HashMap<Cid, MaybeProcessedBlock<T, E>>,
104 process: impl Fn(&[u8]) -> Result<T, E>,
105 ) -> Result<Step<T>, Trip<E>> {
106 loop {
107 let Some(mut need) = self.stack.last() else {
108 log::trace!("tried to walk but we're actually done.");
109 return Ok(Step::Finish);
110 };
111
112 match &mut need {
113 Need::Node(cid) => {
114 log::trace!("need node {cid:?}");
115 let Some(block) = blocks.remove(cid) else {
116 log::trace!("node not found, resting");
117 return Ok(Step::Rest(*cid));
118 };
119
120 let MaybeProcessedBlock::Raw(data) = block else {
121 return Err(Trip::BadCommit("failed commit fingerprint".into()));
122 };
123 let node = serde_ipld_dagcbor::from_slice::<Node>(&data)
124 .map_err(|e| Trip::BadCommit(e.into()))?;
125
126 // found node, make sure we remember
127 self.stack.pop();
128
129 // queue up work on the found node next
130 push_from_node(&mut self.stack, &node)?;
131 }
132 Need::Record { rkey, cid } => {
133 log::trace!("need record {cid:?}");
134 let Some(data) = blocks.get_mut(cid) else {
135 log::trace!("record block not found, resting");
136 return Ok(Step::Rest(*cid));
137 };
138 let rkey = rkey.clone();
139 let data = match data {
140 MaybeProcessedBlock::Raw(data) => process(data),
141 MaybeProcessedBlock::Processed(Ok(t)) => Ok(t.clone()),
142 bad => {
143 // big hack to pull the error out -- this corrupts
144 // a block, so we should not continue trying to work
145 let mut steal = MaybeProcessedBlock::Raw(vec![]);
146 std::mem::swap(&mut steal, bad);
147 let MaybeProcessedBlock::Processed(Err(e)) = steal else {
148 unreachable!();
149 };
150 return Err(Trip::ProcessFailed(e));
151 }
152 };
153
154 // found node, make sure we remember
155 self.stack.pop();
156
157 log::trace!("emitting a block as a step. depth={}", self.stack.len());
158 let data = data.map_err(Trip::ProcessFailed)?;
159
160 // rkeys *must* be in order or else the tree is invalid (or
161 // we have a bug)
162 if rkey <= self.prev {
163 return Err(Trip::RkeyOutOfOrder);
164 }
165 self.prev = rkey.clone();
166
167 return Ok(Step::Step { rkey, data });
168 }
169 }
170 }
171 }
172}
173
174#[cfg(test)]
175mod test {
176 use super::*;
177 // use crate::mst::Entry;
178
179 fn cid1() -> Cid {
180 "bafyreihixenvk3ahqbytas4hk4a26w43bh6eo3w6usjqtxkpzsvi655a3m"
181 .parse()
182 .unwrap()
183 }
184 // fn cid2() -> Cid {
185 // "QmY7Yh4UquoXHLPFo2XbhXkhBvFoPwmQUSa92pxnxjQuPU"
186 // .parse()
187 // .unwrap()
188 // }
189 // fn cid3() -> Cid {
190 // "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
191 // .parse()
192 // .unwrap()
193 // }
194 // fn cid4() -> Cid {
195 // "QmbWqxBEKC3P8tqsKc98xmWNzrzDtRLMiMPL8wBuTGsMnR"
196 // .parse()
197 // .unwrap()
198 // }
199 // fn cid5() -> Cid {
200 // "QmSnuWmxptJZdLJpKRarxBMS2Ju2oANVrgbr2xWbie9b2D"
201 // .parse()
202 // .unwrap()
203 // }
204 // fn cid6() -> Cid {
205 // "QmdmQXB2mzChmMeKY47C43LxUdg1NDJ5MWcKMKxDu7RgQm"
206 // .parse()
207 // .unwrap()
208 // }
209 // fn cid7() -> Cid {
210 // "bafybeiaysi4s6lnjev27ln5icwm6tueaw2vdykrtjkwiphwekaywqhcjze"
211 // .parse()
212 // .unwrap()
213 // }
214 // fn cid8() -> Cid {
215 // "bafyreif3tfdpr5n4jdrbielmcapwvbpcthepfkwq2vwonmlhirbjmotedi"
216 // .parse()
217 // .unwrap()
218 // }
219 // fn cid9() -> Cid {
220 // "bafyreicnokmhmrnlp2wjhyk2haep4tqxiptwfrp2rrs7rzq7uk766chqvq"
221 // .parse()
222 // .unwrap()
223 // }
224
225 #[test]
226 fn test_next_from_node_empty() {
227 let node = Node {
228 left: None,
229 entries: vec![],
230 };
231 let mut stack = vec![];
232 push_from_node(&mut stack, &node).unwrap();
233 assert_eq!(stack.last(), None);
234 }
235
236 #[test]
237 fn test_needs_from_node_just_left() {
238 let node = Node {
239 left: Some(cid1()),
240 entries: vec![],
241 };
242 let mut stack = vec![];
243 push_from_node(&mut stack, &node).unwrap();
244 assert_eq!(stack.last(), Some(Need::Node(cid1())).as_ref());
245 }
246
247 // #[test]
248 // fn test_needs_from_node_just_one_record() {
249 // let node = Node {
250 // left: None,
251 // entries: vec![Entry {
252 // keysuffix: "asdf".into(),
253 // prefix_len: 0,
254 // value: cid1(),
255 // tree: None,
256 // }],
257 // };
258 // assert_eq!(
259 // needs_from_node(node).unwrap(),
260 // vec![Need::Record {
261 // rkey: "asdf".into(),
262 // cid: cid1(),
263 // },]
264 // );
265 // }
266
267 // #[test]
268 // fn test_needs_from_node_two_records() {
269 // let node = Node {
270 // left: None,
271 // entries: vec![
272 // Entry {
273 // keysuffix: "asdf".into(),
274 // prefix_len: 0,
275 // value: cid1(),
276 // tree: None,
277 // },
278 // Entry {
279 // keysuffix: "gh".into(),
280 // prefix_len: 2,
281 // value: cid2(),
282 // tree: None,
283 // },
284 // ],
285 // };
286 // assert_eq!(
287 // needs_from_node(node).unwrap(),
288 // vec![
289 // Need::Record {
290 // rkey: "asdf".into(),
291 // cid: cid1(),
292 // },
293 // Need::Record {
294 // rkey: "asgh".into(),
295 // cid: cid2(),
296 // },
297 // ]
298 // );
299 // }
300
301 // #[test]
302 // fn test_needs_from_node_with_both() {
303 // let node = Node {
304 // left: None,
305 // entries: vec![Entry {
306 // keysuffix: "asdf".into(),
307 // prefix_len: 0,
308 // value: cid1(),
309 // tree: Some(cid2()),
310 // }],
311 // };
312 // assert_eq!(
313 // needs_from_node(node).unwrap(),
314 // vec![
315 // Need::Record {
316 // rkey: "asdf".into(),
317 // cid: cid1(),
318 // },
319 // Need::Node(cid2()),
320 // ]
321 // );
322 // }
323
324 // #[test]
325 // fn test_needs_from_node_left_and_record() {
326 // let node = Node {
327 // left: Some(cid1()),
328 // entries: vec![Entry {
329 // keysuffix: "asdf".into(),
330 // prefix_len: 0,
331 // value: cid2(),
332 // tree: None,
333 // }],
334 // };
335 // assert_eq!(
336 // needs_from_node(node).unwrap(),
337 // vec![
338 // Need::Node(cid1()),
339 // Need::Record {
340 // rkey: "asdf".into(),
341 // cid: cid2(),
342 // },
343 // ]
344 // );
345 // }
346
347 // #[test]
348 // fn test_needs_from_full_node() {
349 // let node = Node {
350 // left: Some(cid1()),
351 // entries: vec![
352 // Entry {
353 // keysuffix: "asdf".into(),
354 // prefix_len: 0,
355 // value: cid2(),
356 // tree: Some(cid3()),
357 // },
358 // Entry {
359 // keysuffix: "ghi".into(),
360 // prefix_len: 1,
361 // value: cid4(),
362 // tree: Some(cid5()),
363 // },
364 // Entry {
365 // keysuffix: "jkl".into(),
366 // prefix_len: 2,
367 // value: cid6(),
368 // tree: Some(cid7()),
369 // },
370 // Entry {
371 // keysuffix: "mno".into(),
372 // prefix_len: 4,
373 // value: cid8(),
374 // tree: Some(cid9()),
375 // },
376 // ],
377 // };
378 // assert_eq!(
379 // needs_from_node(node).unwrap(),
380 // vec![
381 // Need::Node(cid1()),
382 // Need::Record {
383 // rkey: "asdf".into(),
384 // cid: cid2(),
385 // },
386 // Need::Node(cid3()),
387 // Need::Record {
388 // rkey: "aghi".into(),
389 // cid: cid4(),
390 // },
391 // Need::Node(cid5()),
392 // Need::Record {
393 // rkey: "agjkl".into(),
394 // cid: cid6(),
395 // },
396 // Need::Node(cid7()),
397 // Need::Record {
398 // rkey: "agjkmno".into(),
399 // cid: cid8(),
400 // },
401 // Need::Node(cid9()),
402 // ]
403 // );
404 // }
405}