Skip to main content

fips_core/noise/
session.rs

1use super::{CipherState, HandshakeRole, NoiseError, ReplayWindow};
2use ring::aead::LessSafeKey;
3use secp256k1::{PublicKey, XOnlyPublicKey};
4use std::fmt;
5
6/// Completed Noise session for transport encryption.
7///
8/// Provides bidirectional authenticated encryption with replay protection.
9/// The send counter is monotonically incremented; received counters are
10/// validated against a sliding window to prevent replay attacks.
11pub struct NoiseSession {
12    /// Our role in the original handshake.
13    role: HandshakeRole,
14    /// Cipher for sending.
15    send_cipher: CipherState,
16    /// Cipher for receiving.
17    recv_cipher: CipherState,
18    /// Handshake hash for channel binding.
19    handshake_hash: [u8; 32],
20    /// Remote peer's static public key.
21    remote_static: PublicKey,
22    /// Replay window for received packets.
23    replay_window: ReplayWindow,
24}
25
26impl NoiseSession {
27    /// Create a new session from completed handshake data.
28    pub(super) fn from_handshake(
29        role: HandshakeRole,
30        send_cipher: CipherState,
31        recv_cipher: CipherState,
32        handshake_hash: [u8; 32],
33        remote_static: PublicKey,
34    ) -> Self {
35        Self {
36            role,
37            send_cipher,
38            recv_cipher,
39            handshake_hash,
40            remote_static,
41            replay_window: ReplayWindow::new(),
42        }
43    }
44
45    /// Encrypt a message for sending (using internal counter).
46    ///
47    /// Returns the ciphertext. The current send counter should be included
48    /// in the wire format before calling this method.
49    pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, NoiseError> {
50        self.send_cipher.encrypt(plaintext)
51    }
52
53    /// Get the current send counter (before incrementing).
54    ///
55    /// Use this to get the counter to include in the wire format.
56    /// The counter will be incremented when `encrypt` is called.
57    pub fn current_send_counter(&self) -> u64 {
58        self.send_cipher.nonce
59    }
60
61    /// Decrypt a received message (using internal counter).
62    ///
63    /// This is for handshake-phase decryption. For transport phase with
64    /// explicit counters, use `decrypt_with_replay_check` instead.
65    pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, NoiseError> {
66        self.recv_cipher.decrypt(ciphertext)
67    }
68
69    /// Check if a counter passes the replay window.
70    ///
71    /// Returns Ok(()) if the counter is acceptable, Err if it should be rejected.
72    /// Call this before attempting decryption to avoid wasting CPU on replay attacks.
73    pub fn check_replay(&self, counter: u64) -> Result<(), NoiseError> {
74        if self.replay_window.check(counter) {
75            Ok(())
76        } else {
77            Err(NoiseError::ReplayDetected(counter))
78        }
79    }
80
81    /// Decrypt with explicit counter and replay protection.
82    ///
83    /// This is the primary decryption method for transport phase.
84    /// The counter comes from the wire format and is validated against
85    /// the replay window before and after decryption.
86    ///
87    /// On success, the counter is accepted into the replay window.
88    pub fn decrypt_with_replay_check(
89        &mut self,
90        ciphertext: &[u8],
91        counter: u64,
92    ) -> Result<Vec<u8>, NoiseError> {
93        // Check replay window first (cheap)
94        if !self.replay_window.check(counter) {
95            return Err(NoiseError::ReplayDetected(counter));
96        }
97
98        // Attempt decryption (expensive)
99        let plaintext = self.recv_cipher.decrypt_with_counter(ciphertext, counter)?;
100
101        // Only accept into window after successful decryption
102        // This prevents DoS attacks that exhaust the window
103        self.replay_window.accept(counter);
104
105        Ok(plaintext)
106    }
107
108    /// Encrypt a message with Additional Authenticated Data (AAD).
109    ///
110    /// Returns the ciphertext. The current send counter should be included
111    /// in the wire format before calling this method.
112    pub fn encrypt_with_aad(
113        &mut self,
114        plaintext: &[u8],
115        aad: &[u8],
116    ) -> Result<Vec<u8>, NoiseError> {
117        self.send_cipher.encrypt_with_aad(plaintext, aad)
118    }
119
120    /// Decrypt with explicit counter, replay protection, and AAD.
121    ///
122    /// This is the primary decryption method for the FMP transport phase
123    /// with AAD binding. The AAD (typically the 16-byte outer header) must
124    /// match what was used during encryption.
125    pub fn decrypt_with_replay_check_and_aad(
126        &mut self,
127        ciphertext: &[u8],
128        counter: u64,
129        aad: &[u8],
130    ) -> Result<Vec<u8>, NoiseError> {
131        // Check replay window first (cheap)
132        if !self.replay_window.check(counter) {
133            return Err(NoiseError::ReplayDetected(counter));
134        }
135
136        // Attempt decryption with AAD (expensive)
137        let plaintext = self
138            .recv_cipher
139            .decrypt_with_counter_and_aad(ciphertext, counter, aad)?;
140
141        // Only accept into window after successful decryption
142        self.replay_window.accept(counter);
143
144        Ok(plaintext)
145    }
146
147    /// In-place variant of [`Self::decrypt_with_replay_check_and_aad`].
148    ///
149    /// On entry, `buf` holds `ciphertext + 16-byte AEAD tag`. On
150    /// successful return, `buf[..returned_len]` holds the plaintext.
151    /// The caller can then slice into `buf` without paying for an
152    /// extra heap allocation + memcpy per packet — at multi-Gbps
153    /// single-stream the by-value variant's `ciphertext.to_vec()`
154    /// alone is a measurable fraction of the rx_loop's per-packet
155    /// cost.
156    pub fn decrypt_with_replay_check_and_aad_in_place(
157        &mut self,
158        buf: &mut [u8],
159        counter: u64,
160        aad: &[u8],
161    ) -> Result<usize, NoiseError> {
162        if !self.replay_window.check(counter) {
163            return Err(NoiseError::ReplayDetected(counter));
164        }
165        let plaintext_len = self
166            .recv_cipher
167            .decrypt_with_counter_and_aad_in_place(buf, counter, aad)?;
168        self.replay_window.accept(counter);
169        Ok(plaintext_len)
170    }
171
172    /// Get the highest received counter.
173    pub fn highest_received_counter(&self) -> u64 {
174        self.replay_window.highest()
175    }
176
177    /// Clone the recv-side AEAD instance, for off-task decrypt.
178    ///
179    /// Returns `None` if the recv cipher has no key (transport phase has
180    /// not begun). The cloned cipher pairs with `decrypt_with_counter[_and_aad]`
181    /// on `CipherState`: a dispatcher can `check_replay` here, fan the
182    /// AEAD work out to a worker holding the clone + counter + aad, then
183    /// call `accept_replay` here once the worker reports success.
184    pub fn recv_cipher_clone(&self) -> Option<LessSafeKey> {
185        self.recv_cipher.cipher_clone()
186    }
187
188    /// Snapshot the current replay-window state as an **owned**
189    /// `ReplayWindow` value, for hand-off to a shard-owning decrypt
190    /// worker.
191    ///
192    /// **The worker becomes the sole authority for replay protection
193    /// on this session after this snapshot.** The local
194    /// `self.replay_window` is no longer the source of truth — it
195    /// only matters for rare-slow-path uses (rekey, drain-window
196    /// fallback). The worker keeps its copy in its own
197    /// thread-local `HashMap`, so there's no Mutex / no Arc / no
198    /// sharing — direct `&mut` access on every packet.
199    ///
200    /// (Previously this returned an `Arc<Mutex<ReplayWindow>>` for
201    /// concurrent access; the data-plane shard restructure now hands
202    /// the worker exclusive ownership instead.)
203    pub fn recv_replay_snapshot_owned(&self) -> crate::noise::ReplayWindow {
204        self.replay_window.clone()
205    }
206
207    /// Clone the send-side AEAD instance, for off-task encrypt.
208    ///
209    /// Returns `None` if the send cipher has no key. Pairs with
210    /// `encrypt_with_counter[_and_aad]` on `CipherState`. The caller must
211    /// own counter sequencing — `take_send_counter` hands out monotonic
212    /// counters under the session's own &mut.
213    pub fn send_cipher_clone(&self) -> Option<LessSafeKey> {
214        self.send_cipher.cipher_clone()
215    }
216
217    /// Reserve and return the next send counter, advancing the internal
218    /// nonce. For pipelined encrypt paths that call `encrypt_with_counter`
219    /// on a cloned cipher: the dispatcher pre-assigns the counter here
220    /// (under the session's &mut) and the worker performs the AEAD with
221    /// no further mutation of session state.
222    pub fn take_send_counter(&mut self) -> Result<u64, NoiseError> {
223        if self.send_cipher.nonce == u64::MAX {
224            return Err(NoiseError::NonceOverflow);
225        }
226        let counter = self.send_cipher.nonce;
227        self.send_cipher.nonce += 1;
228        Ok(counter)
229    }
230
231    /// Accept a counter into the replay window after a successful out-of-task
232    /// decrypt. Caller is responsible for verifying decrypt success first.
233    pub fn accept_replay(&mut self, counter: u64) {
234        self.replay_window.accept(counter);
235    }
236
237    /// Reset the replay window (use when rekeying).
238    pub fn reset_replay_window(&mut self) {
239        self.replay_window.reset();
240    }
241
242    /// Get the handshake hash for channel binding.
243    pub fn handshake_hash(&self) -> &[u8; 32] {
244        &self.handshake_hash
245    }
246
247    /// Get the remote peer's static public key.
248    pub fn remote_static(&self) -> &PublicKey {
249        &self.remote_static
250    }
251
252    /// Get the remote peer's x-only public key.
253    pub fn remote_static_xonly(&self) -> XOnlyPublicKey {
254        self.remote_static.x_only_public_key().0
255    }
256
257    /// Get our role in the handshake.
258    pub fn role(&self) -> HandshakeRole {
259        self.role
260    }
261
262    /// Get the send nonce (for debugging).
263    pub fn send_nonce(&self) -> u64 {
264        self.send_cipher.nonce()
265    }
266
267    /// Get the receive nonce (for debugging).
268    pub fn recv_nonce(&self) -> u64 {
269        self.recv_cipher.nonce()
270    }
271}
272
273impl fmt::Debug for NoiseSession {
274    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275        f.debug_struct("NoiseSession")
276            .field("role", &self.role)
277            .field("send_nonce", &self.send_cipher.nonce())
278            .field("recv_nonce", &self.recv_cipher.nonce())
279            .field("handshake_hash", &hex::encode(&self.handshake_hash[..8]))
280            .finish()
281    }
282}