use std::{
fmt::Debug,
ops::{Deref, DerefMut},
sync::{Arc, Mutex, MutexGuard},
task::{Context, Poll, Waker},
time::Duration,
};
use crate::{
cid::ConnectionId,
error::{Error, ErrorKind, QuicError},
frame::FrameType,
role::Role,
time::MaxIdleTimer,
};
pub mod core;
pub mod error;
pub mod handy;
pub mod io;
pub mod preferred_address;
pub use self::{
core::{
ClientParameters, ParameterId, ParameterValue, ParameterValueType, PeerParameters,
ServerParameters,
},
io::*,
};
#[derive(Debug, Clone, Copy)]
enum Requirements {
Client {
initial_scid: Option<ConnectionId>,
retry_scid: Option<ConnectionId>,
origin_dcid: ConnectionId,
},
Server {
initial_scid: Option<ConnectionId>,
},
}
#[derive(Debug)]
pub struct Parameters {
state: u8,
client: Arc<ClientParameters>,
server: Arc<ServerParameters>,
remembered: Option<Arc<ServerParameters>>,
requirements: Requirements,
wakers: Vec<Waker>,
}
impl Drop for Parameters {
fn drop(&mut self) {
self.wake_all();
}
}
impl Parameters {
const CLIENT_READY: u8 = 1;
const SERVER_READY: u8 = 2;
pub fn new_client(
client: ClientParameters,
remembered: Option<ServerParameters>,
origin_dcid: ConnectionId,
) -> Self {
Self {
state: Self::CLIENT_READY,
client: Arc::new(client),
server: Arc::default(),
remembered: remembered.map(Arc::new),
requirements: Requirements::Client {
origin_dcid,
initial_scid: None,
retry_scid: None,
},
wakers: Vec::with_capacity(2),
}
}
pub fn new_server(server: ServerParameters) -> Self {
Self {
state: Self::SERVER_READY,
client: Arc::default(),
server: Arc::new(server),
remembered: None,
requirements: Requirements::Server { initial_scid: None },
wakers: Vec::with_capacity(2),
}
}
pub fn role(&self) -> Role {
match self.requirements {
Requirements::Client { .. } => Role::Client,
Requirements::Server { .. } => Role::Server,
}
}
pub fn client(&self) -> Option<&Arc<ClientParameters>> {
if self.state & Self::CLIENT_READY != 0 {
Some(&self.client)
} else {
None
}
}
pub fn server(&self) -> Option<&Arc<ServerParameters>> {
if self.state & Self::SERVER_READY != 0 {
Some(&self.server)
} else {
None
}
}
pub fn remembered(&self) -> Option<&Arc<ServerParameters>> {
self.remembered.as_ref()
}
pub fn get_local<V: TryFrom<ParameterValue>>(&self, id: ParameterId) -> Option<V> {
match self.role() {
Role::Client => self.client()?.get(id),
Role::Server => self.server()?.get(id),
}
}
pub fn get_remote<V: TryFrom<ParameterValue>>(&self, id: ParameterId) -> Option<V> {
match self.role() {
Role::Client => self.server()?.get(id),
Role::Server => self.client()?.get(id),
}
}
pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
if self.state == Self::CLIENT_READY | Self::SERVER_READY {
Poll::Ready(())
} else {
self.wakers.push(cx.waker().clone());
Poll::Pending
}
}
pub fn is_remote_params_received(&self) -> bool {
match self.role() {
Role::Client => !self.server.is_empty(),
Role::Server => !self.client.is_empty(),
}
}
pub fn is_remote_params_ready(&self) -> bool {
self.state == Self::CLIENT_READY | Self::SERVER_READY
}
pub fn recv_remote_params(
&mut self,
params: impl Into<PeerParameters>,
) -> Result<(), QuicError> {
match params.into() {
PeerParameters::Client(p) => {
assert_eq!(self.role(), Role::Server);
assert!(self.client.is_empty());
self.client = Arc::new(p);
}
PeerParameters::Server(p) => {
assert_eq!(self.role(), Role::Client);
assert!(self.server.is_empty());
self.server = Arc::new(p);
}
}
if self.authenticate_cids()? {
self.state = Self::CLIENT_READY | Self::SERVER_READY;
self.remembered.take();
self.wake_all();
return Ok(());
}
Ok(())
}
fn wake_all(&mut self) {
for waker in self.wakers.drain(..) {
waker.wake();
}
}
pub fn initial_scid_from_peer_need_equal(
&mut self,
cid: ConnectionId,
) -> Result<(), QuicError> {
let initial_scid = match &mut self.requirements {
Requirements::Client { initial_scid, .. } => initial_scid,
Requirements::Server { initial_scid } => initial_scid,
};
assert!(initial_scid.replace(cid).is_none());
if self.is_remote_params_received() && self.authenticate_cids()? {
self.state = Self::CLIENT_READY | Self::SERVER_READY;
self.remembered.take();
self.wake_all();
return Ok(());
}
Ok(())
}
pub fn retry_scid_from_server_need_equal(&mut self, cid: ConnectionId) {
match &mut self.requirements {
Requirements::Client { retry_scid, .. } => *retry_scid = Some(cid),
Requirements::Server { .. } => panic!("server shuold never call this"),
}
}
pub fn initial_scid_from_peer(&self) -> Option<ConnectionId> {
match self.requirements {
Requirements::Client { initial_scid, .. } => initial_scid,
Requirements::Server { initial_scid, .. } => initial_scid,
}
}
fn authenticate_cids(&self) -> Result<bool, QuicError> {
fn param_error(reason: &'static str) -> QuicError {
QuicError::new(
ErrorKind::TransportParameter,
FrameType::Crypto.into(),
reason,
)
}
match self.requirements {
Requirements::Client {
initial_scid,
retry_scid: _,
origin_dcid,
} => {
let Some(initial_scid) = initial_scid else {
return Ok(false);
};
if self
.server
.get::<ConnectionId>(ParameterId::InitialSourceConnectionId)
.expect("this value must be set")
!= initial_scid
{
return Err(param_error(
"Initial Source Connection ID from server mismatch",
));
}
if self
.server
.get::<ConnectionId>(ParameterId::OriginalDestinationConnectionId)
.expect("this value must be set")
!= origin_dcid
{
return Err(param_error("Original Destination Connection ID mismatch"));
}
Ok(true)
}
Requirements::Server { initial_scid } => {
let Some(initial_scid) = initial_scid else {
return Ok(false);
};
if self
.client
.get::<ConnectionId>(ParameterId::InitialSourceConnectionId)
.expect("this value must be set")
!= initial_scid
{
return Err(param_error(
"Initial Source Connection ID from client mismatch",
));
}
Ok(true)
}
}
}
pub fn negotiated_max_idle_timeout(&self) -> Option<Duration> {
let local_max_idle_timeout = self.get_local(ParameterId::MaxIdleTimeout)?;
let remote_max_idle_timeout = self.get_remote(ParameterId::MaxIdleTimeout)?;
Some(match (local_max_idle_timeout, remote_max_idle_timeout) {
(Duration::ZERO, Duration::ZERO) => Duration::MAX,
(Duration::ZERO, d) | (d, Duration::ZERO) => d,
(d1, d2) => d1.min(d2),
})
}
}
#[derive(Debug, Clone)]
pub struct ArcParameters(Arc<Mutex<Result<Parameters, Error>>>);
pub struct ParametersGuard<'a>(MutexGuard<'a, Result<Parameters, Error>>);
impl Deref for ParametersGuard<'_> {
type Target = Parameters;
fn deref(&self) -> &Self::Target {
self.0.as_ref().expect("parameters must be valid")
}
}
impl DerefMut for ParametersGuard<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut().expect("parameters must be valid")
}
}
impl From<Parameters> for ArcParameters {
fn from(params: Parameters) -> Self {
Self(Arc::new(Mutex::new(Ok(params))))
}
}
impl ArcParameters {
#[inline]
pub fn lock_guard(&self) -> Result<ParametersGuard<'_>, Error> {
let guard = self.0.lock().unwrap();
match guard.as_ref() {
Ok(_) => Ok(ParametersGuard(guard)),
Err(e) => Err(e.clone()),
}
}
#[inline]
pub async fn remote_ready(&self) -> Result<ParametersGuard<'_>, Error> {
std::future::poll_fn(|cx| {
let mut parameters = self.lock_guard()?;
parameters.poll_ready(cx).map(|()| Ok(parameters))
})
.await
}
#[inline]
pub fn max_idle_timer(&self) -> MaxIdleTimer {
MaxIdleTimer::new(self)
}
pub fn on_conn_error(&self, error: &Error) {
let mut guard = self.0.lock().unwrap();
if guard.deref_mut().is_ok() {
*guard = Err(error.clone());
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::varint::VarInt;
fn create_test_client_params() -> ClientParameters {
let mut params = ClientParameters::default();
params
.set(
ParameterId::InitialSourceConnectionId,
ConnectionId::from_slice(b"client_test"),
)
.unwrap();
params
}
fn create_test_server_params() -> ServerParameters {
let mut params = ServerParameters::default();
params
.set(
ParameterId::InitialSourceConnectionId,
ConnectionId::from_slice(b"server_test"),
)
.unwrap();
params
.set(
ParameterId::OriginalDestinationConnectionId,
ConnectionId::from_slice(b"original"),
)
.unwrap();
params
}
#[test]
fn test_parameters_new() {
let client_params = create_test_client_params();
let params =
Parameters::new_client(client_params, None, ConnectionId::from_slice(b"odcid"));
assert_eq!(params.role(), Role::Client);
assert_eq!(params.state, Parameters::CLIENT_READY);
let server_params = create_test_server_params();
let params = Parameters::new_server(server_params);
assert_eq!(params.role(), Role::Server);
assert_eq!(params.state, Parameters::SERVER_READY);
}
#[test]
fn test_authenticate_cids() {
let client_params = create_test_client_params();
let odcid = ConnectionId::from_slice(b"odcid");
let mut params = Parameters::new_client(client_params, None, odcid);
let server_cid = ConnectionId::from_slice(b"server_test");
params
.initial_scid_from_peer_need_equal(server_cid)
.unwrap();
params.server = Arc::new({
let mut server_params = ServerParameters::default();
server_params
.set(ParameterId::InitialSourceConnectionId, server_cid)
.unwrap();
server_params
.set(ParameterId::OriginalDestinationConnectionId, odcid)
.unwrap();
server_params
});
assert!(params.authenticate_cids().is_ok());
}
#[test]
fn test_parameters_as_client() {
let client_params = create_test_client_params();
let arc_params = ArcParameters::from(Parameters::new_client(
client_params,
None,
ConnectionId::from_slice(b"odcid"),
));
let guard = arc_params.lock_guard().unwrap();
assert!(matches!(
guard.get_local::<VarInt>(ParameterId::MaxUdpPayloadSize),
Some(value) if value.into_inner() >= 1200
));
assert!(guard.remembered().is_none());
}
#[test]
fn test_validate_remote_params() {
assert_eq!(
ClientParameters::parse_from_bytes(&[
1, 1, 0, 3, 2, 0x43, 0xE8, 4, 1, 0, 5, 1, 0, 6, 1, 0, 7, 1, 0, 8, 1, 0, 9, 1, 0, 10, 1, 3, 11, 1, 25, 14, 1, 2, 15, 0, 32, 4, 128, 0, 255, 255, ]),
Err(QuicError::new(
ErrorKind::TransportParameter,
FrameType::Crypto.into(),
"MaxUdpPayloadSize's value 1000 is out of bounds 1200..=65527",
))
);
}
#[test]
fn test_write_parameters() {
let client_params = create_test_client_params();
let params = ArcParameters::from(Parameters::new_client(
client_params,
None,
ConnectionId::from_slice(b"odcid"),
));
let guard = params.lock_guard().unwrap();
assert_eq!(guard.role(), Role::Client);
}
#[tokio::test]
async fn test_arc_parameters_error_handling() {
let arc_params = ArcParameters::from(Parameters::new_client(
create_test_client_params(),
None,
ConnectionId::from_slice(b"odcid"),
));
let error = QuicError::new(
ErrorKind::TransportParameter,
FrameType::Crypto.into(),
"test error",
)
.into();
arc_params.on_conn_error(&error);
assert!(arc_params.lock_guard().is_err());
}
}