use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use core::net::SocketAddr;
use futures::stream;
use futures_lite::{Stream, StreamExt};
use iroh::{
blobs::{util::SetTagOption, Hash},
client::{blobs::ReadAtLen, Iroh as Client},
};
use std::sync::Arc;
use std::{path::Path, str::FromStr};
use uuid::Uuid;
use crate::storage::{
self, ByteStream, ChunkId, ChunkIdMapper, ChunkStream, Error as StorageError, Storage,
};
const CHUNK_SIZE: u64 = 1024;
trait ClientProvider: Send + Sync {
fn client(&self) -> &Client;
}
pub struct IrohStorage {
client_provider: Arc<dyn ClientProvider>,
}
impl Clone for IrohStorage {
fn clone(&self) -> Self {
IrohStorage {
client_provider: Arc::clone(&self.client_provider),
}
}
}
#[derive(Clone)]
struct ClientHolder {
client: Client,
}
impl ClientProvider for ClientHolder {
fn client(&self) -> &Client {
&self.client
}
}
struct NodeHolder<S> {
node: iroh::node::Node<S>,
}
impl<S: iroh::blobs::store::Store> ClientProvider for NodeHolder<S> {
fn client(&self) -> &Client {
self.node.client()
}
}
impl<S: Clone> Clone for NodeHolder<S> {
fn clone(&self) -> Self {
NodeHolder {
node: self.node.clone(),
}
}
}
impl IrohStorage {
pub async fn from_path(root: impl AsRef<Path>) -> Result<Self> {
let client = Client::connect_path(root).await?;
Ok(Self {
client_provider: Arc::new(ClientHolder { client }),
})
}
pub async fn from_addr(addr: SocketAddr) -> Result<Self> {
let client = Client::connect_addr(addr).await?;
Ok(Self {
client_provider: Arc::new(ClientHolder { client }),
})
}
pub fn from_client(client: Client) -> Self {
Self {
client_provider: Arc::new(ClientHolder { client }),
}
}
pub fn from_node<S: iroh::blobs::store::Store + 'static>(node: iroh::node::Node<S>) -> Self {
Self {
client_provider: Arc::new(NodeHolder { node }),
}
}
pub async fn new_in_memory() -> Result<Self> {
let node = iroh::node::Node::memory().spawn().await?;
Ok(Self::from_node(node))
}
pub async fn new_permanent(root: impl AsRef<Path>) -> Result<Self> {
let node = iroh::node::Node::persistent(root).await?.spawn().await?;
Ok(Self::from_client(node.client().clone()))
}
fn client(&self) -> &Client {
self.client_provider.client()
}
}
fn parse_hash(hash: &str) -> Result<Hash, StorageError> {
Hash::from_str(hash).map_err(|e| StorageError::InvalidHash(hash.to_string(), e.to_string()))
}
impl ChunkId for u64 {}
#[derive(Clone)]
pub struct IrohChunkIdMapper {
hash: String,
num_chunks: u64,
}
impl ChunkIdMapper<u64> for IrohChunkIdMapper {
fn index_to_id(&self, index: u64) -> Result<u64, StorageError> {
if index >= self.num_chunks {
return Err(StorageError::ChunkNotFound(
index.to_string(),
self.hash.clone(),
storage::wrap_error(anyhow::anyhow!("Chunk index out of bounds")),
));
}
Ok(index)
}
fn id_to_index(&self, chunk_id: &u64) -> Result<u64, StorageError> {
if *chunk_id >= self.num_chunks {
return Err(StorageError::ChunkNotFound(
chunk_id.to_string(),
self.hash.clone(),
storage::wrap_error(anyhow::anyhow!("Chunk id out of bounds")),
));
}
Ok(*chunk_id)
}
}
#[async_trait]
impl Storage for IrohStorage {
type ChunkId = u64;
type ChunkIdMapper = IrohChunkIdMapper;
async fn upload_bytes(
&self,
bytes: impl Into<Bytes> + Send,
) -> Result<storage::UploadResult, StorageError> {
let bytes = bytes.into();
let size = bytes.len();
let stream = chunked_bytes_stream(bytes, 1024 * 64).map(Ok);
let tag = format!("ent-{}", Uuid::new_v4());
let progress = self
.client()
.blobs()
.add_stream(
stream,
SetTagOption::Named(iroh::blobs::Tag::from(tag.clone())),
)
.await
.map_err(|e| StorageError::StorageError(storage::wrap_error(e)))?;
let blob = progress
.finish()
.await
.map_err(|e| StorageError::StorageError(storage::wrap_error(e)))?;
let mut info = std::collections::HashMap::new();
info.insert("tag".to_string(), tag);
Ok(storage::UploadResult {
hash: blob.hash.to_string(),
info,
size: size as u64,
})
}
async fn download_bytes(&self, hash: &str) -> Result<ByteStream, StorageError> {
let hash = parse_hash(hash)?;
let reader = self
.client()
.blobs()
.read(hash)
.await
.map_err(|e| StorageError::StorageError(storage::wrap_error(e)))?;
let stream = reader
.map(|res| res.map_err(|e| StorageError::StorageError(storage::wrap_error(e.into()))));
Ok(Box::pin(stream))
}
async fn iter_chunks(&self, hash: &str) -> Result<ChunkStream<Self::ChunkId>, StorageError> {
let hash = parse_hash(hash)?;
let reader = self.client().blobs().read(hash).await.map_err(|e| {
let err_str = e.to_string();
if err_str.contains("not found") {
StorageError::BlobNotFound(hash.to_string())
} else {
StorageError::StorageError(storage::wrap_error(e))
}
})?;
let total_size = reader.size();
let stream = stream::unfold(
(self.client().blobs().clone(), 0u64),
move |(client, offset)| async move {
if offset >= total_size {
return None;
}
let remaining = total_size - offset;
let len = std::cmp::min(CHUNK_SIZE, remaining);
let chunk_id = offset / CHUNK_SIZE;
Some(
match client
.read_at_to_bytes(hash, offset, ReadAtLen::Exact(len))
.await
{
Ok(chunk) => {
let new_offset = offset + len as u64;
((chunk_id, Ok(chunk)), (client, new_offset))
}
Err(e) => (
(
chunk_id,
Err(StorageError::StorageError(storage::wrap_error(e))),
),
(client, offset + len as u64),
),
},
)
},
);
Ok(Box::pin(stream))
}
async fn download_chunk(&self, hash: &str, chunk_id: u64) -> Result<Bytes, StorageError> {
let hash = parse_hash(hash)?;
let offset = chunk_id * CHUNK_SIZE;
self.client()
.blobs()
.read_at_to_bytes(hash, offset, ReadAtLen::AtMost(CHUNK_SIZE))
.await
.map_err(|e| {
StorageError::ChunkNotFound(
chunk_id.to_string(),
hash.to_string(),
storage::wrap_error(e),
)
})
}
async fn chunk_id_mapper(&self, hash: &str) -> Result<IrohChunkIdMapper, StorageError> {
let hash = parse_hash(hash).map_err(|_| StorageError::BlobNotFound(hash.to_string()))?;
let reader = self
.client()
.blobs()
.read(hash)
.await
.map_err(|_| StorageError::BlobNotFound(hash.to_string()))?;
Ok(IrohChunkIdMapper {
hash: hash.to_string(),
num_chunks: reader.size().div_ceil(CHUNK_SIZE),
})
}
}
fn chunked_bytes_stream(mut b: Bytes, c: usize) -> impl Stream<Item = Bytes> {
futures_lite::stream::iter(std::iter::from_fn(move || {
Some(b.split_to(b.len().min(c))).filter(|x| !x.is_empty())
}))
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use futures::StreamExt;
use tokio;
async fn collect_chunks(storage: &IrohStorage, hash: &str) -> Result<Vec<Bytes>> {
let stream = storage.iter_chunks(hash).await?;
let results: Vec<Result<Bytes, StorageError>> = stream.map(|res| res.1).collect().await;
let bytes_vec = results.into_iter().collect::<Result<Vec<_>, _>>()?;
Ok(bytes_vec)
}
#[tokio::test]
async fn test_iter_chunks_small_blob() -> Result<()> {
let storage = IrohStorage::new_in_memory().await?;
let data = Bytes::from("Hello, World!");
let upload_result = storage.upload_bytes(data.clone()).await?;
let hash = upload_result.hash;
let chunks = collect_chunks(&storage, &hash).await?;
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0], data);
Ok(())
}
#[tokio::test]
async fn test_iter_chunks_large_blob() -> Result<()> {
let storage = IrohStorage::new_in_memory().await?;
let data = Bytes::from(vec![0u8; 3000]); let upload_result = storage.upload_bytes(data.clone()).await?;
let hash = upload_result.hash;
let chunks = collect_chunks(&storage, &hash).await?;
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0].len(), 1024);
assert_eq!(chunks[1].len(), 1024);
assert_eq!(chunks[2].len(), 952);
assert_eq!(Bytes::from(chunks.concat()), data);
Ok(())
}
#[tokio::test]
async fn test_iter_chunks_empty_blob() -> Result<()> {
let storage = IrohStorage::new_in_memory().await?;
let data = Bytes::new();
let upload_result = storage.upload_bytes(data).await?;
let hash = upload_result.hash;
let chunks = collect_chunks(&storage, &hash).await?;
assert_eq!(chunks.len(), 0);
Ok(())
}
#[tokio::test]
async fn test_iter_chunks_exact_multiple() -> Result<()> {
let storage = IrohStorage::new_in_memory().await?;
let data = Bytes::from(vec![0u8; 2048]);
let upload_result = storage.upload_bytes(data.clone()).await?;
let hash = upload_result.hash;
let chunks = collect_chunks(&storage, &hash).await?;
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0].len(), 1024);
assert_eq!(chunks[1].len(), 1024);
assert_eq!(Bytes::from(chunks.concat()), data);
Ok(())
}
#[tokio::test]
async fn test_iter_chunks_invalid_hash() -> Result<()> {
let storage = IrohStorage::new_in_memory().await?;
let result = storage.iter_chunks("invalid_hash").await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn test_download_chunk_small_blob() -> Result<()> {
let storage = IrohStorage::new_in_memory().await?;
let data = Bytes::from("Hello, World!");
let upload_result = storage.upload_bytes(data.clone()).await?;
let hash = upload_result.hash;
let chunk = storage.download_chunk(&hash, 0).await?;
assert_eq!(chunk, data);
Ok(())
}
#[tokio::test]
async fn test_download_chunk_large_blob() -> Result<()> {
let storage = IrohStorage::new_in_memory().await?;
let data = Bytes::from(vec![0u8; 3000]); let upload_result = storage.upload_bytes(data.clone()).await?;
let hash = upload_result.hash;
let chunk0 = storage.download_chunk(&hash, 0).await?;
let chunk1 = storage.download_chunk(&hash, 1).await?;
let chunk2 = storage.download_chunk(&hash, 2).await?;
assert_eq!(chunk0.len(), 1024);
assert_eq!(chunk1.len(), 1024);
assert_eq!(chunk2.len(), 952);
assert_eq!(Bytes::from([chunk0, chunk1, chunk2].concat()), data);
Ok(())
}
#[tokio::test]
async fn test_download_chunk_exact_multiple() -> Result<()> {
let storage = IrohStorage::new_in_memory().await?;
let data = Bytes::from(vec![0u8; 2048]);
let upload_result = storage.upload_bytes(data.clone()).await?;
let hash = upload_result.hash;
let chunk0 = storage.download_chunk(&hash, 0).await?;
let chunk1 = storage.download_chunk(&hash, 1).await?;
assert_eq!(chunk0.len(), 1024);
assert_eq!(chunk1.len(), 1024);
assert_eq!(Bytes::from([chunk0, chunk1].concat()), data);
Ok(())
}
#[tokio::test]
async fn test_download_chunk_invalid_hash() -> Result<()> {
let storage = IrohStorage::new_in_memory().await?;
let result = storage.download_chunk("invalid_hash", 0).await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn test_download_chunk_out_of_bounds() -> Result<()> {
let storage = IrohStorage::new_in_memory().await?;
let data = Bytes::from("Hello, World!");
let upload_result = storage.upload_bytes(data).await?;
let hash = upload_result.hash;
let result = storage.download_chunk(&hash, 1).await;
assert!(result.is_err());
assert!(matches!(
result.err().unwrap(),
StorageError::ChunkNotFound(c, h, _) if h == hash && c == "1"
));
Ok(())
}
#[tokio::test]
async fn test_chunk_id_mapper() -> Result<()> {
let storage = IrohStorage::new_in_memory().await?;
let data = vec![0u8; 3000]; let upload_result = storage.upload_bytes(data).await?;
let hash = upload_result.hash;
let mapper = storage.chunk_id_mapper(&hash).await?;
assert_eq!(mapper.index_to_id(0)?, 0);
assert_eq!(mapper.index_to_id(1)?, 1);
assert_eq!(mapper.index_to_id(2)?, 2);
assert!(
matches!(mapper.id_to_index(&3), Err(StorageError::ChunkNotFound(c_id, h, _)) if c_id == "3" && h == hash),
"Expected error because chunk id is out of bounds"
);
let res = storage.chunk_id_mapper("invalid").await;
assert!(res.is_err(), "Expected error because hash is invalid");
assert!(
matches!(res.err().unwrap(), StorageError::BlobNotFound(h) if h == "invalid"),
"Expected error because hash is invalid"
);
let last_char = (hash.chars().last().unwrap() as u8 + 1) as char;
let non_existing_hash = hash
.chars()
.take(hash.len() - 1)
.chain(std::iter::once(last_char))
.collect::<String>();
let res = storage.chunk_id_mapper(&non_existing_hash).await;
assert!(res.is_err(), "Expected error because hash does not exist");
assert!(
matches!(res.err().unwrap(), StorageError::BlobNotFound(h) if h == non_existing_hash),
"Expected error because hash does not exist"
);
Ok(())
}
#[tokio::test]
async fn test_upload_bytes_metadata() -> Result<()> {
let storage = IrohStorage::new_in_memory().await?;
let data = Bytes::from("Hello, World!");
let upload_result = storage.upload_bytes(data.clone()).await?;
assert!(!upload_result.hash.is_empty(), "Hash should not be empty");
assert_eq!(
upload_result.size,
data.len() as u64,
"Size should match data length"
);
assert!(
upload_result.info.contains_key("tag"),
"Should contain tag info"
);
assert!(
upload_result
.info
.get("tag")
.is_some_and(|tag| tag.starts_with("ent-") && tag.len() == 40), "Tag should be in the format \"ent-<uuid>\""
);
Ok(())
}
}