use crate::{
Bytes, HashMap,
disk::{DiskError, DiskStore},
mst::MstNode,
walk::Output,
};
use cid::Cid;
use iroh_car::CarReader;
use std::convert::Infallible;
use tokio::{io::AsyncRead, sync::mpsc};
use crate::mst::Commit;
use crate::walk::{WalkError, Walker};
#[derive(Debug, thiserror::Error)]
pub enum DriveError {
#[error("Error from iroh_car: {0}")]
CarReader(#[from] iroh_car::Error),
#[error("Failed to decode commit block: {0}")]
BadBlock(#[from] serde_ipld_dagcbor::DecodeError<Infallible>),
#[error("The Commit block reference by the root was not found")]
MissingCommit,
#[error("Failed to walk the mst tree: {0}")]
WalkError(#[from] WalkError),
#[error("CAR file had no roots")]
MissingRoot,
#[error("Storage error")]
StorageError(#[from] DiskError),
#[error("Tried to send on a closed channel")]
ChannelSendError, #[error("Failed to join a task: {0}")]
JoinError(#[from] tokio::task::JoinError),
}
pub type BlockChunk = Vec<Output>;
#[derive(Debug, Clone)]
pub(crate) enum MaybeProcessedBlock {
Raw(Bytes),
Processed(Bytes),
}
impl MaybeProcessedBlock {
pub(crate) fn maybe(process: fn(Bytes) -> Bytes, data: Bytes) -> Self {
if MstNode::could_be(&data) {
MaybeProcessedBlock::Raw(data)
} else {
MaybeProcessedBlock::Processed(process(data))
}
}
pub(crate) fn len(&self) -> usize {
match self {
MaybeProcessedBlock::Raw(b) => b.len(),
MaybeProcessedBlock::Processed(b) => b.len(),
}
}
pub(crate) fn into_bytes(self) -> Bytes {
match self {
MaybeProcessedBlock::Raw(mut b) => {
b.push(0x00);
b
}
MaybeProcessedBlock::Processed(mut b) => {
b.push(0x01);
b
}
}
}
pub(crate) fn from_bytes(mut b: Bytes) -> Self {
let suffix = b.pop().unwrap();
if suffix == 0x00 {
MaybeProcessedBlock::Raw(b)
} else {
MaybeProcessedBlock::Processed(b)
}
}
}
pub enum Driver<R: AsyncRead + Unpin> {
Memory(Commit, MemDriver),
Disk(NeedDisk<R>),
}
#[inline]
pub fn noop(block: Bytes) -> Bytes {
block
}
#[derive(Debug, Clone)]
pub struct DriverBuilder {
pub mem_limit_mb: usize,
pub block_processor: fn(Bytes) -> Bytes,
}
impl Default for DriverBuilder {
fn default() -> Self {
Self {
mem_limit_mb: 16,
block_processor: noop,
}
}
}
impl DriverBuilder {
pub fn new() -> Self {
Default::default()
}
pub fn with_mem_limit_mb(mut self, new_limit: usize) -> Self {
self.mem_limit_mb = new_limit;
self
}
pub fn with_block_processor(mut self, new_processor: fn(Bytes) -> Bytes) -> DriverBuilder {
self.block_processor = new_processor;
self
}
pub async fn load_car<R: AsyncRead + Unpin>(&self, reader: R) -> Result<Driver<R>, DriveError> {
Driver::load_car(reader, self.block_processor, self.mem_limit_mb).await
}
}
impl<R: AsyncRead + Unpin> Driver<R> {
pub async fn load_car(
reader: R,
process: fn(Bytes) -> Bytes,
mem_limit_mb: usize,
) -> Result<Driver<R>, DriveError> {
let max_size = mem_limit_mb * 2_usize.pow(20);
let mut mem_blocks = HashMap::new();
let mut car = CarReader::new(reader).await?;
let root = *car
.header()
.roots()
.first()
.ok_or(DriveError::MissingRoot)?;
log::debug!("root: {root:?}");
let mut commit = None;
let mut mem_size = 0;
while let Some((cid, data)) = car.next_block().await? {
if cid == root {
let c: Commit = serde_ipld_dagcbor::from_slice(&data)?;
commit = Some(c);
continue;
}
let maybe_processed = MaybeProcessedBlock::maybe(process, data);
mem_size += maybe_processed.len();
mem_blocks.insert(cid, maybe_processed);
if mem_size >= max_size {
return Ok(Driver::Disk(NeedDisk {
car,
root,
process,
max_size,
mem_blocks,
commit,
}));
}
}
let commit = commit.ok_or(DriveError::MissingCommit)?;
let root_node: MstNode = match mem_blocks
.get(&commit.data)
.ok_or(DriveError::MissingCommit)?
{
MaybeProcessedBlock::Processed(_) => Err(WalkError::BadCommitFingerprint)?,
MaybeProcessedBlock::Raw(bytes) => serde_ipld_dagcbor::from_slice(bytes)?,
};
let walker = Walker::new(root_node);
Ok(Driver::Memory(
commit,
MemDriver {
blocks: mem_blocks,
walker,
process,
},
))
}
}
#[derive(Debug)]
pub struct MemDriver {
blocks: HashMap<Cid, MaybeProcessedBlock>,
walker: Walker,
process: fn(Bytes) -> Bytes,
}
impl MemDriver {
pub async fn next_chunk(&mut self, n: usize) -> Result<Option<BlockChunk>, DriveError> {
let mut out = Vec::with_capacity(n);
for _ in 0..n {
let Some(output) = self.walker.step(&mut self.blocks, self.process)? else {
break;
};
out.push(output);
}
if out.is_empty() {
Ok(None)
} else {
Ok(Some(out))
}
}
}
pub struct NeedDisk<R: AsyncRead + Unpin> {
car: CarReader<R>,
root: Cid,
process: fn(Bytes) -> Bytes,
max_size: usize,
mem_blocks: HashMap<Cid, MaybeProcessedBlock>,
pub commit: Option<Commit>,
}
impl<R: AsyncRead + Unpin> NeedDisk<R> {
pub async fn finish_loading(
mut self,
mut store: DiskStore,
) -> Result<(Commit, DiskDriver), DriveError> {
store = tokio::task::spawn(async move {
let kvs = self
.mem_blocks
.into_iter()
.map(|(k, v)| (k.to_bytes(), v.into_bytes()));
store.put_many(kvs)?;
Ok::<_, DriveError>(store)
})
.await??;
let (tx, mut rx) = mpsc::channel::<Vec<(Cid, MaybeProcessedBlock)>>(1);
let store_worker = tokio::task::spawn_blocking(move || {
while let Some(chunk) = rx.blocking_recv() {
let kvs = chunk
.into_iter()
.map(|(k, v)| (k.to_bytes(), v.into_bytes()));
store.put_many(kvs)?;
}
Ok::<_, DriveError>(store)
});
log::debug!("dumping the rest of the stream...");
loop {
let mut mem_size = 0;
let mut chunk = vec![];
loop {
let Some((cid, data)) = self.car.next_block().await? else {
break;
};
if cid == self.root {
let c: Commit = serde_ipld_dagcbor::from_slice(&data)?;
self.commit = Some(c);
continue;
}
let data = Bytes::from(data);
let maybe_processed = MaybeProcessedBlock::maybe(self.process, data);
mem_size += maybe_processed.len();
chunk.push((cid, maybe_processed));
if mem_size >= (self.max_size / 2) {
break;
}
}
if chunk.is_empty() {
break;
}
tx.send(chunk)
.await
.map_err(|_| DriveError::ChannelSendError)?;
}
drop(tx);
log::debug!("done. waiting for worker to finish...");
store = store_worker.await??;
log::debug!("worker finished.");
let commit = self.commit.ok_or(DriveError::MissingCommit)?;
let db_bytes = store
.get(&commit.data.to_bytes())
.map_err(|e| DriveError::StorageError(DiskError::DbError(e)))?
.ok_or(DriveError::MissingCommit)?;
let node: MstNode = match MaybeProcessedBlock::from_bytes(db_bytes.to_vec()) {
MaybeProcessedBlock::Processed(_) => Err(WalkError::BadCommitFingerprint)?,
MaybeProcessedBlock::Raw(bytes) => serde_ipld_dagcbor::from_slice(&bytes)?,
};
let walker = Walker::new(node);
Ok((
commit,
DiskDriver {
process: self.process,
state: Some(BigState { store, walker }),
},
))
}
}
struct BigState {
store: DiskStore,
walker: Walker,
}
pub struct DiskDriver {
process: fn(Bytes) -> Bytes,
state: Option<BigState>,
}
#[doc(hidden)]
pub fn _get_fake_disk_driver() -> DiskDriver {
DiskDriver {
process: noop,
state: None,
}
}
impl DiskDriver {
pub async fn next_chunk(&mut self, n: usize) -> Result<Option<BlockChunk>, DriveError> {
let process = self.process;
let mut state = self.state.take().expect("DiskDriver must have Some(state)");
let (state, res) =
tokio::task::spawn_blocking(move || -> (BigState, Result<BlockChunk, DriveError>) {
let mut out = Vec::with_capacity(n);
for _ in 0..n {
let step = match state.walker.disk_step(&mut state.store, process) {
Ok(s) => s,
Err(e) => {
return (state, Err(e.into()));
}
};
let Some(output) = step else {
break;
};
out.push(output);
}
(state, Ok::<_, DriveError>(out))
})
.await?;
self.state = Some(state);
let out = res?;
if out.is_empty() {
Ok(None)
} else {
Ok(Some(out))
}
}
fn read_tx_blocking(
&mut self,
n: usize,
tx: mpsc::Sender<Result<BlockChunk, DriveError>>,
) -> Result<(), mpsc::error::SendError<Result<BlockChunk, DriveError>>> {
let BigState { store, walker } = self.state.as_mut().expect("valid state");
loop {
let mut out: BlockChunk = Vec::with_capacity(n);
for _ in 0..n {
let step = match walker.disk_step(store, self.process) {
Ok(s) => s,
Err(e) => return tx.blocking_send(Err(e.into())),
};
let Some(output) = step else {
break;
};
out.push(output);
}
if out.is_empty() {
break;
}
tx.blocking_send(Ok(out))?;
}
Ok(())
}
pub fn to_channel(
mut self,
n: usize,
) -> (
mpsc::Receiver<Result<BlockChunk, DriveError>>,
tokio::task::JoinHandle<Self>,
) {
let (tx, rx) = mpsc::channel::<Result<BlockChunk, DriveError>>(1);
let chan_task = tokio::task::spawn_blocking(move || {
if let Err(mpsc::error::SendError(_)) = self.read_tx_blocking(n, tx) {
log::debug!("big car reader exited early due to dropped receiver channel");
}
self
});
(rx, chan_task)
}
pub async fn reset_store(mut self) -> Result<DiskStore, DriveError> {
let BigState { store, .. } = self.state.take().expect("valid state");
store.reset().await?;
Ok(store)
}
}