pub mod capsule;
pub mod error;
pub mod flow_control;
pub mod stream;
pub mod varint;
use std::collections::{HashMap, VecDeque};
use crate::connection::Role;
pub use capsule::{Capsule, CapsuleDecoder, CapsuleEncoder, capsule_type};
pub use error::{WtError, WtErrorKind, WtResult};
pub use flow_control::WtFlowControl;
pub use stream::{RecvState, SendState, WtStream, WtStreamId};
pub use varint::{MAX_VALUE, decode as varint_decode, encode as varint_encode, encoded_len};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum WtSessionState {
#[default]
Initial,
Active,
Draining,
Closed,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WtEvent {
StreamOpened {
stream_id: WtStreamId,
bidirectional: bool,
},
StreamData {
stream_id: WtStreamId,
data: Vec<u8>,
fin: bool,
},
StreamReset {
stream_id: WtStreamId,
error_code: u64,
},
StopSending {
stream_id: WtStreamId,
error_code: u64,
},
DatagramReceived { data: Vec<u8> },
SessionDraining,
SessionClosed { error_code: u32, reason: String },
}
#[derive(Debug, Clone)]
pub struct WtConfig {
pub initial_max_data: u64,
pub initial_max_stream_data_bidi_local: u64,
pub initial_max_stream_data_bidi_remote: u64,
pub initial_max_stream_data_uni: u64,
pub initial_max_streams_bidi: u64,
pub initial_max_streams_uni: u64,
}
impl Default for WtConfig {
fn default() -> Self {
Self {
initial_max_data: 1_048_576, initial_max_stream_data_bidi_local: 262_144, initial_max_stream_data_bidi_remote: 262_144, initial_max_stream_data_uni: 262_144, initial_max_streams_bidi: 100,
initial_max_streams_uni: 100,
}
}
}
#[derive(Debug)]
pub struct WtSession {
role: Role,
config: WtConfig,
state: WtSessionState,
streams: HashMap<WtStreamId, WtStream>,
flow_control: WtFlowControl,
capsule_decoder: CapsuleDecoder,
capsule_encoder: CapsuleEncoder,
output_buffer: VecDeque<u8>,
events: VecDeque<WtEvent>,
next_bidi_stream_id: WtStreamId,
next_uni_stream_id: WtStreamId,
}
impl WtSession {
#[must_use]
pub fn client(config: WtConfig) -> Self {
Self::new(Role::Client, config)
}
#[must_use]
pub fn server(config: WtConfig) -> Self {
Self::new(Role::Server, config)
}
fn new(role: Role, config: WtConfig) -> Self {
let is_client = role == Role::Client;
let flow_control = WtFlowControl::new(
config.initial_max_data,
config.initial_max_streams_bidi,
config.initial_max_streams_uni,
);
Self {
role,
config,
state: WtSessionState::Initial,
streams: HashMap::new(),
flow_control,
capsule_decoder: CapsuleDecoder::new(),
capsule_encoder: CapsuleEncoder::new(),
output_buffer: VecDeque::new(),
events: VecDeque::new(),
next_bidi_stream_id: stream::stream_id::first(is_client, true),
next_uni_stream_id: stream::stream_id::first(is_client, false),
}
}
#[must_use]
pub const fn role(&self) -> Role {
self.role
}
#[must_use]
pub const fn state(&self) -> WtSessionState {
self.state
}
#[must_use]
pub const fn is_active(&self) -> bool {
matches!(self.state, WtSessionState::Active)
}
#[must_use]
pub const fn is_closed(&self) -> bool {
matches!(self.state, WtSessionState::Closed)
}
pub fn initiate(&mut self) -> WtResult<()> {
if self.state != WtSessionState::Initial {
return Err(WtError::session_state_error("session already initiated"));
}
self.state = WtSessionState::Active;
Ok(())
}
pub fn feed(&mut self, data: &[u8]) -> WtResult<usize> {
self.capsule_decoder.feed(data);
Ok(data.len())
}
pub fn process(&mut self) -> WtResult<()> {
while let Some(capsule) = self.capsule_decoder.decode()? {
self.handle_capsule(capsule)?;
}
Ok(())
}
#[must_use]
pub fn poll_output(&mut self) -> Option<Vec<u8>> {
if self.output_buffer.is_empty() {
None
} else {
Some(self.output_buffer.drain(..).collect())
}
}
#[must_use]
pub fn poll_event(&mut self) -> Option<WtEvent> {
self.events.pop_front()
}
#[must_use]
pub fn has_output(&self) -> bool {
!self.output_buffer.is_empty()
}
pub fn open_bidi_stream(&mut self) -> WtResult<WtStreamId> {
self.open_stream(true)
}
pub fn open_uni_stream(&mut self) -> WtResult<WtStreamId> {
self.open_stream(false)
}
fn open_stream(&mut self, bidirectional: bool) -> WtResult<WtStreamId> {
if !matches!(
self.state,
WtSessionState::Active | WtSessionState::Draining
) {
return Err(WtError::session_state_error(
"cannot open stream: session not active or draining",
));
}
if bidirectional {
if !self.flow_control.can_open_bidi_stream() {
return Err(WtError::flow_control_error("bidi stream limit reached"));
}
} else if !self.flow_control.can_open_uni_stream() {
return Err(WtError::flow_control_error("uni stream limit reached"));
}
let stream_id = if bidirectional {
let id = self.next_bidi_stream_id;
self.next_bidi_stream_id = stream::stream_id::next(id);
id
} else {
let id = self.next_uni_stream_id;
self.next_uni_stream_id = stream::stream_id::next(id);
id
};
let initial_max_data = if bidirectional {
self.config.initial_max_stream_data_bidi_local
} else {
self.config.initial_max_stream_data_uni
};
let stream = WtStream::new(stream_id, initial_max_data, bidirectional);
self.streams.insert(stream_id, stream);
self.flow_control.opened_stream(bidirectional);
Ok(stream_id)
}
pub fn send_stream_data(
&mut self,
stream_id: WtStreamId,
data: &[u8],
fin: bool,
) -> WtResult<()> {
if !matches!(
self.state,
WtSessionState::Active | WtSessionState::Draining
) {
return Err(WtError::session_state_error(
"cannot send data: session not active or draining",
));
}
let stream = self
.streams
.get_mut(&stream_id)
.ok_or_else(|| WtError::invalid_stream_id("stream not found"))?;
if !stream.can_send() {
return Err(WtError::stream_state_error("cannot send on this stream"));
}
let capsule = Capsule::WtStream {
stream_id,
data: data.to_vec(),
fin,
};
self.capsule_encoder.encode(&capsule);
self.output_buffer.extend(self.capsule_encoder.take());
stream.send_data(data.len() as u64, fin)?;
self.flow_control.consume_send(data.len() as u64)?;
Ok(())
}
pub fn reset_stream(&mut self, stream_id: WtStreamId, error_code: u64) -> WtResult<()> {
let stream = self
.streams
.get_mut(&stream_id)
.ok_or_else(|| WtError::invalid_stream_id("stream not found"))?;
if !stream.can_send() {
return Err(WtError::stream_state_error(
"cannot send WT_RESET_STREAM: stream not in valid send state",
));
}
let reliable_size = stream.send_offset();
let capsule = Capsule::WtResetStream {
stream_id,
error_code,
reliable_size,
};
self.capsule_encoder.encode(&capsule);
self.output_buffer.extend(self.capsule_encoder.take());
stream.send_reset();
Ok(())
}
pub fn stop_sending(&mut self, stream_id: WtStreamId, error_code: u64) -> WtResult<()> {
let stream = self
.streams
.get_mut(&stream_id)
.ok_or_else(|| WtError::invalid_stream_id("stream not found"))?;
if stream.stop_sending_sent() {
return Err(WtError::stream_state_error(
"cannot send WT_STOP_SENDING: already sent",
));
}
let capsule = Capsule::WtStopSending {
stream_id,
error_code,
};
self.capsule_encoder.encode(&capsule);
self.output_buffer.extend(self.capsule_encoder.take());
stream.set_stop_sending_sent();
Ok(())
}
pub fn send_datagram(&mut self, data: &[u8]) -> WtResult<()> {
if !matches!(
self.state,
WtSessionState::Active | WtSessionState::Draining
) {
return Err(WtError::session_state_error(
"cannot send datagram: session not active or draining",
));
}
let capsule = Capsule::Datagram {
data: data.to_vec(),
};
self.capsule_encoder.encode(&capsule);
self.output_buffer.extend(self.capsule_encoder.take());
Ok(())
}
pub fn close(&mut self, error_code: u32, reason: &str) -> WtResult<()> {
if self.state == WtSessionState::Closed {
return Err(WtError::session_state_error("session already closed"));
}
let capsule = Capsule::WtCloseSession {
error_code,
reason: reason.to_string(),
};
self.capsule_encoder.encode(&capsule);
self.output_buffer.extend(self.capsule_encoder.take());
self.state = WtSessionState::Closed;
Ok(())
}
pub fn send_max_data(&mut self, maximum: u64) -> WtResult<()> {
let capsule = Capsule::WtMaxData { maximum };
self.capsule_encoder.encode(&capsule);
self.output_buffer.extend(self.capsule_encoder.take());
Ok(())
}
pub fn send_max_stream_data(&mut self, stream_id: WtStreamId, maximum: u64) -> WtResult<()> {
let capsule = Capsule::WtMaxStreamData { stream_id, maximum };
self.capsule_encoder.encode(&capsule);
self.output_buffer.extend(self.capsule_encoder.take());
Ok(())
}
pub fn send_max_streams(&mut self, maximum: u64, bidirectional: bool) -> WtResult<()> {
let capsule = Capsule::WtMaxStreams {
maximum,
bidirectional,
};
self.capsule_encoder.encode(&capsule);
self.output_buffer.extend(self.capsule_encoder.take());
Ok(())
}
#[must_use]
pub const fn flow_control(&self) -> &WtFlowControl {
&self.flow_control
}
pub fn flow_control_mut(&mut self) -> &mut WtFlowControl {
&mut self.flow_control
}
#[must_use]
pub fn stream(&self, stream_id: WtStreamId) -> Option<&WtStream> {
self.streams.get(&stream_id)
}
pub fn grow_recv_window(&mut self, increment: u64) -> WtResult<()> {
self.flow_control.add_recv_max(increment)?;
let new_max = self.flow_control.recv_max();
self.send_max_data(new_max)
}
pub fn grow_stream_recv_window(
&mut self,
stream_id: WtStreamId,
increment: u64,
) -> WtResult<()> {
let stream = self
.streams
.get_mut(&stream_id)
.ok_or_else(|| WtError::invalid_stream_id("stream not found"))?;
let new_max = stream.recv_max().saturating_add(increment);
stream.update_recv_max(new_max);
self.send_max_stream_data(stream_id, new_max)
}
pub fn grow_max_streams(&mut self, increment: u64, bidirectional: bool) -> WtResult<()> {
self.flow_control
.add_max_streams_local(increment, bidirectional);
let new_max = if bidirectional {
self.flow_control.max_streams_bidi_local()
} else {
self.flow_control.max_streams_uni_local()
};
self.send_max_streams(new_max, bidirectional)
}
#[must_use]
pub const fn config(&self) -> &WtConfig {
&self.config
}
pub fn drain(&mut self) -> WtResult<()> {
if self.state != WtSessionState::Active {
return Err(WtError::session_state_error(
"cannot drain: session not active",
));
}
let capsule = Capsule::WtDrainSession;
self.capsule_encoder.encode(&capsule);
self.output_buffer.extend(self.capsule_encoder.take());
self.state = WtSessionState::Draining;
Ok(())
}
fn handle_capsule(&mut self, capsule: Capsule) -> WtResult<()> {
match capsule {
Capsule::Datagram { data } => {
self.events.push_back(WtEvent::DatagramReceived { data });
}
Capsule::WtStream {
stream_id,
data,
fin,
} => {
self.handle_stream_data(stream_id, data, fin)?;
}
Capsule::WtResetStream {
stream_id,
error_code,
reliable_size,
} => {
if let Some(stream) = self.streams.get_mut(&stream_id) {
if !stream.can_recv() {
return Err(WtError::stream_state_error(
"WT_RESET_STREAM received for stream not in valid state",
));
}
if reliable_size < stream.recv_offset() {
return Err(WtError::stream_state_error(format!(
"WT_RESET_STREAM reliable_size {} is less than recv_offset {}",
reliable_size,
stream.recv_offset()
)));
}
stream.recv_reset();
}
self.events.push_back(WtEvent::StreamReset {
stream_id,
error_code,
});
}
Capsule::WtStopSending {
stream_id,
error_code,
} => {
if let Some(stream) = self.streams.get_mut(&stream_id) {
if stream.stop_sending_received() {
return Err(WtError::stream_state_error(
"duplicate WT_STOP_SENDING received",
));
}
stream.set_stop_sending_received();
}
self.events.push_back(WtEvent::StopSending {
stream_id,
error_code,
});
}
Capsule::WtMaxData { maximum } => {
self.flow_control.update_send_max(maximum)?;
}
Capsule::WtMaxStreamData { stream_id, maximum } => {
if let Some(stream) = self.streams.get_mut(&stream_id) {
if stream.stop_sending_sent() {
return Err(WtError::stream_state_error(
"WT_MAX_STREAM_DATA received after WT_STOP_SENDING",
));
}
stream.update_send_max(maximum)?;
}
}
Capsule::WtMaxStreams {
maximum,
bidirectional,
} => {
self.flow_control
.update_max_streams(maximum, bidirectional)?;
}
Capsule::WtDataBlocked { maximum: _ } => {
}
Capsule::WtStreamDataBlocked {
stream_id,
maximum: _,
} => {
if let Some(stream) = self.streams.get(&stream_id)
&& !stream.can_recv()
&& !stream.can_send()
{
return Err(WtError::stream_state_error(
"WT_STREAM_DATA_BLOCKED received for stream not in valid state",
));
}
}
Capsule::WtStreamsBlocked {
maximum: _,
bidirectional: _,
} => {
}
Capsule::WtCloseSession { error_code, reason } => {
if self.state != WtSessionState::Closed {
self.state = WtSessionState::Closed;
self.events
.push_back(WtEvent::SessionClosed { error_code, reason });
}
}
Capsule::WtDrainSession => {
if self.state == WtSessionState::Active {
self.state = WtSessionState::Draining;
self.events.push_back(WtEvent::SessionDraining);
}
}
Capsule::Padding { .. } | Capsule::Unknown { .. } => {
}
}
Ok(())
}
fn handle_stream_data(
&mut self,
stream_id: WtStreamId,
data: Vec<u8>,
fin: bool,
) -> WtResult<()> {
let is_new_stream = !self.streams.contains_key(&stream_id);
if data.is_empty() && !is_new_stream && !fin {
if self
.streams
.get(&stream_id)
.is_some_and(|stream| stream.has_received_data())
{
return Err(WtError::stream_state_error(
"empty WT_STREAM capsule without FIN on existing stream",
));
}
}
if is_new_stream {
let is_peer_initiated = match self.role {
Role::Client => stream::stream_id::is_server_initiated(stream_id),
Role::Server => stream::stream_id::is_client_initiated(stream_id),
};
if !is_peer_initiated {
return Err(WtError::stream_state_error(
"received stream with locally-initiated stream ID",
));
}
if !self.flow_control.can_accept_stream(stream_id) {
return Err(WtError::flow_control_error("peer exceeded stream limit"));
}
let bidirectional = stream::stream_id::is_bidirectional(stream_id);
let initial_max_data = if bidirectional {
self.config.initial_max_stream_data_bidi_remote
} else {
self.config.initial_max_stream_data_uni
};
let stream = WtStream::new(stream_id, initial_max_data, bidirectional);
self.streams.insert(stream_id, stream);
self.events.push_back(WtEvent::StreamOpened {
stream_id,
bidirectional,
});
}
let stream = self
.streams
.get_mut(&stream_id)
.expect("stream must exist after insert or lookup");
stream.recv_data(data.len() as u64, fin)?;
if !data.is_empty() {
stream.set_has_received_data();
}
self.flow_control.consume_recv(data.len() as u64)?;
self.events.push_back(WtEvent::StreamData {
stream_id,
data,
fin,
});
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_session_creation() {
let session = WtSession::client(WtConfig::default());
assert_eq!(session.role(), Role::Client);
assert_eq!(session.state(), WtSessionState::Initial);
}
#[test]
fn test_server_session_creation() {
let session = WtSession::server(WtConfig::default());
assert_eq!(session.role(), Role::Server);
assert_eq!(session.state(), WtSessionState::Initial);
}
#[test]
fn test_session_initiate() {
let mut session = WtSession::client(WtConfig::default());
session.initiate().unwrap();
assert_eq!(session.state(), WtSessionState::Active);
}
#[test]
fn test_open_bidi_stream() {
let mut session = WtSession::client(WtConfig::default());
session.initiate().unwrap();
let stream_id = session.open_bidi_stream().unwrap();
assert!(stream::stream_id::is_client_initiated(stream_id));
assert!(stream::stream_id::is_bidirectional(stream_id));
}
#[test]
fn test_open_uni_stream() {
let mut session = WtSession::client(WtConfig::default());
session.initiate().unwrap();
let stream_id = session.open_uni_stream().unwrap();
assert!(stream::stream_id::is_client_initiated(stream_id));
assert!(stream::stream_id::is_unidirectional(stream_id));
}
#[test]
fn test_send_datagram() {
let mut session = WtSession::client(WtConfig::default());
session.initiate().unwrap();
session.send_datagram(b"hello").unwrap();
assert!(session.has_output());
}
#[test]
fn test_close_session() {
let mut session = WtSession::client(WtConfig::default());
session.initiate().unwrap();
session.close(0, "normal close").unwrap();
assert_eq!(session.state(), WtSessionState::Closed);
}
#[test]
fn test_drain_session() {
let mut session = WtSession::client(WtConfig::default());
session.initiate().unwrap();
session.drain().unwrap();
assert_eq!(session.state(), WtSessionState::Draining);
}
}