light_double_ratchet/
kdf_chains.rs

1//! This module defines key derivation function (KDF) chains that are used to generate new keys
2//! from old ones in a manner that is irreversible.
3
4use sha2::Sha256;
5use hkdf::Hkdf;
6use hmac::{Hmac, Mac};
7
8/// This chain starts with an initial shared secret key and is maintained by
9/// both parties in order to derive sending and receiving chain keys, taking
10/// Diffie-Hellman agreements as input.
11pub struct RootChain {
12    cur_root_key: Vec<u8>,
13}
14
15
16impl RootChain {
17    /// Creates a new chain from a shared secret, `root_key`.
18    pub fn new(root_key: Vec<u8>) -> Self {
19        Self {
20            cur_root_key: root_key
21        }
22    }
23
24    /// Processes the `diffie_hellman_output` as the input to this KDF and returns the 32-byte
25    /// output chain key, storing the other 32 bytes of the key expansion into an internal "current
26    /// root key" vector.
27    pub fn step(&mut self, diffie_hellman_output: &[u8]) -> Vec<u8> {
28        let salt = self.cur_root_key.as_slice();
29        let initial_key_material = diffie_hellman_output;
30        let hkdf = Hkdf::<Sha256>::new(Some(salt), initial_key_material);
31
32        // 64 bytes because first 32 will be the next root key and second 32 will be the output
33        let mut output_key_material: [u8; 64] = [0; 64];
34        hkdf.expand("RootChainKDF".as_bytes(), &mut output_key_material)
35            .expect("output_key_material should be of a valid length");
36
37        self.cur_root_key.copy_from_slice(&output_key_material[..32]);
38        output_key_material[32..].to_vec()
39    }
40}
41
42/// This chain starts with an initial chain key (generated via the root chain)
43/// and is used by both parties to derive message keys. The sender and the receiver both
44/// create a `SendReceiveChain` using the same initial chain key, which ensures that the
45/// derived message keys are the same for both parties.
46pub struct SendReceiveChain {
47    cur_chain_key: Vec<u8>,
48}
49
50impl SendReceiveChain {
51    /// Creates a new chain from a shared `chain_key`
52    pub fn new(chain_key: Vec<u8>) -> Self {
53        Self {
54            cur_chain_key: chain_key
55        }
56    }
57
58    /// Evaluates HMAC keyed by the current chain key on the given `input`.
59    fn evaluate_hmac(&self, input: &[u8]) -> Vec<u8> {
60        let mut hmac = Hmac::<Sha256>::new_from_slice(&self.cur_chain_key)
61            .expect("while some MACs might reject a key, HMAC should work with any key");
62
63        hmac.update(input);
64        hmac.finalize().into_bytes().to_vec()
65    }
66
67    /// Steps the key derivation function (KDF) once and returns the derived message key.
68    //
69    /// This implementation follows the Signal protocol specification (Revision 1, 2016-11-20; see
70    /// [PDF][a]), using an HMAC keyed by the current chain key. As per the protocol, the new chain
71    /// key and the message key are generated via `HMAC(cur_key, 0x1)` and `HMAC(cur_key, 0x2)`
72    /// respectively.
73    ///
74    /// [a]: https://signal.org/docs/specifications/doubleratchet/doubleratchet.pdf
75    pub fn step(&mut self) -> Vec<u8> {
76
77        let new_chain_key = self.evaluate_hmac(b"\x01");
78        let message_key = self.evaluate_hmac(b"\x02");
79
80        self.cur_chain_key = new_chain_key;
81
82        message_key
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn basic_root_chain() {
92        let key = vec![0x42; 32];
93        let mut chain1 = RootChain::new(key.clone());
94        let mut chain2 = RootChain::new(key);
95        let dh_output = vec![0xff; 32];
96        assert_eq!(chain1.step(&dh_output).len(), 32, "chain output should 32 bytes");
97        assert_eq!(chain2.step(&dh_output).len(), 32, "chain output should be 32 bytes");
98        assert_eq!(chain1.step(&dh_output), chain2.step(&dh_output), "chains should match");
99        assert_eq!(chain1.step(&dh_output), chain2.step(&dh_output), "chains should match");
100        chain1.step(&dh_output); // chains are now out of sync
101        assert_ne!(chain1.step(&dh_output), chain2.step(&dh_output), "chains shouldn't match");
102    }
103
104    #[test]
105    fn basic_send_receive_chain() {
106        let shared_key = vec![0x42; 32];
107        let mut chain1 = SendReceiveChain::new(shared_key.clone());
108        let mut chain2 = SendReceiveChain::new(shared_key);
109        assert_eq!(chain1.step().len(), 32, "chain output should 32 bytes");
110        assert_eq!(chain2.step().len(), 32, "chain output should be 32 bytes");
111        assert_eq!(chain1.step(), chain2.step());
112        assert_eq!(chain1.step(), chain2.step(), "chains should match");
113        chain1.step(); // chains are now out of sync
114        assert_ne!(chain1.step(), chain2.step(), "chains shouldn't match");
115    }
116}