use std::io::Cursor;
use tls_codec::{
Deserialize, Serialize, SerializeBytes, Size, TlsDeserialize, TlsSerialize, TlsSerializeBytes,
TlsSize, VLByteSlice,
};
use super::{Session, SessionError as Error};
#[cfg(feature = "nonce-control")]
use crate::aead::NONCE_LEN;
use crate::{aead::AEADKeyNonce, traits::Channel};
#[derive(TlsSerialize, TlsSize)]
struct TransportMessageOut<'a> {
channel_identifier: u64,
ciphertext: VLByteSlice<'a>,
tag: [u8; 16],
}
#[derive(TlsDeserialize, TlsSize)]
pub struct TransportMessage {
channel_identifier: u64,
ciphertext: Vec<u8>,
tag: [u8; 16],
}
pub struct Transport {
send_key: AEADKeyNonce,
recv_key: AEADKeyNonce,
channel_identifier: u64,
}
impl Transport {
pub(crate) fn new(session: &Session, is_initiator: bool) -> Result<Self, Error> {
if is_initiator {
Ok(Self {
send_key: derive_i2r_channel_key(session)?,
recv_key: derive_r2i_channel_key(session)?,
channel_identifier: session.channel_counter,
})
} else {
Ok(Self {
send_key: derive_r2i_channel_key(session)?,
recv_key: derive_i2r_channel_key(session)?,
channel_identifier: session.channel_counter,
})
}
}
pub fn identifier(&self) -> u64 {
self.channel_identifier
}
#[cfg(feature = "nonce-control")]
pub fn set_sender_nonce(&mut self, nonce: &[u8; NONCE_LEN]) {
self.send_key.set_nonce(nonce);
}
#[cfg(feature = "nonce-control")]
pub fn set_receiver_nonce(&mut self, nonce: &[u8; NONCE_LEN]) {
self.recv_key.set_nonce(nonce);
}
#[cfg(feature = "nonce-control")]
pub fn sender_nonce(&self) -> &[u8; NONCE_LEN] {
self.send_key.nonce()
}
#[cfg(feature = "nonce-control")]
pub fn receiver_nonce(&self) -> &[u8; NONCE_LEN] {
self.recv_key.nonce()
}
fn prepare_message_contents(&mut self, payload: &[u8]) -> Result<(Vec<u8>, [u8; 16]), Error> {
if payload.len() > 65535 {
return Err(Error::PayloadTooLong(payload.len()));
}
let mut ciphertext = vec![0u8; payload.len()];
let tag = self.send_key.encrypt(payload, &[], &mut ciphertext)?;
Ok((ciphertext, tag))
}
fn process_message(&mut self, message: &TransportMessage, out: &mut [u8]) -> Result<(), Error> {
if self.channel_identifier != message.channel_identifier {
return Err(Error::IdentifierMismatch);
}
if out.len() < message.ciphertext.as_slice().len() {
return Err(Error::OutputBufferShort);
}
self.recv_key.decrypt_out(
message.ciphertext.as_slice(),
&message.tag,
&[],
&mut out[..message.ciphertext.as_slice().len()],
)?;
Ok(())
}
}
impl Channel<Error, TransportMessage> for Transport {
fn write_message(&mut self, payload: &[u8], out: &mut [u8]) -> Result<usize, Error> {
let (ciphertext, tag) = self.prepare_message_contents(payload)?;
let message = TransportMessageOut {
channel_identifier: self.channel_identifier,
ciphertext: VLByteSlice(ciphertext.as_ref()),
tag,
};
message
.tls_serialize(&mut &mut out[..])
.map_err(Error::Serialize)
}
fn read_message(&mut self, message: &[u8], out: &mut [u8]) -> Result<(usize, usize), Error> {
let message = TransportMessage::tls_deserialize(&mut Cursor::new(message))
.map_err(Error::Deserialize)?;
let bytes_deserialized = message.tls_serialized_len();
self.process_message(&message, out)?;
let out_bytes_written = message.ciphertext.as_slice().len();
Ok((bytes_deserialized, out_bytes_written))
}
fn write_message_external_encoding(
&mut self,
payload: &[u8],
) -> Result<TransportMessage, Error> {
let (ciphertext, tag) = self.prepare_message_contents(payload)?;
Ok(TransportMessage {
channel_identifier: self.channel_identifier,
ciphertext,
tag,
})
}
fn read_message_external_encoding(
&mut self,
message: &TransportMessage,
) -> Result<Vec<u8>, Error> {
let mut out = vec![0; message.ciphertext.len()];
self.process_message(&message, &mut out)?;
Ok(out)
}
}
const I2R_CHANNEL_KEY_LABEL: &[u8] = b"i2r channel key";
const R2I_CHANNEL_KEY_LABEL: &[u8] = b"r2i channel key";
fn derive_channel_key<const IS_INITIATOR: bool>(session: &Session) -> Result<AEADKeyNonce, Error> {
#[derive(TlsSerializeBytes, TlsSize)]
struct ChannelKeyInfo<'a> {
domain_separator: &'static [u8],
pk_binder: Option<&'a [u8]>,
counter: u64,
}
AEADKeyNonce::new(
&session.session_key.key,
&ChannelKeyInfo {
domain_separator: if IS_INITIATOR {
I2R_CHANNEL_KEY_LABEL
} else {
R2I_CHANNEL_KEY_LABEL
},
pk_binder: session
.pk_binder
.as_ref()
.map(|pk_binder| pk_binder.as_slice()),
counter: session.channel_counter,
}
.tls_serialize()
.map_err(Error::Serialize)?,
session.aead_type,
)
.map_err(|e| e.into())
}
pub(super) fn derive_i2r_channel_key(session: &Session) -> Result<AEADKeyNonce, Error> {
derive_channel_key::<true>(session)
}
pub(super) fn derive_r2i_channel_key(session: &Session) -> Result<AEADKeyNonce, Error> {
derive_channel_key::<false>(session)
}