noise/
state.rs

1// Set of libraries for privacy-preserving networking apps
2//
3// SPDX-License-Identifier: Apache-2.0
4//
5// Written in 2023 by
6//     Dr. Maxim Orlovsky <orlovsky@cyphernet.org>
7//
8// Copyright 2023 Cyphernet DAO, Switzerland
9//
10// Licensed under the Apache License, Version 2.0 (the "License");
11// you may not use this file except in compliance with the License.
12// You may obtain a copy of the License at
13//
14//     http://www.apache.org/licenses/LICENSE-2.0
15//
16// Unless required by applicable law or agreed to in writing, software
17// distributed under the License is distributed on an "AS IS" BASIS,
18// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19// See the License for the specific language governing permissions and
20// limitations under the License.
21
22use std::collections::VecDeque;
23
24use cypher::{Digest, EcPk, Ecdh};
25
26use crate::cipher::{decrypt, encrypt, rekey, SharedSecret};
27use crate::error::{EncryptionError, NoiseError};
28use crate::hkdf::{hkdf_2, hkdf_3};
29use crate::patterns::{HandshakePattern, Keyset, MessagePattern};
30use crate::{ChainingKey, HandshakeHash, NoiseNonce};
31
32trait WithTruncated {
33    fn with_truncated(temp_key: impl AsRef<[u8]>) -> Self;
34}
35
36impl WithTruncated for SharedSecret {
37    fn with_truncated(temp_key: impl AsRef<[u8]>) -> Self {
38        let mut key = [0u8; 32];
39        match temp_key.as_ref().len() {
40            32 => {
41                key.copy_from_slice(temp_key.as_ref());
42            }
43            64 => {
44                key.copy_from_slice(&temp_key.as_ref()[..32]);
45            }
46            x => {
47                panic!(
48                    "Noise protocol requires HASH function with output length either 32 or 64 \
49                     bytes (a function outputting {x} bytes were given)"
50                )
51            }
52        }
53        key
54    }
55}
56
57#[derive(Clone, Eq, PartialEq, Debug, Default)]
58pub struct CipherState {
59    k: SharedSecret,
60    n: NoiseNonce,
61}
62
63impl CipherState {
64    pub fn new() -> Self { CipherState { k: [0u8; 32], n: 0 } }
65
66    pub fn initialize_key(&mut self, key: SharedSecret) {
67        self.k = key;
68        self.n = 0;
69    }
70
71    pub fn has_key(&self) -> bool { self.k != [0u8; 32] }
72
73    pub fn nonce(&self) -> NoiseNonce { self.n }
74
75    pub fn set_nonce(&mut self, nonce: NoiseNonce) { self.n = nonce; }
76
77    pub fn encrypt_with_ad(
78        &mut self,
79        ad: &[u8],
80        plaintext: &[u8],
81    ) -> Result<Vec<u8>, EncryptionError> {
82        if self.k.is_empty() {
83            Ok(plaintext.to_vec())
84        } else {
85            // If k is non-empty returns ENCRYPT(k, n++, ad, plaintext).
86            let ciphertext = encrypt(self.k, self.n, ad, plaintext);
87            self.n += 1;
88            ciphertext
89        }
90    }
91    pub fn decrypt_with_ad(
92        &mut self,
93        ad: &[u8],
94        ciphertext: &[u8],
95    ) -> Result<Vec<u8>, EncryptionError> {
96        if self.k.is_empty() {
97            Ok(ciphertext.to_vec())
98        } else {
99            // If k is non-empty returns DECRYPT(k, n++, ad, ciphertext).
100            // If an authentication failure occurs in DECRYPT() then n is not incremented and an
101            // error is signaled to the caller.
102            let plaintext = decrypt(self.k, self.n, ad, ciphertext)?;
103            self.n += 1;
104            Ok(plaintext)
105        }
106    }
107
108    pub fn rekey(&mut self) { self.k = rekey(self.k); }
109}
110
111#[derive(Clone, Eq, PartialEq, Debug)]
112pub struct SymmetricState<D: Digest> {
113    cipher: CipherState,
114    ck: ChainingKey<D>,
115    h: HandshakeHash<D>,
116    was_split: bool,
117}
118
119impl<D: Digest> SymmetricState<D> {
120    pub fn with<const HASHLEN: usize>(protocol_name: String) -> Self {
121        debug_assert_eq!(HASHLEN, D::OUTPUT_LEN);
122        let len = protocol_name.len();
123        let h = if len <= HASHLEN {
124            let mut h = [0u8; HASHLEN];
125            h[..len].copy_from_slice(protocol_name.as_bytes());
126            D::Output::try_from(&h).unwrap_or_else(|_| unreachable!())
127        } else {
128            D::digest(protocol_name.as_bytes())
129        };
130        let cipher = CipherState::new();
131        Self {
132            cipher,
133            h,
134            ck: h,
135            was_split: false,
136        }
137    }
138
139    pub fn mix_key(&mut self, input_key_material: impl AsRef<[u8]>) {
140        let (ck, temp_key) = hkdf_2::<D>(self.ck, input_key_material);
141        self.ck = ck;
142        self.cipher.initialize_key(SharedSecret::with_truncated(temp_key));
143    }
144
145    pub fn mix_hash(&mut self, data: impl AsRef<[u8]>) {
146        self.h = D::digest_concat([self.h.as_ref(), data.as_ref()]);
147    }
148
149    // TODO: Use in PSK
150    #[allow(dead_code)]
151    pub fn mix_key_and_hash(&mut self, input_key_material: impl AsRef<[u8]>) {
152        // Sets ck, temp_h, temp_k = HKDF(ck, input_key_material, 3).
153        let (ck, temp_h, temp_k) = hkdf_3::<D>(self.ck, input_key_material);
154        self.ck = ck;
155
156        // Calls MixHash(temp_h).
157        self.mix_hash(temp_h);
158
159        // If HASHLEN is 64, then truncates temp_k to 32 bytes.
160        // Calls InitializeKey(temp_k).
161        self.cipher.initialize_key(SharedSecret::with_truncated(temp_k));
162    }
163
164    pub fn get_handshake_hash(&self) -> HandshakeHash<D> {
165        if !self.was_split {
166            panic!(
167                "SymmetricState::get_handshake_hash must be called only after \
168                 SymmetricState::split"
169            )
170        }
171        self.h
172    }
173
174    pub fn encrypt_and_hash(
175        &mut self,
176        plaintext: impl AsRef<[u8]>,
177    ) -> Result<Vec<u8>, EncryptionError> {
178        // ciphertext = EncryptWithAd(h, plaintext), calls MixHash(ciphertext), and returns
179        // ciphertext. Note that if k is empty, the EncryptWithAd() call will set ciphertext equal
180        // to plaintext.
181        let ciphertext = self.cipher.encrypt_with_ad(self.h.as_ref(), plaintext.as_ref())?;
182        self.mix_hash(&ciphertext);
183        Ok(ciphertext)
184    }
185
186    pub fn decrypt_and_hash(
187        &mut self,
188        ciphertext: impl AsRef<[u8]>,
189    ) -> Result<Vec<u8>, EncryptionError> {
190        // plaintext = DecryptWithAd(h, ciphertext), calls MixHash(ciphertext), and returns
191        // plaintext. Note that if k is empty, the DecryptWithAd() call will set plaintext equal to
192        // ciphertext.
193        let plaintext = self.cipher.decrypt_with_ad(self.h.as_ref(), ciphertext.as_ref())?;
194        self.mix_hash(&ciphertext);
195        Ok(plaintext)
196    }
197
198    pub fn split(&mut self) -> (CipherState, CipherState) {
199        // Sets temp_k1, temp_k2 = HKDF(ck, zerolen, 2).
200        let (temp_k1, temp_k2) = hkdf_2::<D>(self.ck, []);
201
202        // If HASHLEN is 64, then truncates temp_k1 and temp_k2 to 32 bytes.
203        let k1 = SharedSecret::with_truncated(temp_k1);
204        let k2 = SharedSecret::with_truncated(temp_k2);
205
206        // Creates two new CipherState objects c1 and c2.
207        let mut c1 = CipherState::new();
208        let mut c2 = CipherState::new();
209
210        // Calls c1.InitializeKey(temp_k1) and c2.InitializeKey(temp_k2).
211        c1.initialize_key(k1);
212        c2.initialize_key(k2);
213
214        self.was_split = true;
215
216        // Returns the pair (c1, c2).
217        (c1, c2)
218    }
219}
220
221#[derive(Clone, Eq, PartialEq, Debug)]
222pub struct HandshakeState<E: Ecdh, D: Digest> {
223    state: SymmetricState<D>,
224    is_initiator: bool,
225    keyset: Keyset<E>,
226    handshake_pattern: HandshakePattern,
227    read_message_patterns: VecDeque<&'static [MessagePattern]>,
228    write_message_patterns: VecDeque<&'static [MessagePattern]>,
229}
230
231impl<E: Ecdh, D: Digest> HandshakeState<E, D> {
232    /// Initialize(handshake_pattern, initiator, prologue, s, e, rs, re): Takes a valid
233    /// handshake_pattern (see Section 7) and an initiator boolean specifying this party's role
234    /// as either initiator or responder.
235    ///
236    /// Takes a prologue byte sequence which may be zero-length, or which may contain context
237    /// information that both parties want to confirm is identical (see Section 6).
238    ///
239    /// Takes a set of DH key pairs (s, e) and public keys (rs, re) for initializing local
240    /// variables, any of which may be empty. Public keys are only passed in if the
241    /// handshake_pattern uses pre-messages (see Section 7). The ephemeral values (e, re) are
242    /// typically left empty, since they are created and exchanged during the handshake; but
243    /// there are exceptions (see Section 10).
244    ///
245    /// Performs the following steps:
246    ///
247    /// Derives a protocol_name byte sequence by combining the names for the handshake pattern
248    /// and crypto functions, as specified in Section 8.
249    ///
250    /// Calls InitializeSymmetric(protocol_name).
251    ///
252    /// Calls MixHash(prologue).
253    ///
254    /// Sets the initiator, s, e, rs, and re variables to the corresponding arguments.
255    ///
256    /// Calls MixHash() once for each public key listed in the pre-messages from
257    /// handshake_pattern, with the specified public key as input (see Section 7 for an
258    /// explanation of pre-messages). If both initiator and responder have pre-messages, the
259    /// initiator's public keys are hashed first. If multiple public keys are listed in either
260    /// party's pre-message, the public keys are hashed in the order that they are listed.
261    ///
262    /// Sets message_patterns to the message patterns from handshake_pattern.
263    pub fn initialize<const HASHLEN: usize>(
264        handshake_pattern: HandshakePattern,
265        is_initiator: bool,
266        prologue: &[u8],
267        keyset: Keyset<E>,
268    ) -> Self {
269        debug_assert_eq!(HASHLEN, D::OUTPUT_LEN);
270        let mut name_components = vec![s!("Noise")];
271        let curve_name = match E::Pk::CURVE_NAME {
272            "Curve25519" => "25519",
273            "Secp256k1" => "secp256k1",
274            "Edwards25519" => "Edwards25519", // ECDH over Edwards-coordinate version of Curve25519
275            unsupported => {
276                unimplemented!("curve {unsupported} is not supported by the Noise library")
277            }
278        };
279        name_components.push(handshake_pattern.to_string());
280        name_components.push(curve_name.to_owned());
281        name_components.push("ChaChaPoly".to_owned());
282        name_components.push(D::DIGEST_NAME.to_owned());
283        let protocol_name = name_components.join("_");
284
285        let mut state = SymmetricState::<D>::with::<HASHLEN>(protocol_name);
286        state.mix_hash(prologue);
287
288        for pre_msg in handshake_pattern.pre_messages() {
289            if let Some(key) = keyset.pre_message_key(*pre_msg, is_initiator) {
290                state.mix_hash(key.to_pk_compressed().as_ref())
291            }
292        }
293
294        let write_message_patterns =
295            VecDeque::from_iter(handshake_pattern.message_patterns(is_initiator).iter().copied());
296        let read_message_patterns =
297            VecDeque::from_iter(handshake_pattern.message_patterns(!is_initiator).iter().copied());
298
299        Self {
300            handshake_pattern,
301            read_message_patterns,
302            write_message_patterns,
303            is_initiator,
304            keyset,
305            state,
306        }
307    }
308
309    /// Takes a payload byte sequence which may be zero-length
310    ///
311    /// # Errors
312    ///
313    /// If any EncryptAndHash() call returns an error
314    fn write_message(&mut self, payload: &[u8]) -> Result<HandshakeAct, EncryptionError> {
315        match self.write_message_patterns.pop_front() {
316            Some(seq) => {
317                let mut message_buffer = Vec::<u8>::new();
318                // 1. Fetches and deletes the next message pattern from message_patterns, then
319                // sequentially processes each token from the message pattern:
320                //   - For "e": Sets e (which must be empty) to GENERATE_KEYPAIR(). Appends
321                //     e.public_key to the buffer. Calls MixHash(e.public_key).
322                //   - For "s": Appends EncryptAndHash(s.public_key) to the buffer.
323                //   - For "ee": Calls MixKey(DH(e, re)).
324                //   - For "es": Calls MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if
325                //     responder.
326                //   - For "se": Calls MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if
327                //     responder.
328                //   - For "ss": Calls MixKey(DH(s, rs)).
329
330                for pat in seq {
331                    match pat {
332                        MessagePattern::E => {
333                            let (e, pubkey) = E::generate_keypair();
334                            let pk = pubkey.to_pk_compressed();
335                            message_buffer.extend(pk.as_ref());
336
337                            self.state.mix_hash(pk);
338                            self.keyset.e = e;
339                        }
340                        MessagePattern::S => {
341                            let s = self.keyset.expect_s().to_pk()?.to_pk_compressed();
342                            let enc = self.state.encrypt_and_hash(s)?;
343                            message_buffer.extend(&enc)
344                        }
345                        MessagePattern::EE => {
346                            self.state.mix_key(self.keyset.e.ecdh(self.keyset.expect_re())?)
347                        }
348                        MessagePattern::ES if self.is_initiator => {
349                            self.state.mix_key(self.keyset.e.ecdh(self.keyset.expect_rs())?)
350                        }
351                        MessagePattern::ES => self
352                            .state
353                            .mix_key(self.keyset.expect_s().ecdh(self.keyset.expect_re())?),
354                        MessagePattern::SE if self.is_initiator => self
355                            .state
356                            .mix_key(self.keyset.expect_s().ecdh(self.keyset.expect_re())?),
357                        MessagePattern::SE => {
358                            self.state.mix_key(self.keyset.e.ecdh(self.keyset.expect_rs())?)
359                        }
360                        MessagePattern::SS => self
361                            .state
362                            .mix_key(self.keyset.expect_s().ecdh(self.keyset.expect_rs())?),
363                    };
364                }
365
366                // 2. Appends EncryptAndHash(payload) to the buffer.
367                message_buffer.extend(self.state.encrypt_and_hash(payload)?);
368
369                Ok(HandshakeAct::Buffer(message_buffer))
370            }
371            None => {
372                // 3. If there are no more message patterns returns two new CipherState objects by
373                // calling Split().
374                let (c1, c2) = self.state.split();
375                Ok(HandshakeAct::Split(c1, c2))
376            }
377        }
378    }
379
380    /// Takes a byte sequence containing a Noise handshake message, and a payload_buffer to write
381    /// the message's plaintext payload into.
382    ///
383    /// # Errors
384    ///
385    /// If any DecryptAndHash() call returns an error
386    fn read_message(&mut self, message: &[u8]) -> Result<HandshakeAct, EncryptionError> {
387        match self.read_message_patterns.pop_front() {
388            Some(seq) => {
389                let mut payload_buffer = Vec::new();
390                let mut pos = 0usize;
391
392                // Performs the following steps:
393                //
394                // 1. Fetches and deletes the next message pattern from message_patterns, then
395                // sequentially processes each token from the message pattern:
396                //   - For "e": Sets re (which must be empty) to the next DHLEN bytes from the
397                //     message. Calls MixHash(re.public_key).
398                //   - For "s": Sets temp to the next DHLEN + 16 bytes of the message if HasKey() ==
399                //     True, or to the next DHLEN bytes otherwise. Sets rs (which must be empty) to
400                //     DecryptAndHash(temp).
401                //   - For "ee": Calls MixKey(DH(e, re)).
402                //   - For "es": Calls MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if
403                //     responder.
404                //   - For "se": Calls MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if
405                //     responder.
406                //   - For "ss": Calls MixKey(DH(s, rs)).
407                for pat in seq {
408                    match pat {
409                        MessagePattern::E => {
410                            debug_assert!(self.keyset.re.is_none());
411
412                            let next_pos = pos + E::Pk::COMPRESSED_LEN;
413                            let re = E::Pk::from_pk_compressed_slice(&message[pos..next_pos])?;
414                            pos = next_pos;
415
416                            self.state.mix_hash(re.to_pk_compressed());
417                            self.keyset.re = Some(re);
418                        }
419                        MessagePattern::S => {
420                            debug_assert!(self.keyset.rs.is_none());
421                            let next_pos = match self.state.cipher.has_key() {
422                                true => D::OUTPUT_LEN + 16,
423                                false => D::OUTPUT_LEN,
424                            };
425                            let temp = self.state.decrypt_and_hash(&message[pos..next_pos])?;
426                            self.keyset.rs = Some(E::Pk::from_pk_compressed_slice(&temp)?);
427                            pos = next_pos;
428                        }
429                        MessagePattern::EE => {
430                            self.state.mix_key(self.keyset.e.ecdh(self.keyset.expect_re())?);
431                        }
432                        MessagePattern::ES if self.is_initiator => {
433                            self.state.mix_key(self.keyset.e.ecdh(self.keyset.expect_rs())?);
434                        }
435                        MessagePattern::ES => {
436                            self.state
437                                .mix_key(self.keyset.expect_s().ecdh(self.keyset.expect_re())?);
438                        }
439                        MessagePattern::SE if self.is_initiator => {
440                            self.state
441                                .mix_key(self.keyset.expect_s().ecdh(self.keyset.expect_re())?);
442                        }
443                        MessagePattern::SE => {
444                            self.state.mix_key(self.keyset.e.ecdh(self.keyset.expect_rs())?);
445                        }
446                        MessagePattern::SS => {
447                            self.state
448                                .mix_key(self.keyset.expect_s().ecdh(self.keyset.expect_rs())?);
449                        }
450                    }
451                }
452
453                // 2. Calls DecryptAndHash() on the remaining bytes of the message and stores the
454                // output into payload_buffer.
455                let output = self.state.decrypt_and_hash(&message[pos..])?;
456                payload_buffer.extend(output);
457
458                Ok(HandshakeAct::Buffer(payload_buffer))
459            }
460            None => {
461                // 3. If there are no more message patterns returns two new CipherState objects by
462                // calling Split().
463                let (c1, c2) = self.state.split();
464                Ok(HandshakeAct::Split(c1, c2))
465            }
466        }
467    }
468
469    /// Provides information about next message length which should be read from a network stream.
470    fn next_read_len(&self) -> usize {
471        if self.read_message_patterns.is_empty() {
472            return 0;
473        }
474        let mut pos = 0;
475        let seq = self.read_message_patterns[0];
476        for pat in seq {
477            match pat {
478                MessagePattern::E => {
479                    pos += E::Pk::COMPRESSED_LEN;
480                }
481                MessagePattern::S if self.state.cipher.has_key() => {
482                    pos += D::OUTPUT_LEN + 16;
483                }
484                MessagePattern::S => {
485                    pos += D::OUTPUT_LEN;
486                }
487                _ => {}
488            }
489        }
490        pos
491    }
492}
493
494#[derive(Clone, Eq, PartialEq, Debug)]
495pub enum HandshakeAct {
496    Buffer(Vec<u8>),
497    Split(CipherState, CipherState),
498}
499
500#[derive(Clone, Eq, PartialEq, Debug)]
501pub enum NoiseState<E: Ecdh, D: Digest> {
502    AwaitWrite(HandshakeState<E, D>),
503    Handshake(HandshakeState<E, D>),
504    Active {
505        sending_cipher: CipherState,
506        receiving_cipher: CipherState,
507        handshake_hash: HandshakeHash<D>,
508        remote_static_pubkey: Option<E::Pk>,
509    },
510}
511
512impl<E: Ecdh, D: Digest> NoiseState<E, D> {
513    pub fn initialize<const HASHLEN: usize>(
514        handshake_pattern: HandshakePattern,
515        is_initiator: bool,
516        prologue: &[u8],
517        keyset: Keyset<E>,
518    ) -> Self {
519        debug_assert_eq!(HASHLEN, D::OUTPUT_LEN);
520
521        let handshake = HandshakeState::initialize::<HASHLEN>(
522            handshake_pattern,
523            is_initiator,
524            prologue,
525            keyset,
526        );
527        match is_initiator {
528            true => Self::AwaitWrite(handshake),
529            false => Self::Handshake(handshake),
530        }
531    }
532
533    /// Takes incoming data from the remote peer, advances internal state machine
534    /// and returns a data to be sent to the remote peer for the next handshake
535    /// act. If the handshake is over, returns an empty vector. On subsequent
536    /// calls return [`NoiseError::HandshakeComplete`] error.
537    pub fn advance(&mut self, input: &[u8]) -> Result<Vec<u8>, NoiseError> {
538        let (output, payload) = self.advance_with_payload(input, &[])?;
539        if !payload.is_empty() {
540            Err(NoiseError::PayloadNotEmpty)
541        } else {
542            Ok(output)
543        }
544    }
545
546    pub fn advance_with_payload(
547        &mut self,
548        input: &[u8],
549        payload: &[u8],
550    ) -> Result<(Vec<u8>, Vec<u8>), NoiseError> {
551        match self {
552            NoiseState::AwaitWrite(handshake) => {
553                let act = handshake.write_message(payload)?;
554                match act {
555                    HandshakeAct::Buffer(buffer) => {
556                        *self = NoiseState::Handshake(handshake.clone());
557                        Ok((buffer, vec![]))
558                    }
559                    _ => panic!("single-act handshake doesn't exist for noise protocol"),
560                }
561            }
562            NoiseState::Handshake(handshake) => {
563                let act = handshake.read_message(input)?;
564                let read_payload = match act {
565                    HandshakeAct::Buffer(payload) => payload,
566                    HandshakeAct::Split(sending_cipher, receiving_cipher) => {
567                        *self = NoiseState::Active {
568                            sending_cipher,
569                            receiving_cipher,
570                            handshake_hash: handshake.state.get_handshake_hash(),
571                            remote_static_pubkey: handshake.keyset.rs.clone(),
572                        };
573                        return Ok((vec![], vec![]));
574                    }
575                };
576                let act = handshake.write_message(payload)?;
577                match act {
578                    HandshakeAct::Buffer(buffer) => Ok((buffer, read_payload)),
579                    HandshakeAct::Split(sending_cipher, receiving_cipher) => {
580                        *self = NoiseState::Active {
581                            sending_cipher,
582                            receiving_cipher,
583                            handshake_hash: handshake.state.get_handshake_hash(),
584                            remote_static_pubkey: handshake.keyset.rs.clone(),
585                        };
586                        Ok((vec![], vec![]))
587                    }
588                }
589            }
590            NoiseState::Active { .. } => Err(NoiseError::HandshakeComplete),
591        }
592    }
593
594    /// Provides information about next message length which should be read from a network stream.
595    pub fn next_read_len(&self) -> usize {
596        match self {
597            NoiseState::AwaitWrite(_) => 0,
598            NoiseState::Handshake(handshake) => handshake.next_read_len(),
599            NoiseState::Active { .. } => 0,
600        }
601    }
602
603    pub fn get_handshake_hash(&self) -> Option<HandshakeHash<D>> {
604        match self {
605            NoiseState::AwaitWrite(_) | NoiseState::Handshake(_) => None,
606            NoiseState::Active { handshake_hash, .. } => Some(*handshake_hash),
607        }
608    }
609
610    pub fn get_remote_static_key(&self) -> Option<E::Pk> {
611        match self {
612            NoiseState::AwaitWrite(_) | NoiseState::Handshake(_) => None,
613            NoiseState::Active {
614                remote_static_pubkey,
615                ..
616            } => remote_static_pubkey.clone(),
617        }
618    }
619}