use std::collections::BTreeMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Weak};
use bytes::Bytes;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use super::error::BlobError;
use super::mesh::MeshBlobAdapter;
use crate::adapter::net::{MeshNode, Reliability, Stream, StreamConfig};
pub const SUBPROTOCOL_BLOB_TRANSFER: u16 = 0x1100;
const TRANSFER_STREAM_FLAG: u64 = 1 << 61;
const CHANNEL_STREAM_BIT: u64 = 1 << 48;
const TRANSFER_NONCE_MASK: u64 = (1 << 48) - 1;
pub fn transfer_stream_id(nonce: u64) -> u64 {
TRANSFER_STREAM_FLAG | (nonce & TRANSFER_NONCE_MASK)
}
pub fn is_transfer_stream_id(stream_id: u64) -> bool {
stream_id & TRANSFER_STREAM_FLAG != 0 && stream_id & CHANNEL_STREAM_BIT == 0
}
static TRANSFER_STREAM_NONCE: AtomicU64 = AtomicU64::new(1);
pub fn next_transfer_stream_id() -> u64 {
let nonce = TRANSFER_STREAM_NONCE.fetch_add(1, Ordering::Relaxed);
transfer_stream_id(nonce)
}
const DATA_FRAME_BYTES: usize = 8000;
const TRANSFER_STREAM_WINDOW_BYTES: u32 =
crate::adapter::net::ReliableStream::DEFAULT_MAX_PENDING as u32 * DATA_FRAME_BYTES as u32;
const TRANSFER_MAX_CHUNK_BYTES: u64 = 16 * 1024 * 1024;
const TRANSFER_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
const SEND_RETRIES: usize = 64;
const MAX_REORDER_AHEAD: u64 = 1024;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum TransferControl {
Request {
hash: [u8; 32],
},
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum TransferHeader {
Found {
total_len: u64,
},
NotFound,
}
type DoneTx = tokio::sync::oneshot::Sender<Result<Bytes, BlobError>>;
enum ReassembleStep {
Continue,
Fail(BlobError),
Complete,
}
struct PendingInbound {
holder: u64,
expected_hash: [u8; 32],
total_len: Option<u64>,
buf: Vec<u8>,
next_seq: u64,
reorder: BTreeMap<u64, Vec<Bytes>>,
done: Option<DoneTx>,
}
pub struct BlobTransferEngine {
mesh: Weak<MeshNode>,
adapter: Arc<MeshBlobAdapter>,
pending: DashMap<u64, PendingInbound>,
}
impl BlobTransferEngine {
pub fn new(mesh: &Arc<MeshNode>, adapter: Arc<MeshBlobAdapter>) -> Self {
Self {
mesh: Arc::downgrade(mesh),
adapter,
pending: DashMap::new(),
}
}
pub fn register_pending(
&self,
stream_id: u64,
holder: u64,
expected_hash: [u8; 32],
done: DoneTx,
) {
self.pending.insert(
stream_id,
PendingInbound {
holder,
expected_hash,
total_len: None,
buf: Vec::new(),
next_seq: 0,
reorder: BTreeMap::new(),
done: Some(done),
},
);
}
pub fn cancel_pending(&self, stream_id: u64) {
self.pending.remove(&stream_id);
}
pub fn on_reset(&self, stream_id: u64) {
self.finish(
stream_id,
Err(BlobError::Backend(
"transfer: holder reset stream (retransmit exhausted)".into(),
)),
);
}
pub fn on_request(&self, requester: u64, stream_id: u64, payload: &[u8]) {
let control: TransferControl = match postcard::from_bytes(payload) {
Ok(c) => c,
Err(e) => {
tracing::debug!(error = %e, requester, "blob transfer: bad control frame");
return;
}
};
let TransferControl::Request { hash } = control;
let Some(mesh) = self.mesh.upgrade() else {
return;
};
let adapter = self.adapter.clone();
tokio::spawn(async move {
serve_chunk(mesh, adapter, requester, stream_id, hash).await;
});
}
pub fn on_data(&self, stream_id: u64, seq: u64, events: Vec<Bytes>) {
let outcome = {
let mut entry = match self.pending.get_mut(&stream_id) {
Some(e) => e,
None => return, };
if seq < entry.next_seq
|| entry.reorder.contains_key(&seq)
|| seq >= entry.next_seq.saturating_add(MAX_REORDER_AHEAD)
{
return;
}
entry.reorder.insert(seq, events);
let mut outcome = ReassembleStep::Continue;
loop {
let ns = entry.next_seq;
let Some(ready) = entry.reorder.remove(&ns) else {
break;
};
entry.next_seq += 1;
for event in &ready {
outcome = Self::process_event(&mut entry, event);
if !matches!(outcome, ReassembleStep::Continue) {
break;
}
}
if !matches!(outcome, ReassembleStep::Continue) {
break;
}
}
outcome
};
match outcome {
ReassembleStep::Continue => {}
ReassembleStep::Fail(err) => self.finish(stream_id, Err(err)),
ReassembleStep::Complete => self.finish_verified(stream_id),
}
}
fn process_event(entry: &mut PendingInbound, event: &Bytes) -> ReassembleStep {
if entry.total_len.is_none() {
match postcard::from_bytes::<TransferHeader>(event) {
Ok(TransferHeader::NotFound) => {
ReassembleStep::Fail(BlobError::NotFound("transfer: holder NotFound".into()))
}
Ok(TransferHeader::Found { total_len }) if total_len > TRANSFER_MAX_CHUNK_BYTES => {
ReassembleStep::Fail(BlobError::Backend(format!(
"transfer: total_len {total_len} exceeds cap"
)))
}
Ok(TransferHeader::Found { total_len }) => {
entry.total_len = Some(total_len);
entry
.buf
.reserve(total_len.min(TRANSFER_MAX_CHUNK_BYTES) as usize);
if total_len == 0 {
ReassembleStep::Complete
} else {
ReassembleStep::Continue
}
}
Err(e) => {
ReassembleStep::Fail(BlobError::Backend(format!("transfer: bad header: {e}")))
}
}
} else {
let total = entry.total_len.unwrap_or(0);
if (entry.buf.len() as u64).saturating_add(event.len() as u64) > total {
ReassembleStep::Fail(BlobError::Backend(
"transfer: holder sent more than total_len".into(),
))
} else {
entry.buf.extend_from_slice(event);
if entry.buf.len() as u64 >= total {
ReassembleStep::Complete
} else {
ReassembleStep::Continue
}
}
}
}
fn finish(&self, stream_id: u64, result: Result<Bytes, BlobError>) {
if let Some((_, mut pending)) = self.pending.remove(&stream_id) {
if let Some(tx) = pending.done.take() {
let _ = tx.send(result);
}
self.close_receive_stream(pending.holder, stream_id);
}
}
fn finish_verified(&self, stream_id: u64) {
let Some((_, mut pending)) = self.pending.remove(&stream_id) else {
return;
};
let bytes = std::mem::take(&mut pending.buf);
let result = {
let computed: [u8; 32] = blake3::hash(&bytes).into();
if computed == pending.expected_hash {
Ok(Bytes::from(bytes))
} else {
Err(BlobError::HashMismatch {
expected: pending.expected_hash,
actual: computed,
})
}
};
if let Some(tx) = pending.done.take() {
let _ = tx.send(result);
}
self.close_receive_stream(pending.holder, stream_id);
}
fn close_receive_stream(&self, holder: u64, stream_id: u64) {
if let Some(mesh) = self.mesh.upgrade() {
mesh.close_stream(holder, stream_id);
}
}
}
async fn serve_chunk(
mesh: Arc<MeshNode>,
adapter: Arc<MeshBlobAdapter>,
requester: u64,
stream_id: u64,
hash: [u8; 32],
) {
let cfg = StreamConfig::new()
.with_reliability(Reliability::Reliable)
.with_scheduled(true)
.with_window_bytes(TRANSFER_STREAM_WINDOW_BYTES)
.with_fairness_weight(1);
let stream = match mesh.open_stream(requester, stream_id, cfg) {
Ok(s) => s,
Err(e) => {
tracing::debug!(error = %e, requester, "blob transfer: open reply stream failed");
return;
}
};
let local = adapter.fetch_chunk(&hash).await;
match local {
Ok(bytes) => {
let header = TransferHeader::Found {
total_len: bytes.len() as u64,
};
if send_one(&mesh, &stream, postcard_event(&header))
.await
.is_ok()
{
for chunk in bytes.chunks(DATA_FRAME_BYTES) {
if send_one(&mesh, &stream, Bytes::copy_from_slice(chunk))
.await
.is_err()
{
break;
}
}
}
}
Err(_) => {
let _ = send_one(&mesh, &stream, postcard_event(&TransferHeader::NotFound)).await;
}
}
mesh.close_stream_graceful(requester, stream_id, TRANSFER_TIMEOUT)
.await;
}
fn postcard_event<T: Serialize>(value: &T) -> Bytes {
Bytes::from(postcard::to_allocvec(value).unwrap_or_default())
}
async fn send_one(mesh: &Arc<MeshNode>, stream: &Stream, event: Bytes) -> Result<(), ()> {
mesh.send_with_retry(stream, std::slice::from_ref(&event), SEND_RETRIES)
.await
.map_err(|e| {
tracing::debug!(error = %e, "blob transfer: stream send failed");
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn transfer_ids_are_disjoint_from_channel_and_control_streams() {
let channel_like = CHANNEL_STREAM_BIT | 0xDEAD_BEEF_CAFE;
assert!(!is_transfer_stream_id(channel_like));
assert!(!is_transfer_stream_id(u64::MAX));
assert!(!is_transfer_stream_id(SUBPROTOCOL_BLOB_TRANSFER as u64));
assert!(!is_transfer_stream_id(0x1000));
}
#[test]
fn transfer_ids_round_trip_and_self_identify() {
for nonce in [1u64, 42, 0xFFFF, (1 << 48) - 1] {
let id = transfer_stream_id(nonce);
assert!(is_transfer_stream_id(id), "id {id:#x} must self-identify");
assert_eq!(id & CHANNEL_STREAM_BIT, 0);
assert_ne!(id & TRANSFER_STREAM_FLAG, 0);
}
}
#[test]
fn allocator_yields_distinct_transfer_ids() {
let a = next_transfer_stream_id();
let b = next_transfer_stream_id();
assert_ne!(a, b);
assert!(is_transfer_stream_id(a) && is_transfer_stream_id(b));
}
}