1use zeroize::{Zeroize, Zeroizing};
2use std::ops::Deref;
3
4use crate::consts;
5use crate::crypto;
6use crate::wire::{ContactOutput, WireMessage, UserPrompt, UserAnswer, NewMessage};
7use crate::error::Error;
8
9mod export_import;
10mod clone;
11mod smp;
12mod pfs;
13mod msgs;
14mod ratchet;
15pub(crate) use smp::normalize_smp_answer;
16
17
18#[derive(Clone, Copy, Debug, PartialEq)]
20enum ContactState {
21 Uninitialized,
22 SMPInit,
23 SMPStep2,
24 SMPStep3,
25 Verified
26}
27
28#[derive(Zeroize, Debug)]
30#[zeroize(drop)]
31pub struct Contact {
32 #[zeroize(skip)]
33 state: ContactState,
34
35 message_locked: bool,
36
37 our_signing_pub_key: Option<Zeroizing<Vec<u8>>>,
40 our_signing_secret_key: Option<Zeroizing<Vec<u8>>>,
41 contact_signing_pub_key: Option<Zeroizing<Vec<u8>>>,
42
43
44 our_ml_kem_pub_key: Option<Zeroizing<Vec<u8>>>,
45 our_ml_kem_secret_key: Option<Zeroizing<Vec<u8>>>,
46 contact_ml_kem_pub_key: Option<Zeroizing<Vec<u8>>>,
47
48 our_mceliece_pub_key: Option<Zeroizing<Vec<u8>>>,
49 our_mceliece_secret_key: Option<Zeroizing<Vec<u8>>>,
50 contact_mceliece_pub_key: Option<Zeroizing<Vec<u8>>>,
51
52 our_staged_ml_kem_pub_key: Option<Zeroizing<Vec<u8>>>,
53 our_staged_ml_kem_secret_key: Option<Zeroizing<Vec<u8>>>,
54
55 our_staged_mceliece_pub_key: Option<Zeroizing<Vec<u8>>>,
56 our_staged_mceliece_secret_key: Option<Zeroizing<Vec<u8>>>,
57
58 our_smp_tmp_pub_key: Option<Zeroizing<Vec<u8>>>,
59 our_smp_tmp_secret_key: Option<Zeroizing<Vec<u8>>>,
60 contact_smp_tmp_pub_key: Option<Zeroizing<Vec<u8>>>,
61
62 our_next_strand_key : Option<Zeroizing<Vec<u8>>>,
63 our_next_strand_nonce: Option<Zeroizing<Vec<u8>>>,
64
65 contact_next_strand_key : Option<Zeroizing<Vec<u8>>>,
66 contact_next_strand_nonce: Option<Zeroizing<Vec<u8>>>,
67
68 our_smp_nonce: Option<Zeroizing<Vec<u8>>>,
69 contact_smp_nonce: Option<Zeroizing<Vec<u8>>>,
70 contact_smp_proof: Option<Zeroizing<Vec<u8>>>,
71
72 smp_answer: Option<Zeroizing<String>>,
73 smp_question: Option<String>,
74
75 our_pads: Option<Zeroizing<Vec<u8>>>,
76 contact_pads: Option<Zeroizing<Vec<u8>>>,
77
78 our_hash_chain: Option<Zeroizing<Vec<u8>>>,
79 contact_hash_chain: Option<Zeroizing<Vec<u8>>>,
80
81 #[zeroize(skip)]
82 backup: Option<Box<Contact>>
83}
84
85
86impl Contact {
87 pub fn new() -> Result<Self, Error> {
89 let mut contact = Contact {
90 state: ContactState::Uninitialized,
91
92 message_locked: false,
93
94 our_smp_tmp_pub_key: None,
95 our_smp_tmp_secret_key: None,
96 contact_smp_tmp_pub_key: None,
97
98 our_ml_kem_pub_key: None,
99 our_ml_kem_secret_key: None,
100 contact_ml_kem_pub_key: None,
101
102 our_mceliece_pub_key: None,
103 our_mceliece_secret_key: None,
104 contact_mceliece_pub_key: None,
105
106
107 our_staged_ml_kem_pub_key: None,
108 our_staged_ml_kem_secret_key: None,
109
110 our_staged_mceliece_pub_key: None,
111 our_staged_mceliece_secret_key: None,
112
113
114 our_signing_pub_key: None,
115 our_signing_secret_key: None,
116
117 contact_signing_pub_key: None,
118
119 our_next_strand_key: None,
120 our_next_strand_nonce: None,
121 contact_next_strand_key: None,
122 contact_next_strand_nonce: None,
123 our_smp_nonce: None,
124 contact_smp_nonce: None,
125 contact_smp_proof: None,
126
127 smp_answer: None,
128 smp_question: None,
129
130 our_pads: None,
131 contact_pads: None,
132
133 our_hash_chain: None,
134 contact_hash_chain: None,
135
136 backup: None,
137 };
138
139 contact.init_lt_sign_keypair()?;
140 Ok(contact)
141 }
142
143
144
145 pub fn process(&mut self, data: &[u8]) -> Result<ContactOutput, Error> {
147 self.save_backup();
148 let result = match self.state {
149 ContactState::Uninitialized => self.do_smp_step_2(data),
150 ContactState::SMPInit => self.do_smp_step_3(data),
151 ContactState::SMPStep2 => self.do_smp_step_4_request_answer(data),
152 ContactState::SMPStep3 => self.do_smp_step_5(data),
153 ContactState::Verified => self.process_verified(data)
154 };
155
156 if self.state != ContactState::Verified {
157 if result.is_ok() {
160 return result;
161 } else {
162 return self.do_smp_failure();
163 }
164 }
165
166 result
167 }
168
169 pub fn process_verified(&mut self, data: &[u8]) -> Result<ContactOutput, Error> {
170 let data_plaintext = self.decrypt_incoming_data(data)?;
171
172 let type_byte = data_plaintext.get(0)
173 .ok_or(Error::InvalidDataPlaintextLength)?;
174
175 if type_byte == &consts::PFS_TYPE_PFS_NEW {
176 let pfs_plaintext = data_plaintext.get(1..)
177 .ok_or(Error::InvalidPfsPlaintextLength)?;
178 return self.do_pfs_new(pfs_plaintext);
179
180 } else if type_byte == &consts::PFS_TYPE_PFS_ACK {
181 let pfs_plaintext = data_plaintext.get(1..)
182 .ok_or(Error::InvalidPfsPlaintextLength)?;
183 return self.do_pfs_ack(pfs_plaintext);
184
185 } else if type_byte == &consts::MSG_TYPE_MSG_BATCH {
186 let msgs_plaintext = data_plaintext.get(1..)
187 .ok_or(Error::InvalidMsgsPlaintextLength)?;
188
189 return self.do_process_otp_batch(msgs_plaintext);
190
191 } else if type_byte == &consts::MSG_TYPE_MSG_NEW {
192 let msgs_plaintext = data_plaintext.get(1..)
193 .ok_or(Error::InvalidMsgsPlaintextLength)?;
194
195 return self.do_process_new_msg(msgs_plaintext);
196 }
197
198 Err(Error::InvalidDataType)
199 }
200
201
202 fn init_lt_sign_keypair(&mut self) -> Result<(), Error> {
203 let (pk, sk) = crypto::generate_ml_dsa_87_keypair()
204 .map_err(|_| Error::CryptoFail)?;
205
206
207 self.our_signing_pub_key = Some(pk);
208 self.our_signing_secret_key = Some(sk);
209
210 Ok(())
211 }
212
213
214 fn init_tmp_kem_keypair(&mut self) -> Result<(), Error> {
215 let (pk, sk) = crypto::generate_kem_keypair(oqs::kem::Algorithm::MlKem1024)
216 .map_err(|_| Error::CryptoFail)?;
217
218
219 self.our_smp_tmp_pub_key = Some(pk);
220 self.our_smp_tmp_secret_key = Some(sk);
221
222 Ok(())
223 }
224
225}
226
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232
233 #[test]
234 fn test_export_import() {
235 let mut alice = Contact::new().expect("Failed to create new contact instance");
236 let mut bob = Contact::new().expect("Failed to create new contact instance");
237
238
239 let bob_plain = alice.export_plain().unwrap();
240 let alice_plain = bob.export_plain().unwrap();
241
242 assert_ne!(bob_plain, alice_plain, "Bob exported plaintext is equal to Alice exported plaintext.");
243
244
245 }
246
247 #[test]
248 fn test_full_two_sessions() {
249 let alice_question = String::from("This is a question");
251 let alice_answer = String::from("This is an answer");
252
253 let mut alice = Contact::new().expect("Failed to create new contact instance");
254
255 let result = alice.init_smp(
256 alice_question.clone(),
257 alice_answer.clone()
258 );
259
260 println!("Alice result: {:?}", result);
261 assert!(result.is_ok());
262
263 let result = match result.unwrap() {
264 ContactOutput::Wire(w) => w,
265 _ => panic!("Expected Wire output"),
266 };
267
268 assert_eq!(result.len(), 1, "Expected exactly one wire message");
269 assert_eq!(result[0].len(), 1 + consts::ML_KEM_1024_PK_SIZE, "SMP init output length mismatch");
270 assert_eq!(result[0][0], consts::SMP_TYPE_INIT_SMP, "SMP type byte mismatch");
271
272
273 let mut bob = Contact::new().expect("Failed to create new contact instance");
275
276 let result = bob.process(result[0].as_ref());
277 println!("Bob result: {:?}", result);
278 assert!(result.is_ok());
279
280 let result = match result.unwrap() {
281 ContactOutput::Wire(w) => w,
282 _ => panic!("Expected Wire output"),
283 };
284
285 assert_eq!(result.len(), 1, "Expected exactly one wire message");
286
287 assert!(
288 result[0].len() >=
289 1 + (consts::ML_KEM_1024_CT_SIZE * 2) + 16 + (consts::CHACHA20POLY1305_NONCE_SIZE * 3) + 32 + consts::SMP_NONCE_SIZE + consts::ML_DSA_87_PK_SIZE,
290 "SMP step 2 output length mismatch"
291 );
292 assert_eq!(result[0][0], consts::SMP_TYPE_INIT_SMP, "SMP type byte mismatch");
293
294
295
296 let result = alice.process(result[0].as_ref());
298 println!("Alice result: {:?}", result);
299 assert!(result.is_ok());
300
301 let result = match result.unwrap() {
302 ContactOutput::Wire(w) => w,
303 _ => panic!("Expected Wire output"),
304 };
305
306 assert_eq!(result.len(), 1, "Expected exactly one wire message");
307
308
309 let result = bob.process(result[0].as_ref());
311 println!("Bob result: {:?}", result);
312 assert!(result.is_ok());
313
314 let result = match result.unwrap() {
315 ContactOutput::Prompt(p) => p,
316 _ => panic!("Expected Prompt output"),
317 };
318
319
320 let bob_question = result.question;
321
322 assert_eq!(bob_question, alice_question, "Bob question and Alice question do not match");
323
324 let bob_answer = String::from("This is an answer");
325
326 let bob_user_answer = UserAnswer::new(bob_answer).expect("Failed to create new UserAnswer instance");
327
328 let result = bob.provide_smp_answer(bob_user_answer);
329 println!("Bob provide_smp_answer: result: {:?}", result);
330 assert!(result.is_ok());
331
332 let result = match result.unwrap() {
333 ContactOutput::Wire(w) => w,
334 _ => panic!("Expected Wire output"),
335 };
336
337
338 assert_eq!(result.len(), 1, "Expected exactly one wire message");
339
340
341 let result = alice.process(result[0].as_ref());
342 println!("Alice result: {:?}", result);
343 assert!(result.is_ok());
344
345 let result = match result.unwrap() {
346 ContactOutput::Wire(w) => w,
347 _ => panic!("Expected Wire output"),
348 };
349
350 assert_eq!(result.len(), 1, "Expected exactly one wire message");
351
352 let result = bob.process(result[0].as_ref());
355 println!("Bob result: {:?}", result);
356 assert!(result.is_ok());
357
358 let result = match result.unwrap() {
359 ContactOutput::Wire(w) => w,
360 _ => panic!("Expected Wire output"),
361 };
362
363 assert_eq!(result.len(), 2, "Expected exactly 2 wire messages");
364
365
366 let result_1 = alice.process(result[0].as_ref());
367 println!("Alice result 1: {:?}", result_1);
368 assert!(result_1.is_ok());
369
370 match result_1.unwrap() {
371 ContactOutput::None => {},
372 _ => panic!("Expected None output"),
373 };
374
375
376 let result_2 = alice.process(result[1].as_ref());
377 println!("Alice result 2: {:?}", result_2);
378 assert!(result_2.is_ok());
379
380 let result_2 = match result_2.unwrap() {
381 ContactOutput::Wire(w) => w,
382 _ => panic!("Expected Wire output"),
383 };
384
385
386 assert_eq!(result_2.len(), 1, "Expected exactly one wire message");
387
388 let result = bob.process(result_2[0].as_ref());
389 println!("Bob result: {:?}", result);
390 assert!(result.is_ok());
391
392 let result = match result.unwrap() {
393 ContactOutput::None => {},
394 _ => panic!("Expected None output"),
395 };
396
397
398 let alice_message_1 = Zeroizing::new(String::from("Hello, World!"));
401
402 let result = alice.send_message(&alice_message_1);
403 println!("Alice result: {:?}", result);
404 assert!(result.is_ok());
405
406 let result = match result.unwrap() {
407 ContactOutput::Wire(w) => w,
408 _ => panic!("Expected Wire output"),
409 };
410
411
412 assert_eq!(result.len(), 2, "Expected exactly 2 wire message");
414
415
416 let r = alice.i_confirm_message_has_been_sent();
417 assert!(r.is_ok());
418
419
420 let result_1 = bob.process(result[0].as_ref());
421 println!("Bob result 1: {:?}", result_1);
422 assert!(result_1.is_ok());
423
424 match result_1.unwrap() {
425 ContactOutput::Wire(w) => w,
426 _ => panic!("Expected Wire output"),
427 };
428
429
430 let result_2 = bob.process(result[1].as_ref());
431 println!("Bob result 2: {:?}", result_2);
432 assert!(result_2.is_ok());
433
434 let result_2 = match result_2.unwrap() {
435 ContactOutput::Message(m) => m,
436 _ => panic!("Expected Message output"),
437 };
438
439 assert_eq!(alice_message_1, result_2.message, "Decrypted message not equal to original message");
440
441
442 let alice_message_2 = Zeroizing::new(String::from("Hi Bob!!"));
443
444 let result = alice.send_message(&alice_message_2);
445 println!("Alice result: {:?}", result);
446 assert!(result.is_ok());
447
448 let result = match result.unwrap() {
449 ContactOutput::Wire(w) => w,
450 _ => panic!("Expected Wire output"),
451 };
452
453
454 assert_eq!(result.len(), 1, "Expected exactly one wire message");
456
457 let r = alice.i_confirm_message_has_been_sent();
458 assert!(r.is_ok());
459
460
461 let result = bob.process(result[0].as_ref());
462 println!("Bob result: {:?}", result);
463 assert!(result.is_ok());
464
465 let result = match result.unwrap() {
466 ContactOutput::Message(m) => m,
467 _ => panic!("Expected Message output"),
468 };
469
470 assert_eq!(alice_message_2, result.message, "Decrypted message not equal to original message");
471
472
473
474 let r = alice.i_confirm_message_has_been_sent();
476 assert!(r.is_err(), "Confirmation over use did not cause an error");
477
478
479
480 let bob_plain = alice.export_plain().unwrap();
481 let alice_plain = bob.export_plain().unwrap();
482
483 assert_ne!(bob_plain, alice_plain, "Bob exported plaintext is the same as Alice exported plaintext.");
484
485 let bob_plain_2 = alice.export_plain().unwrap();
486 let alice_plain_2 = bob.export_plain().unwrap();
487
488
489 assert_eq!(bob_plain_2, bob_plain, "Bob re-exported plaintext is not equal to Bob exported plaintext.");
491
492 assert_eq!(alice_plain_2, alice_plain, "Alice re-exported plaintext is not equal to Alice exported plaintext.");
493
494 assert_ne!(bob_plain_2, alice_plain_2, "Bob re-exported plaintext is the same as Alice re-exported plaintext.");
495
496
497
498 let alice = Contact::import_plain(bob_plain.as_slice()).unwrap();
499 let bob = Contact::import_plain(alice_plain.as_slice()).unwrap();
500
501 let bob_plain = alice.export_plain().unwrap();
502 let alice_plain = bob.export_plain().unwrap();
503
504 assert_ne!(bob_plain, alice_plain, "Bob exported plaintext is the same as Alice exported plaintext.");
505
506 let bob_plain_2 = alice.export_plain().unwrap();
507 let alice_plain_2 = bob.export_plain().unwrap();
508
509
510 assert_eq!(bob_plain_2, bob_plain, "Bob re-exported plaintext is not equal to Bob exported plaintext.");
512
513 assert_eq!(alice_plain_2, alice_plain, "Alice re-exported plaintext is not equal to Alice exported plaintext.");
514
515 assert_ne!(bob_plain_2, alice_plain_2, "Bob re-exported plaintext is the same as Alice re-exported plaintext.");
516
517 }
518}
519
520