Skip to main content

libcold/
contact.rs

1use zeroize::{Zeroize, Zeroizing};
2use std::ops::Deref;
3
4use crate::consts;
5use crate::crypto;
6use crate::wire::{ContactOutput, WireMessage, UserPrompt, UserAnswer, NewMessage};
7use crate::error::Error;
8
9mod export_import;
10mod clone;
11mod smp;
12mod pfs;
13mod msgs;
14mod ratchet;
15pub(crate) use smp::normalize_smp_answer;
16
17
18// Contact states for one contact
19#[derive(Clone, Copy, Debug, PartialEq)]
20enum ContactState {
21    Uninitialized,
22    SMPInit,
23    SMPStep2,
24    SMPStep3,
25    Verified
26}
27
28// Public contact struct (one per contact)
29#[derive(Zeroize, Debug)]
30#[zeroize(drop)]
31pub struct Contact {
32    #[zeroize(skip)]
33    state: ContactState,
34
35    message_locked: bool,
36
37    // stored key material
38
39    our_signing_pub_key: Option<Zeroizing<Vec<u8>>>,
40    our_signing_secret_key: Option<Zeroizing<Vec<u8>>>,
41    contact_signing_pub_key: Option<Zeroizing<Vec<u8>>>,
42    
43
44    our_ml_kem_pub_key: Option<Zeroizing<Vec<u8>>>,
45    our_ml_kem_secret_key: Option<Zeroizing<Vec<u8>>>,
46    contact_ml_kem_pub_key: Option<Zeroizing<Vec<u8>>>,
47
48    our_mceliece_pub_key: Option<Zeroizing<Vec<u8>>>,
49    our_mceliece_secret_key: Option<Zeroizing<Vec<u8>>>,
50    contact_mceliece_pub_key: Option<Zeroizing<Vec<u8>>>,
51
52    our_staged_ml_kem_pub_key: Option<Zeroizing<Vec<u8>>>,
53    our_staged_ml_kem_secret_key: Option<Zeroizing<Vec<u8>>>,
54
55    our_staged_mceliece_pub_key: Option<Zeroizing<Vec<u8>>>,
56    our_staged_mceliece_secret_key: Option<Zeroizing<Vec<u8>>>,
57
58    our_smp_tmp_pub_key: Option<Zeroizing<Vec<u8>>>,
59    our_smp_tmp_secret_key: Option<Zeroizing<Vec<u8>>>,
60    contact_smp_tmp_pub_key: Option<Zeroizing<Vec<u8>>>,
61
62    our_next_strand_key  : Option<Zeroizing<Vec<u8>>>,
63    our_next_strand_nonce: Option<Zeroizing<Vec<u8>>>,
64
65    contact_next_strand_key  : Option<Zeroizing<Vec<u8>>>,
66    contact_next_strand_nonce: Option<Zeroizing<Vec<u8>>>,
67
68    our_smp_nonce: Option<Zeroizing<Vec<u8>>>,
69    contact_smp_nonce: Option<Zeroizing<Vec<u8>>>,
70    contact_smp_proof: Option<Zeroizing<Vec<u8>>>,
71
72    smp_answer: Option<Zeroizing<String>>,
73    smp_question: Option<String>,
74
75    our_pads: Option<Zeroizing<Vec<u8>>>,
76    contact_pads: Option<Zeroizing<Vec<u8>>>,
77
78    our_hash_chain: Option<Zeroizing<Vec<u8>>>,
79    contact_hash_chain: Option<Zeroizing<Vec<u8>>>,
80
81    #[zeroize(skip)]
82    backup: Option<Box<Contact>>
83}
84
85
86impl Contact {
87    /// Create new contact
88    pub fn new() -> Result<Self, Error> {
89        let mut contact = Contact {
90            state: ContactState::Uninitialized,
91
92            message_locked: false,
93
94            our_smp_tmp_pub_key: None,
95            our_smp_tmp_secret_key: None,
96            contact_smp_tmp_pub_key: None,
97
98            our_ml_kem_pub_key: None, 
99            our_ml_kem_secret_key: None, 
100            contact_ml_kem_pub_key: None,
101
102            our_mceliece_pub_key: None,
103            our_mceliece_secret_key: None, 
104            contact_mceliece_pub_key: None,
105
106
107            our_staged_ml_kem_pub_key: None, 
108            our_staged_ml_kem_secret_key: None, 
109            
110            our_staged_mceliece_pub_key: None, 
111            our_staged_mceliece_secret_key: None, 
112           
113
114            our_signing_pub_key: None,
115            our_signing_secret_key: None,
116
117            contact_signing_pub_key: None,
118
119            our_next_strand_key: None,
120            our_next_strand_nonce: None,
121            contact_next_strand_key: None,
122            contact_next_strand_nonce: None,
123            our_smp_nonce: None,
124            contact_smp_nonce: None,
125            contact_smp_proof: None,
126
127            smp_answer: None,
128            smp_question: None,
129
130            our_pads: None,
131            contact_pads: None,
132
133            our_hash_chain: None, 
134            contact_hash_chain: None,
135
136            backup: None,
137        };
138
139        contact.init_lt_sign_keypair()?;
140        Ok(contact)
141    }
142
143
144   
145    /// Process an incoming blob, returning optional outgoing blob
146    pub fn process(&mut self, data: &[u8]) -> Result<ContactOutput, Error> {
147        self.save_backup();
148        let result = match self.state {
149            ContactState::Uninitialized => self.do_smp_step_2(data),
150            ContactState::SMPInit       => self.do_smp_step_3(data),
151            ContactState::SMPStep2      => self.do_smp_step_4_request_answer(data),
152            ContactState::SMPStep3      => self.do_smp_step_5(data),
153            ContactState::Verified      => self.process_verified(data)
154        };
155        
156        if self.state != ContactState::Verified {
157            // If we are not verified, we must still be in SMP, therefore if we encounter any
158            // error, we send failure.
159            if result.is_ok() {
160                return result;
161            } else {
162                return self.do_smp_failure();
163            }
164        }
165
166        result
167    }
168    
169    pub fn process_verified(&mut self, data: &[u8]) -> Result<ContactOutput, Error> {
170        let data_plaintext = self.decrypt_incoming_data(data)?;
171
172        let type_byte = data_plaintext.get(0)
173            .ok_or(Error::InvalidDataPlaintextLength)?;
174
175        if type_byte == &consts::PFS_TYPE_PFS_NEW {
176            let pfs_plaintext = data_plaintext.get(1..)
177                .ok_or(Error::InvalidPfsPlaintextLength)?;
178            return self.do_pfs_new(pfs_plaintext);
179        
180        } else if type_byte == &consts::PFS_TYPE_PFS_ACK {
181            let pfs_plaintext = data_plaintext.get(1..)
182                .ok_or(Error::InvalidPfsPlaintextLength)?;
183            return self.do_pfs_ack(pfs_plaintext);
184        
185        } else if type_byte == &consts::MSG_TYPE_MSG_BATCH {
186            let msgs_plaintext = data_plaintext.get(1..)
187                .ok_or(Error::InvalidMsgsPlaintextLength)?;
188
189            return self.do_process_otp_batch(msgs_plaintext);
190
191        } else if type_byte == &consts::MSG_TYPE_MSG_NEW {
192            let msgs_plaintext = data_plaintext.get(1..)
193                .ok_or(Error::InvalidMsgsPlaintextLength)?;
194
195            return self.do_process_new_msg(msgs_plaintext);
196        }
197
198        Err(Error::InvalidDataType)
199    }
200  
201    
202    fn init_lt_sign_keypair(&mut self) -> Result<(), Error> {
203        let (pk, sk) = crypto::generate_ml_dsa_87_keypair()
204            .map_err(|_| Error::CryptoFail)?;
205
206
207        self.our_signing_pub_key = Some(pk);
208        self.our_signing_secret_key = Some(sk);
209
210        Ok(())
211    }
212
213
214    fn init_tmp_kem_keypair(&mut self) -> Result<(), Error> {
215        let (pk, sk) = crypto::generate_kem_keypair(oqs::kem::Algorithm::MlKem1024)
216            .map_err(|_| Error::CryptoFail)?;
217
218
219        self.our_smp_tmp_pub_key = Some(pk);
220        self.our_smp_tmp_secret_key = Some(sk);
221
222        Ok(())
223    }
224
225}
226
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232
233    #[test]
234    fn test_export_import() {
235        let mut alice = Contact::new().expect("Failed to create new contact instance");
236        let mut bob = Contact::new().expect("Failed to create new contact instance");
237
238
239        let bob_plain = alice.export_plain().unwrap();
240        let alice_plain = bob.export_plain().unwrap();
241
242        assert_ne!(bob_plain, alice_plain, "Bob exported plaintext is equal to Alice exported plaintext.");
243
244
245    }
246
247    #[test]
248    fn test_full_two_sessions() {
249        // Alice initiates a new SMP session
250        let alice_question = String::from("This is a question");
251        let alice_answer = String::from("This is an answer");
252
253        let mut alice = Contact::new().expect("Failed to create new contact instance");
254
255        let result = alice.init_smp(
256            alice_question.clone(),
257            alice_answer.clone()
258        );
259
260        println!("Alice result: {:?}", result);
261        assert!(result.is_ok());
262
263        let result = match result.unwrap() {
264            ContactOutput::Wire(w) => w,
265            _ => panic!("Expected Wire output"),
266        };
267
268        assert_eq!(result.len(), 1, "Expected exactly one wire message");
269        assert_eq!(result[0].len(), 1 + consts::ML_KEM_1024_PK_SIZE, "SMP init output length mismatch");
270        assert_eq!(result[0][0], consts::SMP_TYPE_INIT_SMP, "SMP type byte mismatch");
271
272
273        // Bob processes Alice's result.
274        let mut bob = Contact::new().expect("Failed to create new contact instance");
275
276        let result = bob.process(result[0].as_ref());
277        println!("Bob result: {:?}", result);
278        assert!(result.is_ok());
279
280        let result = match result.unwrap() {
281            ContactOutput::Wire(w) => w,
282            _ => panic!("Expected Wire output"),
283        };
284
285        assert_eq!(result.len(), 1, "Expected exactly one wire message");
286
287        assert!(
288            result[0].len() >= 
289            1 + (consts::ML_KEM_1024_CT_SIZE * 2) + 16 + (consts::CHACHA20POLY1305_NONCE_SIZE * 3) + 32 + consts::SMP_NONCE_SIZE + consts::ML_DSA_87_PK_SIZE, 
290            "SMP step 2 output length mismatch"
291        );
292        assert_eq!(result[0][0], consts::SMP_TYPE_INIT_SMP, "SMP type byte mismatch");
293
294
295
296        // Alice processes Bob's result.
297        let result = alice.process(result[0].as_ref());
298        println!("Alice result: {:?}", result);
299        assert!(result.is_ok());
300
301        let result = match result.unwrap() {
302            ContactOutput::Wire(w) => w,
303            _ => panic!("Expected Wire output"),
304        };
305
306        assert_eq!(result.len(), 1, "Expected exactly one wire message");
307
308
309        // Bob processes Alice's result.
310        let result = bob.process(result[0].as_ref());
311        println!("Bob result: {:?}", result);
312        assert!(result.is_ok());
313
314        let result = match result.unwrap() {
315            ContactOutput::Prompt(p) => p,
316            _ => panic!("Expected Prompt output"),
317        };
318
319
320        let bob_question = result.question;
321
322        assert_eq!(bob_question, alice_question, "Bob question and Alice question do not match");
323
324        let bob_answer = String::from("This is an answer");
325
326        let bob_user_answer = UserAnswer::new(bob_answer).expect("Failed to create new UserAnswer instance");
327        
328        let result = bob.provide_smp_answer(bob_user_answer);
329        println!("Bob provide_smp_answer: result: {:?}", result);
330        assert!(result.is_ok());
331
332        let result = match result.unwrap() {
333            ContactOutput::Wire(w) => w,
334            _ => panic!("Expected Wire output"),
335        };
336
337
338        assert_eq!(result.len(), 1, "Expected exactly one wire message");
339
340
341        let result = alice.process(result[0].as_ref());
342        println!("Alice result: {:?}", result);
343        assert!(result.is_ok());
344
345        let result = match result.unwrap() {
346            ContactOutput::Wire(w) => w,
347            _ => panic!("Expected Wire output"),
348        };
349
350        assert_eq!(result.len(), 1, "Expected exactly one wire message");
351
352        // PFS
353
354        let result = bob.process(result[0].as_ref());
355        println!("Bob result: {:?}", result);
356        assert!(result.is_ok());
357
358        let result = match result.unwrap() {
359            ContactOutput::Wire(w) => w,
360            _ => panic!("Expected Wire output"),
361        };
362
363        assert_eq!(result.len(), 2, "Expected exactly 2 wire messages");
364
365
366        let result_1 = alice.process(result[0].as_ref());
367        println!("Alice result 1: {:?}", result_1);
368        assert!(result_1.is_ok());
369
370        match result_1.unwrap() {
371            ContactOutput::None => {},
372            _ => panic!("Expected None output"),
373        };
374        
375
376        let result_2 = alice.process(result[1].as_ref());
377        println!("Alice result 2: {:?}", result_2);
378        assert!(result_2.is_ok());
379
380        let result_2 = match result_2.unwrap() {
381            ContactOutput::Wire(w) => w,
382            _ => panic!("Expected Wire output"),
383        };
384
385
386        assert_eq!(result_2.len(), 1, "Expected exactly one wire message");
387
388        let result = bob.process(result_2[0].as_ref());
389        println!("Bob result: {:?}", result);
390        assert!(result.is_ok());
391
392        let result = match result.unwrap() {
393            ContactOutput::None => {},
394            _ => panic!("Expected None output"),
395        };
396
397
398        // MSGS:
399        
400        let alice_message_1 = Zeroizing::new(String::from("Hello, World!"));
401
402        let result = alice.send_message(&alice_message_1);
403        println!("Alice result: {:?}", result);
404        assert!(result.is_ok());
405
406        let result = match result.unwrap() {
407            ContactOutput::Wire(w) => w,
408            _ => panic!("Expected Wire output"),
409        };
410
411  
412        // 2 because we never sent pads to Bob yet.
413        assert_eq!(result.len(), 2, "Expected exactly 2 wire message");
414
415
416        let r = alice.i_confirm_message_has_been_sent();
417        assert!(r.is_ok());
418
419
420        let result_1 = bob.process(result[0].as_ref());
421        println!("Bob result 1: {:?}", result_1);
422        assert!(result_1.is_ok());
423
424        match result_1.unwrap() {
425            ContactOutput::Wire(w) => w,
426            _ => panic!("Expected Wire output"),
427        };
428        
429
430        let result_2 = bob.process(result[1].as_ref());
431        println!("Bob result 2: {:?}", result_2);
432        assert!(result_2.is_ok());
433
434        let result_2 = match result_2.unwrap() {
435            ContactOutput::Message(m) => m,
436            _ => panic!("Expected Message output"),
437        };
438
439        assert_eq!(alice_message_1, result_2.message, "Decrypted message not equal to original message");
440
441
442        let alice_message_2 = Zeroizing::new(String::from("Hi Bob!!"));
443
444        let result = alice.send_message(&alice_message_2);
445        println!("Alice result: {:?}", result);
446        assert!(result.is_ok());
447
448        let result = match result.unwrap() {
449            ContactOutput::Wire(w) => w,
450            _ => panic!("Expected Wire output"),
451        };
452
453  
454        // 1 because we should've at this point sent Bob enough pads
455        assert_eq!(result.len(), 1, "Expected exactly one wire message");
456
457        let r = alice.i_confirm_message_has_been_sent();
458        assert!(r.is_ok());
459
460
461        let result = bob.process(result[0].as_ref());
462        println!("Bob result: {:?}", result);
463        assert!(result.is_ok());
464
465        let result = match result.unwrap() {
466            ContactOutput::Message(m) => m,
467            _ => panic!("Expected Message output"),
468        };
469
470        assert_eq!(alice_message_2, result.message, "Decrypted message not equal to original message");
471
472
473
474        // This should error
475        let r = alice.i_confirm_message_has_been_sent();
476        assert!(r.is_err(), "Confirmation over use did not cause an error");
477
478
479
480        let bob_plain = alice.export_plain().unwrap();
481        let alice_plain = bob.export_plain().unwrap();
482
483        assert_ne!(bob_plain, alice_plain, "Bob exported plaintext is the same as Alice exported plaintext.");
484
485        let bob_plain_2 = alice.export_plain().unwrap();
486        let alice_plain_2 = bob.export_plain().unwrap();
487
488
489        // There should not be any changes on both ends since we have not altered the struct.
490        assert_eq!(bob_plain_2, bob_plain, "Bob re-exported plaintext is not equal to Bob exported plaintext.");
491        
492        assert_eq!(alice_plain_2, alice_plain, "Alice re-exported plaintext is not equal to Alice exported plaintext.");
493
494        assert_ne!(bob_plain_2, alice_plain_2, "Bob re-exported plaintext is the same as Alice re-exported plaintext.");
495
496
497
498        let alice = Contact::import_plain(bob_plain.as_slice()).unwrap();
499        let bob = Contact::import_plain(alice_plain.as_slice()).unwrap();
500
501        let bob_plain = alice.export_plain().unwrap();
502        let alice_plain = bob.export_plain().unwrap();
503
504        assert_ne!(bob_plain, alice_plain, "Bob exported plaintext is the same as Alice exported plaintext.");
505
506        let bob_plain_2 = alice.export_plain().unwrap();
507        let alice_plain_2 = bob.export_plain().unwrap();
508
509
510        // There should not be any changes on both ends since we have not altered the struct.
511        assert_eq!(bob_plain_2, bob_plain, "Bob re-exported plaintext is not equal to Bob exported plaintext.");
512        
513        assert_eq!(alice_plain_2, alice_plain, "Alice re-exported plaintext is not equal to Alice exported plaintext.");
514
515        assert_ne!(bob_plain_2, alice_plain_2, "Bob re-exported plaintext is the same as Alice re-exported plaintext.");
516
517    }
518}
519
520