use std::time::{Duration, Instant};
use rand_core::{OsRng, RngCore};
use sha2::{Digest, Sha256};
use zeroize::Zeroize;
use crate::{
CoreError,
auth::{HandshakeAuth, SessionAuthConfig},
control::ControlMessage,
crypto::{
Direction, EphemeralKeyPair, TrafficKeys, derive_rekey_traffic_keys, derive_traffic_keys,
random_session_salt,
},
};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum HandshakeRole {
Initiator,
Responder,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum SessionState {
Init,
WaitingPeerHello,
Active,
Closed,
}
#[derive(Clone, Debug)]
pub struct RekeyThresholds {
pub max_frames: u64,
pub max_bytes: u64,
pub max_age: Duration,
pub max_previous_keys: usize,
}
impl Default for RekeyThresholds {
fn default() -> Self {
Self {
max_frames: 1 << 20,
max_bytes: 1 << 30,
max_age: Duration::from_secs(600),
max_previous_keys: 2,
}
}
}
#[derive(Clone, Debug)]
pub struct Session {
role: HandshakeRole,
state: SessionState,
local_eph: EphemeralKeyPair,
peer_eph_public: Option<[u8; 32]>,
shared_secret: Option<[u8; 32]>,
session_salt: [u8; 32],
active_keys: Option<TrafficKeys>,
previous_keys: Vec<TrafficKeys>,
thresholds: RekeyThresholds,
auth: SessionAuthConfig,
peer_authenticated: bool,
outbound_frames: u64,
outbound_bytes: u64,
last_rekey_at: Instant,
}
impl Drop for Session {
fn drop(&mut self) {
if let Some(shared) = &mut self.shared_secret {
shared.zeroize();
}
self.session_salt.zeroize();
}
}
impl Session {
pub fn new_initiator(thresholds: RekeyThresholds) -> (Self, ControlMessage) {
Self::new_initiator_with_auth(thresholds, SessionAuthConfig::default())
}
pub fn new_initiator_with_auth(
thresholds: RekeyThresholds,
auth: SessionAuthConfig,
) -> (Self, ControlMessage) {
let local_eph = EphemeralKeyPair::generate();
let session_salt = random_session_salt();
let binding = client_hello_binding(local_eph.public, session_salt);
let auth_payload = auth.local_identity().map(|identity| {
HandshakeAuth::sign(
identity,
&client_auth_message(local_eph.public, session_salt, binding),
)
});
let msg = ControlMessage::ClientHello {
eph_public: local_eph.public,
session_salt,
transcript_binding: binding,
auth: auth_payload,
};
(
Self {
role: HandshakeRole::Initiator,
state: SessionState::WaitingPeerHello,
local_eph,
peer_eph_public: None,
shared_secret: None,
session_salt,
active_keys: None,
previous_keys: Vec::new(),
thresholds,
auth,
peer_authenticated: false,
outbound_frames: 0,
outbound_bytes: 0,
last_rekey_at: Instant::now(),
},
msg,
)
}
pub fn new_responder(thresholds: RekeyThresholds) -> Self {
Self::new_responder_with_auth(thresholds, SessionAuthConfig::default())
}
pub fn new_responder_with_auth(thresholds: RekeyThresholds, auth: SessionAuthConfig) -> Self {
Self {
role: HandshakeRole::Responder,
state: SessionState::WaitingPeerHello,
local_eph: EphemeralKeyPair::generate(),
peer_eph_public: None,
shared_secret: None,
session_salt: [0u8; 32],
active_keys: None,
previous_keys: Vec::new(),
thresholds,
auth,
peer_authenticated: false,
outbound_frames: 0,
outbound_bytes: 0,
last_rekey_at: Instant::now(),
}
}
pub fn state(&self) -> SessionState {
self.state
}
pub fn role(&self) -> HandshakeRole {
self.role
}
pub fn peer_authenticated(&self) -> bool {
self.peer_authenticated
}
pub fn outbound_direction(&self) -> Direction {
match self.role {
HandshakeRole::Initiator => Direction::C2S,
HandshakeRole::Responder => Direction::S2C,
}
}
pub fn inbound_direction(&self) -> Direction {
match self.role {
HandshakeRole::Initiator => Direction::S2C,
HandshakeRole::Responder => Direction::C2S,
}
}
pub fn handle_control(
&mut self,
msg: &ControlMessage,
) -> Result<Option<ControlMessage>, CoreError> {
match (self.role, self.state, msg) {
(
HandshakeRole::Responder,
SessionState::WaitingPeerHello,
ControlMessage::ClientHello {
eph_public,
session_salt,
transcript_binding,
auth,
},
) => {
let expected = client_hello_binding(*eph_public, *session_salt);
if transcript_binding != &expected {
return Err(CoreError::InvalidControlMessage);
}
let peer_authenticated = self.verify_client_auth(
*eph_public,
*session_salt,
*transcript_binding,
auth.as_ref(),
)?;
self.peer_eph_public = Some(*eph_public);
self.session_salt = *session_salt;
let shared = self.local_eph.shared_secret(*eph_public)?;
let keys = derive_traffic_keys(&shared, &self.session_salt, 0)?;
self.shared_secret = Some(shared);
self.active_keys = Some(keys);
self.state = SessionState::Active;
self.peer_authenticated = peer_authenticated;
self.last_rekey_at = Instant::now();
let server_binding =
server_hello_binding(*eph_public, self.local_eph.public, self.session_salt);
let server_auth = self.auth.local_identity().map(|identity| {
HandshakeAuth::sign(
identity,
&server_auth_message(
*eph_public,
self.local_eph.public,
self.session_salt,
server_binding,
),
)
});
Ok(Some(ControlMessage::ServerHello {
eph_public: self.local_eph.public,
transcript_binding: server_binding,
auth: server_auth,
}))
}
(
HandshakeRole::Initiator,
SessionState::WaitingPeerHello,
ControlMessage::ServerHello {
eph_public,
transcript_binding,
auth,
},
) => {
let expected =
server_hello_binding(self.local_eph.public, *eph_public, self.session_salt);
if transcript_binding != &expected {
return Err(CoreError::InvalidControlMessage);
}
let peer_authenticated =
self.verify_server_auth(*eph_public, *transcript_binding, auth.as_ref())?;
self.peer_eph_public = Some(*eph_public);
let shared = self.local_eph.shared_secret(*eph_public)?;
let keys = derive_traffic_keys(&shared, &self.session_salt, 0)?;
self.shared_secret = Some(shared);
self.active_keys = Some(keys);
self.state = SessionState::Active;
self.peer_authenticated = peer_authenticated;
self.last_rekey_at = Instant::now();
Ok(None)
}
(
_,
SessionState::Active,
ControlMessage::Rekey {
old_key_id,
new_key_id,
rekey_salt,
transcript_binding,
},
) => {
let active = self
.active_keys
.as_ref()
.ok_or(CoreError::InvalidSessionState)?;
if *old_key_id != active.key_id {
return Err(CoreError::UnexpectedControlMessage);
}
let expected =
rekey_binding(*old_key_id, *new_key_id, *rekey_salt, self.session_salt);
if transcript_binding != &expected {
return Err(CoreError::InvalidControlMessage);
}
let shared = self.shared_secret.ok_or(CoreError::MissingSessionSecret)?;
let next = derive_rekey_traffic_keys(
&shared,
&self.session_salt,
rekey_salt,
*new_key_id,
)?;
self.install_new_active_key(next);
self.last_rekey_at = Instant::now();
Ok(None)
}
(_, SessionState::Active, ControlMessage::Error { .. }) => Ok(None),
_ => Err(CoreError::UnexpectedControlMessage),
}
}
pub fn active_keys(&self) -> Option<TrafficKeys> {
self.active_keys.clone()
}
pub fn active_and_previous_keys(&self) -> Option<Vec<TrafficKeys>> {
let mut out = Vec::new();
let active = self.active_keys.clone()?;
out.push(active);
out.extend(self.previous_keys.iter().cloned());
Some(out)
}
pub fn key_ring(&self) -> Result<Vec<TrafficKeys>, CoreError> {
self.active_and_previous_keys()
.ok_or(CoreError::InvalidSessionState)
}
pub fn on_outbound_payload(
&mut self,
plaintext_len: usize,
) -> Result<Option<ControlMessage>, CoreError> {
if self.state != SessionState::Active {
return Err(CoreError::InvalidSessionState);
}
self.outbound_frames = self.outbound_frames.saturating_add(1);
self.outbound_bytes = self.outbound_bytes.saturating_add(plaintext_len as u64);
if self.should_rekey() {
let msg = self.force_rekey()?;
return Ok(Some(msg));
}
Ok(None)
}
pub fn force_rekey(&mut self) -> Result<ControlMessage, CoreError> {
if self.state != SessionState::Active {
return Err(CoreError::InvalidSessionState);
}
let active = self
.active_keys
.clone()
.ok_or(CoreError::InvalidSessionState)?;
let old_key_id = active.key_id;
let new_key_id = old_key_id.checked_add(1).ok_or(CoreError::KeyIdExhausted)?;
let mut rekey_salt = [0u8; 32];
OsRng.fill_bytes(&mut rekey_salt);
let shared = self.shared_secret.ok_or(CoreError::MissingSessionSecret)?;
let next = derive_rekey_traffic_keys(&shared, &self.session_salt, &rekey_salt, new_key_id)?;
self.install_new_active_key(next);
self.outbound_frames = 0;
self.outbound_bytes = 0;
self.last_rekey_at = Instant::now();
let transcript_binding =
rekey_binding(old_key_id, new_key_id, rekey_salt, self.session_salt);
Ok(ControlMessage::Rekey {
old_key_id,
new_key_id,
rekey_salt,
transcript_binding,
})
}
fn should_rekey(&self) -> bool {
self.outbound_frames >= self.thresholds.max_frames
|| self.outbound_bytes >= self.thresholds.max_bytes
|| self.last_rekey_at.elapsed() >= self.thresholds.max_age
}
fn install_new_active_key(&mut self, next: TrafficKeys) {
if let Some(current) = self.active_keys.take() {
self.previous_keys.insert(0, current);
if self.previous_keys.len() > self.thresholds.max_previous_keys {
self.previous_keys
.truncate(self.thresholds.max_previous_keys);
}
}
self.active_keys = Some(next);
}
fn verify_client_auth(
&self,
eph_public: [u8; 32],
session_salt: [u8; 32],
transcript_binding: [u8; 32],
auth: Option<&HandshakeAuth>,
) -> Result<bool, CoreError> {
let message = client_auth_message(eph_public, session_salt, transcript_binding);
self.verify_auth_payload(auth, &message)
}
fn verify_server_auth(
&self,
server_public: [u8; 32],
transcript_binding: [u8; 32],
auth: Option<&HandshakeAuth>,
) -> Result<bool, CoreError> {
let message = server_auth_message(
self.local_eph.public,
server_public,
self.session_salt,
transcript_binding,
);
self.verify_auth_payload(auth, &message)
}
fn verify_auth_payload(
&self,
auth: Option<&HandshakeAuth>,
message: &[u8],
) -> Result<bool, CoreError> {
match auth {
Some(auth) => {
auth.verify(message)?;
if let Some(peer_identity) = self.auth.peer_identity()
&& auth.identity_public_key != peer_identity.public_key
{
return Err(CoreError::PeerIdentityMismatch);
}
Ok(true)
}
None if self.auth.requires_peer_authentication()
|| self.auth.peer_identity().is_some() =>
{
Err(CoreError::MissingPeerAuthentication)
}
None => Ok(false),
}
}
}
fn client_hello_binding(client_public: [u8; 32], session_salt: [u8; 32]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(b"foctet hs client");
hasher.update(client_public);
hasher.update(session_salt);
hasher.finalize().into()
}
fn client_auth_message(
client_public: [u8; 32],
session_salt: [u8; 32],
transcript_binding: [u8; 32],
) -> Vec<u8> {
let mut out = Vec::with_capacity(19 + 32 + 32 + 32);
out.extend_from_slice(b"foctet auth client");
out.extend_from_slice(&client_public);
out.extend_from_slice(&session_salt);
out.extend_from_slice(&transcript_binding);
out
}
fn server_hello_binding(
client_public: [u8; 32],
server_public: [u8; 32],
session_salt: [u8; 32],
) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(b"foctet hs server");
hasher.update(client_public);
hasher.update(server_public);
hasher.update(session_salt);
hasher.finalize().into()
}
fn server_auth_message(
client_public: [u8; 32],
server_public: [u8; 32],
session_salt: [u8; 32],
transcript_binding: [u8; 32],
) -> Vec<u8> {
let mut out = Vec::with_capacity(19 + 32 + 32 + 32 + 32);
out.extend_from_slice(b"foctet auth server");
out.extend_from_slice(&client_public);
out.extend_from_slice(&server_public);
out.extend_from_slice(&session_salt);
out.extend_from_slice(&transcript_binding);
out
}
fn rekey_binding(
old_key_id: u8,
new_key_id: u8,
rekey_salt: [u8; 32],
session_salt: [u8; 32],
) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(b"foctet rekey");
hasher.update([old_key_id]);
hasher.update([new_key_id]);
hasher.update(rekey_salt);
hasher.update(session_salt);
hasher.finalize().into()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{IdentityKeyPair, PeerIdentity};
#[test]
fn session_handshake_and_rekey() {
let (mut client, hello) = Session::new_initiator(RekeyThresholds::default());
let mut server = Session::new_responder(RekeyThresholds::default());
let server_hello = server
.handle_control(&hello)
.expect("server handle client hello")
.expect("server hello response");
client
.handle_control(&server_hello)
.expect("client handle server hello");
assert_eq!(client.state(), SessionState::Active);
assert_eq!(server.state(), SessionState::Active);
let rekey = client.force_rekey().expect("client force rekey");
server.handle_control(&rekey).expect("server handle rekey");
let client_key = client.active_keys().expect("client active key");
let server_key = server.active_keys().expect("server active key");
assert_eq!(client_key.key_id, server_key.key_id);
}
#[test]
fn session_authenticates_pinned_peer_identities() {
let client_identity = IdentityKeyPair::from_secret_key_bytes([0x41; 32]);
let server_identity = IdentityKeyPair::from_secret_key_bytes([0x61; 32]);
let client_auth = SessionAuthConfig::new()
.with_local_identity(client_identity.clone())
.with_peer_identity(PeerIdentity::new(server_identity.public_key()))
.require_peer_authentication(true);
let server_auth = SessionAuthConfig::new()
.with_local_identity(server_identity.clone())
.with_peer_identity(PeerIdentity::new(client_identity.public_key()))
.require_peer_authentication(true);
let (mut client, hello) =
Session::new_initiator_with_auth(RekeyThresholds::default(), client_auth);
let mut server = Session::new_responder_with_auth(RekeyThresholds::default(), server_auth);
let server_hello = server
.handle_control(&hello)
.expect("server handle client hello")
.expect("server hello response");
client
.handle_control(&server_hello)
.expect("client handle server hello");
assert!(client.peer_authenticated());
assert!(server.peer_authenticated());
}
}