use std::fmt::Debug;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result};
use bao_tree::io::fsm::{encode_ranges_validated, Outboard};
use futures::future::BoxFuture;
use iroh_io::stats::{
SliceReaderStats, StreamWriterStats, TrackingSliceReader, TrackingStreamWriter,
};
use iroh_io::{AsyncStreamWriter, TokioStreamWriter};
use serde::{Deserialize, Serialize};
use tracing::{debug, debug_span, info, trace, warn};
use tracing_futures::Instrument;
use crate::baomap::*;
use crate::hashseq::parse_hash_seq;
use crate::protocol::{GetRequest, RangeSpec, Request, RequestToken};
use crate::util::{BlobFormat, RpcError, Tag};
use crate::Hash;
#[derive(Debug, Clone)]
pub enum Event {
TaggedBlobAdded {
hash: Hash,
format: BlobFormat,
tag: Tag,
},
ClientConnected {
connection_id: u64,
},
GetRequestReceived {
connection_id: u64,
request_id: u64,
token: Option<RequestToken>,
hash: Hash,
},
CustomGetRequestReceived {
connection_id: u64,
request_id: u64,
token: Option<RequestToken>,
len: usize,
},
TransferHashSeqStarted {
connection_id: u64,
request_id: u64,
num_blobs: u64,
},
TransferBlobCompleted {
connection_id: u64,
request_id: u64,
hash: Hash,
index: u64,
size: u64,
},
TransferCompleted {
connection_id: u64,
request_id: u64,
stats: Box<TransferStats>,
},
TransferAborted {
connection_id: u64,
request_id: u64,
stats: Option<Box<TransferStats>>,
},
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TransferStats {
pub send: StreamWriterStats,
pub read: SliceReaderStats,
pub duration: Duration,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum AddProgress {
Found {
id: u64,
name: String,
size: u64,
},
Progress {
id: u64,
offset: u64,
},
Done {
id: u64,
hash: Hash,
},
AllDone {
hash: Hash,
format: BlobFormat,
tag: Tag,
},
Abort(RpcError),
}
#[derive(Debug, Serialize, Deserialize)]
pub enum GetProgress {
Connected,
Found {
id: u64,
hash: Hash,
size: u64,
},
FoundCollection {
hash: Hash,
num_blobs: Option<u64>,
total_blobs_size: Option<u64>,
},
Progress {
id: u64,
offset: u64,
},
Done {
id: u64,
},
NetworkDone {
bytes_written: u64,
bytes_read: u64,
elapsed: Duration,
},
Export {
id: u64,
hash: Hash,
size: u64,
target: String,
},
ExportProgress {
id: u64,
offset: u64,
},
Abort(RpcError),
AllDone,
}
pub trait RequestAuthorizationHandler: Send + Sync + Debug + 'static {
fn authorize(
&self,
token: Option<RequestToken>,
request: &Request,
) -> BoxFuture<'static, anyhow::Result<()>>;
}
pub async fn read_request(mut reader: quinn::RecvStream) -> Result<Request> {
let payload = reader
.read_to_end(crate::protocol::MAX_MESSAGE_SIZE)
.await?;
let request: Request = postcard::from_bytes(&payload)?;
Ok(request)
}
pub async fn transfer_collection<D: Map, E: EventSender>(
request: GetRequest,
db: &D,
writer: &mut ResponseWriter<E>,
mut outboard: D::Outboard,
mut data: D::DataReader,
stats: &mut TransferStats,
) -> Result<SentStatus> {
let hash = request.hash;
let just_root = matches!(request.ranges.as_single(), Some((0, _)));
let mut c = if !just_root {
let (stream, num_blobs) = parse_hash_seq(&mut data).await?;
writer
.events
.send(Event::TransferHashSeqStarted {
connection_id: writer.connection_id(),
request_id: writer.request_id(),
num_blobs,
})
.await;
Some(stream)
} else {
None
};
let mut prev = 0;
for (offset, ranges) in request.ranges.iter_non_empty() {
let mut tw = writer.tracking_writer();
if offset == 0 {
debug!("writing ranges '{:?}' of sequence {}", ranges, hash);
let mut tracking_reader = TrackingSliceReader::new(&mut data);
encode_ranges_validated(
&mut tracking_reader,
&mut outboard,
&ranges.to_chunk_ranges(),
&mut tw,
)
.await?;
stats.read += tracking_reader.stats();
stats.send += tw.stats();
debug!(
"finished writing ranges '{:?}' of collection {}",
ranges, hash
);
} else {
let c = c.as_mut().context("collection parser not available")?;
debug!("wrtiting ranges '{:?}' of child {}", ranges, offset);
if prev < offset - 1 {
c.skip(offset - prev - 1).await?;
}
if let Some(hash) = c.next().await? {
tokio::task::yield_now().await;
let (status, size, blob_read_stats) = send_blob(db, hash, ranges, &mut tw).await?;
stats.send += tw.stats();
stats.read += blob_read_stats;
if SentStatus::NotFound == status {
writer.inner.finish().await?;
return Ok(status);
}
writer
.events
.send(Event::TransferBlobCompleted {
connection_id: writer.connection_id(),
request_id: writer.request_id(),
hash,
index: offset - 1,
size,
})
.await;
} else {
break;
}
prev = offset;
}
}
debug!("done writing");
Ok(SentStatus::Sent)
}
pub trait EventSender: Clone + Sync + Send + 'static {
fn send(&self, event: Event) -> BoxFuture<()>;
}
pub async fn handle_connection<D: Map, E: EventSender>(
connecting: quinn::Connecting,
db: D,
events: E,
authorization_handler: Arc<dyn RequestAuthorizationHandler>,
rt: crate::util::runtime::Handle,
) {
let remote_addr = connecting.remote_address();
let connection = match connecting.await {
Ok(conn) => conn,
Err(err) => {
warn!(%remote_addr, "Error connecting: {err:#}");
return;
}
};
let connection_id = connection.stable_id() as u64;
let span = debug_span!("connection", connection_id, %remote_addr);
async move {
while let Ok((writer, reader)) = connection.accept_bi().await {
let request_id = reader.id().index();
let span = debug_span!("stream", stream_id = %request_id);
let writer = ResponseWriter {
connection_id,
events: events.clone(),
inner: writer,
};
events.send(Event::ClientConnected { connection_id }).await;
let db = db.clone();
let authorization_handler = authorization_handler.clone();
rt.local_pool().spawn_pinned(|| {
async move {
if let Err(err) = handle_stream(db, reader, writer, authorization_handler).await
{
warn!("error: {err:#?}",);
}
}
.instrument(span)
});
}
}
.instrument(span)
.await
}
async fn handle_stream<D: Map, E: EventSender>(
db: D,
reader: quinn::RecvStream,
writer: ResponseWriter<E>,
authorization_handler: Arc<dyn RequestAuthorizationHandler>,
) -> Result<()> {
debug!("reading request");
let request = match read_request(reader).await {
Ok(r) => r,
Err(e) => {
writer.notify_transfer_aborted(None).await;
return Err(e);
}
};
debug!("authorizing request");
if let Err(e) = authorization_handler
.authorize(request.token().cloned(), &request)
.await
{
writer.notify_transfer_aborted(None).await;
return Err(e);
}
match request {
Request::Get(request) => handle_get(db, request, writer).await,
}
}
pub async fn handle_get<D: Map, E: EventSender>(
db: D,
request: GetRequest,
mut writer: ResponseWriter<E>,
) -> Result<()> {
let hash = request.hash;
debug!(%hash, "received request");
writer
.events
.send(Event::GetRequestReceived {
hash,
connection_id: writer.connection_id(),
request_id: writer.request_id(),
token: request.token().cloned(),
})
.await;
match db.get(&hash) {
Some(entry) => {
let mut stats = Box::<TransferStats>::default();
let t0 = std::time::Instant::now();
let res = transfer_collection(
request,
&db,
&mut writer,
entry.outboard().await?,
entry.data_reader().await?,
&mut stats,
)
.await;
stats.duration = t0.elapsed();
match res {
Ok(SentStatus::Sent) => {
writer.notify_transfer_completed(&hash, stats).await;
}
Ok(SentStatus::NotFound) => {
writer.notify_transfer_aborted(Some(stats)).await;
}
Err(e) => {
writer.notify_transfer_aborted(Some(stats)).await;
return Err(e);
}
}
debug!("finished response");
}
None => {
debug!("not found {}", hash);
writer.notify_transfer_aborted(None).await;
writer.inner.finish().await?;
}
};
Ok(())
}
#[derive(Debug)]
pub struct ResponseWriter<E> {
inner: quinn::SendStream,
events: E,
connection_id: u64,
}
impl<E: EventSender> ResponseWriter<E> {
fn tracking_writer(
&mut self,
) -> TrackingStreamWriter<TokioStreamWriter<&mut quinn::SendStream>> {
TrackingStreamWriter::new(TokioStreamWriter(&mut self.inner))
}
fn connection_id(&self) -> u64 {
self.connection_id
}
fn request_id(&self) -> u64 {
self.inner.id().index()
}
fn print_stats(stats: &TransferStats) {
let send = stats.send.total();
let read = stats.read.total();
let total_sent_bytes = send.size;
let send_duration = send.stats.duration;
let read_duration = read.stats.duration;
let total_duration = stats.duration;
let other_duration = total_duration
.saturating_sub(send_duration)
.saturating_sub(read_duration);
let avg_send_size = total_sent_bytes / send.stats.count;
info!(
"sent {} bytes in {}s",
total_sent_bytes,
total_duration.as_secs_f64()
);
debug!(
"{}s sending, {}s reading, {}s other",
send_duration.as_secs_f64(),
read_duration.as_secs_f64(),
other_duration.as_secs_f64()
);
trace!(
"send_count: {} avg_send_size {}",
send.stats.count,
avg_send_size,
)
}
async fn notify_transfer_completed(&self, hash: &Hash, stats: Box<TransferStats>) {
info!("trasnfer completed for {}", hash);
Self::print_stats(&stats);
self.events
.send(Event::TransferCompleted {
connection_id: self.connection_id(),
request_id: self.request_id(),
stats,
})
.await;
}
async fn notify_transfer_aborted(&self, stats: Option<Box<TransferStats>>) {
if let Some(stats) = &stats {
Self::print_stats(stats);
};
self.events
.send(Event::TransferAborted {
connection_id: self.connection_id(),
request_id: self.request_id(),
stats,
})
.await;
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum SentStatus {
Sent,
NotFound,
}
pub async fn send_blob<D: Map, W: AsyncStreamWriter>(
db: &D,
name: Hash,
ranges: &RangeSpec,
writer: W,
) -> Result<(SentStatus, u64, SliceReaderStats)> {
match db.get(&name) {
Some(entry) => {
let outboard = entry.outboard().await?;
let size = outboard.tree().size().0;
let mut file_reader = TrackingSliceReader::new(entry.data_reader().await?);
let res = encode_ranges_validated(
&mut file_reader,
outboard,
&ranges.to_chunk_ranges(),
writer,
)
.await;
debug!("done sending blob {} {:?}", name, res);
res?;
Ok((SentStatus::Sent, size, file_reader.stats()))
}
_ => {
debug!("blob not found {}", hex::encode(name));
Ok((SentStatus::NotFound, 0, SliceReaderStats::default()))
}
}
}