use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use hkdf::Hkdf;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use subtle::ConstantTimeEq as _;
use crate::error::{CrdtError, Result};
type HmacSha256 = Hmac<Sha256>;
pub const SIGNATURE_SIZE: usize = 32;
const DEVICE_KEY_SALT: &[u8] = b"nodedb-crdt-device-key";
#[derive(Debug, Clone, Default)]
struct DeviceState {
last_seq_no: u64,
}
#[derive(Debug, Default)]
pub struct DeviceRegistry {
inner: Mutex<HashMap<(u64, u64), DeviceState>>,
}
impl DeviceRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn check_seq(&self, user_id: u64, device_id: u64, seq_no: u64) -> Result<u64> {
let guard = self
.inner
.lock()
.map_err(|_| CrdtError::DeltaApplyFailed("device registry lock poisoned".into()))?;
let last_seen = guard
.get(&(user_id, device_id))
.map_or(0u64, |s| s.last_seq_no);
if seq_no <= last_seen {
return Err(CrdtError::ReplayDetected {
user_id,
device_id,
seq_no,
last_seen,
});
}
Ok(last_seen)
}
pub fn commit_seq(&self, user_id: u64, device_id: u64, seq_no: u64) -> Result<()> {
let mut guard = self
.inner
.lock()
.map_err(|_| CrdtError::DeltaApplyFailed("device registry lock poisoned".into()))?;
let entry = guard.entry((user_id, device_id)).or_default();
if seq_no > entry.last_seq_no {
entry.last_seq_no = seq_no;
}
Ok(())
}
pub fn seed(&self, user_id: u64, device_id: u64, last_seq_no: u64) -> Result<()> {
let mut guard = self
.inner
.lock()
.map_err(|_| CrdtError::DeltaApplyFailed("device registry lock poisoned".into()))?;
guard.entry((user_id, device_id)).or_default().last_seq_no = last_seq_no;
Ok(())
}
pub fn last_seen(&self, user_id: u64, device_id: u64) -> u64 {
self.inner
.lock()
.ok()
.and_then(|g| g.get(&(user_id, device_id)).map(|s| s.last_seq_no))
.unwrap_or(0)
}
}
pub struct DeltaSigner {
keys: HashMap<u64, [u8; 32]>,
pub(crate) registry: Arc<DeviceRegistry>,
}
impl DeltaSigner {
pub fn new() -> Self {
Self {
keys: HashMap::new(),
registry: Arc::new(DeviceRegistry::new()),
}
}
pub fn with_registry(registry: Arc<DeviceRegistry>) -> Self {
Self {
keys: HashMap::new(),
registry,
}
}
pub fn register_key(&mut self, user_id: u64, key: [u8; 32]) {
self.keys.insert(user_id, key);
}
pub fn remove_key(&mut self, user_id: u64) {
self.keys.remove(&user_id);
}
fn device_key(&self, user_id: u64, device_id: u64) -> Result<[u8; 32]> {
let stored = self
.keys
.get(&user_id)
.ok_or_else(|| CrdtError::InvalidSignature {
user_id,
detail: "no signing key registered for user".into(),
})?;
let hk = Hkdf::<Sha256>::new(Some(DEVICE_KEY_SALT), stored.as_slice());
let mut okm = [0u8; 32];
hk.expand(&device_id.to_le_bytes(), &mut okm)
.map_err(|_| CrdtError::InvalidSignature {
user_id,
detail: "HKDF expand failed (output too long)".into(),
})?;
Ok(okm)
}
pub fn sign(
&self,
user_id: u64,
device_id: u64,
seq_no: u64,
delta_bytes: &[u8],
) -> Result<[u8; SIGNATURE_SIZE]> {
let key = self.device_key(user_id, device_id)?;
Ok(compute_hmac(&key, user_id, device_id, seq_no, delta_bytes))
}
pub fn verify(
&self,
user_id: u64,
device_id: u64,
seq_no: u64,
delta_bytes: &[u8],
signature: &[u8; SIGNATURE_SIZE],
) -> Result<()> {
let key = self.device_key(user_id, device_id)?;
let expected = compute_hmac(&key, user_id, device_id, seq_no, delta_bytes);
if expected.ct_eq(signature).into() {
Ok(())
} else {
Err(CrdtError::InvalidSignature {
user_id,
detail: "HMAC-SHA256 mismatch".into(),
})
}
}
pub fn registry(&self) -> &Arc<DeviceRegistry> {
&self.registry
}
}
impl Default for DeltaSigner {
fn default() -> Self {
Self::new()
}
}
fn compute_hmac(
key: &[u8; 32],
user_id: u64,
device_id: u64,
seq_no: u64,
delta_bytes: &[u8],
) -> [u8; SIGNATURE_SIZE] {
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key size");
mac.update(delta_bytes);
mac.update(&user_id.to_le_bytes());
mac.update(&device_id.to_le_bytes());
mac.update(&seq_no.to_le_bytes());
let result = mac.finalize();
let mut out = [0u8; SIGNATURE_SIZE];
out.copy_from_slice(&result.into_bytes());
out
}
#[cfg(test)]
mod tests {
use super::*;
fn make_signer(user_id: u64, key: [u8; 32]) -> DeltaSigner {
let mut s = DeltaSigner::new();
s.register_key(user_id, key);
s
}
#[test]
fn hmac_golden_vector() {
let signer = make_signer(1, [0x42u8; 32]);
let sig = signer.sign(1, 2, 1, b"nodedb").unwrap();
let device_key = {
let hk = Hkdf::<Sha256>::new(Some(DEVICE_KEY_SALT), &[0x42u8; 32]);
let mut okm = [0u8; 32];
hk.expand(&2u64.to_le_bytes(), &mut okm).unwrap();
okm
};
let expected = compute_hmac(&device_key, 1, 2, 1, b"nodedb");
assert_eq!(sig, expected, "HMAC golden vector must be stable");
}
#[test]
fn replay_rejected_same_device_seq() {
let signer = make_signer(1, [0x42u8; 32]);
let delta = b"test delta";
let sig = signer.sign(1, 2, 1, delta).unwrap();
signer.registry.check_seq(1, 2, 1).unwrap();
signer.verify(1, 2, 1, delta, &sig).unwrap();
signer.registry.commit_seq(1, 2, 1).unwrap();
let err = signer.registry.check_seq(1, 2, 1).unwrap_err();
assert!(
matches!(
err,
CrdtError::ReplayDetected {
seq_no: 1,
last_seen: 1,
..
}
),
"expected ReplayDetected, got {err}"
);
}
#[test]
fn cross_device_replay_rejected() {
let signer = make_signer(1, [0x42u8; 32]);
let delta = b"cross device test";
let sig = signer.sign(1, 2, 1, delta).unwrap();
let err = signer.verify(1, 3, 1, delta, &sig).unwrap_err();
assert!(
matches!(err, CrdtError::InvalidSignature { .. }),
"cross-device replay must be rejected"
);
}
#[test]
fn seq_zero_rejected() {
let registry = DeviceRegistry::new();
let err = registry.check_seq(1, 0, 0).unwrap_err();
assert!(
matches!(
err,
CrdtError::ReplayDetected {
seq_no: 0,
last_seen: 0,
..
}
),
"seq_no=0 must be rejected (not strictly greater than last_seen=0)"
);
}
#[test]
fn tampered_delta_fails_verification() {
let signer = make_signer(1, [0x42u8; 32]);
let sig = signer.sign(1, 2, 1, b"original").unwrap();
let err = signer.verify(1, 2, 1, b"tampered", &sig).unwrap_err();
assert!(matches!(err, CrdtError::InvalidSignature { .. }));
}
#[test]
fn wrong_user_fails_verification() {
let mut signer = DeltaSigner::new();
signer.register_key(1, [0x42u8; 32]);
signer.register_key(2, [0x99u8; 32]);
let sig = signer.sign(1, 5, 1, b"delta").unwrap();
let err = signer.verify(2, 5, 1, b"delta", &sig).unwrap_err();
assert!(matches!(err, CrdtError::InvalidSignature { .. }));
}
#[test]
fn unregistered_user_fails() {
let signer = DeltaSigner::new();
let err = signer.sign(99, 1, 1, b"data").unwrap_err();
assert!(matches!(
err,
CrdtError::InvalidSignature { user_id: 99, .. }
));
}
#[test]
fn seq_no_must_advance() {
let reg = DeviceRegistry::new();
reg.check_seq(1, 1, 5).unwrap();
reg.commit_seq(1, 1, 5).unwrap();
assert!(reg.check_seq(1, 1, 5).is_err());
assert!(reg.check_seq(1, 1, 4).is_err());
reg.check_seq(1, 1, 6).unwrap();
}
}