use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use thiserror::Error;
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum Tls13Error {
#[error("Invalid key length: expected {expected}, got {actual}")]
InvalidLength { expected: usize, actual: usize },
#[error("Key schedule not initialized")]
NotInitialized,
#[error("Invalid state: {0}")]
InvalidState(String),
}
pub type Tls13Result<T> = Result<T, Tls13Error>;
#[derive(Clone, Serialize, Deserialize)]
pub struct Tls13KeySchedule {
early_secret: [u8; 32],
handshake_secret: Option<[u8; 32]>,
master_secret: Option<[u8; 32]>,
}
impl Tls13KeySchedule {
pub fn new(shared_secret: &[u8]) -> Self {
let zero_salt = [0u8; 32];
let early_secret = hkdf_extract(&zero_salt, &zero_salt);
let handshake_secret = derive_secret(&early_secret, b"derived", &[]);
let handshake_secret = hkdf_extract(&handshake_secret, shared_secret);
Self {
early_secret,
handshake_secret: Some(handshake_secret),
master_secret: None,
}
}
pub fn derive_handshake_secrets(
&mut self,
client_hello: &[u8],
server_hello: &[u8],
) -> ([u8; 32], [u8; 32]) {
let handshake_secret = self
.handshake_secret
.expect("Handshake secret not initialized");
let mut hasher = Sha256::new();
hasher.update(client_hello);
hasher.update(server_hello);
let transcript_hash = hasher.finalize();
let client_hs_traffic_secret =
derive_secret(&handshake_secret, b"c hs traffic", &transcript_hash);
let server_hs_traffic_secret =
derive_secret(&handshake_secret, b"s hs traffic", &transcript_hash);
let derived = derive_secret(&handshake_secret, b"derived", &[]);
let master_secret = hkdf_extract(&derived, &[0u8; 32]);
self.master_secret = Some(master_secret);
(client_hs_traffic_secret, server_hs_traffic_secret)
}
pub fn derive_application_secrets(&self) -> Tls13Result<([u8; 32], [u8; 32])> {
let master_secret = self.master_secret.ok_or(Tls13Error::NotInitialized)?;
let empty_hash = Sha256::digest([]);
let client_app_traffic_secret = derive_secret(&master_secret, b"c ap traffic", &empty_hash);
let server_app_traffic_secret = derive_secret(&master_secret, b"s ap traffic", &empty_hash);
Ok((client_app_traffic_secret, server_app_traffic_secret))
}
pub fn derive_exporter_secret(&self) -> Tls13Result<[u8; 32]> {
let master_secret = self.master_secret.ok_or(Tls13Error::NotInitialized)?;
let empty_hash = Sha256::digest([]);
Ok(derive_secret(&master_secret, b"exp master", &empty_hash))
}
pub fn derive_resumption_secret(&self, transcript_hash: &[u8]) -> Tls13Result<[u8; 32]> {
let master_secret = self.master_secret.ok_or(Tls13Error::NotInitialized)?;
Ok(derive_secret(
&master_secret,
b"res master",
transcript_hash,
))
}
pub fn update_traffic_secret(current_secret: &[u8; 32]) -> [u8; 32] {
derive_secret(current_secret, b"traffic upd", &[])
}
}
fn hkdf_extract(salt: &[u8], ikm: &[u8]) -> [u8; 32] {
use hmac::digest::KeyInit;
use hmac::{Hmac, Mac};
type HmacSha256 = Hmac<Sha256>;
let mut mac =
<HmacSha256 as KeyInit>::new_from_slice(salt).expect("HMAC can take key of any size");
mac.update(ikm);
let result = mac.finalize();
let bytes = result.into_bytes();
let mut output = [0u8; 32];
output.copy_from_slice(&bytes);
output
}
fn hkdf_expand_label(secret: &[u8], label: &[u8], context: &[u8], length: u16) -> Vec<u8> {
let mut hkdf_label = Vec::new();
hkdf_label.extend_from_slice(&length.to_be_bytes());
let full_label = [b"tls13 ", label].concat();
hkdf_label.push(full_label.len() as u8);
hkdf_label.extend_from_slice(&full_label);
hkdf_label.push(context.len() as u8);
hkdf_label.extend_from_slice(context);
hkdf_expand(secret, &hkdf_label, length as usize)
}
fn hkdf_expand(prk: &[u8], info: &[u8], length: usize) -> Vec<u8> {
use hmac::digest::KeyInit;
use hmac::{Hmac, Mac};
type HmacSha256 = Hmac<Sha256>;
let mut output = Vec::with_capacity(length);
let mut t = Vec::new();
let mut counter = 1u8;
while output.len() < length {
let mut mac =
<HmacSha256 as KeyInit>::new_from_slice(prk).expect("HMAC can take key of any size");
mac.update(&t);
mac.update(info);
mac.update(&[counter]);
t = mac.finalize().into_bytes().to_vec();
output.extend_from_slice(&t);
counter += 1;
}
output.truncate(length);
output
}
fn derive_secret(secret: &[u8], label: &[u8], messages: &[u8]) -> [u8; 32] {
let transcript_hash = if messages.is_empty() {
Sha256::digest([]).to_vec()
} else {
messages.to_vec()
};
let expanded = hkdf_expand_label(secret, label, &transcript_hash, 32);
let mut output = [0u8; 32];
output.copy_from_slice(&expanded[..32]);
output
}
pub fn derive_traffic_keys(traffic_secret: &[u8; 32]) -> ([u8; 32], [u8; 12]) {
let key_bytes = hkdf_expand_label(traffic_secret, b"key", &[], 32);
let mut key = [0u8; 32];
key.copy_from_slice(&key_bytes[..32]);
let iv_bytes = hkdf_expand_label(traffic_secret, b"iv", &[], 12);
let mut iv = [0u8; 12];
iv.copy_from_slice(&iv_bytes[..12]);
(key, iv)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_schedule_creation() {
let shared_secret = [0x42u8; 32];
let schedule = Tls13KeySchedule::new(&shared_secret);
assert!(schedule.handshake_secret.is_some());
assert!(schedule.master_secret.is_none());
}
#[test]
fn test_handshake_secrets_derivation() {
let shared_secret = [0x42u8; 32];
let mut schedule = Tls13KeySchedule::new(&shared_secret);
let client_hello = b"client hello message";
let server_hello = b"server hello message";
let (client_hs, server_hs) = schedule.derive_handshake_secrets(client_hello, server_hello);
assert_ne!(client_hs, server_hs);
assert!(schedule.master_secret.is_some());
}
#[test]
fn test_application_secrets_derivation() {
let shared_secret = [0x42u8; 32];
let mut schedule = Tls13KeySchedule::new(&shared_secret);
let client_hello = b"client hello";
let server_hello = b"server hello";
schedule.derive_handshake_secrets(client_hello, server_hello);
let result = schedule.derive_application_secrets();
assert!(result.is_ok());
let (client_app, server_app) = result.unwrap();
assert_ne!(client_app, server_app);
}
#[test]
fn test_application_secrets_before_handshake() {
let shared_secret = [0x42u8; 32];
let schedule = Tls13KeySchedule::new(&shared_secret);
let result = schedule.derive_application_secrets();
assert!(result.is_err());
}
#[test]
fn test_exporter_secret() {
let shared_secret = [0x42u8; 32];
let mut schedule = Tls13KeySchedule::new(&shared_secret);
schedule.derive_handshake_secrets(b"client hello", b"server hello");
let exporter_secret = schedule.derive_exporter_secret();
assert!(exporter_secret.is_ok());
assert_eq!(exporter_secret.unwrap().len(), 32);
}
#[test]
fn test_resumption_secret() {
let shared_secret = [0x42u8; 32];
let mut schedule = Tls13KeySchedule::new(&shared_secret);
schedule.derive_handshake_secrets(b"client hello", b"server hello");
let transcript = Sha256::digest(b"full handshake transcript");
let resumption_secret = schedule.derive_resumption_secret(&transcript);
assert!(resumption_secret.is_ok());
assert_eq!(resumption_secret.unwrap().len(), 32);
}
#[test]
fn test_traffic_key_update() {
let current_secret = [0x42u8; 32];
let new_secret = Tls13KeySchedule::update_traffic_secret(¤t_secret);
assert_ne!(current_secret, new_secret);
}
#[test]
fn test_derive_traffic_keys() {
let traffic_secret = [0x42u8; 32];
let (key, iv) = derive_traffic_keys(&traffic_secret);
assert_eq!(key.len(), 32);
assert_eq!(iv.len(), 12);
}
#[test]
fn test_hkdf_extract() {
let salt = [0x01u8; 32];
let ikm = [0x02u8; 32];
let prk = hkdf_extract(&salt, &ikm);
assert_eq!(prk.len(), 32);
let prk2 = hkdf_extract(&salt, &ikm);
assert_eq!(prk, prk2);
}
#[test]
fn test_hkdf_expand() {
let prk = [0x42u8; 32];
let info = b"test info";
let okm = hkdf_expand(&prk, info, 64);
assert_eq!(okm.len(), 64);
let okm2 = hkdf_expand(&prk, info, 64);
assert_eq!(okm, okm2);
}
#[test]
fn test_hkdf_expand_label() {
let secret = [0x42u8; 32];
let label = b"test label";
let context = b"test context";
let output = hkdf_expand_label(&secret, label, context, 32);
assert_eq!(output.len(), 32);
let output2 = hkdf_expand_label(&secret, label, context, 32);
assert_eq!(output, output2);
}
#[test]
fn test_derive_secret() {
let secret = [0x42u8; 32];
let label = b"test";
let messages = b"messages";
let derived = derive_secret(&secret, label, messages);
assert_eq!(derived.len(), 32);
let derived2 = derive_secret(&secret, label, messages);
assert_eq!(derived, derived2);
}
#[test]
fn test_serialization() {
let shared_secret = [0x42u8; 32];
let schedule = Tls13KeySchedule::new(&shared_secret);
let serialized = crate::codec::encode(&schedule).unwrap();
let deserialized: Tls13KeySchedule = crate::codec::decode(&serialized).unwrap();
assert_eq!(deserialized.early_secret, schedule.early_secret);
assert_eq!(deserialized.handshake_secret, schedule.handshake_secret);
}
}