Skip to main content

libcrux_psq/
session.rs

1//! # Longterm Session
2//!
3//! This module implements long-term sessions derived from PSQ handshakes.
4
5mod session_key;
6mod transport;
7
8use std::io::Cursor;
9
10use session_key::{derive_session_key, SessionKey, SESSION_ID_LENGTH};
11use tls_codec::{
12    Deserialize, Serialize, SerializeBytes, Size, TlsDeserialize, TlsSerialize, TlsSize,
13};
14use transport::Transport;
15
16use crate::{
17    aead::{AEADError, AEADKeyNonce, AeadType},
18    handshake::{
19        ciphersuite::types::PQEncapsulationKey, dhkem::DHPublicKey, transcript::Transcript,
20        types::Authenticator,
21    },
22    session::session_key::derive_import_key,
23};
24
25/// Session related errors
26#[derive(Debug, PartialEq)]
27pub enum SessionError {
28    /// An error during session creation
29    IntoSession,
30    /// An error during serialization using TLSCodec
31    Serialize(tls_codec::Error),
32    /// An error during deserialization using TLSCodec
33    Deserialize(tls_codec::Error),
34    /// The given channel payload exceeded the maximum length
35    PayloadTooLong(usize),
36    /// An error in an underlying cryptographic primitive
37    CryptoError,
38    /// An error in storing or loading a session state
39    Storage,
40    /// The maxmium number of derivable channels has been reached
41    ReachedMaxChannels,
42    /// A channel message contains an inappropriate channel identifier
43    IdentifierMismatch,
44    /// The given payload exceeds the available output buffer
45    OutputBufferShort,
46    /// An error arising during the import of an external secret
47    Import,
48}
49
50impl From<AEADError> for SessionError {
51    fn from(value: AEADError) -> Self {
52        match value {
53            AEADError::CryptoError => SessionError::CryptoError,
54            AEADError::Serialize(error) => SessionError::Serialize(error),
55            AEADError::Deserialize(error) => SessionError::Deserialize(error),
56        }
57    }
58}
59/// The length of the public key binder in bytes.
60pub(crate) const PK_BINDER_LEN: usize = 8;
61
62#[derive(TlsSerialize, TlsDeserialize, TlsSize)]
63#[repr(u8)]
64/// The holder of a `Session` struct.
65///
66/// Note that this refers only to the roles during the initial handshake,
67/// as either peer can initiate a secure channel from a session.
68pub(crate) enum SessionPrincipal {
69    /// The handshake initiator
70    Initiator,
71    /// The handshake responder
72    Responder,
73}
74
75#[derive(TlsSerialize, TlsDeserialize, TlsSize)]
76/// A long-term session derived from the final handshake state.
77///
78/// Allows the creation of up to `u64::MAX` distinct, bi-directional
79/// secure `Channel`s between initiator and responder, via
80/// `Session::channel`.
81///
82/// The `Session` can be stored using `Session::serialize` and loaded
83/// using `Session::deserialize`, which expects references to the same
84/// public keys that were used in session creation to succeed.
85///
86/// **Warning**: Session state must only be stored and loaded using
87/// `Session::serialize` and `Session::deserialize`.  While `Session`
88/// implements `tls_codec::{Serialize, Deserialize}`, the associated
89/// methods should not be called directly, since they do not consume the
90/// `Session`. This opens up the possibility of continual use of a session
91/// that has also been serialized. If the serialized session is then
92/// deserialized, the deserialized version is stale and using it to
93/// re-derive `Channel`s will result in nonce re-use with the potential
94/// for loss of confidentiality.
95pub struct Session {
96    /// Which handshake party holds the `Session`
97    pub(crate) principal: SessionPrincipal,
98    /// The long-term session key derived from the handshake
99    pub(crate) session_key: SessionKey,
100    /// Binds this `Session` to the long-term public key material of both
101    /// parties that was used during the handshake
102    pub(crate) pk_binder: Option<[u8; PK_BINDER_LEN]>,
103    /// An increasing counter of derived secure channels
104    pub(crate) channel_counter: u64,
105    pub(crate) aead_type: AeadType,
106    pub(crate) transcript: Transcript,
107}
108
109// pkBinder = KDF(skCS, g^c | g^s | [pkS])
110fn derive_pk_binder(
111    key: &SessionKey,
112    initiator_authenticator: &Authenticator,
113    responder_ecdh_pk: &DHPublicKey,
114    responder_pq_pk: &Option<PQEncapsulationKey>,
115) -> Result<[u8; PK_BINDER_LEN], SessionError> {
116    #[derive(TlsSerialize, TlsSize)]
117    struct PkBinderInfo<'a> {
118        initiator_authenticator: &'a Authenticator,
119        responder_ecdh_pk: &'a DHPublicKey,
120        responder_pq_pk: &'a Option<PQEncapsulationKey<'a>>,
121    }
122
123    let info = PkBinderInfo {
124        initiator_authenticator,
125        responder_ecdh_pk,
126        responder_pq_pk,
127    };
128    let mut info_buf = vec![0u8; info.tls_serialized_len()];
129    info.tls_serialize(&mut &mut info_buf[..])
130        .map_err(SessionError::Serialize)?;
131
132    let mut binder = [0u8; PK_BINDER_LEN];
133
134    libcrux_hkdf::sha2_256::hkdf(
135        &mut binder,
136        &[],
137        &SerializeBytes::tls_serialize(&key.key).map_err(SessionError::Serialize)?,
138        &info_buf,
139    )
140    .map_err(|_| SessionError::CryptoError)?;
141
142    Ok(binder)
143}
144
145/// Wraps public key material that is bound to a session.
146pub struct SessionBinding<'a> {
147    /// The initiator's authenticator value, i.e. a long-term DH public value or signature verification key.
148    pub initiator_authenticator: &'a Authenticator,
149    /// The responder's long term DH public value.
150    pub responder_ecdh_pk: &'a DHPublicKey,
151    /// The responder's long term PQ-KEM public key (if any).
152    pub responder_pq_pk: Option<PQEncapsulationKey<'a>>,
153}
154
155impl Session {
156    /// Create a new `Session`.
157    ///
158    /// This will derive the long-term session key, and optionally compute a binder tying
159    /// the session key to any long-term public key material that was used during the
160    /// handshake.
161    pub(crate) fn new<'a>(
162        tx2: Transcript,
163        k2: AEADKeyNonce,
164        session_binding: Option<SessionBinding<'a>>,
165        is_initiator: bool,
166        aead_type: AeadType,
167    ) -> Result<Self, SessionError> {
168        let session_key = derive_session_key(k2, &tx2, aead_type)?;
169
170        let pk_binder = session_binding
171            .map(|session_binding| {
172                derive_pk_binder(
173                    &session_key,
174                    session_binding.initiator_authenticator,
175                    session_binding.responder_ecdh_pk,
176                    &session_binding.responder_pq_pk,
177                )
178            })
179            .transpose()?;
180
181        Ok(Self {
182            principal: if is_initiator {
183                SessionPrincipal::Initiator
184            } else {
185                SessionPrincipal::Responder
186            },
187            session_key,
188            pk_binder,
189            channel_counter: 0,
190            aead_type,
191            transcript: tx2,
192        })
193    }
194
195    /// Import a secret, replacing the main session secret
196    ///
197    /// A secret `psk` that is at least 32 bytes long can be imported
198    /// into the session, replacing the original main session secret
199    /// `K_S` with a fresh secret derived from the `psk` and `K_S`.
200    /// If a public key binding is provided it must be the same as for
201    /// the original session, but the new session will derive a fresh
202    /// session ID and update the running transcript `tx` for future
203    /// imports.
204    ///
205    /// In detail:
206    /// ```text
207    /// K_import = KDF(K_S || psk, "secret import")
208    /// tx' = Hash(tx || session_ID)
209    ///
210    /// // From here: treat K_import as though it was the outcome of a handshake
211    /// K_S' = KDF(K_import, "session secret" | tx')
212    /// session_ID' = KDF(K_S', "shared key id")
213    /// ```
214    ///
215    /// # Example
216    ///
217    /// ```ignore
218    /// // Initiator and responder finish the handshake resulting in
219    /// // `initiator_session` and `responder_session` both bound to public
220    /// // keys in `session_binding`.
221    ///
222    /// // WARN: In real usage, `psk` should be at least 32 bytes of high
223    /// // entropy randomness.
224    /// let psk = [0xab; 32];
225    ///
226    /// // Re-key the initiator session, providing the old session
227    /// // binding. This sustains the binding to these public keys.
228    /// // `session_binding`, if provided must match the binding of the
229    /// // original session.
230    /// let initiator_session = initiator_session.import(psk.as_slice(), session_binding).unwrap();
231    ///
232    /// // Re-key the responder session, providing the old session
233    /// // binding. This sustains the binding to these public keys.
234    /// // `session_binding`, if provided must match the binding of the
235    /// // original session.
236    /// let responder_session = initiator_session.import(psk.as_slice(), session_binding).unwrap();
237    ///
238    /// // [.. If `psk` was the same on both sides, you can now derive
239    /// // transport channels from the re-keyed session as before ..]
240    ///
241    /// // WARN: In real usage, `another_psk` should be at least 32 bytes of high
242    /// // entropy randomness.
243    /// let another_psk = [0xcd; 32];
244    ///
245    /// // Re-key the initiator session, stripping the binding to the original
246    /// // handshake public keys. Once the binding has been stripped it cannot
247    /// // be re-established without performing a fresh handshake. Exercise
248    /// // with caution to avoid session misbinding attacks.
249    /// let unbound_initiator_session = initiator_session.import(another_psk.as_slice(), None).unwrap();
250    ///
251    /// // Re-key the responder session, stripping the binding to the original
252    /// // handshake public keys. Once the binding has been stripped it cannot
253    /// // be re-established without performing a fresh handshake. Exercise
254    /// // with caution to avoid session misbinding attacks.
255    /// let unbound_responder_session = responder_session.import(another_psk.as_slice(), None).unwrap();
256    ///
257    /// // [.. If `psk` was the same on both sides, you can now derive
258    /// // transport channels from the re-keyed session as before ..]
259    /// ```
260    pub fn import<'a>(
261        self,
262        psk: &[u8],
263        session_binding: impl Into<Option<SessionBinding<'a>>>,
264    ) -> Result<Self, SessionError> {
265        // We require that the psk is at least 32 bytes long.
266        if psk.len() < 32 {
267            return Err(SessionError::Import);
268        }
269
270        let session_binding = session_binding.into();
271
272        match (self.pk_binder, &session_binding) {
273            // No binder was present, no new binder was provided.
274            (None, None) => (),
275            // No binder was present, but a new binder was provided => We disallow re-binding a session after the binding has been lost.
276            (None, Some(_)) => return Err(SessionError::Import),
277            // Some binder was present and no other binder was provided => This removes the binding from the session for good
278            (Some(_), None) => (),
279            // Some binder was present and a binder was provided for
280            // validation => The new session will be bound to the same
281            // keys as the original, if the binder is valid
282            (
283                Some(pk_binder),
284                Some(SessionBinding {
285                    initiator_authenticator,
286                    responder_ecdh_pk,
287                    responder_pq_pk,
288                }),
289            ) => {
290                if derive_pk_binder(
291                    &self.session_key,
292                    initiator_authenticator,
293                    responder_ecdh_pk,
294                    responder_pq_pk,
295                )? != pk_binder
296                {
297                    return Err(SessionError::Import);
298                }
299            }
300        };
301
302        let transcript =
303            Transcript::add_hash::<3>(Some(&self.transcript), self.session_key.identifier)
304                .map_err(|_| SessionError::Import)?;
305
306        let import_key = derive_import_key(self.session_key.key, psk, self.aead_type)?;
307
308        Self::new(
309            transcript,
310            import_key,
311            session_binding,
312            matches!(self.principal, SessionPrincipal::Initiator),
313            self.aead_type,
314        )
315    }
316
317    /// Serializes the session state for storage.
318    ///
319    /// We require the caller to input the public keys (if any) that
320    /// were used to create the session, in order to enforce they have
321    /// access to all keys necessary to deserialize the session later
322    /// on.
323    ///
324    /// WARN: `tls_serialize`
325    /// should not be called directly, since it does not consume
326    /// `Session`. This opens the possibility for nonce re-use by
327    /// deserializing a stale `Session` since the original could be
328    /// used after serialization.
329    pub fn serialize<'a>(
330        self,
331        out: &mut [u8],
332        session_binding: impl Into<Option<SessionBinding<'a>>>,
333    ) -> Result<usize, SessionError> {
334        let session_binding = session_binding.into();
335        match (self.pk_binder, session_binding) {
336            (None, None) => self
337                .tls_serialize(&mut &mut out[..])
338                .map_err(SessionError::Serialize),
339            (None, Some(_)) | (Some(_), None) => Err(SessionError::Storage),
340            (
341                Some(pk_binder),
342                Some(SessionBinding {
343                    initiator_authenticator,
344                    responder_ecdh_pk,
345                    responder_pq_pk,
346                }),
347            ) => {
348                if derive_pk_binder(
349                    &self.session_key,
350                    initiator_authenticator,
351                    responder_ecdh_pk,
352                    &responder_pq_pk,
353                )? != pk_binder
354                {
355                    Err(SessionError::Storage)
356                } else {
357                    self.tls_serialize(&mut &mut out[..])
358                        .map_err(SessionError::Serialize)
359                }
360            }
361        }
362    }
363
364    /// Export a secret derived from the main session key.
365    ///
366    /// Derives a secret `K` from the main session key as
367    /// `K = KDF(K_session, context || "PSQ secret export")`.
368    pub fn export_secret(&self, context: &[u8], out: &mut [u8]) -> Result<(), SessionError> {
369        use tls_codec::TlsSerializeBytes;
370        const PSQ_EXPORT_CONTEXT: &[u8; 17] = b"PSQ secret export";
371        #[derive(TlsSerializeBytes, TlsSize)]
372        struct ExportInfo<'a> {
373            context: &'a [u8],
374            separator: [u8; 17],
375        }
376
377        libcrux_hkdf::sha2_256::hkdf(
378            out,
379            b"",
380            self.session_key.key.as_ref(),
381            &ExportInfo {
382                context,
383                separator: *PSQ_EXPORT_CONTEXT,
384            }
385            .tls_serialize()
386            .map_err(SessionError::Serialize)?,
387        )
388        .map_err(|_| SessionError::CryptoError)
389    }
390
391    /// Deserialize a session state.
392    ///
393    /// If the session was bound to a set of public keys, those same public keys must be provided to validate the binding on deserialization.
394    // XXX: Use `tls_codec::conditional_deserializable` to implement
395    // the validation.
396    pub fn deserialize<'a>(
397        bytes: &[u8],
398        session_binding: impl Into<Option<SessionBinding<'a>>>,
399    ) -> Result<Self, SessionError> {
400        let session_binding = session_binding.into();
401        let session =
402            Session::tls_deserialize(&mut Cursor::new(bytes)).map_err(SessionError::Deserialize)?;
403
404        match (session.pk_binder, session_binding) {
405            // No binder was expected and none was provided.
406            (None, None) => Ok(session),
407            // No binder was expected, but a binder was provided =>
408            // Error to signal that this session is not bound to the
409            // provided binder.
410            (None, Some(_)) => Err(SessionError::Storage),
411            // Some binder was expected but none was provided.
412            (Some(_), None) => Err(SessionError::Storage),
413            // Some binder was expected and a binder was provided =>
414            // Deserialization is valid, if binder is valid.
415            (Some(pk_binder), Some(provided_binding)) => {
416                if derive_pk_binder(
417                    &session.session_key,
418                    provided_binding.initiator_authenticator,
419                    provided_binding.responder_ecdh_pk,
420                    &provided_binding.responder_pq_pk,
421                )? == pk_binder
422                {
423                    Ok(session)
424                } else {
425                    Err(SessionError::Storage)
426                }
427            }
428        }
429    }
430
431    /// Derive a new secure transport channel from the session state.
432    ///
433    /// The new transport channel allows both peers to send and receive
434    /// messages AEAD encrypted under a fresh channel key derived from the
435    /// long-term session key.
436    pub fn transport_channel(&mut self) -> Result<Transport, SessionError> {
437        let channel = Transport::new(self, matches!(self.principal, SessionPrincipal::Initiator))?;
438        self.channel_counter = self
439            .channel_counter
440            .checked_add(1)
441            .ok_or(SessionError::ReachedMaxChannels)?;
442        Ok(channel)
443    }
444
445    /// Output the channel identifier.
446    pub fn identifier(&self) -> &[u8; SESSION_ID_LENGTH] {
447        &self.session_key.identifier
448    }
449}