1use std::time::{Duration, Instant};
2
3use rand_core::{OsRng, RngCore};
4use sha2::{Digest, Sha256};
5use zeroize::Zeroize;
6
7use crate::{
8 CoreError,
9 auth::{HandshakeAuth, SessionAuthConfig},
10 control::ControlMessage,
11 crypto::{
12 Direction, EphemeralKeyPair, TrafficKeys, derive_rekey_traffic_keys, derive_traffic_keys,
13 random_session_salt,
14 },
15};
16
17#[derive(Clone, Copy, Debug, Eq, PartialEq)]
19pub enum HandshakeRole {
20 Initiator,
22 Responder,
24}
25
26#[derive(Clone, Copy, Debug, Eq, PartialEq)]
28pub enum SessionState {
29 Init,
31 WaitingPeerHello,
33 Active,
35 Closed,
37}
38
39#[derive(Clone, Debug)]
41pub struct RekeyThresholds {
42 pub max_frames: u64,
44 pub max_bytes: u64,
46 pub max_age: Duration,
48 pub max_previous_keys: usize,
50}
51
52impl Default for RekeyThresholds {
53 fn default() -> Self {
54 Self {
55 max_frames: 1 << 20,
56 max_bytes: 1 << 30,
57 max_age: Duration::from_secs(600),
58 max_previous_keys: 2,
59 }
60 }
61}
62
63#[derive(Clone, Debug)]
65pub struct Session {
66 role: HandshakeRole,
67 state: SessionState,
68 local_eph: EphemeralKeyPair,
69 peer_eph_public: Option<[u8; 32]>,
70 shared_secret: Option<[u8; 32]>,
71 session_salt: [u8; 32],
72 active_keys: Option<TrafficKeys>,
73 previous_keys: Vec<TrafficKeys>,
74 thresholds: RekeyThresholds,
75 auth: SessionAuthConfig,
76 peer_authenticated: bool,
77 outbound_frames: u64,
78 outbound_bytes: u64,
79 last_rekey_at: Instant,
80}
81
82impl Drop for Session {
83 fn drop(&mut self) {
84 if let Some(shared) = &mut self.shared_secret {
85 shared.zeroize();
86 }
87 self.session_salt.zeroize();
88 }
89}
90
91impl Session {
92 pub fn new_initiator(thresholds: RekeyThresholds) -> (Self, ControlMessage) {
94 Self::new_initiator_with_auth(thresholds, SessionAuthConfig::default())
95 }
96
97 pub fn new_initiator_with_auth(
99 thresholds: RekeyThresholds,
100 auth: SessionAuthConfig,
101 ) -> (Self, ControlMessage) {
102 let local_eph = EphemeralKeyPair::generate();
103 let session_salt = random_session_salt();
104 let binding = client_hello_binding(local_eph.public, session_salt);
105 let auth_payload = auth.local_identity().map(|identity| {
106 HandshakeAuth::sign(
107 identity,
108 &client_auth_message(local_eph.public, session_salt, binding),
109 )
110 });
111
112 let msg = ControlMessage::ClientHello {
113 eph_public: local_eph.public,
114 session_salt,
115 transcript_binding: binding,
116 auth: auth_payload,
117 };
118
119 (
120 Self {
121 role: HandshakeRole::Initiator,
122 state: SessionState::WaitingPeerHello,
123 local_eph,
124 peer_eph_public: None,
125 shared_secret: None,
126 session_salt,
127 active_keys: None,
128 previous_keys: Vec::new(),
129 thresholds,
130 auth,
131 peer_authenticated: false,
132 outbound_frames: 0,
133 outbound_bytes: 0,
134 last_rekey_at: Instant::now(),
135 },
136 msg,
137 )
138 }
139
140 pub fn new_responder(thresholds: RekeyThresholds) -> Self {
142 Self::new_responder_with_auth(thresholds, SessionAuthConfig::default())
143 }
144
145 pub fn new_responder_with_auth(thresholds: RekeyThresholds, auth: SessionAuthConfig) -> Self {
147 Self {
148 role: HandshakeRole::Responder,
149 state: SessionState::WaitingPeerHello,
150 local_eph: EphemeralKeyPair::generate(),
151 peer_eph_public: None,
152 shared_secret: None,
153 session_salt: [0u8; 32],
154 active_keys: None,
155 previous_keys: Vec::new(),
156 thresholds,
157 auth,
158 peer_authenticated: false,
159 outbound_frames: 0,
160 outbound_bytes: 0,
161 last_rekey_at: Instant::now(),
162 }
163 }
164
165 pub fn state(&self) -> SessionState {
167 self.state
168 }
169
170 pub fn role(&self) -> HandshakeRole {
172 self.role
173 }
174
175 pub fn peer_authenticated(&self) -> bool {
177 self.peer_authenticated
178 }
179
180 pub fn outbound_direction(&self) -> Direction {
182 match self.role {
183 HandshakeRole::Initiator => Direction::C2S,
184 HandshakeRole::Responder => Direction::S2C,
185 }
186 }
187
188 pub fn inbound_direction(&self) -> Direction {
190 match self.role {
191 HandshakeRole::Initiator => Direction::S2C,
192 HandshakeRole::Responder => Direction::C2S,
193 }
194 }
195
196 pub fn handle_control(
198 &mut self,
199 msg: &ControlMessage,
200 ) -> Result<Option<ControlMessage>, CoreError> {
201 match (self.role, self.state, msg) {
202 (
203 HandshakeRole::Responder,
204 SessionState::WaitingPeerHello,
205 ControlMessage::ClientHello {
206 eph_public,
207 session_salt,
208 transcript_binding,
209 auth,
210 },
211 ) => {
212 let expected = client_hello_binding(*eph_public, *session_salt);
213 if transcript_binding != &expected {
214 return Err(CoreError::InvalidControlMessage);
215 }
216 let peer_authenticated = self.verify_client_auth(
217 *eph_public,
218 *session_salt,
219 *transcript_binding,
220 auth.as_ref(),
221 )?;
222
223 self.peer_eph_public = Some(*eph_public);
224 self.session_salt = *session_salt;
225 let shared = self.local_eph.shared_secret(*eph_public)?;
226 let keys = derive_traffic_keys(&shared, &self.session_salt, 0)?;
227
228 self.shared_secret = Some(shared);
229 self.active_keys = Some(keys);
230 self.state = SessionState::Active;
231 self.peer_authenticated = peer_authenticated;
232 self.last_rekey_at = Instant::now();
233
234 let server_binding =
235 server_hello_binding(*eph_public, self.local_eph.public, self.session_salt);
236 let server_auth = self.auth.local_identity().map(|identity| {
237 HandshakeAuth::sign(
238 identity,
239 &server_auth_message(
240 *eph_public,
241 self.local_eph.public,
242 self.session_salt,
243 server_binding,
244 ),
245 )
246 });
247 Ok(Some(ControlMessage::ServerHello {
248 eph_public: self.local_eph.public,
249 transcript_binding: server_binding,
250 auth: server_auth,
251 }))
252 }
253 (
254 HandshakeRole::Initiator,
255 SessionState::WaitingPeerHello,
256 ControlMessage::ServerHello {
257 eph_public,
258 transcript_binding,
259 auth,
260 },
261 ) => {
262 let expected =
263 server_hello_binding(self.local_eph.public, *eph_public, self.session_salt);
264 if transcript_binding != &expected {
265 return Err(CoreError::InvalidControlMessage);
266 }
267 let peer_authenticated =
268 self.verify_server_auth(*eph_public, *transcript_binding, auth.as_ref())?;
269
270 self.peer_eph_public = Some(*eph_public);
271 let shared = self.local_eph.shared_secret(*eph_public)?;
272 let keys = derive_traffic_keys(&shared, &self.session_salt, 0)?;
273
274 self.shared_secret = Some(shared);
275 self.active_keys = Some(keys);
276 self.state = SessionState::Active;
277 self.peer_authenticated = peer_authenticated;
278 self.last_rekey_at = Instant::now();
279 Ok(None)
280 }
281 (
282 _,
283 SessionState::Active,
284 ControlMessage::Rekey {
285 old_key_id,
286 new_key_id,
287 rekey_salt,
288 transcript_binding,
289 },
290 ) => {
291 let active = self
292 .active_keys
293 .as_ref()
294 .ok_or(CoreError::InvalidSessionState)?;
295 if *old_key_id != active.key_id {
296 return Err(CoreError::UnexpectedControlMessage);
297 }
298
299 let expected =
300 rekey_binding(*old_key_id, *new_key_id, *rekey_salt, self.session_salt);
301 if transcript_binding != &expected {
302 return Err(CoreError::InvalidControlMessage);
303 }
304
305 let shared = self.shared_secret.ok_or(CoreError::MissingSessionSecret)?;
306 let next = derive_rekey_traffic_keys(
307 &shared,
308 &self.session_salt,
309 rekey_salt,
310 *new_key_id,
311 )?;
312 self.install_new_active_key(next);
313 self.last_rekey_at = Instant::now();
314 Ok(None)
315 }
316 (_, SessionState::Active, ControlMessage::Error { .. }) => Ok(None),
317 _ => Err(CoreError::UnexpectedControlMessage),
318 }
319 }
320
321 pub fn active_keys(&self) -> Option<TrafficKeys> {
323 self.active_keys.clone()
324 }
325
326 pub fn active_and_previous_keys(&self) -> Option<Vec<TrafficKeys>> {
328 let mut out = Vec::new();
329 let active = self.active_keys.clone()?;
330 out.push(active);
331 out.extend(self.previous_keys.iter().cloned());
332 Some(out)
333 }
334
335 pub fn key_ring(&self) -> Result<Vec<TrafficKeys>, CoreError> {
337 self.active_and_previous_keys()
338 .ok_or(CoreError::InvalidSessionState)
339 }
340
341 pub fn on_outbound_payload(
343 &mut self,
344 plaintext_len: usize,
345 ) -> Result<Option<ControlMessage>, CoreError> {
346 if self.state != SessionState::Active {
347 return Err(CoreError::InvalidSessionState);
348 }
349
350 self.outbound_frames = self.outbound_frames.saturating_add(1);
351 self.outbound_bytes = self.outbound_bytes.saturating_add(plaintext_len as u64);
352
353 if self.should_rekey() {
354 let msg = self.force_rekey()?;
355 return Ok(Some(msg));
356 }
357
358 Ok(None)
359 }
360
361 pub fn force_rekey(&mut self) -> Result<ControlMessage, CoreError> {
363 if self.state != SessionState::Active {
364 return Err(CoreError::InvalidSessionState);
365 }
366
367 let active = self
368 .active_keys
369 .clone()
370 .ok_or(CoreError::InvalidSessionState)?;
371 let old_key_id = active.key_id;
372 let new_key_id = old_key_id.checked_add(1).ok_or(CoreError::KeyIdExhausted)?;
373
374 let mut rekey_salt = [0u8; 32];
375 OsRng.fill_bytes(&mut rekey_salt);
376
377 let shared = self.shared_secret.ok_or(CoreError::MissingSessionSecret)?;
378 let next = derive_rekey_traffic_keys(&shared, &self.session_salt, &rekey_salt, new_key_id)?;
379 self.install_new_active_key(next);
380
381 self.outbound_frames = 0;
382 self.outbound_bytes = 0;
383 self.last_rekey_at = Instant::now();
384
385 let transcript_binding =
386 rekey_binding(old_key_id, new_key_id, rekey_salt, self.session_salt);
387 Ok(ControlMessage::Rekey {
388 old_key_id,
389 new_key_id,
390 rekey_salt,
391 transcript_binding,
392 })
393 }
394
395 fn should_rekey(&self) -> bool {
396 self.outbound_frames >= self.thresholds.max_frames
397 || self.outbound_bytes >= self.thresholds.max_bytes
398 || self.last_rekey_at.elapsed() >= self.thresholds.max_age
399 }
400
401 fn install_new_active_key(&mut self, next: TrafficKeys) {
402 if let Some(current) = self.active_keys.take() {
403 self.previous_keys.insert(0, current);
404 if self.previous_keys.len() > self.thresholds.max_previous_keys {
405 self.previous_keys
406 .truncate(self.thresholds.max_previous_keys);
407 }
408 }
409 self.active_keys = Some(next);
410 }
411
412 fn verify_client_auth(
413 &self,
414 eph_public: [u8; 32],
415 session_salt: [u8; 32],
416 transcript_binding: [u8; 32],
417 auth: Option<&HandshakeAuth>,
418 ) -> Result<bool, CoreError> {
419 let message = client_auth_message(eph_public, session_salt, transcript_binding);
420 self.verify_auth_payload(auth, &message)
421 }
422
423 fn verify_server_auth(
424 &self,
425 server_public: [u8; 32],
426 transcript_binding: [u8; 32],
427 auth: Option<&HandshakeAuth>,
428 ) -> Result<bool, CoreError> {
429 let message = server_auth_message(
430 self.local_eph.public,
431 server_public,
432 self.session_salt,
433 transcript_binding,
434 );
435 self.verify_auth_payload(auth, &message)
436 }
437
438 fn verify_auth_payload(
439 &self,
440 auth: Option<&HandshakeAuth>,
441 message: &[u8],
442 ) -> Result<bool, CoreError> {
443 match auth {
444 Some(auth) => {
445 auth.verify(message)?;
446 if let Some(peer_identity) = self.auth.peer_identity()
447 && auth.identity_public_key != peer_identity.public_key
448 {
449 return Err(CoreError::PeerIdentityMismatch);
450 }
451 Ok(true)
452 }
453 None if self.auth.requires_peer_authentication()
454 || self.auth.peer_identity().is_some() =>
455 {
456 Err(CoreError::MissingPeerAuthentication)
457 }
458 None => Ok(false),
459 }
460 }
461}
462
463fn client_hello_binding(client_public: [u8; 32], session_salt: [u8; 32]) -> [u8; 32] {
464 let mut hasher = Sha256::new();
465 hasher.update(b"foctet hs client");
466 hasher.update(client_public);
467 hasher.update(session_salt);
468 hasher.finalize().into()
469}
470
471fn client_auth_message(
472 client_public: [u8; 32],
473 session_salt: [u8; 32],
474 transcript_binding: [u8; 32],
475) -> Vec<u8> {
476 let mut out = Vec::with_capacity(19 + 32 + 32 + 32);
477 out.extend_from_slice(b"foctet auth client");
478 out.extend_from_slice(&client_public);
479 out.extend_from_slice(&session_salt);
480 out.extend_from_slice(&transcript_binding);
481 out
482}
483
484fn server_hello_binding(
485 client_public: [u8; 32],
486 server_public: [u8; 32],
487 session_salt: [u8; 32],
488) -> [u8; 32] {
489 let mut hasher = Sha256::new();
490 hasher.update(b"foctet hs server");
491 hasher.update(client_public);
492 hasher.update(server_public);
493 hasher.update(session_salt);
494 hasher.finalize().into()
495}
496
497fn server_auth_message(
498 client_public: [u8; 32],
499 server_public: [u8; 32],
500 session_salt: [u8; 32],
501 transcript_binding: [u8; 32],
502) -> Vec<u8> {
503 let mut out = Vec::with_capacity(19 + 32 + 32 + 32 + 32);
504 out.extend_from_slice(b"foctet auth server");
505 out.extend_from_slice(&client_public);
506 out.extend_from_slice(&server_public);
507 out.extend_from_slice(&session_salt);
508 out.extend_from_slice(&transcript_binding);
509 out
510}
511
512fn rekey_binding(
513 old_key_id: u8,
514 new_key_id: u8,
515 rekey_salt: [u8; 32],
516 session_salt: [u8; 32],
517) -> [u8; 32] {
518 let mut hasher = Sha256::new();
519 hasher.update(b"foctet rekey");
520 hasher.update([old_key_id]);
521 hasher.update([new_key_id]);
522 hasher.update(rekey_salt);
523 hasher.update(session_salt);
524 hasher.finalize().into()
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530 use crate::{IdentityKeyPair, PeerIdentity};
531
532 #[test]
533 fn session_handshake_and_rekey() {
534 let (mut client, hello) = Session::new_initiator(RekeyThresholds::default());
535 let mut server = Session::new_responder(RekeyThresholds::default());
536
537 let server_hello = server
538 .handle_control(&hello)
539 .expect("server handle client hello")
540 .expect("server hello response");
541
542 client
543 .handle_control(&server_hello)
544 .expect("client handle server hello");
545
546 assert_eq!(client.state(), SessionState::Active);
547 assert_eq!(server.state(), SessionState::Active);
548
549 let rekey = client.force_rekey().expect("client force rekey");
550 server.handle_control(&rekey).expect("server handle rekey");
551
552 let client_key = client.active_keys().expect("client active key");
553 let server_key = server.active_keys().expect("server active key");
554 assert_eq!(client_key.key_id, server_key.key_id);
555 }
556
557 #[test]
558 fn session_authenticates_pinned_peer_identities() {
559 let client_identity = IdentityKeyPair::from_secret_key_bytes([0x41; 32]);
560 let server_identity = IdentityKeyPair::from_secret_key_bytes([0x61; 32]);
561 let client_auth = SessionAuthConfig::new()
562 .with_local_identity(client_identity.clone())
563 .with_peer_identity(PeerIdentity::new(server_identity.public_key()))
564 .require_peer_authentication(true);
565 let server_auth = SessionAuthConfig::new()
566 .with_local_identity(server_identity.clone())
567 .with_peer_identity(PeerIdentity::new(client_identity.public_key()))
568 .require_peer_authentication(true);
569
570 let (mut client, hello) =
571 Session::new_initiator_with_auth(RekeyThresholds::default(), client_auth);
572 let mut server = Session::new_responder_with_auth(RekeyThresholds::default(), server_auth);
573
574 let server_hello = server
575 .handle_control(&hello)
576 .expect("server handle client hello")
577 .expect("server hello response");
578 client
579 .handle_control(&server_hello)
580 .expect("client handle server hello");
581
582 assert!(client.peer_authenticated());
583 assert!(server.peer_authenticated());
584 }
585}