1use zeroize::{Zeroize, Zeroizing};
2use std::ops::Deref;
3
4use crate::consts;
5use crate::crypto;
6use crate::wire::{ContactOutput, WireMessage, UserPrompt, UserAnswer, NewMessage};
7use crate::error::Error;
8
9mod export_import;
10mod clone;
11mod smp;
12mod pfs;
13mod msgs;
14mod ratchet;
15pub(crate) use smp::normalize_smp_answer;
16
17
18#[derive(Clone, Copy, Debug, PartialEq)]
20pub enum ContactState {
21 Uninitialized,
22 SMPInit,
23 SMPStep2,
24 SMPStep3,
25 Verified
26}
27
28#[derive(Zeroize, Debug)]
30#[zeroize(drop)]
31pub struct Contact {
32 #[zeroize(skip)]
33 pub state: ContactState,
34
35 message_locked: bool,
36
37 our_signing_pub_key: Option<Zeroizing<Vec<u8>>>,
40 our_signing_secret_key: Option<Zeroizing<Vec<u8>>>,
41 contact_signing_pub_key: Option<Zeroizing<Vec<u8>>>,
42
43
44 our_ml_kem_pub_key: Option<Zeroizing<Vec<u8>>>,
45 our_ml_kem_secret_key: Option<Zeroizing<Vec<u8>>>,
46 contact_ml_kem_pub_key: Option<Zeroizing<Vec<u8>>>,
47
48 our_mceliece_pub_key: Option<Zeroizing<Vec<u8>>>,
49 our_mceliece_secret_key: Option<Zeroizing<Vec<u8>>>,
50 contact_mceliece_pub_key: Option<Zeroizing<Vec<u8>>>,
51
52 our_staged_ml_kem_pub_key: Option<Zeroizing<Vec<u8>>>,
53 our_staged_ml_kem_secret_key: Option<Zeroizing<Vec<u8>>>,
54
55 our_staged_mceliece_pub_key: Option<Zeroizing<Vec<u8>>>,
56 our_staged_mceliece_secret_key: Option<Zeroizing<Vec<u8>>>,
57
58 our_smp_tmp_pub_key: Option<Zeroizing<Vec<u8>>>,
59 our_smp_tmp_secret_key: Option<Zeroizing<Vec<u8>>>,
60 contact_smp_tmp_pub_key: Option<Zeroizing<Vec<u8>>>,
61
62 our_next_strand_key : Option<Zeroizing<Vec<u8>>>,
63 our_next_strand_nonce: Option<Zeroizing<Vec<u8>>>,
64
65 contact_next_strand_key : Option<Zeroizing<Vec<u8>>>,
66 contact_next_strand_nonce: Option<Zeroizing<Vec<u8>>>,
67
68 our_smp_nonce: Option<Zeroizing<Vec<u8>>>,
69 contact_smp_nonce: Option<Zeroizing<Vec<u8>>>,
70 contact_smp_proof: Option<Zeroizing<Vec<u8>>>,
71
72 smp_answer: Option<Zeroizing<String>>,
73 smp_question: Option<Zeroizing<String>>,
74
75 our_pads: Option<Zeroizing<Vec<u8>>>,
76 contact_pads: Option<Zeroizing<Vec<u8>>>,
77
78 our_hash_chain: Option<Zeroizing<Vec<u8>>>,
79 contact_hash_chain: Option<Zeroizing<Vec<u8>>>,
80
81
82 pub additional_data: Option<Zeroizing<Vec<u8>>>,
83
84 #[zeroize(skip)]
85 backup: Option<Box<Contact>>
86}
87
88
89impl Contact {
90 pub fn new() -> Result<Self, Error> {
92 let mut contact = Contact {
93 state: ContactState::Uninitialized,
94
95 message_locked: false,
96
97 our_smp_tmp_pub_key: None,
98 our_smp_tmp_secret_key: None,
99 contact_smp_tmp_pub_key: None,
100
101 our_ml_kem_pub_key: None,
102 our_ml_kem_secret_key: None,
103 contact_ml_kem_pub_key: None,
104
105 our_mceliece_pub_key: None,
106 our_mceliece_secret_key: None,
107 contact_mceliece_pub_key: None,
108
109
110 our_staged_ml_kem_pub_key: None,
111 our_staged_ml_kem_secret_key: None,
112
113 our_staged_mceliece_pub_key: None,
114 our_staged_mceliece_secret_key: None,
115
116
117 our_signing_pub_key: None,
118 our_signing_secret_key: None,
119
120 contact_signing_pub_key: None,
121
122 our_next_strand_key: None,
123 our_next_strand_nonce: None,
124 contact_next_strand_key: None,
125 contact_next_strand_nonce: None,
126 our_smp_nonce: None,
127 contact_smp_nonce: None,
128 contact_smp_proof: None,
129
130 smp_answer: None,
131 smp_question: None,
132
133 our_pads: None,
134 contact_pads: None,
135
136 our_hash_chain: None,
137 contact_hash_chain: None,
138
139 additional_data: None,
140
141 backup: None,
142 };
143
144 contact.init_lt_sign_keypair()?;
145 Ok(contact)
146 }
147
148
149
150 pub fn process(&mut self, data: &[u8]) -> Result<ContactOutput, Error> {
152 self.save_backup();
153
154 let mut failure = Vec::new();
155 failure.push(consts::SMP_TYPE_INIT_SMP); failure.extend_from_slice(b"failure");
157
158
159 if data == failure.as_slice() {
160 self.uninitialize_contact();
161 return Ok(ContactOutput::None);
162
163 }
164
165
166 let result = match self.state {
167 ContactState::Uninitialized => self.do_smp_step_2(data),
168 ContactState::SMPInit => self.do_smp_step_3(data),
169 ContactState::SMPStep2 => self.do_smp_step_4_request_answer(data),
170 ContactState::SMPStep3 => self.do_smp_step_5(data),
171 ContactState::Verified => self.process_verified(data)
172 };
173
174 if self.state != ContactState::Verified {
175 if result.is_ok() {
178 return result;
179 } else {
180 return self.do_smp_failure();
181 }
182 }
183
184 result
185 }
186
187 pub fn process_verified(&mut self, data: &[u8]) -> Result<ContactOutput, Error> {
188 let data_plaintext = self.decrypt_incoming_data(data)?;
189
190 let type_byte = data_plaintext.get(0)
191 .ok_or(Error::InvalidDataPlaintextLength)?;
192
193 if type_byte == &consts::PFS_TYPE_PFS_NEW {
194 let pfs_plaintext = data_plaintext.get(1..)
195 .ok_or(Error::InvalidPfsPlaintextLength)?;
196 return self.do_pfs_new(pfs_plaintext);
197
198 } else if type_byte == &consts::PFS_TYPE_PFS_ACK {
199 let pfs_plaintext = data_plaintext.get(1..)
200 .ok_or(Error::InvalidPfsPlaintextLength)?;
201 return self.do_pfs_ack(pfs_plaintext);
202
203 } else if type_byte == &consts::MSG_TYPE_MSG_BATCH {
204 let msgs_plaintext = data_plaintext.get(1..)
205 .ok_or(Error::InvalidMsgsPlaintextLength)?;
206
207 return self.do_process_otp_batch(msgs_plaintext);
208
209 } else if type_byte == &consts::MSG_TYPE_MSG_NEW {
210 let msgs_plaintext = data_plaintext.get(1..)
211 .ok_or(Error::InvalidMsgsPlaintextLength)?;
212
213 return self.do_process_new_msg(msgs_plaintext);
214 }
215
216 Err(Error::InvalidDataType)
217 }
218
219
220 fn init_lt_sign_keypair(&mut self) -> Result<(), Error> {
221 let (pk, sk) = crypto::generate_ml_dsa_87_keypair()
222 .map_err(|_| Error::CryptoFail)?;
223
224
225 self.our_signing_pub_key = Some(pk);
226 self.our_signing_secret_key = Some(sk);
227
228 Ok(())
229 }
230
231
232 fn init_tmp_kem_keypair(&mut self) -> Result<(), Error> {
233 let (pk, sk) = crypto::generate_kem_keypair(oqs::kem::Algorithm::MlKem1024)
234 .map_err(|_| Error::CryptoFail)?;
235
236
237 self.our_smp_tmp_pub_key = Some(pk);
238 self.our_smp_tmp_secret_key = Some(sk);
239
240 Ok(())
241 }
242
243}
244
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250
251 #[test]
252 fn test_export_import() {
253 let mut alice = Contact::new().expect("Failed to create new contact instance");
254 let mut bob = Contact::new().expect("Failed to create new contact instance");
255
256
257 let bob_plain = alice.export_plain().unwrap();
258 let alice_plain = bob.export_plain().unwrap();
259
260 assert_ne!(bob_plain, alice_plain, "Bob exported plaintext is equal to Alice exported plaintext.");
261
262
263 }
264
265 #[test]
266 fn test_full_two_sessions() {
267 let alice_question = String::from("This is a question");
269 let alice_answer = String::from("This is an answer");
270
271 let mut alice = Contact::new().expect("Failed to create new contact instance");
272
273 let result = alice.init_smp(
274 Zeroizing::new(alice_question.clone()),
275 Zeroizing::new(alice_answer.clone())
276 );
277
278 println!("Alice result: {:?}", result);
279 assert!(result.is_ok());
280
281 let result = match result.unwrap() {
282 ContactOutput::Wire(w) => w,
283 _ => panic!("Expected Wire output"),
284 };
285
286 assert_eq!(result.len(), 1, "Expected exactly one wire message");
287 assert_eq!(result[0].len(), 1 + consts::ML_KEM_1024_PK_SIZE, "SMP init output length mismatch");
288 assert_eq!(result[0][0], consts::SMP_TYPE_INIT_SMP, "SMP type byte mismatch");
289
290
291 let mut bob = Contact::new().expect("Failed to create new contact instance");
293
294 let result = bob.process(result[0].as_ref());
295 println!("Bob result: {:?}", result);
296 assert!(result.is_ok());
297
298 let result = match result.unwrap() {
299 ContactOutput::Wire(w) => w,
300 _ => panic!("Expected Wire output"),
301 };
302
303 assert_eq!(result.len(), 1, "Expected exactly one wire message");
304
305 assert!(
306 result[0].len() >=
307 1 + (consts::ML_KEM_1024_CT_SIZE * 2) + 16 + (consts::CHACHA20POLY1305_NONCE_SIZE * 3) + 32 + consts::SMP_NONCE_SIZE + consts::ML_DSA_87_PK_SIZE,
308 "SMP step 2 output length mismatch"
309 );
310 assert_eq!(result[0][0], consts::SMP_TYPE_INIT_SMP, "SMP type byte mismatch");
311
312
313
314 let result = alice.process(result[0].as_ref());
316 println!("Alice result: {:?}", result);
317 assert!(result.is_ok());
318
319 let result = match result.unwrap() {
320 ContactOutput::Wire(w) => w,
321 _ => panic!("Expected Wire output"),
322 };
323
324 assert_eq!(result.len(), 1, "Expected exactly one wire message");
325
326
327 let result = bob.process(result[0].as_ref());
329 println!("Bob result: {:?}", result);
330 assert!(result.is_ok());
331
332 let result = match result.unwrap() {
333 ContactOutput::Prompt(p) => p,
334 _ => panic!("Expected Prompt output"),
335 };
336
337
338 let bob_question = result.question;
339
340 assert_eq!(bob_question, alice_question, "Bob question and Alice question do not match");
341
342 let bob_answer = Zeroizing::new(String::from("This is an answer"));
343
344 let bob_user_answer = UserAnswer::new(bob_answer).expect("Failed to create new UserAnswer instance");
345
346 let result = bob.provide_smp_answer(bob_user_answer);
347 println!("Bob provide_smp_answer: result: {:?}", result);
348 assert!(result.is_ok());
349
350 let result = match result.unwrap() {
351 ContactOutput::Wire(w) => w,
352 _ => panic!("Expected Wire output"),
353 };
354
355
356 assert_eq!(result.len(), 1, "Expected exactly one wire message");
357
358
359 let result = alice.process(result[0].as_ref());
360 println!("Alice result: {:?}", result);
361 assert!(result.is_ok());
362
363 let result = match result.unwrap() {
364 ContactOutput::Wire(w) => w,
365 _ => panic!("Expected Wire output"),
366 };
367
368 assert_eq!(result.len(), 1, "Expected exactly one wire message");
369
370 let result = bob.process(result[0].as_ref());
373 println!("Bob result: {:?}", result);
374 assert!(result.is_ok());
375
376 let result = match result.unwrap() {
377 ContactOutput::Wire(w) => w,
378 _ => panic!("Expected Wire output"),
379 };
380
381 assert_eq!(result.len(), 2, "Expected exactly 2 wire messages");
382
383
384 let result_1 = alice.process(result[0].as_ref());
385 println!("Alice result 1: {:?}", result_1);
386 assert!(result_1.is_ok());
387
388 match result_1.unwrap() {
389 ContactOutput::None => {},
390 _ => panic!("Expected None output"),
391 };
392
393
394 let result_2 = alice.process(result[1].as_ref());
395 println!("Alice result 2: {:?}", result_2);
396 assert!(result_2.is_ok());
397
398 let result_2 = match result_2.unwrap() {
399 ContactOutput::Wire(w) => w,
400 _ => panic!("Expected Wire output"),
401 };
402
403
404 assert_eq!(result_2.len(), 1, "Expected exactly one wire message");
405
406 let result = bob.process(result_2[0].as_ref());
407 println!("Bob result: {:?}", result);
408 assert!(result.is_ok());
409
410 let result = match result.unwrap() {
411 ContactOutput::None => {},
412 _ => panic!("Expected None output"),
413 };
414
415
416 let alice_message_1 = Zeroizing::new(String::from("Hello, World!"));
419
420 let result = alice.send_message(&alice_message_1);
421 println!("Alice result: {:?}", result);
422 assert!(result.is_ok());
423
424 let result = match result.unwrap() {
425 ContactOutput::Wire(w) => w,
426 _ => panic!("Expected Wire output"),
427 };
428
429
430 assert_eq!(result.len(), 2, "Expected exactly 2 wire message");
432
433
434 let r = alice.i_confirm_message_has_been_sent();
435 assert!(r.is_ok());
436
437
438 let result_1 = bob.process(result[0].as_ref());
439 println!("Bob result 1: {:?}", result_1);
440 assert!(result_1.is_ok());
441
442 let result_2 = bob.process(result[1].as_ref());
449 println!("Bob result 2: {:?}", result_2);
450 assert!(result_2.is_ok());
451
452 let result_2 = match result_2.unwrap() {
453 ContactOutput::Message(m) => m,
454 _ => panic!("Expected Message output"),
455 };
456
457 assert_eq!(alice_message_1, result_2.message, "Decrypted message not equal to original message");
458
459
460 let alice_message_2 = Zeroizing::new(String::from("Hi Bob!!"));
461
462 let result = alice.send_message(&alice_message_2);
463 println!("Alice result: {:?}", result);
464 assert!(result.is_ok());
465
466 let result = match result.unwrap() {
467 ContactOutput::Wire(w) => w,
468 _ => panic!("Expected Wire output"),
469 };
470
471
472 assert_eq!(result.len(), 1, "Expected exactly one wire message");
474
475 let r = alice.i_confirm_message_has_been_sent();
476 assert!(r.is_ok());
477
478
479 let result = bob.process(result[0].as_ref());
480 println!("Bob result: {:?}", result);
481 assert!(result.is_ok());
482
483 let result = match result.unwrap() {
484 ContactOutput::Message(m) => m,
485 _ => panic!("Expected Message output"),
486 };
487
488 assert_eq!(alice_message_2, result.message, "Decrypted message not equal to original message");
489
490
491
492 let r = alice.i_confirm_message_has_been_sent();
494 assert!(r.is_err(), "Confirmation over use did not cause an error");
495
496
497 let bob_message_1 = Zeroizing::new(String::from("Hey Alice!"));
500
501 let result = bob.send_message(&bob_message_1);
502 println!("Bob result: {:?}", result);
503 assert!(result.is_ok());
504
505 let result = match result.unwrap() {
506 ContactOutput::Wire(w) => w,
507 _ => panic!("Expected Wire output"),
508 };
509
510
511 assert_eq!(result.len(), 2, "Expected exactly 2 wire message");
513
514
515 let r = bob.i_confirm_message_has_been_sent();
516 assert!(r.is_ok());
517
518
519 let result_1 = alice.process(result[0].as_ref());
520 println!("Alice result 1: {:?}", result_1);
521 assert!(result_1.is_ok());
522
523
524 let result_2 = alice.process(result[1].as_ref());
532 println!("Alice result 2: {:?}", result_2);
533 assert!(result_2.is_ok());
534
535 let result_2 = match result_2.unwrap() {
536 ContactOutput::Message(m) => m,
537 _ => panic!("Expected Message output"),
538 };
539
540 assert_eq!(bob_message_1, result_2.message, "Decrypted message not equal to original message");
541
542
543
544 let bob_plain = alice.export_plain().unwrap();
548 let alice_plain = bob.export_plain().unwrap();
549
550 assert_ne!(bob_plain, alice_plain, "Bob exported plaintext is the same as Alice exported plaintext.");
551
552 let bob_plain_2 = alice.export_plain().unwrap();
553 let alice_plain_2 = bob.export_plain().unwrap();
554
555
556 assert_eq!(bob_plain_2, bob_plain, "Bob re-exported plaintext is not equal to Bob exported plaintext.");
558
559 assert_eq!(alice_plain_2, alice_plain, "Alice re-exported plaintext is not equal to Alice exported plaintext.");
560
561 assert_ne!(bob_plain_2, alice_plain_2, "Bob re-exported plaintext is the same as Alice re-exported plaintext.");
562
563
564
565 let alice = Contact::import_plain(bob_plain.as_slice()).unwrap();
566 let bob = Contact::import_plain(alice_plain.as_slice()).unwrap();
567
568 let bob_plain = alice.export_plain().unwrap();
569 let alice_plain = bob.export_plain().unwrap();
570
571 assert_ne!(bob_plain, alice_plain, "Bob exported plaintext is the same as Alice exported plaintext.");
572
573 let bob_plain_2 = alice.export_plain().unwrap();
574 let alice_plain_2 = bob.export_plain().unwrap();
575
576
577 assert_eq!(bob_plain_2, bob_plain, "Bob re-exported plaintext is not equal to Bob exported plaintext.");
579
580 assert_eq!(alice_plain_2, alice_plain, "Alice re-exported plaintext is not equal to Alice exported plaintext.");
581
582 assert_ne!(bob_plain_2, alice_plain_2, "Bob re-exported plaintext is the same as Alice re-exported plaintext.");
583
584 }
585}
586
587