light_double_ratchet/
kdf_chains.rs1use sha2::Sha256;
5use hkdf::Hkdf;
6use hmac::{Hmac, Mac};
7
8pub struct RootChain {
12 cur_root_key: Vec<u8>,
13}
14
15
16impl RootChain {
17 pub fn new(root_key: Vec<u8>) -> Self {
19 Self {
20 cur_root_key: root_key
21 }
22 }
23
24 pub fn step(&mut self, diffie_hellman_output: &[u8]) -> Vec<u8> {
28 let salt = self.cur_root_key.as_slice();
29 let initial_key_material = diffie_hellman_output;
30 let hkdf = Hkdf::<Sha256>::new(Some(salt), initial_key_material);
31
32 let mut output_key_material: [u8; 64] = [0; 64];
34 hkdf.expand("RootChainKDF".as_bytes(), &mut output_key_material)
35 .expect("output_key_material should be of a valid length");
36
37 self.cur_root_key.copy_from_slice(&output_key_material[..32]);
38 output_key_material[32..].to_vec()
39 }
40}
41
42pub struct SendReceiveChain {
47 cur_chain_key: Vec<u8>,
48}
49
50impl SendReceiveChain {
51 pub fn new(chain_key: Vec<u8>) -> Self {
53 Self {
54 cur_chain_key: chain_key
55 }
56 }
57
58 fn evaluate_hmac(&self, input: &[u8]) -> Vec<u8> {
60 let mut hmac = Hmac::<Sha256>::new_from_slice(&self.cur_chain_key)
61 .expect("while some MACs might reject a key, HMAC should work with any key");
62
63 hmac.update(input);
64 hmac.finalize().into_bytes().to_vec()
65 }
66
67 pub fn step(&mut self) -> Vec<u8> {
76
77 let new_chain_key = self.evaluate_hmac(b"\x01");
78 let message_key = self.evaluate_hmac(b"\x02");
79
80 self.cur_chain_key = new_chain_key;
81
82 message_key
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 #[test]
91 fn basic_root_chain() {
92 let key = vec![0x42; 32];
93 let mut chain1 = RootChain::new(key.clone());
94 let mut chain2 = RootChain::new(key);
95 let dh_output = vec![0xff; 32];
96 assert_eq!(chain1.step(&dh_output).len(), 32, "chain output should 32 bytes");
97 assert_eq!(chain2.step(&dh_output).len(), 32, "chain output should be 32 bytes");
98 assert_eq!(chain1.step(&dh_output), chain2.step(&dh_output), "chains should match");
99 assert_eq!(chain1.step(&dh_output), chain2.step(&dh_output), "chains should match");
100 chain1.step(&dh_output); assert_ne!(chain1.step(&dh_output), chain2.step(&dh_output), "chains shouldn't match");
102 }
103
104 #[test]
105 fn basic_send_receive_chain() {
106 let shared_key = vec![0x42; 32];
107 let mut chain1 = SendReceiveChain::new(shared_key.clone());
108 let mut chain2 = SendReceiveChain::new(shared_key);
109 assert_eq!(chain1.step().len(), 32, "chain output should 32 bytes");
110 assert_eq!(chain2.step().len(), 32, "chain output should be 32 bytes");
111 assert_eq!(chain1.step(), chain2.step());
112 assert_eq!(chain1.step(), chain2.step(), "chains should match");
113 chain1.step(); assert_ne!(chain1.step(), chain2.step(), "chains shouldn't match");
115 }
116}