use crate::{
application,
crypto::{tls, CryptoSuite},
transport,
};
use alloc::{boxed::Box, vec::Vec};
use core::{any::Any, task::Poll};
const DEFER_COUNT: u8 = 3;
pub struct SlowEndpoint<E: tls::Endpoint> {
endpoint: E,
}
impl<E: tls::Endpoint> SlowEndpoint<E> {
pub fn new(endpoint: E) -> Self {
SlowEndpoint { endpoint }
}
}
impl<E: tls::Endpoint> tls::Endpoint for SlowEndpoint<E> {
type Session = SlowSession<<E as tls::Endpoint>::Session>;
fn new_server_session<Params: s2n_codec::EncoderValue>(
&mut self,
transport_parameters: &Params,
connection_info: tls::ConnectionInfo,
) -> Self::Session {
let inner_session = self
.endpoint
.new_server_session(transport_parameters, connection_info);
SlowSession {
defer: DEFER_COUNT,
inner_session,
}
}
fn new_client_session<Params: s2n_codec::EncoderValue>(
&mut self,
transport_parameters: &Params,
server_name: application::ServerName,
) -> Self::Session {
let inner_session = self
.endpoint
.new_client_session(transport_parameters, server_name);
SlowSession {
defer: DEFER_COUNT,
inner_session,
}
}
fn max_tag_length(&self) -> usize {
self.endpoint.max_tag_length()
}
}
#[derive(Debug)]
pub struct SlowSession<S: tls::Session> {
defer: u8,
inner_session: S,
}
impl<S: tls::Session> tls::Session for SlowSession<S> {
#[inline]
fn poll<W>(&mut self, context: &mut W) -> Poll<Result<(), transport::Error>>
where
W: tls::Context<Self>,
{
if let Some(d) = self.defer.checked_sub(1) {
self.defer = d;
context.waker().wake_by_ref();
return Poll::Pending;
}
self.defer = DEFER_COUNT;
self.inner_session.poll(&mut SlowContext(context))
}
}
impl<S: tls::Session> CryptoSuite for SlowSession<S> {
type HandshakeKey = <S as CryptoSuite>::HandshakeKey;
type HandshakeHeaderKey = <S as CryptoSuite>::HandshakeHeaderKey;
type InitialKey = <S as CryptoSuite>::InitialKey;
type InitialHeaderKey = <S as CryptoSuite>::InitialHeaderKey;
type ZeroRttKey = <S as CryptoSuite>::ZeroRttKey;
type ZeroRttHeaderKey = <S as CryptoSuite>::ZeroRttHeaderKey;
type OneRttKey = <S as CryptoSuite>::OneRttKey;
type OneRttHeaderKey = <S as CryptoSuite>::OneRttHeaderKey;
type RetryKey = <S as CryptoSuite>::RetryKey;
}
struct SlowContext<'a, Inner>(&'a mut Inner);
impl<I, S: tls::Session> tls::Context<S> for SlowContext<'_, I>
where
I: tls::Context<SlowSession<S>>,
{
fn on_client_application_params(
&mut self,
client_params: tls::ApplicationParameters,
server_params: &mut Vec<u8>,
) -> Result<(), transport::Error> {
self.0
.on_client_application_params(client_params, server_params)
}
fn on_handshake_keys(
&mut self,
key: <S as CryptoSuite>::HandshakeKey,
header_key: <S as CryptoSuite>::HandshakeHeaderKey,
) -> Result<(), transport::Error> {
self.0.on_handshake_keys(key, header_key)
}
fn on_zero_rtt_keys(
&mut self,
key: <S>::ZeroRttKey,
header_key: <S>::ZeroRttHeaderKey,
application_parameters: tls::ApplicationParameters,
) -> Result<(), transport::Error> {
self.0
.on_zero_rtt_keys(key, header_key, application_parameters)
}
fn on_one_rtt_keys(
&mut self,
key: <S>::OneRttKey,
header_key: <S>::OneRttHeaderKey,
application_parameters: tls::ApplicationParameters,
) -> Result<(), transport::Error> {
self.0
.on_one_rtt_keys(key, header_key, application_parameters)
}
fn on_server_name(
&mut self,
server_name: application::ServerName,
) -> Result<(), transport::Error> {
self.0.on_server_name(server_name)
}
fn on_application_protocol(
&mut self,
application_protocol: tls::Bytes,
) -> Result<(), transport::Error> {
self.0.on_application_protocol(application_protocol)
}
fn on_handshake_complete(&mut self) -> Result<(), transport::Error> {
self.0.on_handshake_complete()
}
fn on_tls_exporter_ready(
&mut self,
session: &impl tls::TlsSession,
) -> Result<(), transport::Error> {
self.0.on_tls_exporter_ready(session)
}
fn on_tls_handshake_failed(
&mut self,
session: &impl tls::TlsSession,
e: &(dyn core::error::Error + Send + Sync + 'static),
) -> Result<(), transport::Error> {
self.0.on_tls_handshake_failed(session, e)
}
fn receive_initial(&mut self, max_len: Option<usize>) -> Option<tls::Bytes> {
self.0.receive_initial(max_len)
}
fn receive_handshake(&mut self, max_len: Option<usize>) -> Option<tls::Bytes> {
self.0.receive_handshake(max_len)
}
fn receive_application(&mut self, max_len: Option<usize>) -> Option<tls::Bytes> {
self.0.receive_application(max_len)
}
fn can_send_initial(&self) -> bool {
self.0.can_send_initial()
}
fn send_initial(&mut self, transmission: tls::Bytes) {
self.0.send_initial(transmission);
}
fn can_send_handshake(&self) -> bool {
self.0.can_send_handshake()
}
fn send_handshake(&mut self, transmission: tls::Bytes) {
self.0.send_handshake(transmission);
}
fn can_send_application(&self) -> bool {
self.0.can_send_application()
}
fn send_application(&mut self, transmission: tls::Bytes) {
self.0.send_application(transmission)
}
fn waker(&self) -> &core::task::Waker {
self.0.waker()
}
fn on_key_exchange_group(
&mut self,
named_group: tls::NamedGroup,
) -> Result<(), transport::Error> {
self.0.on_key_exchange_group(named_group)
}
fn on_tls_context(&mut self, context: Box<dyn Any + Send>) {
self.0.on_tls_context(context)
}
}