fips_core/noise/session.rs
1use super::{CipherState, HandshakeRole, NoiseError, ReplayWindow};
2use ring::aead::LessSafeKey;
3use secp256k1::{PublicKey, XOnlyPublicKey};
4use std::fmt;
5
6/// Completed Noise session for transport encryption.
7///
8/// Provides bidirectional authenticated encryption with replay protection.
9/// The send counter is monotonically incremented; received counters are
10/// validated against a sliding window to prevent replay attacks.
11pub struct NoiseSession {
12 /// Our role in the original handshake.
13 role: HandshakeRole,
14 /// Cipher for sending.
15 send_cipher: CipherState,
16 /// Cipher for receiving.
17 recv_cipher: CipherState,
18 /// Handshake hash for channel binding.
19 handshake_hash: [u8; 32],
20 /// Remote peer's static public key.
21 remote_static: PublicKey,
22 /// Replay window for received packets.
23 replay_window: ReplayWindow,
24}
25
26impl NoiseSession {
27 /// Create a new session from completed handshake data.
28 pub(super) fn from_handshake(
29 role: HandshakeRole,
30 send_cipher: CipherState,
31 recv_cipher: CipherState,
32 handshake_hash: [u8; 32],
33 remote_static: PublicKey,
34 ) -> Self {
35 Self {
36 role,
37 send_cipher,
38 recv_cipher,
39 handshake_hash,
40 remote_static,
41 replay_window: ReplayWindow::new(),
42 }
43 }
44
45 /// Encrypt a message for sending (using internal counter).
46 ///
47 /// Returns the ciphertext. The current send counter should be included
48 /// in the wire format before calling this method.
49 pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, NoiseError> {
50 self.send_cipher.encrypt(plaintext)
51 }
52
53 /// Get the current send counter (before incrementing).
54 ///
55 /// Use this to get the counter to include in the wire format.
56 /// The counter will be incremented when `encrypt` is called.
57 pub fn current_send_counter(&self) -> u64 {
58 self.send_cipher.nonce
59 }
60
61 /// Decrypt a received message (using internal counter).
62 ///
63 /// This is for handshake-phase decryption. For transport phase with
64 /// explicit counters, use `decrypt_with_replay_check` instead.
65 pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, NoiseError> {
66 self.recv_cipher.decrypt(ciphertext)
67 }
68
69 /// Check if a counter passes the replay window.
70 ///
71 /// Returns Ok(()) if the counter is acceptable, Err if it should be rejected.
72 /// Call this before attempting decryption to avoid wasting CPU on replay attacks.
73 pub fn check_replay(&self, counter: u64) -> Result<(), NoiseError> {
74 if self.replay_window.check(counter) {
75 Ok(())
76 } else {
77 Err(NoiseError::ReplayDetected(counter))
78 }
79 }
80
81 /// Decrypt with explicit counter and replay protection.
82 ///
83 /// This is the primary decryption method for transport phase.
84 /// The counter comes from the wire format and is validated against
85 /// the replay window before and after decryption.
86 ///
87 /// On success, the counter is accepted into the replay window.
88 pub fn decrypt_with_replay_check(
89 &mut self,
90 ciphertext: &[u8],
91 counter: u64,
92 ) -> Result<Vec<u8>, NoiseError> {
93 // Check replay window first (cheap)
94 if !self.replay_window.check(counter) {
95 return Err(NoiseError::ReplayDetected(counter));
96 }
97
98 // Attempt decryption (expensive)
99 let plaintext = self.recv_cipher.decrypt_with_counter(ciphertext, counter)?;
100
101 // Only accept into window after successful decryption
102 // This prevents DoS attacks that exhaust the window
103 self.replay_window.accept(counter);
104
105 Ok(plaintext)
106 }
107
108 /// Encrypt a message with Additional Authenticated Data (AAD).
109 ///
110 /// Returns the ciphertext. The current send counter should be included
111 /// in the wire format before calling this method.
112 pub fn encrypt_with_aad(
113 &mut self,
114 plaintext: &[u8],
115 aad: &[u8],
116 ) -> Result<Vec<u8>, NoiseError> {
117 self.send_cipher.encrypt_with_aad(plaintext, aad)
118 }
119
120 /// Decrypt with explicit counter, replay protection, and AAD.
121 ///
122 /// This is the primary decryption method for the FMP transport phase
123 /// with AAD binding. The AAD (typically the 16-byte outer header) must
124 /// match what was used during encryption.
125 pub fn decrypt_with_replay_check_and_aad(
126 &mut self,
127 ciphertext: &[u8],
128 counter: u64,
129 aad: &[u8],
130 ) -> Result<Vec<u8>, NoiseError> {
131 // Check replay window first (cheap)
132 if !self.replay_window.check(counter) {
133 return Err(NoiseError::ReplayDetected(counter));
134 }
135
136 // Attempt decryption with AAD (expensive)
137 let plaintext = self
138 .recv_cipher
139 .decrypt_with_counter_and_aad(ciphertext, counter, aad)?;
140
141 // Only accept into window after successful decryption
142 self.replay_window.accept(counter);
143
144 Ok(plaintext)
145 }
146
147 /// In-place variant of [`Self::decrypt_with_replay_check_and_aad`].
148 ///
149 /// On entry, `buf` holds `ciphertext + 16-byte AEAD tag`. On
150 /// successful return, `buf[..returned_len]` holds the plaintext.
151 /// The caller can then slice into `buf` without paying for an
152 /// extra heap allocation + memcpy per packet — at multi-Gbps
153 /// single-stream the by-value variant's `ciphertext.to_vec()`
154 /// alone is a measurable fraction of the rx_loop's per-packet
155 /// cost.
156 pub fn decrypt_with_replay_check_and_aad_in_place(
157 &mut self,
158 buf: &mut [u8],
159 counter: u64,
160 aad: &[u8],
161 ) -> Result<usize, NoiseError> {
162 if !self.replay_window.check(counter) {
163 return Err(NoiseError::ReplayDetected(counter));
164 }
165 let plaintext_len = self
166 .recv_cipher
167 .decrypt_with_counter_and_aad_in_place(buf, counter, aad)?;
168 self.replay_window.accept(counter);
169 Ok(plaintext_len)
170 }
171
172 /// Get the highest received counter.
173 pub fn highest_received_counter(&self) -> u64 {
174 self.replay_window.highest()
175 }
176
177 /// Clone the recv-side AEAD instance, for off-task decrypt.
178 ///
179 /// Returns `None` if the recv cipher has no key (transport phase has
180 /// not begun). The cloned cipher pairs with `decrypt_with_counter[_and_aad]`
181 /// on `CipherState`: a dispatcher can `check_replay` here, fan the
182 /// AEAD work out to a worker holding the clone + counter + aad, then
183 /// call `accept_replay` here once the worker reports success.
184 pub fn recv_cipher_clone(&self) -> Option<LessSafeKey> {
185 self.recv_cipher.cipher_clone()
186 }
187
188 /// Snapshot the current replay-window state as an **owned**
189 /// `ReplayWindow` value, for hand-off to a shard-owning decrypt
190 /// worker.
191 ///
192 /// **The worker becomes the sole authority for replay protection
193 /// on this session after this snapshot.** The local
194 /// `self.replay_window` is no longer the source of truth — it
195 /// only matters for rare-slow-path uses (rekey, drain-window
196 /// fallback). The worker keeps its copy in its own
197 /// thread-local `HashMap`, so there's no Mutex / no Arc / no
198 /// sharing — direct `&mut` access on every packet.
199 ///
200 /// (Previously this returned an `Arc<Mutex<ReplayWindow>>` for
201 /// concurrent access; the data-plane shard restructure now hands
202 /// the worker exclusive ownership instead.)
203 pub fn recv_replay_snapshot_owned(&self) -> crate::noise::ReplayWindow {
204 self.replay_window.clone()
205 }
206
207 /// Clone the send-side AEAD instance, for off-task encrypt.
208 ///
209 /// Returns `None` if the send cipher has no key. Pairs with
210 /// `encrypt_with_counter[_and_aad]` on `CipherState`. The caller must
211 /// own counter sequencing — `take_send_counter` hands out monotonic
212 /// counters under the session's own &mut.
213 pub fn send_cipher_clone(&self) -> Option<LessSafeKey> {
214 self.send_cipher.cipher_clone()
215 }
216
217 /// Reserve and return the next send counter, advancing the internal
218 /// nonce. For pipelined encrypt paths that call `encrypt_with_counter`
219 /// on a cloned cipher: the dispatcher pre-assigns the counter here
220 /// (under the session's &mut) and the worker performs the AEAD with
221 /// no further mutation of session state.
222 pub fn take_send_counter(&mut self) -> Result<u64, NoiseError> {
223 if self.send_cipher.nonce == u64::MAX {
224 return Err(NoiseError::NonceOverflow);
225 }
226 let counter = self.send_cipher.nonce;
227 self.send_cipher.nonce += 1;
228 Ok(counter)
229 }
230
231 /// Accept a counter into the replay window after a successful out-of-task
232 /// decrypt. Caller is responsible for verifying decrypt success first.
233 pub fn accept_replay(&mut self, counter: u64) {
234 self.replay_window.accept(counter);
235 }
236
237 /// Reset the replay window (use when rekeying).
238 pub fn reset_replay_window(&mut self) {
239 self.replay_window.reset();
240 }
241
242 /// Get the handshake hash for channel binding.
243 pub fn handshake_hash(&self) -> &[u8; 32] {
244 &self.handshake_hash
245 }
246
247 /// Get the remote peer's static public key.
248 pub fn remote_static(&self) -> &PublicKey {
249 &self.remote_static
250 }
251
252 /// Get the remote peer's x-only public key.
253 pub fn remote_static_xonly(&self) -> XOnlyPublicKey {
254 self.remote_static.x_only_public_key().0
255 }
256
257 /// Get our role in the handshake.
258 pub fn role(&self) -> HandshakeRole {
259 self.role
260 }
261
262 /// Get the send nonce (for debugging).
263 pub fn send_nonce(&self) -> u64 {
264 self.send_cipher.nonce()
265 }
266
267 /// Get the receive nonce (for debugging).
268 pub fn recv_nonce(&self) -> u64 {
269 self.recv_cipher.nonce()
270 }
271}
272
273impl fmt::Debug for NoiseSession {
274 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275 f.debug_struct("NoiseSession")
276 .field("role", &self.role)
277 .field("send_nonce", &self.send_cipher.nonce())
278 .field("recv_nonce", &self.recv_cipher.nonce())
279 .field("handshake_hash", &hex::encode(&self.handshake_hash[..8]))
280 .finish()
281 }
282}