use std::net::SocketAddr;
use std::time::Duration;
use tokio::time::Instant;
use dashmap::DashMap;
use tokio::sync::oneshot;
use crate::transport::peer_connection::StreamId;
use crate::transport::peer_connection::streaming::StreamHandle;
pub const ORPHAN_STREAM_TIMEOUT: Duration = Duration::from_secs(60);
pub const STREAM_CLAIM_TIMEOUT: Duration = Duration::from_secs(60);
type StreamKey = (SocketAddr, StreamId);
pub struct OrphanStreamRegistry {
orphan_streams: DashMap<StreamKey, (StreamHandle, Instant)>,
stream_waiters: DashMap<StreamKey, oneshot::Sender<StreamHandle>>,
claimed_streams: DashMap<StreamKey, ()>,
}
impl OrphanStreamRegistry {
pub fn new() -> Self {
Self {
orphan_streams: DashMap::new(),
stream_waiters: DashMap::new(),
claimed_streams: DashMap::new(),
}
}
pub fn register_orphan(
&self,
peer_addr: SocketAddr,
stream_id: StreamId,
handle: StreamHandle,
) {
let key = (peer_addr, stream_id);
if let Some((_, waiter)) = self.stream_waiters.remove(&key) {
if waiter.send(handle).is_err() {
tracing::warn!(
%peer_addr,
stream_id = %stream_id,
"Failed to deliver orphan stream to waiter (receiver dropped)"
);
} else {
tracing::debug!(
%peer_addr,
stream_id = %stream_id,
"Delivered stream to waiting operation"
);
}
} else {
tracing::debug!(
%peer_addr,
stream_id = %stream_id,
"Registered orphan stream (metadata not yet received)"
);
self.orphan_streams.insert(key, (handle, Instant::now()));
}
}
pub async fn claim_or_wait(
&self,
peer_addr: SocketAddr,
stream_id: StreamId,
timeout: Duration,
) -> Result<StreamHandle, OrphanStreamError> {
let key = (peer_addr, stream_id);
use dashmap::mapref::entry::Entry;
match self.claimed_streams.entry(key) {
Entry::Occupied(_) => {
tracing::debug!(
%peer_addr,
stream_id = %stream_id,
"Stream already claimed (dedup)"
);
return Err(OrphanStreamError::AlreadyClaimed);
}
Entry::Vacant(entry) => {
entry.insert(());
}
}
if let Some((_, (handle, _))) = self.orphan_streams.remove(&key) {
tracing::debug!(
%peer_addr,
stream_id = %stream_id,
"Claimed orphan stream immediately"
);
return Ok(handle);
}
let (tx, rx) = oneshot::channel();
self.stream_waiters.insert(key, tx);
tracing::debug!(
%peer_addr,
stream_id = %stream_id,
timeout_ms = timeout.as_millis(),
"Waiting for stream to arrive"
);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(handle)) => {
tracing::debug!(
%peer_addr,
stream_id = %stream_id,
"Stream arrived while waiting"
);
Ok(handle)
}
Ok(Err(_)) => {
self.stream_waiters.remove(&key);
self.claimed_streams.remove(&key);
tracing::warn!(
%peer_addr,
stream_id = %stream_id,
"Stream waiter cancelled unexpectedly"
);
Err(OrphanStreamError::WaiterCancelled)
}
Err(_) => {
self.stream_waiters.remove(&key);
self.claimed_streams.remove(&key);
tracing::warn!(
%peer_addr,
stream_id = %stream_id,
timeout_ms = timeout.as_millis(),
"Timeout waiting for stream"
);
Err(OrphanStreamError::Timeout)
}
}
}
pub fn gc_expired(&self) {
let now = Instant::now();
let mut expired_count = 0;
self.orphan_streams
.retain(|(peer_addr, stream_id), (handle, created)| {
if now.duration_since(*created) > ORPHAN_STREAM_TIMEOUT {
tracing::debug!(
%peer_addr,
stream_id = %stream_id,
age_secs = now.duration_since(*created).as_secs(),
"Garbage collecting expired orphan stream"
);
handle.cancel();
expired_count += 1;
false
} else {
true
}
});
if self.claimed_streams.len() > 1000 {
self.claimed_streams.clear();
}
if expired_count > 0 {
tracing::info!(
expired_count,
remaining = self.orphan_streams.len(),
"Garbage collected expired orphan streams"
);
}
}
#[cfg(test)]
pub fn orphan_count(&self) -> usize {
self.orphan_streams.len()
}
#[cfg(test)]
pub fn waiter_count(&self) -> usize {
self.stream_waiters.len()
}
pub fn start_gc_task(registry: std::sync::Arc<Self>) {
use crate::config::GlobalExecutor;
GlobalExecutor::spawn(Self::gc_task(registry));
}
async fn gc_task(registry: std::sync::Arc<Self>) {
use crate::config::GlobalRng;
let initial_delay = Duration::from_secs(GlobalRng::random_range(5u64..=15u64));
tokio::time::sleep(initial_delay).await;
const GC_INTERVAL: Duration = Duration::from_secs(5);
let mut interval = tokio::time::interval(GC_INTERVAL);
tracing::debug!("Orphan stream GC task started");
loop {
interval.tick().await;
registry.gc_expired();
}
}
}
impl Default for OrphanStreamRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OrphanStreamError {
Timeout,
WaiterCancelled,
AlreadyClaimed,
}
impl std::fmt::Display for OrphanStreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OrphanStreamError::Timeout => write!(f, "timeout waiting for stream"),
OrphanStreamError::WaiterCancelled => write!(f, "stream waiter was cancelled"),
OrphanStreamError::AlreadyClaimed => write!(f, "stream already claimed (duplicate)"),
}
}
}
impl std::error::Error for OrphanStreamError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::GlobalExecutor;
const WAITER_REGISTRATION_DELAY: Duration = Duration::from_millis(50);
const EXPIRED_ORPHAN_AGE: Duration = Duration::from_secs(60);
fn dummy_addr() -> SocketAddr {
"127.0.0.1:9000".parse().unwrap()
}
fn dummy_addr_2() -> SocketAddr {
"127.0.0.2:9000".parse().unwrap()
}
fn make_test_handle(stream_id: StreamId) -> StreamHandle {
StreamHandle::new(stream_id, 1000)
}
#[test]
fn test_orphan_registry_new() {
let registry = OrphanStreamRegistry::new();
assert_eq!(registry.orphan_count(), 0);
assert_eq!(registry.waiter_count(), 0);
}
#[tokio::test]
async fn test_orphan_claim_immediate() {
let registry = OrphanStreamRegistry::new();
let stream_id = StreamId::next();
let handle = make_test_handle(stream_id);
let addr = dummy_addr();
registry.register_orphan(addr, stream_id, handle);
assert_eq!(registry.orphan_count(), 1);
let claimed = registry
.claim_or_wait(addr, stream_id, Duration::from_secs(1))
.await;
assert!(claimed.is_ok());
assert_eq!(registry.orphan_count(), 0);
}
#[tokio::test]
async fn test_orphan_wait_then_register() {
let registry = std::sync::Arc::new(OrphanStreamRegistry::new());
let stream_id = StreamId::next();
let addr = dummy_addr();
let registry_clone = registry.clone();
let waiter = GlobalExecutor::spawn(async move {
registry_clone
.claim_or_wait(addr, stream_id, Duration::from_secs(5))
.await
});
tokio::time::sleep(WAITER_REGISTRATION_DELAY).await;
assert_eq!(registry.waiter_count(), 1);
let handle = make_test_handle(stream_id);
registry.register_orphan(addr, stream_id, handle);
let result = waiter.await.unwrap();
assert!(result.is_ok());
assert_eq!(registry.waiter_count(), 0);
}
#[tokio::test]
async fn test_duplicate_claim_returns_already_claimed() {
let registry = OrphanStreamRegistry::new();
let stream_id = StreamId::next();
let handle = make_test_handle(stream_id);
let addr = dummy_addr();
registry.register_orphan(addr, stream_id, handle);
let result = registry
.claim_or_wait(addr, stream_id, Duration::from_secs(1))
.await;
assert!(result.is_ok());
let result = registry
.claim_or_wait(addr, stream_id, Duration::from_secs(5))
.await;
assert!(matches!(result, Err(OrphanStreamError::AlreadyClaimed)));
}
#[tokio::test]
async fn test_orphan_timeout() {
let registry = OrphanStreamRegistry::new();
let stream_id = StreamId::next();
let addr = dummy_addr();
let result = registry
.claim_or_wait(addr, stream_id, WAITER_REGISTRATION_DELAY)
.await;
assert!(matches!(result, Err(OrphanStreamError::Timeout)));
}
#[test]
fn test_gc_expired() {
let registry = OrphanStreamRegistry::new();
let stream_id = StreamId::next();
let handle = make_test_handle(stream_id);
let addr = dummy_addr();
registry.orphan_streams.insert(
(addr, stream_id),
(handle, Instant::now() - EXPIRED_ORPHAN_AGE),
);
assert_eq!(registry.orphan_count(), 1);
registry.gc_expired();
assert_eq!(registry.orphan_count(), 0);
}
#[test]
fn test_gc_preserves_fresh() {
let registry = OrphanStreamRegistry::new();
let stream_id = StreamId::next();
let handle = make_test_handle(stream_id);
let addr = dummy_addr();
registry.register_orphan(addr, stream_id, handle);
assert_eq!(registry.orphan_count(), 1);
registry.gc_expired();
assert_eq!(registry.orphan_count(), 1);
}
#[tokio::test]
async fn test_different_peers_same_stream_id_no_collision() {
let registry = OrphanStreamRegistry::new();
let stream_id = StreamId::next();
let addr_a = dummy_addr();
let addr_b = dummy_addr_2();
let handle_a = make_test_handle(stream_id);
let handle_b = make_test_handle(stream_id);
registry.register_orphan(addr_a, stream_id, handle_a);
registry.register_orphan(addr_b, stream_id, handle_b);
assert_eq!(registry.orphan_count(), 2);
let result_a = registry
.claim_or_wait(addr_a, stream_id, Duration::from_secs(1))
.await;
assert!(result_a.is_ok());
let result_b = registry
.claim_or_wait(addr_b, stream_id, Duration::from_secs(1))
.await;
assert!(result_b.is_ok());
assert_eq!(registry.orphan_count(), 0);
}
}