repo_stream/
drive.rs

1//! Consume a CAR from an AsyncRead, producing an ordered stream of records
2
3use crate::disk::{DiskError, DiskStore};
4use crate::process::Processable;
5use ipld_core::cid::Cid;
6use iroh_car::CarReader;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::convert::Infallible;
10use tokio::{io::AsyncRead, sync::mpsc};
11
12use crate::mst::{Commit, Node};
13use crate::walk::{Step, WalkError, Walker};
14
15/// Errors that can happen while consuming and emitting blocks and records
16#[derive(Debug, thiserror::Error)]
17pub enum DriveError {
18    #[error("Error from iroh_car: {0}")]
19    CarReader(#[from] iroh_car::Error),
20    #[error("Failed to decode commit block: {0}")]
21    BadBlock(#[from] serde_ipld_dagcbor::DecodeError<Infallible>),
22    #[error("The Commit block reference by the root was not found")]
23    MissingCommit,
24    #[error("The MST block {0} could not be found")]
25    MissingBlock(Cid),
26    #[error("Failed to walk the mst tree: {0}")]
27    WalkError(#[from] WalkError),
28    #[error("CAR file had no roots")]
29    MissingRoot,
30    #[error("Storage error")]
31    StorageError(#[from] DiskError),
32    #[error("Encode error: {0}")]
33    BincodeEncodeError(#[from] bincode::error::EncodeError),
34    #[error("Tried to send on a closed channel")]
35    ChannelSendError, // SendError takes <T> which we don't need
36    #[error("Failed to join a task: {0}")]
37    JoinError(#[from] tokio::task::JoinError),
38}
39
40#[derive(Debug, thiserror::Error)]
41pub enum DecodeError {
42    #[error(transparent)]
43    BincodeDecodeError(#[from] bincode::error::DecodeError),
44    #[error("extra bytes remained after decoding")]
45    ExtraGarbage,
46}
47
48/// An in-order chunk of Rkey + (processed) Block pairs
49pub type BlockChunk<T> = Vec<(String, T)>;
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub(crate) enum MaybeProcessedBlock<T> {
53    /// A block that's *probably* a Node (but we can't know yet)
54    ///
55    /// It *can be* a record that suspiciously looks a lot like a node, so we
56    /// cannot eagerly turn it into a Node. We only know for sure what it is
57    /// when we actually walk down the MST
58    Raw(Vec<u8>),
59    /// A processed record from a block that was definitely not a Node
60    ///
61    /// Processing has to be fallible because the CAR can have totally-unused
62    /// blocks, which can just be garbage. since we're eagerly trying to process
63    /// record blocks without knowing for sure that they *are* records, we
64    /// discard any definitely-not-nodes that fail processing and keep their
65    /// error in the buffer for them. if we later try to retreive them as a
66    /// record, then we can surface the error.
67    ///
68    /// If we _never_ needed this block, then we may have wasted a bit of effort
69    /// trying to process it. Oh well.
70    ///
71    /// There's an alternative here, which would be to kick unprocessable blocks
72    /// back to Raw, or maybe even a new RawUnprocessable variant. Then we could
73    /// surface the typed error later if needed by trying to reprocess.
74    Processed(T),
75}
76
77impl<T: Processable> Processable for MaybeProcessedBlock<T> {
78    /// TODO this is probably a little broken
79    fn get_size(&self) -> usize {
80        use std::{cmp::max, mem::size_of};
81
82        // enum is always as big as its biggest member?
83        let base_size = max(size_of::<Vec<u8>>(), size_of::<T>());
84
85        let extra = match self {
86            Self::Raw(bytes) => bytes.len(),
87            Self::Processed(t) => t.get_size(),
88        };
89
90        base_size + extra
91    }
92}
93
94impl<T> MaybeProcessedBlock<T> {
95    fn maybe(process: fn(Vec<u8>) -> T, data: Vec<u8>) -> Self {
96        if Node::could_be(&data) {
97            MaybeProcessedBlock::Raw(data)
98        } else {
99            MaybeProcessedBlock::Processed(process(data))
100        }
101    }
102}
103
104/// Read a CAR file, buffering blocks in memory or to disk
105pub enum Driver<R: AsyncRead + Unpin, T: Processable> {
106    /// All blocks fit within the memory limit
107    ///
108    /// You probably want to check the commit's signature. You can go ahead and
109    /// walk the MST right away.
110    Memory(Commit, MemDriver<T>),
111    /// Blocks exceed the memory limit
112    ///
113    /// You'll need to provide a disk storage to continue. The commit will be
114    /// returned and can be validated only once all blocks are loaded.
115    Disk(NeedDisk<R, T>),
116}
117
118/// Builder-style driver setup
119pub struct DriverBuilder {
120    pub mem_limit_mb: usize,
121}
122
123impl Default for DriverBuilder {
124    fn default() -> Self {
125        Self { mem_limit_mb: 16 }
126    }
127}
128
129impl DriverBuilder {
130    /// Begin configuring the driver with defaults
131    pub fn new() -> Self {
132        Default::default()
133    }
134    /// Set the in-memory size limit, in MiB
135    ///
136    /// Default: 16 MiB
137    pub fn with_mem_limit_mb(self, new_limit: usize) -> Self {
138        Self {
139            mem_limit_mb: new_limit,
140        }
141    }
142    /// Set the block processor
143    ///
144    /// Default: noop, raw blocks will be emitted
145    pub fn with_block_processor<T: Processable>(
146        self,
147        p: fn(Vec<u8>) -> T,
148    ) -> DriverBuilderWithProcessor<T> {
149        DriverBuilderWithProcessor {
150            mem_limit_mb: self.mem_limit_mb,
151            block_processor: p,
152        }
153    }
154    /// Begin processing an atproto MST from a CAR file
155    pub async fn load_car<R: AsyncRead + Unpin>(
156        self,
157        reader: R,
158    ) -> Result<Driver<R, Vec<u8>>, DriveError> {
159        Driver::load_car(reader, crate::process::noop, self.mem_limit_mb).await
160    }
161}
162
163/// Builder-style driver intermediate step
164///
165/// start from `DriverBuilder`
166pub struct DriverBuilderWithProcessor<T: Processable> {
167    pub mem_limit_mb: usize,
168    pub block_processor: fn(Vec<u8>) -> T,
169}
170
171impl<T: Processable> DriverBuilderWithProcessor<T> {
172    /// Set the in-memory size limit, in MiB
173    ///
174    /// Default: 16 MiB
175    pub fn with_mem_limit_mb(mut self, new_limit: usize) -> Self {
176        self.mem_limit_mb = new_limit;
177        self
178    }
179    /// Begin processing an atproto MST from a CAR file
180    pub async fn load_car<R: AsyncRead + Unpin>(
181        self,
182        reader: R,
183    ) -> Result<Driver<R, T>, DriveError> {
184        Driver::load_car(reader, self.block_processor, self.mem_limit_mb).await
185    }
186}
187
188impl<R: AsyncRead + Unpin, T: Processable> Driver<R, T> {
189    /// Begin processing an atproto MST from a CAR file
190    ///
191    /// Blocks will be loaded, processed, and buffered in memory. If the entire
192    /// processed size is under the `mem_limit_mb` limit, a `Driver::Memory`
193    /// will be returned along with a `Commit` ready for validation.
194    ///
195    /// If the `mem_limit_mb` limit is reached before loading all blocks, the
196    /// partial state will be returned as `Driver::Disk(needed)`, which can be
197    /// resumed by providing a `SqliteStorage` for on-disk block storage.
198    pub async fn load_car(
199        reader: R,
200        process: fn(Vec<u8>) -> T,
201        mem_limit_mb: usize,
202    ) -> Result<Driver<R, T>, DriveError> {
203        let max_size = mem_limit_mb * 2_usize.pow(20);
204        let mut mem_blocks = HashMap::new();
205
206        let mut car = CarReader::new(reader).await?;
207
208        let root = *car
209            .header()
210            .roots()
211            .first()
212            .ok_or(DriveError::MissingRoot)?;
213        log::debug!("root: {root:?}");
214
215        let mut commit = None;
216
217        // try to load all the blocks into memory
218        let mut mem_size = 0;
219        while let Some((cid, data)) = car.next_block().await? {
220            // the root commit is a Special Third Kind of block that we need to make
221            // sure not to optimistically send to the processing function
222            if cid == root {
223                let c: Commit = serde_ipld_dagcbor::from_slice(&data)?;
224                commit = Some(c);
225                continue;
226            }
227
228            // remaining possible types: node, record, other. optimistically process
229            let maybe_processed = MaybeProcessedBlock::maybe(process, data);
230
231            // stash (maybe processed) blocks in memory as long as we have room
232            mem_size += std::mem::size_of::<Cid>() + maybe_processed.get_size();
233            mem_blocks.insert(cid, maybe_processed);
234            if mem_size >= max_size {
235                return Ok(Driver::Disk(NeedDisk {
236                    car,
237                    root,
238                    process,
239                    max_size,
240                    mem_blocks,
241                    commit,
242                }));
243            }
244        }
245
246        // all blocks loaded and we fit in memory! hopefully we found the commit...
247        let commit = commit.ok_or(DriveError::MissingCommit)?;
248
249        let walker = Walker::new(commit.data);
250
251        Ok(Driver::Memory(
252            commit,
253            MemDriver {
254                blocks: mem_blocks,
255                walker,
256                process,
257            },
258        ))
259    }
260}
261
262/// The core driver between the block stream and MST walker
263///
264/// In the future, PDSs will export CARs in a stream-friendly order that will
265/// enable processing them with tiny memory overhead. But that future is not
266/// here yet.
267///
268/// CARs are almost always in a stream-unfriendly order, so I'm reverting the
269/// optimistic stream features: we load all block first, then walk the MST.
270///
271/// This makes things much simpler: we only need to worry about spilling to disk
272/// in one place, and we always have a reasonable expecatation about how much
273/// work the init function will do. We can drop the CAR reader before walking,
274/// so the sync/async boundaries become a little easier to work around.
275#[derive(Debug)]
276pub struct MemDriver<T: Processable> {
277    blocks: HashMap<Cid, MaybeProcessedBlock<T>>,
278    walker: Walker,
279    process: fn(Vec<u8>) -> T,
280}
281
282impl<T: Processable> MemDriver<T> {
283    /// Step through the record outputs, in rkey order
284    pub async fn next_chunk(&mut self, n: usize) -> Result<Option<BlockChunk<T>>, DriveError> {
285        let mut out = Vec::with_capacity(n);
286        for _ in 0..n {
287            // walk as far as we can until we run out of blocks or find a record
288            match self.walker.step(&mut self.blocks, self.process)? {
289                Step::Missing(cid) => return Err(DriveError::MissingBlock(cid)),
290                Step::Finish => break,
291                Step::Found { rkey, data } => {
292                    out.push((rkey, data));
293                    continue;
294                }
295            };
296        }
297
298        if out.is_empty() {
299            Ok(None)
300        } else {
301            Ok(Some(out))
302        }
303    }
304}
305
306/// A partially memory-loaded car file that needs disk spillover to continue
307pub struct NeedDisk<R: AsyncRead + Unpin, T: Processable> {
308    car: CarReader<R>,
309    root: Cid,
310    process: fn(Vec<u8>) -> T,
311    max_size: usize,
312    mem_blocks: HashMap<Cid, MaybeProcessedBlock<T>>,
313    pub commit: Option<Commit>,
314}
315
316fn encode(v: impl Serialize) -> Result<Vec<u8>, bincode::error::EncodeError> {
317    bincode::serde::encode_to_vec(v, bincode::config::standard())
318}
319
320pub(crate) fn decode<T: Processable>(bytes: &[u8]) -> Result<T, DecodeError> {
321    let (t, n) = bincode::serde::decode_from_slice(bytes, bincode::config::standard())?;
322    if n != bytes.len() {
323        return Err(DecodeError::ExtraGarbage);
324    }
325    Ok(t)
326}
327
328impl<R: AsyncRead + Unpin, T: Processable + Send + 'static> NeedDisk<R, T> {
329    pub async fn finish_loading(
330        mut self,
331        mut store: DiskStore,
332    ) -> Result<(Commit, DiskDriver<T>), DriveError> {
333        // move store in and back out so we can manage lifetimes
334        // dump mem blocks into the store
335        store = tokio::task::spawn(async move {
336            let mut writer = store.get_writer()?;
337
338            let kvs = self
339                .mem_blocks
340                .into_iter()
341                .map(|(k, v)| Ok(encode(v).map(|v| (k.to_bytes(), v))?));
342
343            writer.put_many(kvs)?;
344            writer.commit()?;
345            Ok::<_, DriveError>(store)
346        })
347        .await??;
348
349        let (tx, mut rx) = mpsc::channel::<Vec<(Cid, MaybeProcessedBlock<T>)>>(2);
350
351        let store_worker = tokio::task::spawn_blocking(move || {
352            let mut writer = store.get_writer()?;
353
354            while let Some(chunk) = rx.blocking_recv() {
355                let kvs = chunk
356                    .into_iter()
357                    .map(|(k, v)| Ok(encode(v).map(|v| (k.to_bytes(), v))?));
358                writer.put_many(kvs)?;
359            }
360
361            writer.commit()?;
362            Ok::<_, DriveError>(store)
363        }); // await later
364
365        // dump the rest to disk (in chunks)
366        log::debug!("dumping the rest of the stream...");
367        loop {
368            let mut mem_size = 0;
369            let mut chunk = vec![];
370            loop {
371                let Some((cid, data)) = self.car.next_block().await? else {
372                    break;
373                };
374                // we still gotta keep checking for the root since we might not have it
375                if cid == self.root {
376                    let c: Commit = serde_ipld_dagcbor::from_slice(&data)?;
377                    self.commit = Some(c);
378                    continue;
379                }
380                // remaining possible types: node, record, other. optimistically process
381                // TODO: get the actual in-memory size to compute disk spill
382                let maybe_processed = MaybeProcessedBlock::maybe(self.process, data);
383                mem_size += std::mem::size_of::<Cid>() + maybe_processed.get_size();
384                chunk.push((cid, maybe_processed));
385                if mem_size >= self.max_size {
386                    // soooooo if we're setting the db cache to max_size and then letting
387                    // multiple chunks in the queue that are >= max_size, then at any time
388                    // we might be using some multiple of max_size?
389                    break;
390                }
391            }
392            if chunk.is_empty() {
393                break;
394            }
395            tx.send(chunk)
396                .await
397                .map_err(|_| DriveError::ChannelSendError)?;
398        }
399        drop(tx);
400        log::debug!("done. waiting for worker to finish...");
401
402        store = store_worker.await??;
403
404        log::debug!("worker finished.");
405
406        let commit = self.commit.ok_or(DriveError::MissingCommit)?;
407
408        let walker = Walker::new(commit.data);
409
410        Ok((
411            commit,
412            DiskDriver {
413                process: self.process,
414                state: Some(BigState { store, walker }),
415            },
416        ))
417    }
418}
419
420struct BigState {
421    store: DiskStore,
422    walker: Walker,
423}
424
425/// MST walker that reads from disk instead of an in-memory hashmap
426pub struct DiskDriver<T: Clone> {
427    process: fn(Vec<u8>) -> T,
428    state: Option<BigState>,
429}
430
431// for doctests only
432#[doc(hidden)]
433pub fn _get_fake_disk_driver() -> DiskDriver<Vec<u8>> {
434    use crate::process::noop;
435    DiskDriver {
436        process: noop,
437        state: None,
438    }
439}
440
441impl<T: Processable + Send + 'static> DiskDriver<T> {
442    /// Walk the MST returning up to `n` rkey + record pairs
443    ///
444    /// ```no_run
445    /// # use repo_stream::{drive::{DiskDriver, DriveError, _get_fake_disk_driver}, process::noop};
446    /// # #[tokio::main]
447    /// # async fn main() -> Result<(), DriveError> {
448    /// # let mut disk_driver = _get_fake_disk_driver();
449    /// while let Some(pairs) = disk_driver.next_chunk(256).await? {
450    ///     for (rkey, record) in pairs {
451    ///         println!("{rkey}: size={}", record.len());
452    ///     }
453    /// }
454    /// let store = disk_driver.reset_store().await?;
455    /// # Ok(())
456    /// # }
457    /// ```
458    pub async fn next_chunk(&mut self, n: usize) -> Result<Option<BlockChunk<T>>, DriveError> {
459        let process = self.process;
460
461        // state should only *ever* be None transiently while inside here
462        let mut state = self.state.take().expect("DiskDriver must have Some(state)");
463
464        // the big pain here is that we don't want to leave self.state in an
465        // invalid state (None), so all the error paths have to make sure it
466        // comes out again.
467        let (state, res) = tokio::task::spawn_blocking(
468            move || -> (BigState, Result<BlockChunk<T>, DriveError>) {
469                let mut reader_res = state.store.get_reader();
470                let reader: &mut _ = match reader_res {
471                    Ok(ref mut r) => r,
472                    Err(ref mut e) => {
473                        // unfortunately we can't return the error directly because
474                        // (for some reason) it's attached to the lifetime of the
475                        // reader?
476                        // hack a mem::swap so we can get it out :/
477                        let e_swapped = e.steal();
478                        // the pain: `state` *has to* outlive the reader
479                        drop(reader_res);
480                        return (state, Err(e_swapped.into()));
481                    }
482                };
483
484                let mut out = Vec::with_capacity(n);
485
486                for _ in 0..n {
487                    // walk as far as we can until we run out of blocks or find a record
488                    let step = match state.walker.disk_step(reader, process) {
489                        Ok(s) => s,
490                        Err(e) => {
491                            // the pain: `state` *has to* outlive the reader
492                            drop(reader_res);
493                            return (state, Err(e.into()));
494                        }
495                    };
496                    match step {
497                        Step::Missing(cid) => {
498                            // the pain: `state` *has to* outlive the reader
499                            drop(reader_res);
500                            return (state, Err(DriveError::MissingBlock(cid)));
501                        }
502                        Step::Finish => break,
503                        Step::Found { rkey, data } => out.push((rkey, data)),
504                    };
505                }
506
507                // `state` *has to* outlive the reader
508                drop(reader_res);
509
510                (state, Ok::<_, DriveError>(out))
511            },
512        )
513        .await?; // on tokio JoinError, we'll be left with invalid state :(
514
515        // *must* restore state before dealing with the actual result
516        self.state = Some(state);
517
518        let out = res?;
519
520        if out.is_empty() {
521            Ok(None)
522        } else {
523            Ok(Some(out))
524        }
525    }
526
527    fn read_tx_blocking(
528        &mut self,
529        n: usize,
530        tx: mpsc::Sender<Result<BlockChunk<T>, DriveError>>,
531    ) -> Result<(), mpsc::error::SendError<Result<BlockChunk<T>, DriveError>>> {
532        let BigState { store, walker } = self.state.as_mut().expect("valid state");
533        let mut reader = match store.get_reader() {
534            Ok(r) => r,
535            Err(e) => return tx.blocking_send(Err(e.into())),
536        };
537
538        loop {
539            let mut out: BlockChunk<T> = Vec::with_capacity(n);
540
541            for _ in 0..n {
542                // walk as far as we can until we run out of blocks or find a record
543
544                let step = match walker.disk_step(&mut reader, self.process) {
545                    Ok(s) => s,
546                    Err(e) => return tx.blocking_send(Err(e.into())),
547                };
548
549                match step {
550                    Step::Missing(cid) => {
551                        return tx.blocking_send(Err(DriveError::MissingBlock(cid)));
552                    }
553                    Step::Finish => return Ok(()),
554                    Step::Found { rkey, data } => {
555                        out.push((rkey, data));
556                        continue;
557                    }
558                };
559            }
560
561            if out.is_empty() {
562                break;
563            }
564            tx.blocking_send(Ok(out))?;
565        }
566
567        Ok(())
568    }
569
570    /// Spawn the disk reading task into a tokio blocking thread
571    ///
572    /// The idea is to avoid so much sending back and forth to the blocking
573    /// thread, letting a blocking task do all the disk reading work and sending
574    /// records and rkeys back through an `mpsc` channel instead.
575    ///
576    /// This might also allow the disk work to continue while processing the
577    /// records. It's still not yet clear if this method actually has much
578    /// benefit over just using `.next_chunk(n)`.
579    ///
580    /// ```no_run
581    /// # use repo_stream::{drive::{DiskDriver, DriveError, _get_fake_disk_driver}, process::noop};
582    /// # #[tokio::main]
583    /// # async fn main() -> Result<(), DriveError> {
584    /// # let mut disk_driver = _get_fake_disk_driver();
585    /// let (mut rx, join) = disk_driver.to_channel(512);
586    /// while let Some(recvd) = rx.recv().await {
587    ///     let pairs = recvd?;
588    ///     for (rkey, record) in pairs {
589    ///         println!("{rkey}: size={}", record.len());
590    ///     }
591    ///
592    /// }
593    /// let store = join.await?.reset_store().await?;
594    /// # Ok(())
595    /// # }
596    /// ```
597    pub fn to_channel(
598        mut self,
599        n: usize,
600    ) -> (
601        mpsc::Receiver<Result<BlockChunk<T>, DriveError>>,
602        tokio::task::JoinHandle<Self>,
603    ) {
604        let (tx, rx) = mpsc::channel::<Result<BlockChunk<T>, DriveError>>(1);
605
606        // sketch: this worker is going to be allowed to execute without a join handle
607        let chan_task = tokio::task::spawn_blocking(move || {
608            if let Err(mpsc::error::SendError(_)) = self.read_tx_blocking(n, tx) {
609                log::debug!("big car reader exited early due to dropped receiver channel");
610            }
611            self
612        });
613
614        (rx, chan_task)
615    }
616
617    /// Reset the disk storage so it can be reused. You must call this.
618    ///
619    /// Ideally we'd put this in an `impl Drop`, but since it makes blocking
620    /// calls, that would be risky in an async context. For now you just have to
621    /// carefully make sure you call it.
622    ///
623    /// The sqlite store is returned, so it can be reused for another
624    /// `DiskDriver`.
625    pub async fn reset_store(mut self) -> Result<DiskStore, DriveError> {
626        let BigState { store, .. } = self.state.take().expect("valid state");
627        Ok(store.reset().await?)
628    }
629}