use crate::quic::{BoxedBidiStream, BoxedRecvStream};
use std::{
any::Any,
fmt::{self, Debug},
sync::{Arc, RwLock},
};
#[derive(fieldwork::Fieldwork)]
#[fieldwork(get)]
pub enum WebTransportStream {
Bidi {
session_id: u64,
stream: BoxedBidiStream,
buffer: Vec<u8>,
},
Uni {
session_id: u64,
stream: BoxedRecvStream,
buffer: Vec<u8>,
},
}
impl Debug for WebTransportStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Bidi { session_id, .. } => f
.debug_struct("WebTransportStream::Bidi")
.field("session_id", session_id)
.finish_non_exhaustive(),
Self::Uni { session_id, .. } => f
.debug_struct("WebTransportStream::Uni")
.field("session_id", session_id)
.finish_non_exhaustive(),
}
}
}
pub trait WebTransportDispatch: Any + Send + Sync {
fn dispatch(&self, stream: WebTransportStream);
}
enum DispatchState {
Buffering(Vec<WebTransportStream>),
Active(Arc<dyn WebTransportDispatch>),
}
#[derive(Clone)]
pub struct WebTransportDispatcher(Arc<RwLock<DispatchState>>);
impl Default for WebTransportDispatcher {
fn default() -> Self {
Self(Arc::new(RwLock::new(DispatchState::Buffering(Vec::new()))))
}
}
impl Debug for WebTransportDispatcher {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let state = self.0.read().expect("dispatcher lock poisoned");
let label = match &*state {
DispatchState::Buffering(buf) => {
format!("Buffering({} streams)", buf.len())
}
DispatchState::Active(_) => "Active".to_string(),
};
f.debug_tuple("WebTransportDispatcher")
.field(&label)
.finish()
}
}
impl WebTransportDispatcher {
pub fn new() -> Self {
Self::default()
}
pub fn dispatch(&self, stream: WebTransportStream) {
{
let state = self.0.read().expect("dispatcher lock poisoned");
if let DispatchState::Active(handler) = &*state {
handler.dispatch(stream);
return;
}
}
{
let mut state = self.0.write().expect("dispatcher lock poisoned");
match &*state {
DispatchState::Buffering(_) => {
let DispatchState::Buffering(buf) = &mut *state else {
unreachable!()
};
buf.push(stream);
}
DispatchState::Active(handler) => handler.dispatch(stream),
}
}
}
pub fn get_or_init_with<T: WebTransportDispatch>(
&self,
init: impl FnOnce() -> T,
) -> Option<Arc<T>> {
{
let state = self.0.read().expect("dispatcher lock poisoned");
if let DispatchState::Active(handler) = &*state {
return downcast_arc(handler.clone());
}
}
let mut state = self.0.write().expect("dispatcher lock poisoned");
match &*state {
DispatchState::Active(handler) => downcast_arc(handler.clone()),
DispatchState::Buffering(_) => {
let handler = Arc::new(init());
let buffered = std::mem::replace(
&mut *state,
DispatchState::Active(handler.clone() as Arc<dyn WebTransportDispatch>),
);
let DispatchState::Buffering(buffered) = buffered else {
unreachable!()
};
drop(state);
for stream in buffered {
handler.dispatch(stream);
}
Some(handler)
}
}
}
}
fn downcast_arc<T: Any + Send + Sync>(arc: Arc<dyn WebTransportDispatch>) -> Option<Arc<T>> {
let any: Arc<dyn Any + Send + Sync> = arc;
any.downcast::<T>().ok()
}