use hkdf::Hkdf;
use sha2::Sha256;
use zeroize::{Zeroize, Zeroizing};
pub struct RecursiveChain {
chain: Zeroizing<[u8; 32]>,
alice_key: Zeroizing<[u8; 32]>,
bob_key: Zeroizing<[u8; 32]>,
generation: u64,
is_alice: bool,
}
impl RecursiveChain {
pub fn new(shared_secret: &[u8; 32], is_alice: bool) -> Self {
let hk = Hkdf::<Sha256>::new(Some(b"zks-recursive-chain-v1"), shared_secret);
let mut alice_key = Zeroizing::new([0u8; 32]);
let mut bob_key = Zeroizing::new([0u8; 32]);
let mut chain = Zeroizing::new([0u8; 32]);
hk.expand(b"alice-initial-key", alice_key.as_mut())
.expect("HKDF expansion should not fail");
hk.expand(b"bob-initial-key", bob_key.as_mut())
.expect("HKDF expansion should not fail");
let mut xor_result = [0u8; 32];
for i in 0..32 {
xor_result[i] = alice_key[i] ^ bob_key[i];
}
let chain_hk = Hkdf::<Sha256>::new(Some(b"zks-chain-key"), &xor_result);
chain_hk.expand(b"chain-0", chain.as_mut())
.expect("HKDF expansion should not fail");
xor_result.zeroize();
Self {
chain,
alice_key,
bob_key,
generation: 0,
is_alice,
}
}
pub fn advance(&mut self, new_entropy: &[u8]) -> [u8; 32] {
let mut input = Vec::with_capacity(32 + 32 + new_entropy.len() + 8);
input.extend_from_slice(&*self.chain);
input.extend_from_slice(if self.is_alice { &*self.alice_key } else { &*self.bob_key });
input.extend_from_slice(new_entropy);
input.extend_from_slice(&self.generation.to_le_bytes());
let hk = Hkdf::<Sha256>::new(Some(b"zks-key-advance"), &input);
let mut new_session_key = [0u8; 32];
hk.expand(b"session-key", &mut new_session_key)
.expect("HKDF expansion should not fail");
let mut new_contribution = Zeroizing::new([0u8; 32]);
hk.expand(b"contribution-key", new_contribution.as_mut())
.expect("HKDF expansion should not fail");
if self.is_alice {
self.alice_key = new_contribution;
} else {
self.bob_key = new_contribution;
}
let mut xor_result = [0u8; 32];
for i in 0..32 {
xor_result[i] = self.alice_key[i] ^ self.bob_key[i];
}
let chain_hk = Hkdf::<Sha256>::new(Some(b"zks-chain-key"), &xor_result);
chain_hk.expand(&(self.generation + 1).to_le_bytes(), self.chain.as_mut())
.expect("HKDF expansion should not fail");
xor_result.zeroize();
input.zeroize();
self.generation += 1;
tracing::debug!(
"🔗 RecursiveChain advanced to generation {} (is_alice: {})",
self.generation,
self.is_alice
);
new_session_key
}
pub fn update_peer_key(&mut self, peer_contribution: &[u8; 32]) {
if self.is_alice {
self.bob_key.copy_from_slice(peer_contribution);
} else {
self.alice_key.copy_from_slice(peer_contribution);
}
let mut xor_result = [0u8; 32];
for i in 0..32 {
xor_result[i] = self.alice_key[i] ^ self.bob_key[i];
}
let chain_hk = Hkdf::<Sha256>::new(Some(b"zks-chain-key"), &xor_result);
chain_hk.expand(&self.generation.to_le_bytes(), self.chain.as_mut())
.expect("HKDF expansion should not fail");
xor_result.zeroize();
}
pub fn generation(&self) -> u64 {
self.generation
}
pub fn our_contribution(&self) -> [u8; 32] {
if self.is_alice {
*self.alice_key
} else {
*self.bob_key
}
}
pub fn export_state(&self) -> ChainState {
ChainState {
chain: *self.chain,
alice_key: *self.alice_key,
bob_key: *self.bob_key,
generation: self.generation,
is_alice: self.is_alice,
}
}
pub fn import_state(state: ChainState) -> Self {
Self {
chain: Zeroizing::new(state.chain),
alice_key: Zeroizing::new(state.alice_key),
bob_key: Zeroizing::new(state.bob_key),
generation: state.generation,
is_alice: state.is_alice,
}
}
}
pub struct ChainState {
pub chain: [u8; 32],
pub alice_key: [u8; 32],
pub bob_key: [u8; 32],
pub generation: u64,
pub is_alice: bool,
}
impl Drop for ChainState {
fn drop(&mut self) {
self.chain.zeroize();
self.alice_key.zeroize();
self.bob_key.zeroize();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chain_creation() {
let shared_secret = [0x42u8; 32];
let chain = RecursiveChain::new(&shared_secret, true);
assert_eq!(chain.generation(), 0);
}
#[test]
fn test_chain_advance() {
let shared_secret = [0x42u8; 32];
let mut chain = RecursiveChain::new(&shared_secret, true);
let entropy = [0xABu8; 16];
let key1 = chain.advance(&entropy);
assert_eq!(chain.generation(), 1);
let key2 = chain.advance(&entropy);
assert_eq!(chain.generation(), 2);
assert_ne!(key1, key2);
}
#[test]
fn test_synchronized_chains() {
let shared_secret = [0x42u8; 32];
let mut alice = RecursiveChain::new(&shared_secret, true);
let mut bob = RecursiveChain::new(&shared_secret, false);
assert_eq!(alice.generation(), 0);
assert_eq!(bob.generation(), 0);
let alice_contrib = alice.our_contribution();
let bob_contrib = bob.our_contribution();
alice.update_peer_key(&bob_contrib);
bob.update_peer_key(&alice_contrib);
let entropy = [0xABu8; 16];
let _alice_key = alice.advance(&entropy);
let _bob_key = bob.advance(&entropy);
assert_eq!(alice.generation(), bob.generation());
}
#[test]
fn test_export_import_state() {
let shared_secret = [0x42u8; 32];
let mut chain = RecursiveChain::new(&shared_secret, true);
let entropy = [0xABu8; 16];
chain.advance(&entropy);
let state = chain.export_state();
let restored = RecursiveChain::import_state(state);
assert_eq!(chain.generation(), restored.generation());
}
}