use crate::config::{TlsCipherSuite, TlsConfig, TlsVerifier};
use crate::handshake::{ClientHandshake, ServerHandshake};
use crate::key_schedule::{KeySchedule, ReadKeySchedule, WriteKeySchedule};
use crate::record::{ClientRecord, ServerRecord};
use crate::record_reader::RecordReader;
use crate::write_buffer::WriteBuffer;
use crate::TlsError;
use crate::{
alert::*,
handshake::{certificate::CertificateRef, certificate_request::CertificateRequest},
};
use core::fmt::Debug;
use embedded_io::Error as _;
use embedded_io::{Read as BlockingRead, Write as BlockingWrite};
use embedded_io_async::{Read as AsyncRead, Write as AsyncWrite};
use rand_core::{CryptoRng, RngCore};
use crate::application_data::ApplicationData;
use crate::buffer::CryptoBuffer;
use digest::generic_array::typenum::Unsigned;
use p256::ecdh::EphemeralSecret;
use crate::content_types::ContentType;
use crate::parse_buffer::ParseBuffer;
use aes_gcm::aead::{AeadCore, AeadInPlace, KeyInit};
pub(crate) fn decrypt_record<CipherSuite>(
key_schedule: &mut ReadKeySchedule<CipherSuite>,
record: ServerRecord<'_, CipherSuite>,
mut cb: impl FnMut(
&mut ReadKeySchedule<CipherSuite>,
ServerRecord<'_, CipherSuite>,
) -> Result<(), TlsError>,
) -> Result<(), TlsError>
where
CipherSuite: TlsCipherSuite,
{
if let ServerRecord::ApplicationData(ApplicationData {
header,
data: mut app_data,
}) = record
{
let server_key = key_schedule.get_key()?;
let nonce = key_schedule.get_nonce()?;
let crypto = <CipherSuite::Cipher as KeyInit>::new(&server_key);
crypto
.decrypt_in_place(&nonce, header.data(), &mut app_data)
.map_err(|_| TlsError::CryptoError)?;
let padding = app_data
.as_slice()
.iter()
.enumerate()
.rfind(|(_, b)| **b != 0);
if let Some((index, _)) = padding {
app_data.truncate(index + 1);
};
let content_type =
ContentType::of(*app_data.as_slice().last().unwrap()).ok_or(TlsError::InvalidRecord)?;
trace!("Decrypting: content type = {:?}", content_type);
app_data.truncate(app_data.len() - 1);
let mut buf = ParseBuffer::new(app_data.as_slice());
match content_type {
ContentType::Handshake => {
while buf.remaining() > 0 {
let inner = ServerHandshake::read(&mut buf, key_schedule.transcript_hash())?;
cb(key_schedule, ServerRecord::Handshake(inner))?;
}
}
ContentType::ApplicationData => {
let inner = ApplicationData::new(app_data, header);
cb(key_schedule, ServerRecord::ApplicationData(inner))?;
}
ContentType::Alert => {
let alert = Alert::parse(&mut buf)?;
cb(key_schedule, ServerRecord::Alert(alert))?;
}
_ => return Err(TlsError::Unimplemented),
}
key_schedule.increment_counter();
} else {
trace!("Not decrypting: content_type = {:?}", record.content_type());
cb(key_schedule, record)?;
}
Ok(())
}
pub(crate) fn encrypt<CipherSuite>(
key_schedule: &WriteKeySchedule<CipherSuite>,
buf: &mut CryptoBuffer<'_>,
) -> Result<(), TlsError>
where
CipherSuite: TlsCipherSuite,
{
let client_key = key_schedule.get_key()?;
let nonce = key_schedule.get_nonce()?;
let crypto = <CipherSuite::Cipher as KeyInit>::new(&client_key);
let len = buf.len() + <CipherSuite::Cipher as AeadCore>::TagSize::to_usize();
if len > buf.capacity() {
return Err(TlsError::InsufficientSpace);
}
trace!("output size {}", len);
let len_bytes = (len as u16).to_be_bytes();
let additional_data = [
ContentType::ApplicationData as u8,
0x03,
0x03,
len_bytes[0],
len_bytes[1],
];
crypto
.encrypt_in_place(&nonce, &additional_data, buf)
.map_err(|_| TlsError::InvalidApplicationData)
}
pub struct Handshake<CipherSuite, Verifier>
where
CipherSuite: TlsCipherSuite,
{
traffic_hash: Option<CipherSuite::Hash>,
secret: Option<EphemeralSecret>,
certificate_request: Option<CertificateRequest>,
verifier: Verifier,
}
impl<'v, CipherSuite, Verifier> Handshake<CipherSuite, Verifier>
where
CipherSuite: TlsCipherSuite,
Verifier: TlsVerifier<'v, CipherSuite>,
{
pub fn new(verifier: Verifier) -> Handshake<CipherSuite, Verifier> {
Handshake {
traffic_hash: None,
secret: None,
certificate_request: None,
verifier,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum State {
ClientHello,
ServerHello,
ServerVerify,
ClientCert,
ClientFinished,
ApplicationData,
}
impl<'a> State {
#[allow(clippy::too_many_arguments)]
pub async fn process<'v, Transport, CipherSuite, RNG, Verifier>(
self,
transport: &mut Transport,
handshake: &mut Handshake<CipherSuite, Verifier>,
record_reader: &mut RecordReader<'_, CipherSuite>,
tx_buf: &mut WriteBuffer<'_>,
key_schedule: &mut KeySchedule<CipherSuite>,
config: &TlsConfig<'a, CipherSuite>,
rng: &mut RNG,
) -> Result<State, TlsError>
where
Transport: AsyncRead + AsyncWrite + 'a,
RNG: CryptoRng + RngCore + 'a,
CipherSuite: TlsCipherSuite,
Verifier: TlsVerifier<'v, CipherSuite>,
{
match self {
State::ClientHello => {
let (state, tx) = client_hello(key_schedule, config, rng, tx_buf, handshake)?;
respond(tx, transport, key_schedule).await?;
Ok(state)
}
State::ServerHello => {
let record = record_reader
.read(transport, key_schedule.read_state())
.await?;
let result = process_server_hello(handshake, key_schedule, record);
handle_processing_error(result, transport, key_schedule, tx_buf).await
}
State::ServerVerify => {
let record = record_reader
.read(transport, key_schedule.read_state())
.await?;
let result = process_server_verify(handshake, key_schedule, config, record);
handle_processing_error(result, transport, key_schedule, tx_buf).await
}
State::ClientCert => {
let (state, tx) = client_cert(handshake, key_schedule, config, tx_buf)?;
respond(tx, transport, key_schedule).await?;
Ok(state)
}
State::ClientFinished => {
let tx = client_finished(key_schedule, tx_buf)?;
respond(tx, transport, key_schedule).await?;
client_finished_finalize(key_schedule, handshake)
}
State::ApplicationData => Ok(State::ApplicationData),
}
}
#[allow(clippy::too_many_arguments)]
pub fn process_blocking<'v, Transport, CipherSuite, RNG, Verifier>(
self,
transport: &mut Transport,
handshake: &mut Handshake<CipherSuite, Verifier>,
record_reader: &mut RecordReader<'_, CipherSuite>,
tx_buf: &mut WriteBuffer,
key_schedule: &mut KeySchedule<CipherSuite>,
config: &TlsConfig<'a, CipherSuite>,
rng: &mut RNG,
) -> Result<State, TlsError>
where
Transport: BlockingRead + BlockingWrite + 'a,
RNG: CryptoRng + RngCore,
CipherSuite: TlsCipherSuite + 'static,
Verifier: TlsVerifier<'v, CipherSuite>,
{
match self {
State::ClientHello => {
let (state, tx) = client_hello(key_schedule, config, rng, tx_buf, handshake)?;
respond_blocking(tx, transport, key_schedule)?;
Ok(state)
}
State::ServerHello => {
let record = record_reader.read_blocking(transport, key_schedule.read_state())?;
let result = process_server_hello(handshake, key_schedule, record);
handle_processing_error_blocking(result, transport, key_schedule, tx_buf)
}
State::ServerVerify => {
let record = record_reader.read_blocking(transport, key_schedule.read_state())?;
let result = process_server_verify(handshake, key_schedule, config, record);
handle_processing_error_blocking(result, transport, key_schedule, tx_buf)
}
State::ClientCert => {
let (state, tx) = client_cert(handshake, key_schedule, config, tx_buf)?;
respond_blocking(tx, transport, key_schedule)?;
Ok(state)
}
State::ClientFinished => {
let tx = client_finished(key_schedule, tx_buf)?;
respond_blocking(tx, transport, key_schedule)?;
client_finished_finalize(key_schedule, handshake)
}
State::ApplicationData => Ok(State::ApplicationData),
}
}
}
fn handle_processing_error_blocking<CipherSuite>(
result: Result<State, TlsError>,
transport: &mut impl BlockingWrite,
key_schedule: &mut KeySchedule<CipherSuite>,
tx_buf: &mut WriteBuffer,
) -> Result<State, TlsError>
where
CipherSuite: TlsCipherSuite,
{
if let Err(TlsError::AbortHandshake(level, description)) = result {
let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
let tx = tx_buf.write_record(
&ClientRecord::Alert(Alert { level, description }, false),
write_key_schedule,
Some(read_key_schedule),
)?;
respond_blocking(tx, transport, key_schedule)?;
}
result
}
fn respond_blocking<CipherSuite>(
tx: &[u8],
transport: &mut impl BlockingWrite,
key_schedule: &mut KeySchedule<CipherSuite>,
) -> Result<(), TlsError>
where
CipherSuite: TlsCipherSuite,
{
transport
.write_all(tx)
.map_err(|e| TlsError::Io(e.kind()))?;
key_schedule.write_state().increment_counter();
transport.flush().map_err(|e| TlsError::Io(e.kind()))?;
Ok(())
}
async fn handle_processing_error<'a, CipherSuite>(
result: Result<State, TlsError>,
transport: &mut impl AsyncWrite,
key_schedule: &mut KeySchedule<CipherSuite>,
tx_buf: &mut WriteBuffer<'a>,
) -> Result<State, TlsError>
where
CipherSuite: TlsCipherSuite,
{
if let Err(TlsError::AbortHandshake(level, description)) = result {
let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
let tx = tx_buf.write_record(
&ClientRecord::Alert(Alert { level, description }, false),
write_key_schedule,
Some(read_key_schedule),
)?;
respond(tx, transport, key_schedule).await?;
}
result
}
async fn respond<CipherSuite>(
tx: &[u8],
transport: &mut impl AsyncWrite,
key_schedule: &mut KeySchedule<CipherSuite>,
) -> Result<(), TlsError>
where
CipherSuite: TlsCipherSuite,
{
transport
.write_all(tx)
.await
.map_err(|e| TlsError::Io(e.kind()))?;
key_schedule.write_state().increment_counter();
transport
.flush()
.await
.map_err(|e| TlsError::Io(e.kind()))?;
Ok(())
}
fn client_hello<'r, CipherSuite, RNG, Verifier>(
key_schedule: &mut KeySchedule<CipherSuite>,
config: &TlsConfig<CipherSuite>,
rng: &mut RNG,
tx_buf: &'r mut WriteBuffer,
handshake: &mut Handshake<CipherSuite, Verifier>,
) -> Result<(State, &'r [u8]), TlsError>
where
RNG: CryptoRng + RngCore,
CipherSuite: TlsCipherSuite,
{
key_schedule.initialize_early_secret(config.psk.as_ref().map(|p| p.0))?;
let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
let client_hello = ClientRecord::client_hello(config, rng);
let slice = tx_buf.write_record(&client_hello, write_key_schedule, Some(read_key_schedule))?;
if let ClientRecord::Handshake(ClientHandshake::ClientHello(client_hello), _) = client_hello {
handshake.secret.replace(client_hello.secret);
Ok((State::ServerHello, slice))
} else {
Err(TlsError::EncodeError)
}
}
fn process_server_hello<CipherSuite, Verifier>(
handshake: &mut Handshake<CipherSuite, Verifier>,
key_schedule: &mut KeySchedule<CipherSuite>,
record: ServerRecord<'_, CipherSuite>,
) -> Result<State, TlsError>
where
CipherSuite: TlsCipherSuite,
{
match record {
ServerRecord::Handshake(server_handshake) => match server_handshake {
ServerHandshake::ServerHello(server_hello) => {
trace!("********* ServerHello");
let secret = handshake.secret.take().ok_or(TlsError::InvalidHandshake)?;
let shared = server_hello
.calculate_shared_secret(&secret)
.ok_or(TlsError::InvalidKeyShare)?;
key_schedule.initialize_handshake_secret(shared.raw_secret_bytes())?;
Ok(State::ServerVerify)
}
_ => Err(TlsError::InvalidHandshake),
},
ServerRecord::Alert(alert) => {
Err(TlsError::HandshakeAborted(alert.level, alert.description))
}
_ => Err(TlsError::InvalidRecord),
}
}
fn process_server_verify<'a, 'v, CipherSuite, Verifier>(
handshake: &mut Handshake<CipherSuite, Verifier>,
key_schedule: &mut KeySchedule<CipherSuite>,
config: &TlsConfig<'a, CipherSuite>,
record: ServerRecord<'_, CipherSuite>,
) -> Result<State, TlsError>
where
CipherSuite: TlsCipherSuite,
Verifier: TlsVerifier<'v, CipherSuite>,
{
let mut state = State::ServerVerify;
decrypt_record(key_schedule.read_state(), record, |key_schedule, record| {
match record {
ServerRecord::Handshake(server_handshake) => {
match server_handshake {
ServerHandshake::EncryptedExtensions(_) => {}
ServerHandshake::Certificate(certificate) => {
let transcript = key_schedule.transcript_hash();
handshake.verifier.verify_certificate(
transcript,
&config.ca,
certificate,
)?;
debug!("Certificate verified!");
}
ServerHandshake::CertificateVerify(verify) => {
handshake.verifier.verify_signature(verify)?;
debug!("Signature verified!");
}
ServerHandshake::CertificateRequest(request) => {
handshake.certificate_request.replace(request.try_into()?);
}
ServerHandshake::Finished(finished) => {
if !key_schedule.verify_server_finished(&finished)? {
warn!("Server signature verification failed");
return Err(TlsError::InvalidSignature);
}
state = if handshake.certificate_request.is_some() {
State::ClientCert
} else {
handshake
.traffic_hash
.replace(key_schedule.transcript_hash().clone());
State::ClientFinished
};
}
_ => return Err(TlsError::InvalidHandshake),
}
}
ServerRecord::ChangeCipherSpec(_) => {}
_ => return Err(TlsError::InvalidRecord),
}
Ok(())
})?;
Ok(state)
}
fn client_cert<'r, CipherSuite, Verifier>(
handshake: &mut Handshake<CipherSuite, Verifier>,
key_schedule: &mut KeySchedule<CipherSuite>,
config: &TlsConfig<CipherSuite>,
buffer: &'r mut WriteBuffer,
) -> Result<(State, &'r [u8]), TlsError>
where
CipherSuite: TlsCipherSuite,
{
handshake
.traffic_hash
.replace(key_schedule.transcript_hash().clone());
let request_context = &handshake
.certificate_request
.as_ref()
.ok_or(TlsError::InvalidHandshake)?
.request_context;
let mut certificate = CertificateRef::with_context(request_context);
if let Some(cert) = &config.cert {
certificate.add(cert.into())?;
}
let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
buffer
.write_record(
&ClientRecord::Handshake(ClientHandshake::ClientCert(certificate), true),
write_key_schedule,
Some(read_key_schedule),
)
.map(|slice| (State::ClientFinished, slice))
}
fn client_finished<'r, CipherSuite>(
key_schedule: &mut KeySchedule<CipherSuite>,
buffer: &'r mut WriteBuffer,
) -> Result<&'r [u8], TlsError>
where
CipherSuite: TlsCipherSuite,
{
let client_finished = key_schedule
.create_client_finished()
.map_err(|_| TlsError::InvalidHandshake)?;
let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
buffer.write_record(
&ClientRecord::Handshake(ClientHandshake::Finished(client_finished), true),
write_key_schedule,
Some(read_key_schedule),
)
}
fn client_finished_finalize<CipherSuite, Verifier>(
key_schedule: &mut KeySchedule<CipherSuite>,
handshake: &mut Handshake<CipherSuite, Verifier>,
) -> Result<State, TlsError>
where
CipherSuite: TlsCipherSuite,
{
key_schedule.replace_transcript_hash(
handshake
.traffic_hash
.take()
.ok_or(TlsError::InvalidHandshake)?,
);
key_schedule.initialize_master_secret()?;
Ok(State::ApplicationData)
}