1use std::collections::HashMap;
38use std::sync::{Arc, Mutex};
39
40use hkdf::Hkdf;
41use hmac::{Hmac, Mac};
42use sha2::Sha256;
43use subtle::ConstantTimeEq as _;
44
45use crate::error::{CrdtError, Result};
46
47type HmacSha256 = Hmac<Sha256>;
48
49pub const SIGNATURE_SIZE: usize = 32;
51
52const DEVICE_KEY_SALT: &[u8] = b"nodedb-crdt-device-key";
54
55#[derive(Debug, Clone, Default)]
57struct DeviceState {
58 last_seq_no: u64,
59}
60
61#[derive(Debug, Default)]
67pub struct DeviceRegistry {
68 inner: Mutex<HashMap<(u64, u64), DeviceState>>,
69}
70
71impl DeviceRegistry {
72 pub fn new() -> Self {
74 Self::default()
75 }
76
77 pub fn check_seq(&self, user_id: u64, device_id: u64, seq_no: u64) -> Result<u64> {
86 let guard = self
87 .inner
88 .lock()
89 .map_err(|_| CrdtError::DeltaApplyFailed("device registry lock poisoned".into()))?;
90 let last_seen = guard
91 .get(&(user_id, device_id))
92 .map_or(0u64, |s| s.last_seq_no);
93 if seq_no <= last_seen {
94 return Err(CrdtError::ReplayDetected {
95 user_id,
96 device_id,
97 seq_no,
98 last_seen,
99 });
100 }
101 Ok(last_seen)
102 }
103
104 pub fn commit_seq(&self, user_id: u64, device_id: u64, seq_no: u64) -> Result<()> {
108 let mut guard = self
109 .inner
110 .lock()
111 .map_err(|_| CrdtError::DeltaApplyFailed("device registry lock poisoned".into()))?;
112 let entry = guard.entry((user_id, device_id)).or_default();
113 if seq_no > entry.last_seq_no {
114 entry.last_seq_no = seq_no;
115 }
116 Ok(())
117 }
118
119 pub fn seed(&self, user_id: u64, device_id: u64, last_seq_no: u64) -> Result<()> {
122 let mut guard = self
123 .inner
124 .lock()
125 .map_err(|_| CrdtError::DeltaApplyFailed("device registry lock poisoned".into()))?;
126 guard.entry((user_id, device_id)).or_default().last_seq_no = last_seq_no;
127 Ok(())
128 }
129
130 pub fn last_seen(&self, user_id: u64, device_id: u64) -> u64 {
132 self.inner
133 .lock()
134 .ok()
135 .and_then(|g| g.get(&(user_id, device_id)).map(|s| s.last_seq_no))
136 .unwrap_or(0)
137 }
138}
139
140pub struct DeltaSigner {
144 keys: HashMap<u64, [u8; 32]>,
147 pub(crate) registry: Arc<DeviceRegistry>,
149}
150
151impl DeltaSigner {
152 pub fn new() -> Self {
154 Self {
155 keys: HashMap::new(),
156 registry: Arc::new(DeviceRegistry::new()),
157 }
158 }
159
160 pub fn with_registry(registry: Arc<DeviceRegistry>) -> Self {
163 Self {
164 keys: HashMap::new(),
165 registry,
166 }
167 }
168
169 pub fn register_key(&mut self, user_id: u64, key: [u8; 32]) {
171 self.keys.insert(user_id, key);
172 }
173
174 pub fn remove_key(&mut self, user_id: u64) {
176 self.keys.remove(&user_id);
177 }
178
179 fn device_key(&self, user_id: u64, device_id: u64) -> Result<[u8; 32]> {
183 let stored = self
184 .keys
185 .get(&user_id)
186 .ok_or_else(|| CrdtError::InvalidSignature {
187 user_id,
188 detail: "no signing key registered for user".into(),
189 })?;
190
191 let hk = Hkdf::<Sha256>::new(Some(DEVICE_KEY_SALT), stored.as_slice());
192 let mut okm = [0u8; 32];
193 hk.expand(&device_id.to_le_bytes(), &mut okm)
194 .map_err(|_| CrdtError::InvalidSignature {
195 user_id,
196 detail: "HKDF expand failed (output too long)".into(),
197 })?;
198 Ok(okm)
199 }
200
201 pub fn sign(
206 &self,
207 user_id: u64,
208 device_id: u64,
209 seq_no: u64,
210 delta_bytes: &[u8],
211 ) -> Result<[u8; SIGNATURE_SIZE]> {
212 let key = self.device_key(user_id, device_id)?;
213 Ok(compute_hmac(&key, user_id, device_id, seq_no, delta_bytes))
214 }
215
216 pub fn verify(
220 &self,
221 user_id: u64,
222 device_id: u64,
223 seq_no: u64,
224 delta_bytes: &[u8],
225 signature: &[u8; SIGNATURE_SIZE],
226 ) -> Result<()> {
227 let key = self.device_key(user_id, device_id)?;
228 let expected = compute_hmac(&key, user_id, device_id, seq_no, delta_bytes);
229
230 if expected.ct_eq(signature).into() {
232 Ok(())
233 } else {
234 Err(CrdtError::InvalidSignature {
235 user_id,
236 detail: "HMAC-SHA256 mismatch".into(),
237 })
238 }
239 }
240
241 pub fn registry(&self) -> &Arc<DeviceRegistry> {
243 &self.registry
244 }
245}
246
247impl Default for DeltaSigner {
248 fn default() -> Self {
249 Self::new()
250 }
251}
252
253fn compute_hmac(
258 key: &[u8; 32],
259 user_id: u64,
260 device_id: u64,
261 seq_no: u64,
262 delta_bytes: &[u8],
263) -> [u8; SIGNATURE_SIZE] {
264 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key size");
265 mac.update(delta_bytes);
266 mac.update(&user_id.to_le_bytes());
267 mac.update(&device_id.to_le_bytes());
268 mac.update(&seq_no.to_le_bytes());
269 let result = mac.finalize();
270 let mut out = [0u8; SIGNATURE_SIZE];
271 out.copy_from_slice(&result.into_bytes());
272 out
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 fn make_signer(user_id: u64, key: [u8; 32]) -> DeltaSigner {
280 let mut s = DeltaSigner::new();
281 s.register_key(user_id, key);
282 s
283 }
284
285 #[test]
289 fn hmac_golden_vector() {
290 let signer = make_signer(1, [0x42u8; 32]);
292 let sig = signer.sign(1, 2, 1, b"nodedb").unwrap();
293
294 let device_key = {
296 let hk = Hkdf::<Sha256>::new(Some(DEVICE_KEY_SALT), &[0x42u8; 32]);
297 let mut okm = [0u8; 32];
298 hk.expand(&2u64.to_le_bytes(), &mut okm).unwrap();
299 okm
300 };
301 let expected = compute_hmac(&device_key, 1, 2, 1, b"nodedb");
302 assert_eq!(sig, expected, "HMAC golden vector must be stable");
303 }
304
305 #[test]
307 fn replay_rejected_same_device_seq() {
308 let signer = make_signer(1, [0x42u8; 32]);
309 let delta = b"test delta";
310 let sig = signer.sign(1, 2, 1, delta).unwrap();
311
312 signer.registry.check_seq(1, 2, 1).unwrap();
314 signer.verify(1, 2, 1, delta, &sig).unwrap();
315 signer.registry.commit_seq(1, 2, 1).unwrap();
316
317 let err = signer.registry.check_seq(1, 2, 1).unwrap_err();
319 assert!(
320 matches!(
321 err,
322 CrdtError::ReplayDetected {
323 seq_no: 1,
324 last_seen: 1,
325 ..
326 }
327 ),
328 "expected ReplayDetected, got {err}"
329 );
330 }
331
332 #[test]
335 fn cross_device_replay_rejected() {
336 let signer = make_signer(1, [0x42u8; 32]);
337 let delta = b"cross device test";
338 let sig = signer.sign(1, 2, 1, delta).unwrap();
339
340 let err = signer.verify(1, 3, 1, delta, &sig).unwrap_err();
342 assert!(
343 matches!(err, CrdtError::InvalidSignature { .. }),
344 "cross-device replay must be rejected"
345 );
346 }
347
348 #[test]
351 fn seq_zero_rejected() {
352 let registry = DeviceRegistry::new();
353 let err = registry.check_seq(1, 0, 0).unwrap_err();
355 assert!(
356 matches!(
357 err,
358 CrdtError::ReplayDetected {
359 seq_no: 0,
360 last_seen: 0,
361 ..
362 }
363 ),
364 "seq_no=0 must be rejected (not strictly greater than last_seen=0)"
365 );
366 }
367
368 #[test]
370 fn tampered_delta_fails_verification() {
371 let signer = make_signer(1, [0x42u8; 32]);
372 let sig = signer.sign(1, 2, 1, b"original").unwrap();
373 let err = signer.verify(1, 2, 1, b"tampered", &sig).unwrap_err();
374 assert!(matches!(err, CrdtError::InvalidSignature { .. }));
375 }
376
377 #[test]
379 fn wrong_user_fails_verification() {
380 let mut signer = DeltaSigner::new();
381 signer.register_key(1, [0x42u8; 32]);
382 signer.register_key(2, [0x99u8; 32]);
383
384 let sig = signer.sign(1, 5, 1, b"delta").unwrap();
385 let err = signer.verify(2, 5, 1, b"delta", &sig).unwrap_err();
386 assert!(matches!(err, CrdtError::InvalidSignature { .. }));
387 }
388
389 #[test]
391 fn unregistered_user_fails() {
392 let signer = DeltaSigner::new();
393 let err = signer.sign(99, 1, 1, b"data").unwrap_err();
394 assert!(matches!(
395 err,
396 CrdtError::InvalidSignature { user_id: 99, .. }
397 ));
398 }
399
400 #[test]
402 fn seq_no_must_advance() {
403 let reg = DeviceRegistry::new();
404 reg.check_seq(1, 1, 5).unwrap();
405 reg.commit_seq(1, 1, 5).unwrap();
406
407 assert!(reg.check_seq(1, 1, 5).is_err());
409 assert!(reg.check_seq(1, 1, 4).is_err());
411 reg.check_seq(1, 1, 6).unwrap();
413 }
414}