mod session_key;
mod transport;
use std::io::Cursor;
use session_key::{derive_session_key, SessionKey, SESSION_ID_LENGTH};
use tls_codec::{
Deserialize, Serialize, SerializeBytes, Size, TlsDeserialize, TlsSerialize, TlsSize,
};
pub use transport::Transport;
use crate::{
aead::{AEADError, AEADKeyNonce, AeadType},
handshake::{
ciphersuite::types::PQEncapsulationKey, dhkem::DHPublicKey, transcript::Transcript,
types::Authenticator,
},
session::session_key::derive_import_key,
};
#[derive(Debug, PartialEq)]
pub enum SessionError {
IntoSession,
Serialize(tls_codec::Error),
Deserialize(tls_codec::Error),
PayloadTooLong(usize),
CryptoError,
Storage,
ReachedMaxChannels,
IdentifierMismatch,
OutputBufferShort,
Import,
}
impl From<AEADError> for SessionError {
fn from(value: AEADError) -> Self {
match value {
AEADError::CryptoError | AEADError::KeyExpired => SessionError::CryptoError,
AEADError::Serialize(error) => SessionError::Serialize(error),
AEADError::Deserialize(error) => SessionError::Deserialize(error),
}
}
}
pub(crate) const PK_BINDER_LEN: usize = 8;
#[derive(TlsSerialize, TlsDeserialize, TlsSize)]
#[repr(u8)]
pub(crate) enum SessionPrincipal {
Initiator,
Responder,
}
#[derive(TlsSerialize, TlsDeserialize, TlsSize)]
pub struct Session {
pub(crate) principal: SessionPrincipal,
pub(crate) session_key: SessionKey,
pub(crate) pk_binder: Option<[u8; PK_BINDER_LEN]>,
pub(crate) channel_counter: u64,
pub(crate) aead_type: AeadType,
pub(crate) transcript: Transcript,
}
fn derive_pk_binder(
key: &SessionKey,
initiator_authenticator: &Authenticator,
responder_ecdh_pk: &DHPublicKey,
responder_pq_pk: &Option<PQEncapsulationKey>,
) -> Result<[u8; PK_BINDER_LEN], SessionError> {
#[derive(TlsSerialize, TlsSize)]
struct PkBinderInfo<'a> {
initiator_authenticator: &'a Authenticator,
responder_ecdh_pk: &'a DHPublicKey,
responder_pq_pk: &'a Option<PQEncapsulationKey<'a>>,
}
let info = PkBinderInfo {
initiator_authenticator,
responder_ecdh_pk,
responder_pq_pk,
};
let mut info_buf = vec![0u8; info.tls_serialized_len()];
info.tls_serialize(&mut &mut info_buf[..])
.map_err(SessionError::Serialize)?;
let mut binder = [0u8; PK_BINDER_LEN];
libcrux_hkdf::sha2_256::hkdf(
&mut binder,
&[],
&SerializeBytes::tls_serialize(&key.key).map_err(SessionError::Serialize)?,
&info_buf,
)
.map_err(|_| SessionError::CryptoError)?;
Ok(binder)
}
pub struct SessionBinding<'a> {
pub initiator_authenticator: &'a Authenticator,
pub responder_ecdh_pk: &'a DHPublicKey,
pub responder_pq_pk: Option<PQEncapsulationKey<'a>>,
}
impl Session {
pub(crate) fn new<'a>(
tx2: Transcript,
k2: AEADKeyNonce,
session_binding: Option<SessionBinding<'a>>,
is_initiator: bool,
aead_type: AeadType,
) -> Result<Self, SessionError> {
let session_key = derive_session_key(k2, &tx2, aead_type)?;
let pk_binder = session_binding
.map(|session_binding| {
derive_pk_binder(
&session_key,
session_binding.initiator_authenticator,
session_binding.responder_ecdh_pk,
&session_binding.responder_pq_pk,
)
})
.transpose()?;
Ok(Self {
principal: if is_initiator {
SessionPrincipal::Initiator
} else {
SessionPrincipal::Responder
},
session_key,
pk_binder,
channel_counter: 0,
aead_type,
transcript: tx2,
})
}
pub fn import<'a>(
self,
psk: &[u8],
session_binding: impl Into<Option<SessionBinding<'a>>>,
) -> Result<Self, SessionError> {
if psk.len() < 32 {
return Err(SessionError::Import);
}
let session_binding = session_binding.into();
match (self.pk_binder, &session_binding) {
(None, None) => (),
(None, Some(_)) => return Err(SessionError::Import),
(Some(_), None) => (),
(
Some(pk_binder),
Some(SessionBinding {
initiator_authenticator,
responder_ecdh_pk,
responder_pq_pk,
}),
) => {
if derive_pk_binder(
&self.session_key,
initiator_authenticator,
responder_ecdh_pk,
responder_pq_pk,
)? != pk_binder
{
return Err(SessionError::Import);
}
}
};
let transcript =
Transcript::add_hash::<3>(Some(&self.transcript), self.session_key.identifier)
.map_err(|_| SessionError::Import)?;
let import_key = derive_import_key(self.session_key.key, psk, self.aead_type)?;
Self::new(
transcript,
import_key,
session_binding,
matches!(self.principal, SessionPrincipal::Initiator),
self.aead_type,
)
}
pub fn serialize<'a>(
self,
out: &mut [u8],
session_binding: impl Into<Option<SessionBinding<'a>>>,
) -> Result<usize, SessionError> {
let session_binding = session_binding.into();
match (self.pk_binder, session_binding) {
(None, None) => self
.tls_serialize(&mut &mut out[..])
.map_err(SessionError::Serialize),
(None, Some(_)) | (Some(_), None) => Err(SessionError::Storage),
(
Some(pk_binder),
Some(SessionBinding {
initiator_authenticator,
responder_ecdh_pk,
responder_pq_pk,
}),
) => {
if derive_pk_binder(
&self.session_key,
initiator_authenticator,
responder_ecdh_pk,
&responder_pq_pk,
)? != pk_binder
{
Err(SessionError::Storage)
} else {
self.tls_serialize(&mut &mut out[..])
.map_err(SessionError::Serialize)
}
}
}
}
pub fn export_secret(&self, context: &[u8], out: &mut [u8]) -> Result<(), SessionError> {
use tls_codec::TlsSerializeBytes;
const PSQ_EXPORT_CONTEXT: &[u8; 17] = b"PSQ secret export";
#[derive(TlsSerializeBytes, TlsSize)]
struct ExportInfo<'a> {
context: &'a [u8],
separator: [u8; 17],
}
libcrux_hkdf::sha2_256::hkdf(
out,
b"",
self.session_key.key.as_ref(),
&ExportInfo {
context,
separator: *PSQ_EXPORT_CONTEXT,
}
.tls_serialize()
.map_err(SessionError::Serialize)?,
)
.map_err(|_| SessionError::CryptoError)
}
pub fn deserialize<'a>(
bytes: &[u8],
session_binding: impl Into<Option<SessionBinding<'a>>>,
) -> Result<Self, SessionError> {
let session_binding = session_binding.into();
let mut session =
Session::tls_deserialize(&mut Cursor::new(bytes)).map_err(SessionError::Deserialize)?;
session.session_key.key.expire();
let session = session;
match (session.pk_binder, session_binding) {
(None, None) => Ok(session),
(None, Some(_)) => Err(SessionError::Storage),
(Some(_), None) => Err(SessionError::Storage),
(Some(pk_binder), Some(provided_binding)) => {
if derive_pk_binder(
&session.session_key,
provided_binding.initiator_authenticator,
provided_binding.responder_ecdh_pk,
&provided_binding.responder_pq_pk,
)? == pk_binder
{
Ok(session)
} else {
Err(SessionError::Storage)
}
}
}
}
pub fn transport_channel(&mut self) -> Result<Transport, SessionError> {
let channel = Transport::new(self, matches!(self.principal, SessionPrincipal::Initiator))?;
self.channel_counter = self
.channel_counter
.checked_add(1)
.ok_or(SessionError::ReachedMaxChannels)?;
Ok(channel)
}
pub fn identifier(&self) -> &[u8; SESSION_ID_LENGTH] {
&self.session_key.identifier
}
}