use crate::callback::{self, Callback};
use bytes::BytesMut;
use core::{marker::PhantomData, task::Poll};
use s2n_quic_core::{
application::ServerName,
crypto::{tls, tls::CipherSuite, CryptoSuite},
endpoint, ensure, transport,
};
use s2n_quic_crypto::Suite;
use s2n_tls::{
config::Config,
connection::Connection,
enums::{Blinding, Mode},
error::{Error, ErrorType},
};
#[derive(Debug)]
pub struct Session {
endpoint: endpoint::Type,
pub(crate) connection: Connection,
state: callback::State,
handshake_complete: bool,
send_buffer: BytesMut,
emitted_server_name: bool,
server_name: Option<ServerName>,
received_ticket: bool,
server_params: Vec<u8>,
}
impl Session {
pub fn new(
endpoint: endpoint::Type,
config: Config,
params: &[u8],
server_name: Option<ServerName>,
) -> Result<Self, Error> {
let mut connection = Connection::new(match endpoint {
endpoint::Type::Server => Mode::Server,
endpoint::Type::Client => Mode::Client,
});
connection.set_config(config)?;
connection.enable_quic()?;
let server_params = {
if endpoint.is_client() {
connection.set_quic_transport_parameters(params)?;
Vec::new()
} else {
params.to_vec()
}
};
connection.set_blinding(Blinding::SelfService)?;
if let Some(server_name) = server_name.as_ref() {
connection
.set_server_name(server_name)
.expect("invalid server name value");
}
Ok(Self {
endpoint,
connection,
state: Default::default(),
handshake_complete: false,
send_buffer: BytesMut::new(),
emitted_server_name: false,
server_name,
received_ticket: false,
server_params,
})
}
}
impl CryptoSuite for Session {
type HandshakeKey = <Suite as CryptoSuite>::HandshakeKey;
type HandshakeHeaderKey = <Suite as CryptoSuite>::HandshakeHeaderKey;
type InitialKey = <Suite as CryptoSuite>::InitialKey;
type InitialHeaderKey = <Suite as CryptoSuite>::InitialHeaderKey;
type OneRttKey = <Suite as CryptoSuite>::OneRttKey;
type OneRttHeaderKey = <Suite as CryptoSuite>::OneRttHeaderKey;
type ZeroRttKey = <Suite as CryptoSuite>::ZeroRttKey;
type ZeroRttHeaderKey = <Suite as CryptoSuite>::ZeroRttHeaderKey;
type RetryKey = <Suite as CryptoSuite>::RetryKey;
}
impl tls::TlsSession for Session {
fn tls_exporter(
&self,
label: &[u8],
context: &[u8],
output: &mut [u8],
) -> Result<(), tls::TlsExportError> {
self.connection
.tls_exporter(label, context, output)
.map_err(|_| tls::TlsExportError::failure())
}
fn cipher_suite(&self) -> CipherSuite {
self.state.cipher_suite()
}
}
impl tls::Session for Session {
fn poll<W>(&mut self, context: &mut W) -> Poll<Result<(), transport::Error>>
where
W: tls::Context<Self>,
{
let mut callback: Callback<W, Self> = Callback {
context,
endpoint: self.endpoint,
state: &mut self.state,
suite: PhantomData,
err: None,
send_buffer: &mut self.send_buffer,
emitted_server_name: &mut self.emitted_server_name,
server_name: &self.server_name,
server_params: &mut self.server_params,
};
unsafe {
callback.set(&mut self.connection);
}
let result = self.connection.poll_negotiate().map_ok(|_| ());
callback.unset(&mut self.connection)?;
match result {
Poll::Ready(Ok(())) => {
if !self.handshake_complete {
self.state.on_handshake_complete();
context.on_handshake_complete()?;
context.on_tls_exporter_ready(self)?;
self.handshake_complete = true;
}
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e
.alert()
.map(tls::Error::new)
.unwrap_or(tls::Error::HANDSHAKE_FAILURE)
.into())),
Poll::Pending => Poll::Pending,
}
}
fn process_post_handshake_message<W>(&mut self, context: &mut W) -> Result<(), transport::Error>
where
W: tls::Context<Self>,
{
let mut callback: Callback<W, Self> = Callback {
context,
endpoint: self.endpoint,
state: &mut self.state,
suite: PhantomData,
err: None,
send_buffer: &mut self.send_buffer,
emitted_server_name: &mut self.emitted_server_name,
server_name: &self.server_name,
server_params: &mut self.server_params,
};
unsafe {
callback.set(&mut self.connection);
}
let result = self
.connection
.quic_process_post_handshake_message()
.map(|_| ());
callback.unset(&mut self.connection)?;
match result {
Ok(_) => {
self.received_ticket = true;
Ok(())
}
Err(e) => {
if matches!(e.kind(), ErrorType::Blocked) {
Ok(())
} else {
Err(e
.alert()
.map(tls::Error::new)
.unwrap_or(tls::Error::HANDSHAKE_FAILURE)
.into())
}
}
}
}
fn should_discard_session(&self) -> bool {
ensure!(self.endpoint.is_client(), true);
ensure!(self.connection.are_session_tickets_enabled(), true);
self.received_ticket
}
}