mod channels;
mod counter;
mod dispatch;
mod drain;
mod event;
mod frame;
mod multi_packet;
mod output;
mod polling;
mod response;
mod shutdown;
mod state;
use std::{net::SocketAddr, sync::Arc};
pub use channels::ConnectionChannels;
use counter::ActiveConnection;
pub use counter::active_connection_count;
use event::Event;
use log::info;
use multi_packet::MultiPacketContext;
use output::{ActiveOutput, EventAvailability};
use state::ActorState;
use thiserror::Error;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
pub use crate::fairness::FairnessConfig;
use crate::{
app::Packet,
correlation::CorrelatableFrame,
fairness::FairnessTracker,
fragment::{FragmentationConfig, Fragmenter},
hooks::{ConnectionContext, ProtocolHooks},
push::{FrameLike, PushHandle, PushQueues},
response::{FrameStream, WireframeError},
session::ConnectionId,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Error)]
pub enum ConnectionStateError {
#[error("cannot set response while a multi-packet channel is active")]
MultiPacketActive,
#[error("cannot set multi-packet channel while a response stream is active")]
ResponseActive,
}
pub struct ConnectionActor<F, E> {
high_rx: Option<mpsc::Receiver<F>>,
low_rx: Option<mpsc::Receiver<F>>,
active_output: ActiveOutput<F, E>,
shutdown: CancellationToken,
counter: Option<ActiveConnection>,
hooks: ProtocolHooks<F, E>,
ctx: ConnectionContext,
fairness: FairnessTracker,
fragmenter: Option<Arc<Fragmenter>>,
connection_id: Option<ConnectionId>,
peer_addr: Option<SocketAddr>,
}
impl<F, E> ConnectionActor<F, E>
where
F: FrameLike + CorrelatableFrame + Packet,
E: std::fmt::Debug,
{
#[must_use]
pub fn new(
queues: PushQueues<F>,
handle: PushHandle<F>,
response: Option<FrameStream<F, E>>,
shutdown: CancellationToken,
) -> Self {
Self::with_hooks(
ConnectionChannels::new(queues, handle),
response,
shutdown,
ProtocolHooks::<F, E>::default(),
)
}
#[must_use]
pub fn with_hooks(
channels: ConnectionChannels<F>,
response: Option<FrameStream<F, E>>,
shutdown: CancellationToken,
hooks: ProtocolHooks<F, E>,
) -> Self {
let ConnectionChannels { queues, handle } = channels;
let ctx = ConnectionContext;
let counter = ActiveConnection::new();
let active_output = match response {
Some(stream) => ActiveOutput::Response(stream),
None => ActiveOutput::None,
};
let mut actor = Self {
high_rx: Some(queues.high_priority_rx),
low_rx: Some(queues.low_priority_rx),
active_output,
shutdown,
counter: Some(counter),
hooks,
ctx,
fairness: FairnessTracker::new(FairnessConfig::default()),
fragmenter: None,
connection_id: None,
peer_addr: None,
};
info!(
"connection opened: wireframe_active_connections={}, id={:?}, peer={:?}",
counter::current_count(),
actor.connection_id,
actor.peer_addr
);
actor.hooks.on_connection_setup(handle, &mut actor.ctx);
actor
}
pub fn set_fairness(&mut self, fairness: FairnessConfig) { self.fairness.set_config(fairness); }
pub fn enable_fragmentation(&mut self, config: FragmentationConfig)
where
F: Packet,
{
self.fragmenter = Some(Arc::new(Fragmenter::new(config.fragment_payload_cap)));
}
pub fn set_response(
&mut self,
stream: Option<FrameStream<F, E>>,
) -> Result<(), ConnectionStateError> {
if self.active_output.is_multi_packet() {
return Err(ConnectionStateError::MultiPacketActive);
}
self.active_output = match stream {
Some(s) => ActiveOutput::Response(s),
None => ActiveOutput::None,
};
Ok(())
}
pub fn set_multi_packet(
&mut self,
channel: Option<mpsc::Receiver<F>>,
) -> Result<(), ConnectionStateError> {
self.set_multi_packet_with_correlation(channel, None)
}
pub fn set_multi_packet_with_correlation(
&mut self,
channel: Option<mpsc::Receiver<F>>,
correlation_id: Option<u64>,
) -> Result<(), ConnectionStateError> {
if self.active_output.is_response() {
return Err(ConnectionStateError::ResponseActive);
}
self.active_output = match channel {
Some(rx) => {
let mut ctx = MultiPacketContext::new();
ctx.install(Some(rx), correlation_id);
ActiveOutput::MultiPacket(ctx)
}
None => ActiveOutput::None,
};
Ok(())
}
pub fn set_low_queue(&mut self, queue: Option<mpsc::Receiver<F>>) { self.low_rx = queue; }
#[must_use]
pub fn shutdown_token(&self) -> CancellationToken { self.shutdown.clone() }
pub async fn run(&mut self, out: &mut Vec<F>) -> Result<(), WireframeError<E>> {
if self.shutdown.is_cancelled() {
info!(
"connection aborted before start: id={:?}, peer={:?}",
self.connection_id, self.peer_addr
);
let _ = self.counter.take();
return Ok(());
}
let mut state = ActorState::new(
self.active_output.is_response(),
self.active_output.is_multi_packet(),
);
while !state.is_done() {
self.poll_sources(&mut state, out).await?;
}
info!(
"connection closed: id={:?}, peer={:?}",
self.connection_id, self.peer_addr
);
let _ = self.counter.take();
Ok(())
}
fn compute_availability(&self, state: &ActorState) -> EventAvailability {
EventAvailability {
high: self.high_rx.is_some(),
low: self.low_rx.is_some(),
multi_packet: self.active_output.is_multi_packet() && !state.is_shutting_down(),
response: self.active_output.is_response() && !state.is_shutting_down(),
}
}
#[expect(
clippy::integer_division_remainder_used,
reason = "tokio::select! expands to modulus operations internally"
)]
async fn next_event(&mut self, state: &ActorState) -> Event<F, E> {
let avail = self.compute_availability(state);
let (multi_rx, response_stream) = match &mut self.active_output {
ActiveOutput::MultiPacket(ctx) => (ctx.channel_mut(), None),
ActiveOutput::Response(stream) => (None, Some(stream)),
ActiveOutput::None => (None, None),
};
tokio::select! {
biased;
() = Self::wait_shutdown(self.shutdown.clone()), if state.is_active() => Event::Shutdown,
res = Self::poll_queue(self.high_rx.as_mut()), if avail.high => Event::High(res),
res = Self::poll_queue(self.low_rx.as_mut()), if avail.low => Event::Low(res),
res = Self::poll_queue(multi_rx), if avail.multi_packet => Event::MultiPacket(res),
res = Self::poll_response(response_stream), if avail.response => Event::Response(res),
else => Event::Idle,
}
}
async fn poll_sources(
&mut self,
state: &mut ActorState,
out: &mut Vec<F>,
) -> Result<(), WireframeError<E>> {
let event = self.next_event(state).await;
self.dispatch_event(event, state, out)
}
}
#[cfg(all(not(loom), any(test, feature = "test-support")))]
pub mod test_support;