Skip to main content

libcold/
contact.rs

1use zeroize::{Zeroize, Zeroizing};
2use std::ops::Deref;
3
4use crate::consts;
5use crate::crypto;
6use crate::wire::{ContactOutput, WireMessage, UserPrompt, UserAnswer, NewMessage};
7use crate::error::Error;
8
9mod export_import;
10mod clone;
11mod smp;
12mod pfs;
13mod msgs;
14mod ratchet;
15pub(crate) use smp::normalize_smp_answer;
16
17
18// Contact states for one contact
19#[derive(Clone, Copy, Debug, PartialEq)]
20pub enum ContactState {
21    Uninitialized,
22    SMPInit,
23    SMPStep2,
24    SMPStep3,
25    Verified
26}
27
28// Public contact struct (one per contact)
29#[derive(Zeroize, Debug)]
30#[zeroize(drop)]
31pub struct Contact {
32    #[zeroize(skip)]
33    pub state: ContactState,
34
35    message_locked: bool,
36
37    // stored key material
38
39    our_signing_pub_key: Option<Zeroizing<Vec<u8>>>,
40    our_signing_secret_key: Option<Zeroizing<Vec<u8>>>,
41    contact_signing_pub_key: Option<Zeroizing<Vec<u8>>>,
42    
43
44    our_ml_kem_pub_key: Option<Zeroizing<Vec<u8>>>,
45    our_ml_kem_secret_key: Option<Zeroizing<Vec<u8>>>,
46    contact_ml_kem_pub_key: Option<Zeroizing<Vec<u8>>>,
47
48    our_mceliece_pub_key: Option<Zeroizing<Vec<u8>>>,
49    our_mceliece_secret_key: Option<Zeroizing<Vec<u8>>>,
50    contact_mceliece_pub_key: Option<Zeroizing<Vec<u8>>>,
51
52    our_staged_ml_kem_pub_key: Option<Zeroizing<Vec<u8>>>,
53    our_staged_ml_kem_secret_key: Option<Zeroizing<Vec<u8>>>,
54
55    our_staged_mceliece_pub_key: Option<Zeroizing<Vec<u8>>>,
56    our_staged_mceliece_secret_key: Option<Zeroizing<Vec<u8>>>,
57
58    our_smp_tmp_pub_key: Option<Zeroizing<Vec<u8>>>,
59    our_smp_tmp_secret_key: Option<Zeroizing<Vec<u8>>>,
60    contact_smp_tmp_pub_key: Option<Zeroizing<Vec<u8>>>,
61
62    our_next_strand_key  : Option<Zeroizing<Vec<u8>>>,
63    our_next_strand_nonce: Option<Zeroizing<Vec<u8>>>,
64
65    contact_next_strand_key  : Option<Zeroizing<Vec<u8>>>,
66    contact_next_strand_nonce: Option<Zeroizing<Vec<u8>>>,
67
68    our_smp_nonce: Option<Zeroizing<Vec<u8>>>,
69    contact_smp_nonce: Option<Zeroizing<Vec<u8>>>,
70    contact_smp_proof: Option<Zeroizing<Vec<u8>>>,
71
72    smp_answer: Option<Zeroizing<String>>,
73    smp_question: Option<Zeroizing<String>>,
74
75    our_pads: Option<Zeroizing<Vec<u8>>>,
76    contact_pads: Option<Zeroizing<Vec<u8>>>,
77
78    our_hash_chain: Option<Zeroizing<Vec<u8>>>,
79    contact_hash_chain: Option<Zeroizing<Vec<u8>>>,
80
81
82    pub additional_data: Option<Zeroizing<Vec<u8>>>,
83
84    #[zeroize(skip)]
85    backup: Option<Box<Contact>>
86}
87
88
89impl Contact {
90    /// Create new contact
91    pub fn new() -> Result<Self, Error> {
92        let mut contact = Contact {
93            state: ContactState::Uninitialized,
94
95            message_locked: false,
96
97            our_smp_tmp_pub_key: None,
98            our_smp_tmp_secret_key: None,
99            contact_smp_tmp_pub_key: None,
100
101            our_ml_kem_pub_key: None, 
102            our_ml_kem_secret_key: None, 
103            contact_ml_kem_pub_key: None,
104
105            our_mceliece_pub_key: None,
106            our_mceliece_secret_key: None, 
107            contact_mceliece_pub_key: None,
108
109
110            our_staged_ml_kem_pub_key: None, 
111            our_staged_ml_kem_secret_key: None, 
112            
113            our_staged_mceliece_pub_key: None, 
114            our_staged_mceliece_secret_key: None, 
115           
116
117            our_signing_pub_key: None,
118            our_signing_secret_key: None,
119
120            contact_signing_pub_key: None,
121
122            our_next_strand_key: None,
123            our_next_strand_nonce: None,
124            contact_next_strand_key: None,
125            contact_next_strand_nonce: None,
126            our_smp_nonce: None,
127            contact_smp_nonce: None,
128            contact_smp_proof: None,
129
130            smp_answer: None,
131            smp_question: None,
132
133            our_pads: None,
134            contact_pads: None,
135
136            our_hash_chain: None, 
137            contact_hash_chain: None,
138            
139            additional_data: None,
140
141            backup: None,
142        };
143
144        contact.init_lt_sign_keypair()?;
145        Ok(contact)
146    }
147
148
149   
150    /// Process an incoming blob, returning optional outgoing blob
151    pub fn process(&mut self, data: &[u8]) -> Result<ContactOutput, Error> {
152        self.save_backup();
153
154        let mut failure = Vec::new();
155        failure.push(consts::SMP_TYPE_INIT_SMP); // push single u8
156        failure.extend_from_slice(b"failure");
157
158
159        if data == failure.as_slice() {
160            self.uninitialize_contact();
161            return Ok(ContactOutput::None);
162
163        }
164
165
166        let result = match self.state {
167            ContactState::Uninitialized => self.do_smp_step_2(data),
168            ContactState::SMPInit       => self.do_smp_step_3(data),
169            ContactState::SMPStep2      => self.do_smp_step_4_request_answer(data),
170            ContactState::SMPStep3      => self.do_smp_step_5(data),
171            ContactState::Verified      => self.process_verified(data)
172        };
173        
174        if self.state != ContactState::Verified {
175            // If we are not verified, we must still be in SMP, therefore if we encounter any
176            // error, we send failure.
177            if result.is_ok() {
178                return result;
179            } else {
180                return self.do_smp_failure();
181            }
182        }
183
184        result
185    }
186    
187    pub fn process_verified(&mut self, data: &[u8]) -> Result<ContactOutput, Error> {
188        let data_plaintext = self.decrypt_incoming_data(data)?;
189
190        let type_byte = data_plaintext.get(0)
191            .ok_or(Error::InvalidDataPlaintextLength)?;
192
193        if type_byte == &consts::PFS_TYPE_PFS_NEW {
194            let pfs_plaintext = data_plaintext.get(1..)
195                .ok_or(Error::InvalidPfsPlaintextLength)?;
196            return self.do_pfs_new(pfs_plaintext);
197        
198        } else if type_byte == &consts::PFS_TYPE_PFS_ACK {
199            let pfs_plaintext = data_plaintext.get(1..)
200                .ok_or(Error::InvalidPfsPlaintextLength)?;
201            return self.do_pfs_ack(pfs_plaintext);
202        
203        } else if type_byte == &consts::MSG_TYPE_MSG_BATCH {
204            let msgs_plaintext = data_plaintext.get(1..)
205                .ok_or(Error::InvalidMsgsPlaintextLength)?;
206
207            return self.do_process_otp_batch(msgs_plaintext);
208
209        } else if type_byte == &consts::MSG_TYPE_MSG_NEW {
210            let msgs_plaintext = data_plaintext.get(1..)
211                .ok_or(Error::InvalidMsgsPlaintextLength)?;
212
213            return self.do_process_new_msg(msgs_plaintext);
214        }
215
216        Err(Error::InvalidDataType)
217    }
218  
219    
220    fn init_lt_sign_keypair(&mut self) -> Result<(), Error> {
221        let (pk, sk) = crypto::generate_ml_dsa_87_keypair()
222            .map_err(|_| Error::CryptoFail)?;
223
224
225        self.our_signing_pub_key = Some(pk);
226        self.our_signing_secret_key = Some(sk);
227
228        Ok(())
229    }
230
231
232    fn init_tmp_kem_keypair(&mut self) -> Result<(), Error> {
233        let (pk, sk) = crypto::generate_kem_keypair(oqs::kem::Algorithm::MlKem1024)
234            .map_err(|_| Error::CryptoFail)?;
235
236
237        self.our_smp_tmp_pub_key = Some(pk);
238        self.our_smp_tmp_secret_key = Some(sk);
239
240        Ok(())
241    }
242
243}
244
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250
251    #[test]
252    fn test_export_import() {
253        let mut alice = Contact::new().expect("Failed to create new contact instance");
254        let mut bob = Contact::new().expect("Failed to create new contact instance");
255
256
257        let bob_plain = alice.export_plain().unwrap();
258        let alice_plain = bob.export_plain().unwrap();
259
260        assert_ne!(bob_plain, alice_plain, "Bob exported plaintext is equal to Alice exported plaintext.");
261
262
263    }
264
265    #[test]
266    fn test_full_two_sessions() {
267        // Alice initiates a new SMP session
268        let alice_question = String::from("This is a question");
269        let alice_answer = String::from("This is an answer");
270
271        let mut alice = Contact::new().expect("Failed to create new contact instance");
272
273        let result = alice.init_smp(
274            Zeroizing::new(alice_question.clone()),
275            Zeroizing::new(alice_answer.clone())
276        );
277
278        println!("Alice result: {:?}", result);
279        assert!(result.is_ok());
280
281        let result = match result.unwrap() {
282            ContactOutput::Wire(w) => w,
283            _ => panic!("Expected Wire output"),
284        };
285
286        assert_eq!(result.len(), 1, "Expected exactly one wire message");
287        assert_eq!(result[0].len(), 1 + consts::ML_KEM_1024_PK_SIZE, "SMP init output length mismatch");
288        assert_eq!(result[0][0], consts::SMP_TYPE_INIT_SMP, "SMP type byte mismatch");
289
290
291        // Bob processes Alice's result.
292        let mut bob = Contact::new().expect("Failed to create new contact instance");
293
294        let result = bob.process(result[0].as_ref());
295        println!("Bob result: {:?}", result);
296        assert!(result.is_ok());
297
298        let result = match result.unwrap() {
299            ContactOutput::Wire(w) => w,
300            _ => panic!("Expected Wire output"),
301        };
302
303        assert_eq!(result.len(), 1, "Expected exactly one wire message");
304
305        assert!(
306            result[0].len() >= 
307            1 + (consts::ML_KEM_1024_CT_SIZE * 2) + 16 + (consts::CHACHA20POLY1305_NONCE_SIZE * 3) + 32 + consts::SMP_NONCE_SIZE + consts::ML_DSA_87_PK_SIZE, 
308            "SMP step 2 output length mismatch"
309        );
310        assert_eq!(result[0][0], consts::SMP_TYPE_INIT_SMP, "SMP type byte mismatch");
311
312
313
314        // Alice processes Bob's result.
315        let result = alice.process(result[0].as_ref());
316        println!("Alice result: {:?}", result);
317        assert!(result.is_ok());
318
319        let result = match result.unwrap() {
320            ContactOutput::Wire(w) => w,
321            _ => panic!("Expected Wire output"),
322        };
323
324        assert_eq!(result.len(), 1, "Expected exactly one wire message");
325
326
327        // Bob processes Alice's result.
328        let result = bob.process(result[0].as_ref());
329        println!("Bob result: {:?}", result);
330        assert!(result.is_ok());
331
332        let result = match result.unwrap() {
333            ContactOutput::Prompt(p) => p,
334            _ => panic!("Expected Prompt output"),
335        };
336
337
338        let bob_question = result.question;
339
340        assert_eq!(bob_question, alice_question, "Bob question and Alice question do not match");
341
342        let bob_answer = Zeroizing::new(String::from("This is an answer"));
343
344        let bob_user_answer = UserAnswer::new(bob_answer).expect("Failed to create new UserAnswer instance");
345        
346        let result = bob.provide_smp_answer(bob_user_answer);
347        println!("Bob provide_smp_answer: result: {:?}", result);
348        assert!(result.is_ok());
349
350        let result = match result.unwrap() {
351            ContactOutput::Wire(w) => w,
352            _ => panic!("Expected Wire output"),
353        };
354
355
356        assert_eq!(result.len(), 1, "Expected exactly one wire message");
357
358
359        let result = alice.process(result[0].as_ref());
360        println!("Alice result: {:?}", result);
361        assert!(result.is_ok());
362
363        let result = match result.unwrap() {
364            ContactOutput::Wire(w) => w,
365            _ => panic!("Expected Wire output"),
366        };
367
368        assert_eq!(result.len(), 1, "Expected exactly one wire message");
369
370        // PFS
371
372        let result = bob.process(result[0].as_ref());
373        println!("Bob result: {:?}", result);
374        assert!(result.is_ok());
375
376        let result = match result.unwrap() {
377            ContactOutput::Wire(w) => w,
378            _ => panic!("Expected Wire output"),
379        };
380
381        assert_eq!(result.len(), 2, "Expected exactly 2 wire messages");
382
383
384        let result_1 = alice.process(result[0].as_ref());
385        println!("Alice result 1: {:?}", result_1);
386        assert!(result_1.is_ok());
387
388        match result_1.unwrap() {
389            ContactOutput::None => {},
390            _ => panic!("Expected None output"),
391        };
392        
393
394        let result_2 = alice.process(result[1].as_ref());
395        println!("Alice result 2: {:?}", result_2);
396        assert!(result_2.is_ok());
397
398        let result_2 = match result_2.unwrap() {
399            ContactOutput::Wire(w) => w,
400            _ => panic!("Expected Wire output"),
401        };
402
403
404        assert_eq!(result_2.len(), 1, "Expected exactly one wire message");
405
406        let result = bob.process(result_2[0].as_ref());
407        println!("Bob result: {:?}", result);
408        assert!(result.is_ok());
409
410        let result = match result.unwrap() {
411            ContactOutput::None => {},
412            _ => panic!("Expected None output"),
413        };
414
415
416        // MSGS:
417        
418        let alice_message_1 = Zeroizing::new(String::from("Hello, World!"));
419
420        let result = alice.send_message(&alice_message_1);
421        println!("Alice result: {:?}", result);
422        assert!(result.is_ok());
423
424        let result = match result.unwrap() {
425            ContactOutput::Wire(w) => w,
426            _ => panic!("Expected Wire output"),
427        };
428
429  
430        // 2 because we never sent pads to Bob yet.
431        assert_eq!(result.len(), 2, "Expected exactly 2 wire message");
432
433
434        let r = alice.i_confirm_message_has_been_sent();
435        assert!(r.is_ok());
436
437
438        let result_1 = bob.process(result[0].as_ref());
439        println!("Bob result 1: {:?}", result_1);
440        assert!(result_1.is_ok());
441
442        /*match result_1.unwrap() {
443            ContactOutput::Wire(w) => w,
444            _ => panic!("Expected Wire output"),
445        };*/
446        
447
448        let result_2 = bob.process(result[1].as_ref());
449        println!("Bob result 2: {:?}", result_2);
450        assert!(result_2.is_ok());
451
452        let result_2 = match result_2.unwrap() {
453            ContactOutput::Message(m) => m,
454            _ => panic!("Expected Message output"),
455        };
456
457        assert_eq!(alice_message_1, result_2.message, "Decrypted message not equal to original message");
458
459
460        let alice_message_2 = Zeroizing::new(String::from("Hi Bob!!"));
461
462        let result = alice.send_message(&alice_message_2);
463        println!("Alice result: {:?}", result);
464        assert!(result.is_ok());
465
466        let result = match result.unwrap() {
467            ContactOutput::Wire(w) => w,
468            _ => panic!("Expected Wire output"),
469        };
470
471  
472        // 1 because we should've at this point sent Bob enough pads
473        assert_eq!(result.len(), 1, "Expected exactly one wire message");
474
475        let r = alice.i_confirm_message_has_been_sent();
476        assert!(r.is_ok());
477
478
479        let result = bob.process(result[0].as_ref());
480        println!("Bob result: {:?}", result);
481        assert!(result.is_ok());
482
483        let result = match result.unwrap() {
484            ContactOutput::Message(m) => m,
485            _ => panic!("Expected Message output"),
486        };
487
488        assert_eq!(alice_message_2, result.message, "Decrypted message not equal to original message");
489
490
491
492        // This should error
493        let r = alice.i_confirm_message_has_been_sent();
494        assert!(r.is_err(), "Confirmation over use did not cause an error");
495
496
497        //
498        
499        let bob_message_1 = Zeroizing::new(String::from("Hey Alice!"));
500
501        let result = bob.send_message(&bob_message_1);
502        println!("Bob result: {:?}", result);
503        assert!(result.is_ok());
504
505        let result = match result.unwrap() {
506            ContactOutput::Wire(w) => w,
507            _ => panic!("Expected Wire output"),
508        };
509
510  
511        // 2 because we never sent pads to Bob yet.
512        assert_eq!(result.len(), 2, "Expected exactly 2 wire message");
513
514
515        let r = bob.i_confirm_message_has_been_sent();
516        assert!(r.is_ok());
517
518
519        let result_1 = alice.process(result[0].as_ref());
520        println!("Alice result 1: {:?}", result_1);
521        assert!(result_1.is_ok());
522
523
524        /*
525        match result_1.unwrap() {
526            ContactOutput::Wire(w) => w,
527            _ => panic!("Expected Wire output"),
528        };*/
529        
530
531        let result_2 = alice.process(result[1].as_ref());
532        println!("Alice result 2: {:?}", result_2);
533        assert!(result_2.is_ok());
534
535        let result_2 = match result_2.unwrap() {
536            ContactOutput::Message(m) => m,
537            _ => panic!("Expected Message output"),
538        };
539
540        assert_eq!(bob_message_1, result_2.message, "Decrypted message not equal to original message");
541
542
543
544        //
545
546
547        let bob_plain = alice.export_plain().unwrap();
548        let alice_plain = bob.export_plain().unwrap();
549
550        assert_ne!(bob_plain, alice_plain, "Bob exported plaintext is the same as Alice exported plaintext.");
551
552        let bob_plain_2 = alice.export_plain().unwrap();
553        let alice_plain_2 = bob.export_plain().unwrap();
554
555
556        // There should not be any changes on both ends since we have not altered the struct.
557        assert_eq!(bob_plain_2, bob_plain, "Bob re-exported plaintext is not equal to Bob exported plaintext.");
558        
559        assert_eq!(alice_plain_2, alice_plain, "Alice re-exported plaintext is not equal to Alice exported plaintext.");
560
561        assert_ne!(bob_plain_2, alice_plain_2, "Bob re-exported plaintext is the same as Alice re-exported plaintext.");
562
563
564
565        let alice = Contact::import_plain(bob_plain.as_slice()).unwrap();
566        let bob = Contact::import_plain(alice_plain.as_slice()).unwrap();
567
568        let bob_plain = alice.export_plain().unwrap();
569        let alice_plain = bob.export_plain().unwrap();
570
571        assert_ne!(bob_plain, alice_plain, "Bob exported plaintext is the same as Alice exported plaintext.");
572
573        let bob_plain_2 = alice.export_plain().unwrap();
574        let alice_plain_2 = bob.export_plain().unwrap();
575
576
577        // There should not be any changes on both ends since we have not altered the struct.
578        assert_eq!(bob_plain_2, bob_plain, "Bob re-exported plaintext is not equal to Bob exported plaintext.");
579        
580        assert_eq!(alice_plain_2, alice_plain, "Alice re-exported plaintext is not equal to Alice exported plaintext.");
581
582        assert_ne!(bob_plain_2, alice_plain_2, "Bob re-exported plaintext is the same as Alice re-exported plaintext.");
583
584    }
585}
586
587