repo_stream/
drive.rs

1//! Consume a CAR from an AsyncRead, producing an ordered stream of records
2
3use crate::disk::{DiskError, DiskStore};
4use crate::process::Processable;
5use ipld_core::cid::Cid;
6use iroh_car::CarReader;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::convert::Infallible;
10use tokio::{io::AsyncRead, sync::mpsc};
11
12use crate::mst::{Commit, Node};
13use crate::walk::{Step, WalkError, Walker};
14
15/// Errors that can happen while consuming and emitting blocks and records
16#[derive(Debug, thiserror::Error)]
17pub enum DriveError {
18    #[error("Error from iroh_car: {0}")]
19    CarReader(#[from] iroh_car::Error),
20    #[error("Failed to decode commit block: {0}")]
21    BadBlock(#[from] serde_ipld_dagcbor::DecodeError<Infallible>),
22    #[error("The Commit block reference by the root was not found")]
23    MissingCommit,
24    #[error("The MST block {0} could not be found")]
25    MissingBlock(Cid),
26    #[error("Failed to walk the mst tree: {0}")]
27    WalkError(#[from] WalkError),
28    #[error("CAR file had no roots")]
29    MissingRoot,
30    #[error("Storage error")]
31    StorageError(#[from] DiskError),
32    #[error("Encode error: {0}")]
33    BincodeEncodeError(#[from] bincode::error::EncodeError),
34    #[error("Tried to send on a closed channel")]
35    ChannelSendError, // SendError takes <T> which we don't need
36    #[error("Failed to join a task: {0}")]
37    JoinError(#[from] tokio::task::JoinError),
38}
39
40#[derive(Debug, thiserror::Error)]
41pub enum DecodeError {
42    #[error(transparent)]
43    BincodeDecodeError(#[from] bincode::error::DecodeError),
44    #[error("extra bytes remained after decoding")]
45    ExtraGarbage,
46}
47
48/// An in-order chunk of Rkey + (processed) Block pairs
49pub type BlockChunk<T> = Vec<(String, T)>;
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub(crate) enum MaybeProcessedBlock<T> {
53    /// A block that's *probably* a Node (but we can't know yet)
54    ///
55    /// It *can be* a record that suspiciously looks a lot like a node, so we
56    /// cannot eagerly turn it into a Node. We only know for sure what it is
57    /// when we actually walk down the MST
58    Raw(Vec<u8>),
59    /// A processed record from a block that was definitely not a Node
60    ///
61    /// Processing has to be fallible because the CAR can have totally-unused
62    /// blocks, which can just be garbage. since we're eagerly trying to process
63    /// record blocks without knowing for sure that they *are* records, we
64    /// discard any definitely-not-nodes that fail processing and keep their
65    /// error in the buffer for them. if we later try to retreive them as a
66    /// record, then we can surface the error.
67    ///
68    /// If we _never_ needed this block, then we may have wasted a bit of effort
69    /// trying to process it. Oh well.
70    ///
71    /// There's an alternative here, which would be to kick unprocessable blocks
72    /// back to Raw, or maybe even a new RawUnprocessable variant. Then we could
73    /// surface the typed error later if needed by trying to reprocess.
74    Processed(T),
75}
76
77impl<T: Processable> Processable for MaybeProcessedBlock<T> {
78    /// TODO this is probably a little broken
79    fn get_size(&self) -> usize {
80        use std::{cmp::max, mem::size_of};
81
82        // enum is always as big as its biggest member?
83        let base_size = max(size_of::<Vec<u8>>(), size_of::<T>());
84
85        let extra = match self {
86            Self::Raw(bytes) => bytes.len(),
87            Self::Processed(t) => t.get_size(),
88        };
89
90        base_size + extra
91    }
92}
93
94impl<T> MaybeProcessedBlock<T> {
95    fn maybe(process: fn(Vec<u8>) -> T, data: Vec<u8>) -> Self {
96        if Node::could_be(&data) {
97            MaybeProcessedBlock::Raw(data)
98        } else {
99            MaybeProcessedBlock::Processed(process(data))
100        }
101    }
102}
103
104/// Read a CAR file, buffering blocks in memory or to disk
105pub enum Driver<R: AsyncRead + Unpin, T: Processable> {
106    /// All blocks fit within the memory limit
107    ///
108    /// You probably want to check the commit's signature. You can go ahead and
109    /// walk the MST right away.
110    Memory(Commit, MemDriver<T>),
111    /// Blocks exceed the memory limit
112    ///
113    /// You'll need to provide a disk storage to continue. The commit will be
114    /// returned and can be validated only once all blocks are loaded.
115    Disk(NeedDisk<R, T>),
116}
117
118/// Builder-style driver setup
119#[derive(Debug, Clone)]
120pub struct DriverBuilder {
121    pub mem_limit_mb: usize,
122}
123
124impl Default for DriverBuilder {
125    fn default() -> Self {
126        Self { mem_limit_mb: 16 }
127    }
128}
129
130impl DriverBuilder {
131    /// Begin configuring the driver with defaults
132    pub fn new() -> Self {
133        Default::default()
134    }
135    /// Set the in-memory size limit, in MiB
136    ///
137    /// Default: 16 MiB
138    pub fn with_mem_limit_mb(self, new_limit: usize) -> Self {
139        Self {
140            mem_limit_mb: new_limit,
141        }
142    }
143    /// Set the block processor
144    ///
145    /// Default: noop, raw blocks will be emitted
146    pub fn with_block_processor<T: Processable>(
147        self,
148        p: fn(Vec<u8>) -> T,
149    ) -> DriverBuilderWithProcessor<T> {
150        DriverBuilderWithProcessor {
151            mem_limit_mb: self.mem_limit_mb,
152            block_processor: p,
153        }
154    }
155    /// Begin processing an atproto MST from a CAR file
156    pub async fn load_car<R: AsyncRead + Unpin>(
157        &self,
158        reader: R,
159    ) -> Result<Driver<R, Vec<u8>>, DriveError> {
160        Driver::load_car(reader, crate::process::noop, self.mem_limit_mb).await
161    }
162}
163
164/// Builder-style driver intermediate step
165///
166/// start from `DriverBuilder`
167#[derive(Debug, Clone)]
168pub struct DriverBuilderWithProcessor<T: Processable> {
169    pub mem_limit_mb: usize,
170    pub block_processor: fn(Vec<u8>) -> T,
171}
172
173impl<T: Processable> DriverBuilderWithProcessor<T> {
174    /// Set the in-memory size limit, in MiB
175    ///
176    /// Default: 16 MiB
177    pub fn with_mem_limit_mb(mut self, new_limit: usize) -> Self {
178        self.mem_limit_mb = new_limit;
179        self
180    }
181    /// Begin processing an atproto MST from a CAR file
182    pub async fn load_car<R: AsyncRead + Unpin>(
183        &self,
184        reader: R,
185    ) -> Result<Driver<R, T>, DriveError> {
186        Driver::load_car(reader, self.block_processor, self.mem_limit_mb).await
187    }
188}
189
190impl<R: AsyncRead + Unpin, T: Processable> Driver<R, T> {
191    /// Begin processing an atproto MST from a CAR file
192    ///
193    /// Blocks will be loaded, processed, and buffered in memory. If the entire
194    /// processed size is under the `mem_limit_mb` limit, a `Driver::Memory`
195    /// will be returned along with a `Commit` ready for validation.
196    ///
197    /// If the `mem_limit_mb` limit is reached before loading all blocks, the
198    /// partial state will be returned as `Driver::Disk(needed)`, which can be
199    /// resumed by providing a `SqliteStorage` for on-disk block storage.
200    pub async fn load_car(
201        reader: R,
202        process: fn(Vec<u8>) -> T,
203        mem_limit_mb: usize,
204    ) -> Result<Driver<R, T>, DriveError> {
205        let max_size = mem_limit_mb * 2_usize.pow(20);
206        let mut mem_blocks = HashMap::new();
207
208        let mut car = CarReader::new(reader).await?;
209
210        let root = *car
211            .header()
212            .roots()
213            .first()
214            .ok_or(DriveError::MissingRoot)?;
215        log::debug!("root: {root:?}");
216
217        let mut commit = None;
218
219        // try to load all the blocks into memory
220        let mut mem_size = 0;
221        while let Some((cid, data)) = car.next_block().await? {
222            // the root commit is a Special Third Kind of block that we need to make
223            // sure not to optimistically send to the processing function
224            if cid == root {
225                let c: Commit = serde_ipld_dagcbor::from_slice(&data)?;
226                commit = Some(c);
227                continue;
228            }
229
230            // remaining possible types: node, record, other. optimistically process
231            let maybe_processed = MaybeProcessedBlock::maybe(process, data);
232
233            // stash (maybe processed) blocks in memory as long as we have room
234            mem_size += std::mem::size_of::<Cid>() + maybe_processed.get_size();
235            mem_blocks.insert(cid, maybe_processed);
236            if mem_size >= max_size {
237                return Ok(Driver::Disk(NeedDisk {
238                    car,
239                    root,
240                    process,
241                    max_size,
242                    mem_blocks,
243                    commit,
244                }));
245            }
246        }
247
248        // all blocks loaded and we fit in memory! hopefully we found the commit...
249        let commit = commit.ok_or(DriveError::MissingCommit)?;
250
251        let walker = Walker::new(commit.data);
252
253        Ok(Driver::Memory(
254            commit,
255            MemDriver {
256                blocks: mem_blocks,
257                walker,
258                process,
259            },
260        ))
261    }
262}
263
264/// The core driver between the block stream and MST walker
265///
266/// In the future, PDSs will export CARs in a stream-friendly order that will
267/// enable processing them with tiny memory overhead. But that future is not
268/// here yet.
269///
270/// CARs are almost always in a stream-unfriendly order, so I'm reverting the
271/// optimistic stream features: we load all block first, then walk the MST.
272///
273/// This makes things much simpler: we only need to worry about spilling to disk
274/// in one place, and we always have a reasonable expecatation about how much
275/// work the init function will do. We can drop the CAR reader before walking,
276/// so the sync/async boundaries become a little easier to work around.
277#[derive(Debug)]
278pub struct MemDriver<T: Processable> {
279    blocks: HashMap<Cid, MaybeProcessedBlock<T>>,
280    walker: Walker,
281    process: fn(Vec<u8>) -> T,
282}
283
284impl<T: Processable> MemDriver<T> {
285    /// Step through the record outputs, in rkey order
286    pub async fn next_chunk(&mut self, n: usize) -> Result<Option<BlockChunk<T>>, DriveError> {
287        let mut out = Vec::with_capacity(n);
288        for _ in 0..n {
289            // walk as far as we can until we run out of blocks or find a record
290            match self.walker.step(&mut self.blocks, self.process)? {
291                Step::Missing(cid) => return Err(DriveError::MissingBlock(cid)),
292                Step::Finish => break,
293                Step::Found { rkey, data } => {
294                    out.push((rkey, data));
295                    continue;
296                }
297            };
298        }
299
300        if out.is_empty() {
301            Ok(None)
302        } else {
303            Ok(Some(out))
304        }
305    }
306}
307
308/// A partially memory-loaded car file that needs disk spillover to continue
309pub struct NeedDisk<R: AsyncRead + Unpin, T: Processable> {
310    car: CarReader<R>,
311    root: Cid,
312    process: fn(Vec<u8>) -> T,
313    max_size: usize,
314    mem_blocks: HashMap<Cid, MaybeProcessedBlock<T>>,
315    pub commit: Option<Commit>,
316}
317
318fn encode(v: impl Serialize) -> Result<Vec<u8>, bincode::error::EncodeError> {
319    bincode::serde::encode_to_vec(v, bincode::config::standard())
320}
321
322pub(crate) fn decode<T: Processable>(bytes: &[u8]) -> Result<T, DecodeError> {
323    let (t, n) = bincode::serde::decode_from_slice(bytes, bincode::config::standard())?;
324    if n != bytes.len() {
325        return Err(DecodeError::ExtraGarbage);
326    }
327    Ok(t)
328}
329
330impl<R: AsyncRead + Unpin, T: Processable + Send + 'static> NeedDisk<R, T> {
331    pub async fn finish_loading(
332        mut self,
333        mut store: DiskStore,
334    ) -> Result<(Commit, DiskDriver<T>), DriveError> {
335        // move store in and back out so we can manage lifetimes
336        // dump mem blocks into the store
337        store = tokio::task::spawn(async move {
338            let mut writer = store.get_writer()?;
339
340            let kvs = self
341                .mem_blocks
342                .into_iter()
343                .map(|(k, v)| Ok(encode(v).map(|v| (k.to_bytes(), v))?));
344
345            writer.put_many(kvs)?;
346            writer.commit()?;
347            Ok::<_, DriveError>(store)
348        })
349        .await??;
350
351        let (tx, mut rx) = mpsc::channel::<Vec<(Cid, MaybeProcessedBlock<T>)>>(1);
352
353        let store_worker = tokio::task::spawn_blocking(move || {
354            let mut writer = store.get_writer()?;
355
356            while let Some(chunk) = rx.blocking_recv() {
357                let kvs = chunk
358                    .into_iter()
359                    .map(|(k, v)| Ok(encode(v).map(|v| (k.to_bytes(), v))?));
360                writer.put_many(kvs)?;
361            }
362
363            writer.commit()?;
364            Ok::<_, DriveError>(store)
365        }); // await later
366
367        // dump the rest to disk (in chunks)
368        log::debug!("dumping the rest of the stream...");
369        loop {
370            let mut mem_size = 0;
371            let mut chunk = vec![];
372            loop {
373                let Some((cid, data)) = self.car.next_block().await? else {
374                    break;
375                };
376                // we still gotta keep checking for the root since we might not have it
377                if cid == self.root {
378                    let c: Commit = serde_ipld_dagcbor::from_slice(&data)?;
379                    self.commit = Some(c);
380                    continue;
381                }
382                // remaining possible types: node, record, other. optimistically process
383                // TODO: get the actual in-memory size to compute disk spill
384                let maybe_processed = MaybeProcessedBlock::maybe(self.process, data);
385                mem_size += std::mem::size_of::<Cid>() + maybe_processed.get_size();
386                chunk.push((cid, maybe_processed));
387                if mem_size >= self.max_size {
388                    // soooooo if we're setting the db cache to max_size and then letting
389                    // multiple chunks in the queue that are >= max_size, then at any time
390                    // we might be using some multiple of max_size?
391                    break;
392                }
393            }
394            if chunk.is_empty() {
395                break;
396            }
397            tx.send(chunk)
398                .await
399                .map_err(|_| DriveError::ChannelSendError)?;
400        }
401        drop(tx);
402        log::debug!("done. waiting for worker to finish...");
403
404        store = store_worker.await??;
405
406        log::debug!("worker finished.");
407
408        let commit = self.commit.ok_or(DriveError::MissingCommit)?;
409
410        let walker = Walker::new(commit.data);
411
412        Ok((
413            commit,
414            DiskDriver {
415                process: self.process,
416                state: Some(BigState { store, walker }),
417            },
418        ))
419    }
420}
421
422struct BigState {
423    store: DiskStore,
424    walker: Walker,
425}
426
427/// MST walker that reads from disk instead of an in-memory hashmap
428pub struct DiskDriver<T: Clone> {
429    process: fn(Vec<u8>) -> T,
430    state: Option<BigState>,
431}
432
433// for doctests only
434#[doc(hidden)]
435pub fn _get_fake_disk_driver() -> DiskDriver<Vec<u8>> {
436    use crate::process::noop;
437    DiskDriver {
438        process: noop,
439        state: None,
440    }
441}
442
443impl<T: Processable + Send + 'static> DiskDriver<T> {
444    /// Walk the MST returning up to `n` rkey + record pairs
445    ///
446    /// ```no_run
447    /// # use repo_stream::{drive::{DiskDriver, DriveError, _get_fake_disk_driver}, process::noop};
448    /// # #[tokio::main]
449    /// # async fn main() -> Result<(), DriveError> {
450    /// # let mut disk_driver = _get_fake_disk_driver();
451    /// while let Some(pairs) = disk_driver.next_chunk(256).await? {
452    ///     for (rkey, record) in pairs {
453    ///         println!("{rkey}: size={}", record.len());
454    ///     }
455    /// }
456    /// let store = disk_driver.reset_store().await?;
457    /// # Ok(())
458    /// # }
459    /// ```
460    pub async fn next_chunk(&mut self, n: usize) -> Result<Option<BlockChunk<T>>, DriveError> {
461        let process = self.process;
462
463        // state should only *ever* be None transiently while inside here
464        let mut state = self.state.take().expect("DiskDriver must have Some(state)");
465
466        // the big pain here is that we don't want to leave self.state in an
467        // invalid state (None), so all the error paths have to make sure it
468        // comes out again.
469        let (state, res) = tokio::task::spawn_blocking(
470            move || -> (BigState, Result<BlockChunk<T>, DriveError>) {
471                let mut reader_res = state.store.get_reader();
472                let reader: &mut _ = match reader_res {
473                    Ok(ref mut r) => r,
474                    Err(ref mut e) => {
475                        // unfortunately we can't return the error directly because
476                        // (for some reason) it's attached to the lifetime of the
477                        // reader?
478                        // hack a mem::swap so we can get it out :/
479                        let e_swapped = e.steal();
480                        // the pain: `state` *has to* outlive the reader
481                        drop(reader_res);
482                        return (state, Err(e_swapped.into()));
483                    }
484                };
485
486                let mut out = Vec::with_capacity(n);
487
488                for _ in 0..n {
489                    // walk as far as we can until we run out of blocks or find a record
490                    let step = match state.walker.disk_step(reader, process) {
491                        Ok(s) => s,
492                        Err(e) => {
493                            // the pain: `state` *has to* outlive the reader
494                            drop(reader_res);
495                            return (state, Err(e.into()));
496                        }
497                    };
498                    match step {
499                        Step::Missing(cid) => {
500                            // the pain: `state` *has to* outlive the reader
501                            drop(reader_res);
502                            return (state, Err(DriveError::MissingBlock(cid)));
503                        }
504                        Step::Finish => break,
505                        Step::Found { rkey, data } => out.push((rkey, data)),
506                    };
507                }
508
509                // `state` *has to* outlive the reader
510                drop(reader_res);
511
512                (state, Ok::<_, DriveError>(out))
513            },
514        )
515        .await?; // on tokio JoinError, we'll be left with invalid state :(
516
517        // *must* restore state before dealing with the actual result
518        self.state = Some(state);
519
520        let out = res?;
521
522        if out.is_empty() {
523            Ok(None)
524        } else {
525            Ok(Some(out))
526        }
527    }
528
529    fn read_tx_blocking(
530        &mut self,
531        n: usize,
532        tx: mpsc::Sender<Result<BlockChunk<T>, DriveError>>,
533    ) -> Result<(), mpsc::error::SendError<Result<BlockChunk<T>, DriveError>>> {
534        let BigState { store, walker } = self.state.as_mut().expect("valid state");
535        let mut reader = match store.get_reader() {
536            Ok(r) => r,
537            Err(e) => return tx.blocking_send(Err(e.into())),
538        };
539
540        loop {
541            let mut out: BlockChunk<T> = Vec::with_capacity(n);
542
543            for _ in 0..n {
544                // walk as far as we can until we run out of blocks or find a record
545
546                let step = match walker.disk_step(&mut reader, self.process) {
547                    Ok(s) => s,
548                    Err(e) => return tx.blocking_send(Err(e.into())),
549                };
550
551                match step {
552                    Step::Missing(cid) => {
553                        return tx.blocking_send(Err(DriveError::MissingBlock(cid)));
554                    }
555                    Step::Finish => return Ok(()),
556                    Step::Found { rkey, data } => {
557                        out.push((rkey, data));
558                        continue;
559                    }
560                };
561            }
562
563            if out.is_empty() {
564                break;
565            }
566            tx.blocking_send(Ok(out))?;
567        }
568
569        Ok(())
570    }
571
572    /// Spawn the disk reading task into a tokio blocking thread
573    ///
574    /// The idea is to avoid so much sending back and forth to the blocking
575    /// thread, letting a blocking task do all the disk reading work and sending
576    /// records and rkeys back through an `mpsc` channel instead.
577    ///
578    /// This might also allow the disk work to continue while processing the
579    /// records. It's still not yet clear if this method actually has much
580    /// benefit over just using `.next_chunk(n)`.
581    ///
582    /// ```no_run
583    /// # use repo_stream::{drive::{DiskDriver, DriveError, _get_fake_disk_driver}, process::noop};
584    /// # #[tokio::main]
585    /// # async fn main() -> Result<(), DriveError> {
586    /// # let mut disk_driver = _get_fake_disk_driver();
587    /// let (mut rx, join) = disk_driver.to_channel(512);
588    /// while let Some(recvd) = rx.recv().await {
589    ///     let pairs = recvd?;
590    ///     for (rkey, record) in pairs {
591    ///         println!("{rkey}: size={}", record.len());
592    ///     }
593    ///
594    /// }
595    /// let store = join.await?.reset_store().await?;
596    /// # Ok(())
597    /// # }
598    /// ```
599    pub fn to_channel(
600        mut self,
601        n: usize,
602    ) -> (
603        mpsc::Receiver<Result<BlockChunk<T>, DriveError>>,
604        tokio::task::JoinHandle<Self>,
605    ) {
606        let (tx, rx) = mpsc::channel::<Result<BlockChunk<T>, DriveError>>(1);
607
608        // sketch: this worker is going to be allowed to execute without a join handle
609        let chan_task = tokio::task::spawn_blocking(move || {
610            if let Err(mpsc::error::SendError(_)) = self.read_tx_blocking(n, tx) {
611                log::debug!("big car reader exited early due to dropped receiver channel");
612            }
613            self
614        });
615
616        (rx, chan_task)
617    }
618
619    /// Reset the disk storage so it can be reused. You must call this.
620    ///
621    /// Ideally we'd put this in an `impl Drop`, but since it makes blocking
622    /// calls, that would be risky in an async context. For now you just have to
623    /// carefully make sure you call it.
624    ///
625    /// The sqlite store is returned, so it can be reused for another
626    /// `DiskDriver`.
627    pub async fn reset_store(mut self) -> Result<DiskStore, DriveError> {
628        let BigState { store, .. } = self.state.take().expect("valid state");
629        Ok(store.reset().await?)
630    }
631}