1use std::collections::VecDeque;
23
24use cypher::{Digest, EcPk, Ecdh};
25
26use crate::cipher::{decrypt, encrypt, rekey, SharedSecret};
27use crate::error::{EncryptionError, NoiseError};
28use crate::hkdf::{hkdf_2, hkdf_3};
29use crate::patterns::{HandshakePattern, Keyset, MessagePattern};
30use crate::{ChainingKey, HandshakeHash, NoiseNonce};
31
32trait WithTruncated {
33 fn with_truncated(temp_key: impl AsRef<[u8]>) -> Self;
34}
35
36impl WithTruncated for SharedSecret {
37 fn with_truncated(temp_key: impl AsRef<[u8]>) -> Self {
38 let mut key = [0u8; 32];
39 match temp_key.as_ref().len() {
40 32 => {
41 key.copy_from_slice(temp_key.as_ref());
42 }
43 64 => {
44 key.copy_from_slice(&temp_key.as_ref()[..32]);
45 }
46 x => {
47 panic!(
48 "Noise protocol requires HASH function with output length either 32 or 64 \
49 bytes (a function outputting {x} bytes were given)"
50 )
51 }
52 }
53 key
54 }
55}
56
57#[derive(Clone, Eq, PartialEq, Debug, Default)]
58pub struct CipherState {
59 k: SharedSecret,
60 n: NoiseNonce,
61}
62
63impl CipherState {
64 pub fn new() -> Self { CipherState { k: [0u8; 32], n: 0 } }
65
66 pub fn initialize_key(&mut self, key: SharedSecret) {
67 self.k = key;
68 self.n = 0;
69 }
70
71 pub fn has_key(&self) -> bool { self.k != [0u8; 32] }
72
73 pub fn nonce(&self) -> NoiseNonce { self.n }
74
75 pub fn set_nonce(&mut self, nonce: NoiseNonce) { self.n = nonce; }
76
77 pub fn encrypt_with_ad(
78 &mut self,
79 ad: &[u8],
80 plaintext: &[u8],
81 ) -> Result<Vec<u8>, EncryptionError> {
82 if self.k.is_empty() {
83 Ok(plaintext.to_vec())
84 } else {
85 let ciphertext = encrypt(self.k, self.n, ad, plaintext);
87 self.n += 1;
88 ciphertext
89 }
90 }
91 pub fn decrypt_with_ad(
92 &mut self,
93 ad: &[u8],
94 ciphertext: &[u8],
95 ) -> Result<Vec<u8>, EncryptionError> {
96 if self.k.is_empty() {
97 Ok(ciphertext.to_vec())
98 } else {
99 let plaintext = decrypt(self.k, self.n, ad, ciphertext)?;
103 self.n += 1;
104 Ok(plaintext)
105 }
106 }
107
108 pub fn rekey(&mut self) { self.k = rekey(self.k); }
109}
110
111#[derive(Clone, Eq, PartialEq, Debug)]
112pub struct SymmetricState<D: Digest> {
113 cipher: CipherState,
114 ck: ChainingKey<D>,
115 h: HandshakeHash<D>,
116 was_split: bool,
117}
118
119impl<D: Digest> SymmetricState<D> {
120 pub fn with<const HASHLEN: usize>(protocol_name: String) -> Self {
121 debug_assert_eq!(HASHLEN, D::OUTPUT_LEN);
122 let len = protocol_name.len();
123 let h = if len <= HASHLEN {
124 let mut h = [0u8; HASHLEN];
125 h[..len].copy_from_slice(protocol_name.as_bytes());
126 D::Output::try_from(&h).unwrap_or_else(|_| unreachable!())
127 } else {
128 D::digest(protocol_name.as_bytes())
129 };
130 let cipher = CipherState::new();
131 Self {
132 cipher,
133 h,
134 ck: h,
135 was_split: false,
136 }
137 }
138
139 pub fn mix_key(&mut self, input_key_material: impl AsRef<[u8]>) {
140 let (ck, temp_key) = hkdf_2::<D>(self.ck, input_key_material);
141 self.ck = ck;
142 self.cipher.initialize_key(SharedSecret::with_truncated(temp_key));
143 }
144
145 pub fn mix_hash(&mut self, data: impl AsRef<[u8]>) {
146 self.h = D::digest_concat([self.h.as_ref(), data.as_ref()]);
147 }
148
149 #[allow(dead_code)]
151 pub fn mix_key_and_hash(&mut self, input_key_material: impl AsRef<[u8]>) {
152 let (ck, temp_h, temp_k) = hkdf_3::<D>(self.ck, input_key_material);
154 self.ck = ck;
155
156 self.mix_hash(temp_h);
158
159 self.cipher.initialize_key(SharedSecret::with_truncated(temp_k));
162 }
163
164 pub fn get_handshake_hash(&self) -> HandshakeHash<D> {
165 if !self.was_split {
166 panic!(
167 "SymmetricState::get_handshake_hash must be called only after \
168 SymmetricState::split"
169 )
170 }
171 self.h
172 }
173
174 pub fn encrypt_and_hash(
175 &mut self,
176 plaintext: impl AsRef<[u8]>,
177 ) -> Result<Vec<u8>, EncryptionError> {
178 let ciphertext = self.cipher.encrypt_with_ad(self.h.as_ref(), plaintext.as_ref())?;
182 self.mix_hash(&ciphertext);
183 Ok(ciphertext)
184 }
185
186 pub fn decrypt_and_hash(
187 &mut self,
188 ciphertext: impl AsRef<[u8]>,
189 ) -> Result<Vec<u8>, EncryptionError> {
190 let plaintext = self.cipher.decrypt_with_ad(self.h.as_ref(), ciphertext.as_ref())?;
194 self.mix_hash(&ciphertext);
195 Ok(plaintext)
196 }
197
198 pub fn split(&mut self) -> (CipherState, CipherState) {
199 let (temp_k1, temp_k2) = hkdf_2::<D>(self.ck, []);
201
202 let k1 = SharedSecret::with_truncated(temp_k1);
204 let k2 = SharedSecret::with_truncated(temp_k2);
205
206 let mut c1 = CipherState::new();
208 let mut c2 = CipherState::new();
209
210 c1.initialize_key(k1);
212 c2.initialize_key(k2);
213
214 self.was_split = true;
215
216 (c1, c2)
218 }
219}
220
221#[derive(Clone, Eq, PartialEq, Debug)]
222pub struct HandshakeState<E: Ecdh, D: Digest> {
223 state: SymmetricState<D>,
224 is_initiator: bool,
225 keyset: Keyset<E>,
226 handshake_pattern: HandshakePattern,
227 read_message_patterns: VecDeque<&'static [MessagePattern]>,
228 write_message_patterns: VecDeque<&'static [MessagePattern]>,
229}
230
231impl<E: Ecdh, D: Digest> HandshakeState<E, D> {
232 pub fn initialize<const HASHLEN: usize>(
264 handshake_pattern: HandshakePattern,
265 is_initiator: bool,
266 prologue: &[u8],
267 keyset: Keyset<E>,
268 ) -> Self {
269 debug_assert_eq!(HASHLEN, D::OUTPUT_LEN);
270 let mut name_components = vec![s!("Noise")];
271 let curve_name = match E::Pk::CURVE_NAME {
272 "Curve25519" => "25519",
273 "Secp256k1" => "secp256k1",
274 "Edwards25519" => "Edwards25519", unsupported => {
276 unimplemented!("curve {unsupported} is not supported by the Noise library")
277 }
278 };
279 name_components.push(handshake_pattern.to_string());
280 name_components.push(curve_name.to_owned());
281 name_components.push("ChaChaPoly".to_owned());
282 name_components.push(D::DIGEST_NAME.to_owned());
283 let protocol_name = name_components.join("_");
284
285 let mut state = SymmetricState::<D>::with::<HASHLEN>(protocol_name);
286 state.mix_hash(prologue);
287
288 for pre_msg in handshake_pattern.pre_messages() {
289 if let Some(key) = keyset.pre_message_key(*pre_msg, is_initiator) {
290 state.mix_hash(key.to_pk_compressed().as_ref())
291 }
292 }
293
294 let write_message_patterns =
295 VecDeque::from_iter(handshake_pattern.message_patterns(is_initiator).iter().copied());
296 let read_message_patterns =
297 VecDeque::from_iter(handshake_pattern.message_patterns(!is_initiator).iter().copied());
298
299 Self {
300 handshake_pattern,
301 read_message_patterns,
302 write_message_patterns,
303 is_initiator,
304 keyset,
305 state,
306 }
307 }
308
309 fn write_message(&mut self, payload: &[u8]) -> Result<HandshakeAct, EncryptionError> {
315 match self.write_message_patterns.pop_front() {
316 Some(seq) => {
317 let mut message_buffer = Vec::<u8>::new();
318 for pat in seq {
331 match pat {
332 MessagePattern::E => {
333 let (e, pubkey) = E::generate_keypair();
334 let pk = pubkey.to_pk_compressed();
335 message_buffer.extend(pk.as_ref());
336
337 self.state.mix_hash(pk);
338 self.keyset.e = e;
339 }
340 MessagePattern::S => {
341 let s = self.keyset.expect_s().to_pk()?.to_pk_compressed();
342 let enc = self.state.encrypt_and_hash(s)?;
343 message_buffer.extend(&enc)
344 }
345 MessagePattern::EE => {
346 self.state.mix_key(self.keyset.e.ecdh(self.keyset.expect_re())?)
347 }
348 MessagePattern::ES if self.is_initiator => {
349 self.state.mix_key(self.keyset.e.ecdh(self.keyset.expect_rs())?)
350 }
351 MessagePattern::ES => self
352 .state
353 .mix_key(self.keyset.expect_s().ecdh(self.keyset.expect_re())?),
354 MessagePattern::SE if self.is_initiator => self
355 .state
356 .mix_key(self.keyset.expect_s().ecdh(self.keyset.expect_re())?),
357 MessagePattern::SE => {
358 self.state.mix_key(self.keyset.e.ecdh(self.keyset.expect_rs())?)
359 }
360 MessagePattern::SS => self
361 .state
362 .mix_key(self.keyset.expect_s().ecdh(self.keyset.expect_rs())?),
363 };
364 }
365
366 message_buffer.extend(self.state.encrypt_and_hash(payload)?);
368
369 Ok(HandshakeAct::Buffer(message_buffer))
370 }
371 None => {
372 let (c1, c2) = self.state.split();
375 Ok(HandshakeAct::Split(c1, c2))
376 }
377 }
378 }
379
380 fn read_message(&mut self, message: &[u8]) -> Result<HandshakeAct, EncryptionError> {
387 match self.read_message_patterns.pop_front() {
388 Some(seq) => {
389 let mut payload_buffer = Vec::new();
390 let mut pos = 0usize;
391
392 for pat in seq {
408 match pat {
409 MessagePattern::E => {
410 debug_assert!(self.keyset.re.is_none());
411
412 let next_pos = pos + E::Pk::COMPRESSED_LEN;
413 let re = E::Pk::from_pk_compressed_slice(&message[pos..next_pos])?;
414 pos = next_pos;
415
416 self.state.mix_hash(re.to_pk_compressed());
417 self.keyset.re = Some(re);
418 }
419 MessagePattern::S => {
420 debug_assert!(self.keyset.rs.is_none());
421 let next_pos = match self.state.cipher.has_key() {
422 true => D::OUTPUT_LEN + 16,
423 false => D::OUTPUT_LEN,
424 };
425 let temp = self.state.decrypt_and_hash(&message[pos..next_pos])?;
426 self.keyset.rs = Some(E::Pk::from_pk_compressed_slice(&temp)?);
427 pos = next_pos;
428 }
429 MessagePattern::EE => {
430 self.state.mix_key(self.keyset.e.ecdh(self.keyset.expect_re())?);
431 }
432 MessagePattern::ES if self.is_initiator => {
433 self.state.mix_key(self.keyset.e.ecdh(self.keyset.expect_rs())?);
434 }
435 MessagePattern::ES => {
436 self.state
437 .mix_key(self.keyset.expect_s().ecdh(self.keyset.expect_re())?);
438 }
439 MessagePattern::SE if self.is_initiator => {
440 self.state
441 .mix_key(self.keyset.expect_s().ecdh(self.keyset.expect_re())?);
442 }
443 MessagePattern::SE => {
444 self.state.mix_key(self.keyset.e.ecdh(self.keyset.expect_rs())?);
445 }
446 MessagePattern::SS => {
447 self.state
448 .mix_key(self.keyset.expect_s().ecdh(self.keyset.expect_rs())?);
449 }
450 }
451 }
452
453 let output = self.state.decrypt_and_hash(&message[pos..])?;
456 payload_buffer.extend(output);
457
458 Ok(HandshakeAct::Buffer(payload_buffer))
459 }
460 None => {
461 let (c1, c2) = self.state.split();
464 Ok(HandshakeAct::Split(c1, c2))
465 }
466 }
467 }
468
469 fn next_read_len(&self) -> usize {
471 if self.read_message_patterns.is_empty() {
472 return 0;
473 }
474 let mut pos = 0;
475 let seq = self.read_message_patterns[0];
476 for pat in seq {
477 match pat {
478 MessagePattern::E => {
479 pos += E::Pk::COMPRESSED_LEN;
480 }
481 MessagePattern::S if self.state.cipher.has_key() => {
482 pos += D::OUTPUT_LEN + 16;
483 }
484 MessagePattern::S => {
485 pos += D::OUTPUT_LEN;
486 }
487 _ => {}
488 }
489 }
490 pos
491 }
492}
493
494#[derive(Clone, Eq, PartialEq, Debug)]
495pub enum HandshakeAct {
496 Buffer(Vec<u8>),
497 Split(CipherState, CipherState),
498}
499
500#[derive(Clone, Eq, PartialEq, Debug)]
501pub enum NoiseState<E: Ecdh, D: Digest> {
502 AwaitWrite(HandshakeState<E, D>),
503 Handshake(HandshakeState<E, D>),
504 Active {
505 sending_cipher: CipherState,
506 receiving_cipher: CipherState,
507 handshake_hash: HandshakeHash<D>,
508 remote_static_pubkey: Option<E::Pk>,
509 },
510}
511
512impl<E: Ecdh, D: Digest> NoiseState<E, D> {
513 pub fn initialize<const HASHLEN: usize>(
514 handshake_pattern: HandshakePattern,
515 is_initiator: bool,
516 prologue: &[u8],
517 keyset: Keyset<E>,
518 ) -> Self {
519 debug_assert_eq!(HASHLEN, D::OUTPUT_LEN);
520
521 let handshake = HandshakeState::initialize::<HASHLEN>(
522 handshake_pattern,
523 is_initiator,
524 prologue,
525 keyset,
526 );
527 match is_initiator {
528 true => Self::AwaitWrite(handshake),
529 false => Self::Handshake(handshake),
530 }
531 }
532
533 pub fn advance(&mut self, input: &[u8]) -> Result<Vec<u8>, NoiseError> {
538 let (output, payload) = self.advance_with_payload(input, &[])?;
539 if !payload.is_empty() {
540 Err(NoiseError::PayloadNotEmpty)
541 } else {
542 Ok(output)
543 }
544 }
545
546 pub fn advance_with_payload(
547 &mut self,
548 input: &[u8],
549 payload: &[u8],
550 ) -> Result<(Vec<u8>, Vec<u8>), NoiseError> {
551 match self {
552 NoiseState::AwaitWrite(handshake) => {
553 let act = handshake.write_message(payload)?;
554 match act {
555 HandshakeAct::Buffer(buffer) => {
556 *self = NoiseState::Handshake(handshake.clone());
557 Ok((buffer, vec![]))
558 }
559 _ => panic!("single-act handshake doesn't exist for noise protocol"),
560 }
561 }
562 NoiseState::Handshake(handshake) => {
563 let act = handshake.read_message(input)?;
564 let read_payload = match act {
565 HandshakeAct::Buffer(payload) => payload,
566 HandshakeAct::Split(sending_cipher, receiving_cipher) => {
567 *self = NoiseState::Active {
568 sending_cipher,
569 receiving_cipher,
570 handshake_hash: handshake.state.get_handshake_hash(),
571 remote_static_pubkey: handshake.keyset.rs.clone(),
572 };
573 return Ok((vec![], vec![]));
574 }
575 };
576 let act = handshake.write_message(payload)?;
577 match act {
578 HandshakeAct::Buffer(buffer) => Ok((buffer, read_payload)),
579 HandshakeAct::Split(sending_cipher, receiving_cipher) => {
580 *self = NoiseState::Active {
581 sending_cipher,
582 receiving_cipher,
583 handshake_hash: handshake.state.get_handshake_hash(),
584 remote_static_pubkey: handshake.keyset.rs.clone(),
585 };
586 Ok((vec![], vec![]))
587 }
588 }
589 }
590 NoiseState::Active { .. } => Err(NoiseError::HandshakeComplete),
591 }
592 }
593
594 pub fn next_read_len(&self) -> usize {
596 match self {
597 NoiseState::AwaitWrite(_) => 0,
598 NoiseState::Handshake(handshake) => handshake.next_read_len(),
599 NoiseState::Active { .. } => 0,
600 }
601 }
602
603 pub fn get_handshake_hash(&self) -> Option<HandshakeHash<D>> {
604 match self {
605 NoiseState::AwaitWrite(_) | NoiseState::Handshake(_) => None,
606 NoiseState::Active { handshake_hash, .. } => Some(*handshake_hash),
607 }
608 }
609
610 pub fn get_remote_static_key(&self) -> Option<E::Pk> {
611 match self {
612 NoiseState::AwaitWrite(_) | NoiseState::Handshake(_) => None,
613 NoiseState::Active {
614 remote_static_pubkey,
615 ..
616 } => remote_static_pubkey.clone(),
617 }
618 }
619}