use std::mem::MaybeUninit;
use crate::aead::TlsAead;
use crate::client_hello::client_hello_client;
use crate::config::TurtlsConfig;
use crate::error::{FullError, TurtlsError};
use crate::extensions::parse_encrypted_extensions;
use crate::record::{ContentType, RecordLayer};
use crate::server_hello::server_hello_client;
use crate::state::{
GlobalState, MaybeProt, ProtShakeMsg, ShakeState, TranscriptHasher, UnprotShakeMsg,
};
use crate::TurtlsAlert;
pub(crate) fn handshake_client(
shake_state: &mut ShakeState,
global_state: &mut GlobalState,
config: &TurtlsConfig,
) -> Result<(), ()> {
loop {
match shake_state.state {
MaybeProt::Unprot {
ref mut next,
ref mut unprot_state,
} => match next {
UnprotShakeMsg::ClientHello => {
client_hello_client(unprot_state, &mut shake_state.buf, config)
.map_err(|err| global_state.error.turtls_error = err)?;
shake_state
.buf
.write_raw(&mut global_state.rl, &mut global_state.transcript)
.map_err(|err| global_state.error.turtls_error = err)?;
*next = UnprotShakeMsg::ServerHello;
},
UnprotShakeMsg::ServerHello => {
shake_state
.buf
.read_raw(&mut global_state.rl, &mut global_state.transcript)
.map_err(|err| {
if let TurtlsError::Tls = err.turtls_error {
global_state.rl.close_raw(err.alert);
}
global_state.error = err
})?;
let aead =
server_hello_client(shake_state.buf.data(), unprot_state, global_state)
.map_err(|alert| {
global_state.rl.close_raw(alert);
global_state.error = FullError::sending_alert(alert);
})?;
shake_state.state = MaybeProt::Prot {
next: ProtShakeMsg::EncryptedExtensions,
aead,
};
},
},
MaybeProt::Prot {
ref mut next,
ref mut aead,
} => match next {
ProtShakeMsg::EncryptedExtensions => {
shake_state
.buf
.read(&mut global_state.rl, aead, &mut global_state.transcript)
.map_err(|err| {
if let TurtlsError::Tls = err.turtls_error {
global_state.rl.close(err.alert, aead);
}
global_state.error = err
})?;
parse_encrypted_extensions(shake_state.buf.data(), global_state).map_err(
|alert| {
global_state.rl.close(alert, aead);
global_state.error = FullError::sending_alert(alert);
},
)?;
*next = ProtShakeMsg::Certificate;
},
ProtShakeMsg::Certificate => {
shake_state
.buf
.read(&mut global_state.rl, aead, &mut global_state.transcript)
.map_err(|err| {
if let TurtlsError::Tls = err.turtls_error {
global_state.rl.close(err.alert, aead);
}
global_state.error = err
})?;
let cert_msg = shake_state.buf.data();
let context_len = cert_msg[0] as usize;
let certs = &cert_msg[1 + context_len as usize..];
let data_len = u32::from_be_bytes([0, certs[0], certs[1], certs[2]]) as usize;
for cert in crate::certificates::CertIter::new(&certs[3..][..data_len]) {
crate::certificates::x509::validate_cert(cert.data).unwrap();
}
todo!("validate certificates");
},
_ => todo!("Finish handshake"),
},
};
}
}
#[derive(Debug, PartialEq, Eq)]
#[expect(unused, reason = "not all handshake messages are implemented yet")]
pub(crate) enum ShakeType {
ClientHello = 1,
ServerHello = 2,
NewSessionTicket = 4,
EndOfEarlyData = 5,
EncryptedExtensions = 8,
Certificate = 11,
CertificateRequest = 13,
CertificateVerify = 15,
Finished = 20,
KeyUpdate = 24,
MessageHash = 254,
}
impl ShakeType {
pub const fn to_byte(self) -> u8 {
self as u8
}
}
pub(crate) struct ShakeBuf {
buf: Box<[u8]>,
len: usize,
max_size: usize,
status: ReadStatus,
}
enum ReadStatus {
NeedsHeader(usize),
NeedsData(usize),
}
impl ReadStatus {
const fn new() -> Self {
Self::NeedsHeader(0)
}
}
impl ShakeBuf {
const INIT_SIZE: usize = 0x4000;
pub(crate) const LEN_SIZE: usize = 3;
pub(crate) const HEADER_SIZE: usize = size_of::<ShakeType>() + Self::LEN_SIZE;
pub(crate) fn new(max_len: usize) -> Self {
let mut buf = Box::new_uninit_slice(Self::INIT_SIZE);
buf.fill(MaybeUninit::zeroed());
Self {
buf: unsafe { buf.assume_init() },
len: 0,
max_size: max_len,
status: ReadStatus::new(),
}
}
pub(crate) fn start(&mut self, msg_type: ShakeType) {
self.len = 0;
self.buf[0] = msg_type.to_byte();
self.buf[1..][..Self::LEN_SIZE].copy_from_slice(&[0; Self::LEN_SIZE]);
}
pub(crate) fn push(&mut self, value: u8) {
if self.len + Self::HEADER_SIZE + 1 > self.buf.len() {
todo!("grow handshake buffer");
}
self.buf[self.len + Self::HEADER_SIZE] = value;
self.len += 1;
}
pub(crate) fn extend_from_slice(&mut self, slice: &[u8]) {
if self.len + Self::HEADER_SIZE + slice.len() > self.buf.len() {
todo!("grow handshake buffer");
}
self.buf[Self::HEADER_SIZE + self.len..][..slice.len()].copy_from_slice(slice);
self.len += slice.len();
}
pub(crate) fn data(&self) -> &[u8] {
&self.buf[Self::HEADER_SIZE..][..self.len]
}
fn read_inner(
&mut self,
rl: &mut RecordLayer,
mut get_fn: impl FnMut(&mut RecordLayer) -> Result<(), FullError>,
transcript: &mut TranscriptHasher,
) -> Result<(), FullError> {
loop {
match self.status {
ReadStatus::NeedsHeader(ref mut amt) => {
if *amt == 0 {
get_fn(rl)?;
if rl.msg_type() == ContentType::ChangeCipherSpec.to_byte() {
rl.discard();
get_fn(rl)?;
}
rl.check_alert()
.map_err(|alert| FullError::recving_alert(alert))?;
if rl.msg_type() != ContentType::Handshake.to_byte() {
return Err(FullError::sending_alert(TurtlsAlert::UnexpectedMessage));
}
}
while *amt < Self::HEADER_SIZE {
get_fn(rl)?;
let new_bytes = rl.read_remaining(&mut self.buf[*amt..Self::HEADER_SIZE]);
if new_bytes == 0 {
return Err(FullError::sending_alert(TurtlsAlert::IllegalParam));
}
*amt += new_bytes;
}
self.len =
u32::from_be_bytes([0, self.buf[1], self.buf[2], self.buf[3]]) as usize;
if self.len > self.buf.len() {
todo!("grow handshake buffer");
}
self.status = ReadStatus::NeedsData(0);
},
ReadStatus::NeedsData(ref mut amt) => {
while *amt < self.len {
get_fn(rl)?;
let new_bytes = rl.read_remaining(
&mut self.buf[Self::HEADER_SIZE + *amt..Self::HEADER_SIZE + self.len],
);
if new_bytes == 0 {
return Err(FullError::sending_alert(TurtlsAlert::IllegalParam));
}
*amt += new_bytes;
}
self.status = ReadStatus::new();
transcript.update_with(&self.buf[..Self::HEADER_SIZE + self.len]);
return Ok(());
},
}
}
}
pub(crate) fn read_raw(
&mut self,
rl: &mut RecordLayer,
transcript: &mut TranscriptHasher,
) -> Result<(), FullError> {
let get_fn = RecordLayer::get_raw;
self.read_inner(rl, get_fn, transcript)
}
pub(crate) fn read(
&mut self,
rl: &mut RecordLayer,
aead: &mut TlsAead,
transcript: &mut TranscriptHasher,
) -> Result<(), FullError> {
let get_fn = |rl: &mut RecordLayer| RecordLayer::get(rl, aead);
self.read_inner(rl, get_fn, transcript)
}
pub(crate) fn write_raw(
&mut self,
rl: &mut RecordLayer,
transcript: &mut TranscriptHasher,
) -> Result<(), TurtlsError> {
self.encode_len();
transcript.update_with(&self.buf[..Self::HEADER_SIZE + self.len]);
rl.write_raw(
&self.buf[..Self::HEADER_SIZE + self.len],
ContentType::Handshake,
)
}
pub(crate) fn write(
&mut self,
rl: &mut RecordLayer,
transcript: &mut TranscriptHasher,
aead: &mut TlsAead,
) -> Result<(), TurtlsError> {
self.encode_len();
transcript.update_with(&self.buf);
rl.write(
&self.buf[..Self::HEADER_SIZE + self.len],
ContentType::Handshake,
aead,
)
}
pub(crate) fn msg_type(&self) -> u8 {
self.buf[0]
}
pub(crate) fn encode_len(&mut self) {
let len = (self.len as u32).to_be_bytes();
self.buf[1..][..Self::LEN_SIZE].copy_from_slice(&len[1..]);
}
}