use std::collections::HashSet;
use std::io::Read;
use bytes::Bytes;
use mnem_core::id::Cid;
use mnem_core::store::{Blockstore, blockstore::recompute_cid};
use crate::car::{CarBlockReader, CarHeader, read_header, usize_to_u64};
use crate::error::TransportError;
pub const DEFAULT_MAX_IMPORT_BYTES: u64 = 4 * 1024 * 1024 * 1024;
#[derive(Debug, Clone, Default)]
pub struct ImportStats {
pub blocks: u64,
pub bytes: u64,
pub roots: Vec<Cid>,
}
pub fn import<R, B>(r: &mut R, bs: &B) -> Result<ImportStats, TransportError>
where
R: Read + ?Sized,
B: Blockstore + ?Sized,
{
import_with_limit(r, bs, DEFAULT_MAX_IMPORT_BYTES)
}
#[tracing::instrument(
name = "import_with_limit",
level = "info",
target = "mnem::transport::import",
skip(r, bs),
fields(
max_total_bytes,
block_count = tracing::field::Empty,
bytes = tracing::field::Empty,
)
)]
pub fn import_with_limit<R, B>(
r: &mut R,
bs: &B,
max_total_bytes: u64,
) -> Result<ImportStats, TransportError>
where
R: Read + ?Sized,
B: Blockstore + ?Sized,
{
let header: CarHeader = read_header(r)?;
let roots = header.roots;
let mut reader = CarBlockReader::new(r);
let mut blocks: u64 = 0;
let mut bytes: u64 = 0;
let mut imported_cids: HashSet<Cid> = HashSet::new();
while let Some((claimed_cid, data)) = reader.next_block()? {
let computed = recompute_cid(&claimed_cid, &data)
.ok_or_else(|| TransportError::UnsupportedHash(claimed_cid.multihash().code()))?;
if computed != claimed_cid {
return Err(TransportError::CidMismatch {
claimed: claimed_cid,
computed,
});
}
let payload_len = usize_to_u64(data.len());
let next_total = bytes
.checked_add(payload_len)
.ok_or(TransportError::SizeLimit {
limit: max_total_bytes,
observed: u64::MAX,
})?;
if next_total > max_total_bytes {
return Err(TransportError::SizeLimit {
limit: max_total_bytes,
observed: next_total,
});
}
bytes = next_total;
bs.put_trusted(claimed_cid.clone(), Bytes::from(data))?;
imported_cids.insert(claimed_cid);
blocks += 1;
}
for root in &roots {
if !imported_cids.contains(root) {
return Err(TransportError::MissingRoot { root: root.clone() });
}
}
let span = tracing::Span::current();
span.record("block_count", blocks);
span.record("bytes", bytes);
Ok(ImportStats {
blocks,
bytes,
roots,
})
}
#[cfg(test)]
mod tests {
use super::*;
use mnem_core::codec::hash_to_cid;
use mnem_core::store::MemoryBlockstore;
use serde::Serialize;
use crate::car::{write_block, write_header};
use crate::export::export;
#[derive(Serialize)]
struct Leaf {
tag: &'static str,
n: u32,
}
#[test]
fn import_rejects_tampered_cid() {
let src = MemoryBlockstore::new();
let (bytes, cid) = hash_to_cid(&Leaf { tag: "ok", n: 1 }).unwrap();
src.put(cid.clone(), bytes).unwrap();
let mut car = Vec::new();
export(&src, &cid, &mut car).unwrap();
let last = car.len() - 1;
car[last] ^= 0xff;
let dst = MemoryBlockstore::new();
let err = import(&mut &car[..], &dst).unwrap_err();
match err {
TransportError::CidMismatch { .. } => {}
other => panic!("expected CidMismatch, got {other:?}"),
}
}
#[test]
fn import_rejects_header_root_not_in_body() {
let (_real_bytes, real_cid) = hash_to_cid(&Leaf { tag: "real", n: 1 }).unwrap();
let (fake_bytes, fake_cid) = hash_to_cid(&Leaf { tag: "fake", n: 2 }).unwrap();
let mut car = Vec::new();
let header = CarHeader {
version: 1,
roots: vec![real_cid.clone()],
};
write_header(&mut car, &header).unwrap();
write_block(&mut car, &fake_cid, &fake_bytes).unwrap();
let dst = MemoryBlockstore::new();
let err = import(&mut &car[..], &dst).unwrap_err();
match err {
TransportError::MissingRoot { root } => assert_eq!(root, real_cid),
other => panic!("expected MissingRoot, got {other:?}"),
}
}
#[test]
fn import_enforces_total_bytes_cap() {
let src = MemoryBlockstore::new();
let (bytes, cid) = hash_to_cid(&Leaf { tag: "ok", n: 1 }).unwrap();
let payload_len = bytes.len();
src.put(cid.clone(), bytes).unwrap();
let mut car = Vec::new();
export(&src, &cid, &mut car).unwrap();
let dst = MemoryBlockstore::new();
let cap = u64::try_from(payload_len).unwrap() - 1;
let err = import_with_limit(&mut &car[..], &dst, cap).unwrap_err();
match err {
TransportError::SizeLimit { limit, observed } => {
assert_eq!(limit, cap);
assert!(observed > limit);
}
other => panic!("expected SizeLimit, got {other:?}"),
}
}
}