use std::collections::VecDeque;
use std::fmt::Debug;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Poll, ready};
use crate::encryption::{Chacha20Cipher, EncryptionKey, encrypt_recovered_shards};
use crate::erasure_coding::{self, ErasureCoder};
use crate::hosts::{Hosts, RPCError};
use crate::rhp4::{Client, Transport};
use crate::time::{Duration, Elapsed, Instant, sleep};
use crate::{AppKey, DownloadOptions, Object, Sector, ShardProgress, ShardProgressCallback, Slab};
use bytes::{Buf, Bytes};
use chacha20::cipher::StreamCipher;
use sia_core::rhp4::SEGMENT_SIZE;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::task::JoinSet;
use tokio_util::task::AbortOnDropHandle;
#[derive(Debug, Error)]
pub enum DownloadError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("encoder error: {0}")]
Encoder(#[from] erasure_coding::Error),
#[error("not enough shards: {0}/{1}")]
NotEnoughShards(u8, u8),
#[error("invalid range: {0}-{1}")]
OutOfRange(usize, usize),
#[error("timeout error: {0}")]
Timeout(#[from] Elapsed),
#[error("semaphore error: {0}")]
SemaphoreError(#[from] tokio::sync::AcquireError),
#[error("join error: {0}")]
JoinError(#[from] tokio::task::JoinError),
#[error("invalid slab: {0}")]
InvalidSlab(String),
#[error("rhp4 error: {0}")]
RPC(#[from] RPCError),
#[error("custom error: {0}")]
Custom(String),
#[error("download errored")]
Errored,
}
struct SectorTask {
sector: Sector,
shard_index: usize,
}
struct AwaitingRecovery {
sectors: Vec<SectorTask>,
}
struct ShardsRecovered {
shard_offset: usize,
shards: Vec<Option<Vec<u8>>>,
}
struct SlabDecoded {
data_shards: Vec<Bytes>,
}
struct SlabRecovery<State, T: Transport> {
client: Hosts<T>,
account_key: Arc<AppKey>,
slab_index: usize,
min_shards: u8,
encryption_key: EncryptionKey,
offset: usize,
length: usize,
state: State,
}
impl<T: Transport> SlabRecovery<AwaitingRecovery, T> {
fn new(
client: Hosts<T>,
account_key: Arc<AppKey>,
slab: ChunkSlab,
) -> Result<Self, DownloadError> {
if slab.slab.min_shards == 0 {
return Err(DownloadError::InvalidSlab(
"min_shards cannot be 0".to_string(),
));
} else if slab.slab.min_shards as usize > slab.slab.sectors.len() {
return Err(DownloadError::InvalidSlab(format!(
"min_shards {} cannot be greater than number of sectors {}",
slab.slab.min_shards,
slab.slab.sectors.len()
)));
}
let mut sectors = slab
.slab
.sectors
.iter()
.enumerate()
.map(|(i, sector)| SectorTask {
sector: sector.clone(),
shard_index: i,
})
.collect::<Vec<_>>();
client.prioritize(&mut sectors, |task| &task.sector.host_key);
Ok(Self {
client,
account_key,
slab_index: slab.index,
min_shards: slab.slab.min_shards,
encryption_key: slab.slab.encryption_key,
offset: slab.slab.offset as usize,
length: slab.slab.length as usize,
state: AwaitingRecovery { sectors },
})
}
async fn recover_shard(
client: Hosts<T>,
account_key: Arc<AppKey>,
task: SectorTask,
slab_index: usize,
sector_offset: usize,
sector_length: usize,
) -> Result<(usize, Vec<u8>, ShardProgress), DownloadError> {
let start = Instant::now();
let data = client
.read_sector(
task.sector.host_key,
&account_key.0,
task.sector.root,
sector_offset,
sector_length,
Duration::from_secs(60),
)
.await?;
let elapsed = start.elapsed();
Ok((
task.shard_index,
Vec::from(data),
ShardProgress {
host_key: task.sector.host_key,
shard_size: sector_length,
shard_index: task.shard_index,
slab_index,
elapsed,
},
)) }
async fn recover_shards(
self,
shard_downloaded: Option<ShardProgressCallback>,
) -> Result<SlabRecovery<ShardsRecovered, T>, DownloadError> {
let mut shard_tasks = JoinSet::new();
let mut shards = vec![None; self.state.sectors.len()];
let mut sectors = VecDeque::from(self.state.sectors);
let min_shards = self.min_shards;
let client = self.client;
let account_key = self.account_key;
let encryption_key = self.encryption_key;
let chunk_size = SEGMENT_SIZE * self.min_shards as usize;
let start = (self.offset / chunk_size) * SEGMENT_SIZE;
let end = (self.offset + self.length).div_ceil(chunk_size) * SEGMENT_SIZE;
let shard_offset = start;
let shard_length = end - start;
for i in 0..self.min_shards {
let task = sectors
.pop_front()
.ok_or(DownloadError::NotEnoughShards(i, self.min_shards))?;
join_set_spawn!(
&mut shard_tasks,
Self::recover_shard(
client.clone(),
account_key.clone(),
task,
self.slab_index,
shard_offset,
shard_length,
)
);
}
let mut recovered_shards: u8 = 0;
loop {
tokio::select! {
Some(res) = shard_tasks.join_next() => {
match res? {
Ok((index, data, progress)) => {
shards[index] = Some(data);
recovered_shards += 1;
if recovered_shards <= min_shards && let Some(callback) = &shard_downloaded {
callback(progress);
}
if recovered_shards >= min_shards {
return Ok(SlabRecovery {
client,
account_key,
min_shards,
slab_index: self.slab_index,
encryption_key,
offset: self.offset,
length: self.length,
state: ShardsRecovered {
shard_offset,
shards,
},
});
}
},
Err(_) => {
if recovered_shards as usize + shard_tasks.len() + sectors.len() < min_shards as usize {
return Err(DownloadError::NotEnoughShards(recovered_shards, min_shards));
} else if let Some(task) = sectors.pop_front() {
join_set_spawn!(&mut shard_tasks, Self::recover_shard(client.clone(), account_key.clone(), task, self.slab_index, shard_offset, shard_length));
}
}
}
},
_ = sleep(Duration::from_millis(500)), if !sectors.is_empty() => {
let task = sectors.pop_front().expect("sectors should not be empty");
join_set_spawn!(&mut shard_tasks, Self::recover_shard(client.clone(), account_key.clone(), task, self.slab_index, shard_offset, shard_length));
},
}
}
}
}
impl<T: Transport> SlabRecovery<ShardsRecovered, T> {
fn decode(self) -> Result<SlabRecovery<SlabDecoded, T>, DownloadError> {
let parity_shards = self.state.shards.len() - self.min_shards as usize;
let rs = ErasureCoder::new(self.min_shards as usize, parity_shards)?;
let mut shards = self.state.shards;
encrypt_recovered_shards(
&self.encryption_key,
0,
self.state.shard_offset,
&mut shards,
);
rs.reconstruct_data_shards(&mut shards)?;
let data_shards = shards
.into_iter()
.take(self.min_shards as usize)
.map(|s| Bytes::from(s.unwrap())) .collect();
Ok(SlabRecovery {
client: self.client,
account_key: self.account_key,
min_shards: self.min_shards,
slab_index: self.slab_index,
encryption_key: self.encryption_key,
offset: self.offset,
length: self.length,
state: SlabDecoded { data_shards },
})
}
}
impl<T: Transport> SlabRecovery<SlabDecoded, T> {
async fn write<W: AsyncWrite + Unpin>(self, w: &mut W) -> Result<(), DownloadError> {
let skip = self.offset % (SEGMENT_SIZE * self.state.data_shards.len());
ErasureCoder::write_data_shards(w, &self.state.data_shards, skip, self.length).await?;
Ok(())
}
}
struct ChunkSlab {
slab: Slab,
index: usize,
}
struct ChunkIter<const N: usize> {
slabs: Vec<Slab>,
slab_idx: usize,
offset: u64,
remaining: u64,
}
impl<const N: usize> ChunkIter<N> {
fn new(slabs: Vec<Slab>, offset: u64, length: u64) -> Self {
let mut slab_idx = 0;
let mut offset = offset;
while slab_idx < slabs.len() {
let slab_length = slabs[slab_idx].length as u64;
if offset < slab_length {
break;
}
offset -= slab_length;
slab_idx += 1;
}
Self {
slabs,
slab_idx,
offset,
remaining: length,
}
}
}
impl<const N: usize> Iterator for ChunkIter<N> {
type Item = ChunkSlab;
fn next(&mut self) -> Option<ChunkSlab> {
if self.remaining == 0 {
return None;
}
let slab_index = self.slab_idx;
let slab = &self.slabs[slab_index];
let slab_offset = slab.offset as u64 + self.offset;
let slab_length = (slab.length as u64 - self.offset)
.min(self.remaining)
.min(N as u64);
self.offset += slab_length;
if self.offset >= slab.length as u64 {
self.offset = 0;
self.slab_idx += 1;
}
self.remaining -= slab_length;
let mut chunk = slab.clone();
chunk.offset = slab_offset as u32;
chunk.length = slab_length as u32;
Some(ChunkSlab {
slab: chunk,
index: slab_index,
})
}
}
const CHUNK_SIZE: usize = 1 << 18;
pub struct Download {
hosts: Hosts<Client>,
account_key: Arc<AppKey>,
cipher: Chacha20Cipher,
buf: Bytes,
queue: VecDeque<AbortOnDropHandle<Result<Vec<u8>, DownloadError>>>,
chunk_iter: ChunkIter<CHUNK_SIZE>,
errored: bool,
shard_downloaded: Option<ShardProgressCallback>,
}
impl AsyncRead for Download {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
if self.errored {
return Poll::Ready(Err(std::io::Error::other(DownloadError::Errored)));
}
if !self.buf.is_empty() {
self.drain_buf(buf);
return Poll::Ready(Ok(()));
}
if let Some(chunk_handle) = self.queue.front_mut() {
let mut result = match ready!(Pin::new(chunk_handle).poll(cx)) {
Ok(Ok(data)) => data,
Ok(Err(e)) => {
self.set_err();
return Poll::Ready(Err(std::io::Error::other(e)));
}
Err(e) => {
self.set_err();
return Poll::Ready(Err(std::io::Error::other(e)));
}
};
self.queue.pop_front();
self.spawn_next();
self.cipher.apply_keystream(&mut result); self.buf = Bytes::from(result);
self.drain_buf(buf);
}
Poll::Ready(Ok(()))
}
}
impl Download {
fn drain_buf(&mut self, buf: &mut tokio::io::ReadBuf<'_>) {
let to_copy = std::cmp::min(buf.remaining(), self.buf.len());
buf.put_slice(&self.buf[..to_copy]);
self.buf.advance(to_copy);
}
fn set_err(&mut self) {
self.errored = true;
self.buf = Bytes::new();
self.queue.clear();
}
fn spawn_next(&mut self) {
if let Some(chunk_slab) = self.chunk_iter.next() {
let hosts = self.hosts.clone();
let account_key = self.account_key.clone();
let shard_progress_callback = self.shard_downloaded.clone();
self.queue
.push_back(AbortOnDropHandle::new(maybe_spawn!(async move {
let len = chunk_slab.slab.length as usize;
let mut buf = Vec::with_capacity(len);
SlabRecovery::new(hosts, account_key, chunk_slab)?
.recover_shards(shard_progress_callback)
.await?
.decode()?
.write(&mut buf)
.await?;
Ok(buf)
})));
}
}
#[doc(hidden)]
pub async fn read_chunk(&mut self) -> Result<Vec<u8>, DownloadError> {
if self.errored {
return Err(DownloadError::Errored);
}
if !self.buf.is_empty() {
return Ok(std::mem::take(&mut self.buf).to_vec());
}
let Some(chunk_handle) = self.queue.pop_front() else {
return Ok(Vec::new()); };
let mut result = match chunk_handle.await {
Ok(Ok(data)) => data,
Ok(Err(e)) => {
self.set_err();
return Err(e);
}
Err(e) => {
self.set_err();
return Err(e.into());
}
};
self.spawn_next();
self.cipher.apply_keystream(&mut result); Ok(result)
}
pub(crate) fn new(
object: &Object,
hosts: Hosts<Client>,
account_key: Arc<AppKey>,
options: DownloadOptions,
) -> Result<Self, DownloadError> {
if options.max_inflight == 0 {
return Err(DownloadError::Custom(
"max_inflight must be greater than 0".to_string(),
));
}
let object_size = object.size();
let cipher = object.cipher(options.offset);
let available = object_size.saturating_sub(options.offset);
let remaining = options.length.unwrap_or(available).min(available);
let slabs = object.slabs().to_vec();
let chunk_iter = ChunkIter::new(slabs, options.offset, remaining);
let mut download = Self {
hosts,
account_key,
cipher,
buf: Bytes::new(),
queue: VecDeque::with_capacity(options.max_inflight),
chunk_iter,
errored: false,
shard_downloaded: options.shard_downloaded,
};
for _ in 0..options.max_inflight {
download.spawn_next();
}
Ok(download)
}
}
#[cfg(test)]
mod test {
use super::*;
use std::io::Cursor;
use std::sync::Arc;
use bytes::BytesMut;
use rand::Rng;
use sia_core::rhp4::SECTOR_SIZE;
use sia_core::signing::PrivateKey;
use sia_core::types::v2::NetAddress;
use crate::hosts::Hosts;
use crate::upload::{upload_object, upload_slabs};
use crate::{Host, ShardProgress, UploadOptions};
#[cfg(target_arch = "wasm32")]
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
#[sia_core_derive::cross_target_test]
async fn test_out_of_order_download() {
let upload_options = UploadOptions::default();
let optimal_data_size = upload_options.data_shards as usize * SECTOR_SIZE;
let transport = Client::new();
let hosts = Hosts::new(transport.clone());
hosts.update(
(0..60)
.map(|_| Host {
public_key: PrivateKey::from_seed(&rand::random()).public_key(),
addresses: vec![NetAddress {
protocol: sia_core::types::v2::Protocol::QUIC,
address: "localhost:1234".to_string(),
}],
country_code: "US".to_string(),
latitude: 0.0,
longitude: 0.0,
good_for_upload: true,
})
.collect(),
true,
);
let mut data = BytesMut::zeroed(optimal_data_size);
rand::rng().fill_bytes(&mut data);
let data = data.freeze();
let app_key = Arc::new(AppKey::import(rand::random()));
transport.set_initial_read_delay(Duration::from_millis(500));
let obj = upload_object(
hosts.clone(),
app_key.clone(),
Object::default(),
Cursor::new(data.clone()),
UploadOptions::default(),
)
.await
.unwrap();
let mut recovered_data = Vec::with_capacity(optimal_data_size);
let mut download = Download::new(
&obj,
hosts.clone(),
app_key.clone(),
DownloadOptions::default(),
)
.unwrap();
tokio::io::copy(&mut download, &mut recovered_data)
.await
.unwrap();
assert_eq!(data, recovered_data);
}
#[sia_core_derive::cross_target_test]
async fn test_slab_recovery() {
let upload_options = UploadOptions::default();
let optimal_data_size = upload_options.data_shards as usize * SECTOR_SIZE;
let transport = Client::new();
let hosts = Hosts::new(transport.clone());
hosts.update(
(0..60)
.map(|_| Host {
public_key: PrivateKey::from_seed(&rand::random()).public_key(),
addresses: vec![NetAddress {
protocol: sia_core::types::v2::Protocol::QUIC,
address: "localhost:1234".to_string(),
}],
country_code: "US".to_string(),
latitude: 0.0,
longitude: 0.0,
good_for_upload: true,
})
.collect(),
true,
);
let mut data = BytesMut::zeroed(optimal_data_size);
rand::rng().fill_bytes(&mut data);
let data = data.freeze();
let app_key = Arc::new(AppKey::import(rand::random()));
let slabs = upload_slabs(
hosts.clone(),
app_key.clone(),
Cursor::new(data.clone()),
upload_options,
)
.await
.unwrap();
let test_cases: Vec<(&str, usize, usize)> = vec![
("full slab", 0, optimal_data_size),
("first half", 0, optimal_data_size / 2),
("second half", optimal_data_size / 2, optimal_data_size / 2),
("first 30 bytes", 0, 30),
("middle 30 bytes", optimal_data_size / 2 - 15, 30),
("last 30 bytes", optimal_data_size - 30, 30),
("first 4KiB", 0, 4096),
("middle 4KiB", optimal_data_size / 2 - 2048, 4096),
("last 4KiB", optimal_data_size - 4096, 4096),
];
for (name, offset, length) in test_cases {
let mut slab = slabs[0].clone();
slab.offset = offset as u32;
slab.length = length as u32;
let mut recovered_data = Vec::with_capacity(length);
SlabRecovery::new(hosts.clone(), app_key.clone(), ChunkSlab { slab, index: 0 })
.unwrap()
.recover_shards(None)
.await
.unwrap()
.decode()
.unwrap()
.write(&mut recovered_data)
.await
.unwrap();
assert_eq!(
&data[offset..offset + length],
&recovered_data[..],
"mismatch for case: {name}"
);
}
}
#[sia_core_derive::cross_target_test]
async fn test_slab_recovery_progress_callback() {
let upload_options = UploadOptions::default();
let min_shards = upload_options.data_shards as usize;
let total_shards = min_shards + upload_options.parity_shards as usize;
let optimal_data_size = upload_options.optimal_data_size();
let num_slabs = 3;
let transport = Client::new();
let hosts = Hosts::new(transport.clone());
hosts.update(
(0..60)
.map(|_| Host {
public_key: PrivateKey::from_seed(&rand::random()).public_key(),
addresses: vec![NetAddress {
protocol: sia_core::types::v2::Protocol::QUIC,
address: "localhost:1234".to_string(),
}],
country_code: "US".to_string(),
latitude: 0.0,
longitude: 0.0,
good_for_upload: true,
})
.collect(),
true,
);
let data_size = optimal_data_size * num_slabs;
let mut data = BytesMut::zeroed(data_size);
rand::rng().fill_bytes(&mut data);
let data = data.freeze();
let app_key = Arc::new(AppKey::import(rand::random()));
let obj = upload_object(
hosts.clone(),
app_key.clone(),
Object::default(),
Cursor::new(data.clone()),
upload_options,
)
.await
.unwrap();
assert_eq!(obj.slabs().len(), num_slabs);
let progress: Arc<std::sync::Mutex<Vec<ShardProgress>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
let progress_clone = progress.clone();
let opts = DownloadOptions::default().on_shard_downloaded(move |p: ShardProgress| {
progress_clone.lock().unwrap().push(p);
});
let mut recovered_data = Vec::with_capacity(data_size);
let mut download = Download::new(&obj, hosts.clone(), app_key.clone(), opts).unwrap();
tokio::io::copy(&mut download, &mut recovered_data)
.await
.unwrap();
assert_eq!(data, recovered_data);
let events = progress.lock().unwrap();
let chunks_per_slab = optimal_data_size.div_ceil(CHUNK_SIZE);
let expected_total = chunks_per_slab * min_shards * num_slabs;
assert_eq!(
events.len(),
expected_total,
"expected {expected_total} progress callbacks ({chunks_per_slab} chunks × {min_shards} shards × {num_slabs} slabs), got {}",
events.len()
);
let mut per_slab: std::collections::HashMap<usize, usize> =
std::collections::HashMap::new();
for event in events.iter() {
assert!(
event.shard_size > 0 && event.shard_size <= SECTOR_SIZE,
"shard_size {} out of range",
event.shard_size
);
assert!(
event.shard_index < total_shards,
"shard_index {} out of range for total_shards {}",
event.shard_index,
total_shards
);
*per_slab.entry(event.slab_index).or_default() += 1;
}
for slab_idx in 0..num_slabs {
assert!(
per_slab.contains_key(&slab_idx),
"slab {slab_idx} had no progress callbacks"
);
}
}
}