1use std::sync::Arc;
38use std::time::Duration;
39
40use chrono::{DateTime, Utc};
41use fred::prelude::*;
42
43use crate::authn::ids::{DeviceId, TenantId, UserId};
44use crate::device::store::{DeviceStore, SweepConfig, SweepCounts};
45use crate::device::types::{Device, DeviceBinding, DeviceTrustLevel, FingerprintHash};
46use crate::session::crypto::SessionCrypto;
47use std::future::Future;
48
49const DEFAULT_PREFIX: &str = "axess";
50
51fn device_key(prefix: &str, tenant: &str, id: &str) -> String {
52 format!("{prefix}:dev:{tenant}:{id}")
53}
54
55fn fingerprint_key(prefix: &str, tenant: &str, hex_hash: &str) -> String {
56 format!("{prefix}:dev:fp:{tenant}:{hex_hash}")
57}
58
59fn user_index_key(prefix: &str, tenant: &str, user: &str) -> String {
60 format!("{prefix}:dev:user:{tenant}:{user}")
61}
62
63fn family_index_key(prefix: &str, tenant: &str, family: &str) -> String {
64 format!("{prefix}:dev:fam:{tenant}:{family}")
65}
66
67fn tenant_index_key(prefix: &str, tenant: &str) -> String {
68 format!("{prefix}:dev:tenant:{tenant}")
69}
70
71fn fingerprint_hex(h: &FingerprintHash) -> String {
73 use std::fmt::Write as _;
74 let bytes = h.as_bytes();
75 let mut s = String::with_capacity(bytes.len() * 2);
76 for b in bytes {
77 write!(s, "{:02x}", b).expect("writing into a String never fails");
78 }
79 s
80}
81
82#[derive(Debug, thiserror::Error)]
89pub enum ValkeyDeviceStoreError {
90 #[error("connection error: {0}")]
92 Connection(#[source] fred::error::Error),
93
94 #[error("device row MessagePack encoding failed: {0}")]
96 Encode(#[source] rmp_serde::encode::Error),
97
98 #[error("device row MessagePack decoding failed: {0}")]
100 Decode(#[source] rmp_serde::decode::Error),
101
102 #[error("encryption/decryption error: {0}")]
104 Crypto(#[source] crate::session::crypto::CryptoError),
105}
106
107impl From<fred::error::Error> for ValkeyDeviceStoreError {
108 fn from(e: fred::error::Error) -> Self {
109 Self::Connection(e)
110 }
111}
112
113impl From<rmp_serde::encode::Error> for ValkeyDeviceStoreError {
114 fn from(e: rmp_serde::encode::Error) -> Self {
115 Self::Encode(e)
116 }
117}
118
119impl From<rmp_serde::decode::Error> for ValkeyDeviceStoreError {
120 fn from(e: rmp_serde::decode::Error) -> Self {
121 Self::Decode(e)
122 }
123}
124
125impl From<crate::session::crypto::CryptoError> for ValkeyDeviceStoreError {
126 fn from(e: crate::session::crypto::CryptoError) -> Self {
127 Self::Crypto(e)
128 }
129}
130
131#[derive(Clone)]
136pub struct ValkeyDeviceStore {
137 client: Client,
138 prefix: Arc<str>,
139 crypto: Option<SessionCrypto>,
145 sweep_config: SweepConfig,
146}
147
148impl ValkeyDeviceStore {
149 pub fn new(client: Client, key: [u8; 32]) -> Self {
151 Self {
152 client,
153 prefix: DEFAULT_PREFIX.into(),
154 crypto: Some(SessionCrypto::new(key)),
155 sweep_config: SweepConfig::default(),
156 }
157 }
158
159 pub fn plaintext(client: Client) -> Self {
162 tracing::warn!(
163 "ValkeyDeviceStore created without encryption; \
164 do not use in production"
165 );
166 Self {
167 client,
168 prefix: DEFAULT_PREFIX.into(),
169 crypto: None,
170 sweep_config: SweepConfig::default(),
171 }
172 }
173
174 pub fn with_prefix(mut self, prefix: impl Into<Arc<str>>) -> Self {
176 self.prefix = prefix.into();
177 self
178 }
179
180 pub fn with_sweep_config(mut self, config: SweepConfig) -> Self {
182 self.sweep_config = config;
183 self
184 }
185
186 fn ttl_seconds_for(&self, device: &Device, now: DateTime<Utc>) -> i64 {
194 let cfg = &self.sweep_config;
195 let (anchor, window) = match device.trust_level {
196 DeviceTrustLevel::Trusted => (device.last_seen_at, cfg.trusted_idle),
197 DeviceTrustLevel::Unknown | DeviceTrustLevel::Seen => {
198 (device.last_seen_at, cfg.seen_idle)
199 }
200 DeviceTrustLevel::Revoked => (device.revoked_at.unwrap_or(now), cfg.revoked_grace),
201 };
202 let expiry = anchor + window;
203 let remaining = expiry.signed_duration_since(now).num_seconds();
204 remaining.max(1)
208 }
209
210 fn encode_row(&self, device: &Device) -> Result<Vec<u8>, ValkeyDeviceStoreError> {
214 let bytes = rmp_serde::to_vec_named(device)?;
215 match &self.crypto {
216 Some(c) => Ok(c.encrypt(&bytes)?),
217 None => Ok(bytes),
218 }
219 }
220
221 fn decode_row(&self, payload: &[u8]) -> Result<Device, ValkeyDeviceStoreError> {
222 let plaintext = match &self.crypto {
223 Some(c) => c.decrypt(payload)?,
224 None => payload.to_vec(),
225 };
226 Ok(rmp_serde::from_slice(&plaintext)?)
227 }
228
229 async fn get_device(
231 &self,
232 tenant: &str,
233 id: &str,
234 ) -> Result<Option<Device>, ValkeyDeviceStoreError> {
235 let key = device_key(&self.prefix, tenant, id);
236 let bytes: Option<Vec<u8>> = self.client.get(&key).await?;
237 match bytes {
238 Some(b) => Ok(Some(self.decode_row(&b)?)),
239 None => Ok(None),
240 }
241 }
242}
243
244impl DeviceStore for ValkeyDeviceStore {
245 type Error = ValkeyDeviceStoreError;
246
247 fn load(
248 &self,
249 tenant_id: &TenantId,
250 id: &DeviceId,
251 ) -> impl Future<Output = Result<Option<Device>, Self::Error>> + Send {
252 let store = self.clone();
253 let tenant = tenant_id.to_string().to_string();
254 let device_id = id.to_string().to_string();
255 async move { store.get_device(&tenant, &device_id).await }
256 }
257
258 fn find_by_fingerprint(
259 &self,
260 tenant_id: &TenantId,
261 hash: &FingerprintHash,
262 ) -> impl Future<Output = Result<Option<Device>, Self::Error>> + Send {
263 let store = self.clone();
264 let tenant = tenant_id.to_string().to_string();
265 let hex = fingerprint_hex(hash);
266 async move {
267 let fp_key = fingerprint_key(&store.prefix, &tenant, &hex);
268 let device_id: Option<String> = store.client.get(&fp_key).await?;
269 match device_id {
270 Some(id) => store.get_device(&tenant, &id).await,
271 None => Ok(None),
272 }
273 }
274 }
275
276 fn find_for_user(
277 &self,
278 tenant_id: &TenantId,
279 user_id: &UserId,
280 limit: usize,
281 ) -> impl Future<Output = Result<Vec<Device>, Self::Error>> + Send {
282 let store = self.clone();
283 let tenant = tenant_id.to_string().to_string();
284 let user = user_id.to_string().to_string();
285 async move {
286 let idx = user_index_key(&store.prefix, &tenant, &user);
287 let members: Vec<String> = store.client.smembers(&idx).await?;
288 let mut out = Vec::with_capacity(members.len().min(limit));
289 for member in members {
290 if let Some(device) = store.get_device(&tenant, &member).await? {
291 out.push(device);
292 }
293 }
294 out.sort_by_key(|d| std::cmp::Reverse(d.last_seen_at));
296 out.truncate(limit);
297 Ok(out)
298 }
299 }
300
301 fn find_by_refresh_family(
302 &self,
303 tenant_id: &TenantId,
304 family_id: &str,
305 ) -> impl Future<Output = Result<Vec<Device>, Self::Error>> + Send {
306 let store = self.clone();
307 let tenant = tenant_id.to_string().to_string();
308 let family = family_id.to_string();
309 async move {
310 let idx = family_index_key(&store.prefix, &tenant, &family);
311 let members: Vec<String> = store.client.smembers(&idx).await?;
312 let mut out = Vec::with_capacity(members.len());
313 for member in members {
314 if let Some(device) = store.get_device(&tenant, &member).await? {
315 out.push(device);
316 }
317 }
318 out.sort_by_key(|d| std::cmp::Reverse(d.last_seen_at));
319 Ok(out)
320 }
321 }
322
323 fn save(&self, device: &Device) -> impl Future<Output = Result<(), Self::Error>> + Send {
324 let store = self.clone();
325 let device = device.clone();
326 async move {
327 let now = Utc::now();
328 let ttl_secs = store.ttl_seconds_for(&device, now);
329 let payload = store.encode_row(&device)?;
330 let tenant = device.tenant_id.to_string();
331 let id = device.id.to_string();
332
333 let row_key = device_key(&store.prefix, &tenant, &id);
334 let fp_key = fingerprint_key(
335 &store.prefix,
336 &tenant,
337 &fingerprint_hex(&device.fingerprint_hash),
338 );
339 let tenant_idx = tenant_index_key(&store.prefix, &tenant);
340
341 if let Some(prev) = store.get_device(&tenant, &id).await? {
345 let prev_hex = fingerprint_hex(&prev.fingerprint_hash);
346 if prev_hex != fingerprint_hex(&device.fingerprint_hash) {
347 let stale_fp = fingerprint_key(&store.prefix, &tenant, &prev_hex);
348 let _: () = store.client.del(&stale_fp).await?;
349 }
350 let new_families: Vec<&str> = device
353 .bindings
354 .iter()
355 .filter_map(|b| match b {
356 DeviceBinding::Refresh { family_id, .. } => Some(family_id.as_str()),
357 _ => None,
358 })
359 .collect();
360 for binding in &prev.bindings {
361 if let DeviceBinding::Refresh { family_id, .. } = binding
362 && !new_families.contains(&family_id.as_str())
363 {
364 let stale_idx = family_index_key(&store.prefix, &tenant, family_id);
365 let _: () = store.client.srem(&stale_idx, &id).await?;
366 }
367 }
368 if let Some(prev_user) = &prev.user_id
370 && device.user_id.as_ref() != Some(prev_user)
371 {
372 let stale_user_idx =
373 user_index_key(&store.prefix, &tenant, &prev_user.to_string());
374 let _: () = store.client.srem(&stale_user_idx, &id).await?;
375 }
376 }
377
378 let _: () = store
380 .client
381 .set(
382 &row_key,
383 payload,
384 Some(Expiration::EX(ttl_secs)),
385 None,
386 false,
387 )
388 .await?;
389 let _: () = store
392 .client
393 .set(&fp_key, &id, Some(Expiration::EX(ttl_secs)), None, false)
394 .await?;
395
396 let _: () = store.client.sadd(&tenant_idx, &id).await?;
399 if let Some(uid) = &device.user_id {
400 let user_idx = user_index_key(&store.prefix, &tenant, &uid.to_string());
401 let _: () = store.client.sadd(&user_idx, &id).await?;
402 }
403 for binding in &device.bindings {
404 if let DeviceBinding::Refresh { family_id, .. } = binding {
405 let fam_idx = family_index_key(&store.prefix, &tenant, family_id);
406 let _: () = store.client.sadd(&fam_idx, &id).await?;
407 }
408 }
409
410 Ok(())
411 }
412 }
413
414 fn record_sighting(
415 &self,
416 tenant_id: &TenantId,
417 id: &DeviceId,
418 now: DateTime<Utc>,
419 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
420 let store = self.clone();
421 let tenant = tenant_id.to_string().to_string();
422 let device_id = id.to_string().to_string();
423 async move {
424 let Some(mut device) = store.get_device(&tenant, &device_id).await? else {
430 return Ok(());
431 };
432 device.last_seen_at = now;
433 let payload = store.encode_row(&device)?;
434 let ttl = store.ttl_seconds_for(&device, now);
435 let row_key = device_key(&store.prefix, &tenant, &device_id);
436 let _: () = store
437 .client
438 .set(&row_key, payload, Some(Expiration::EX(ttl)), None, false)
439 .await?;
440 let fp_key = fingerprint_key(
442 &store.prefix,
443 &tenant,
444 &fingerprint_hex(&device.fingerprint_hash),
445 );
446 let _: () = store.client.expire(&fp_key, ttl, None).await?;
447 Ok(())
448 }
449 }
450
451 fn set_trust_level(
452 &self,
453 tenant_id: &TenantId,
454 id: &DeviceId,
455 level: DeviceTrustLevel,
456 now: DateTime<Utc>,
457 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
458 let store = self.clone();
459 let tenant = tenant_id.to_string().to_string();
460 let device_id = id.to_string().to_string();
461 async move {
462 let Some(mut device) = store.get_device(&tenant, &device_id).await? else {
463 return Ok(());
464 };
465 device.trust_level = level;
466 device.revoked_at = matches!(level, DeviceTrustLevel::Revoked).then_some(now);
467 let payload = store.encode_row(&device)?;
468 let ttl = store.ttl_seconds_for(&device, now);
469 let row_key = device_key(&store.prefix, &tenant, &device_id);
470 let _: () = store
471 .client
472 .set(&row_key, payload, Some(Expiration::EX(ttl)), None, false)
473 .await?;
474 let fp_key = fingerprint_key(
475 &store.prefix,
476 &tenant,
477 &fingerprint_hex(&device.fingerprint_hash),
478 );
479 let _: () = store.client.expire(&fp_key, ttl, None).await?;
480 Ok(())
481 }
482 }
483
484 fn delete(
485 &self,
486 tenant_id: &TenantId,
487 id: &DeviceId,
488 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
489 let store = self.clone();
490 let tenant = tenant_id.to_string().to_string();
491 let device_id = id.to_string().to_string();
492 async move {
493 let device = store.get_device(&tenant, &device_id).await?;
497 let row_key = device_key(&store.prefix, &tenant, &device_id);
498 let _: () = store.client.del(&row_key).await?;
499
500 let tenant_idx = tenant_index_key(&store.prefix, &tenant);
501 let _: () = store.client.srem(&tenant_idx, &device_id).await?;
502
503 if let Some(d) = device {
504 let fp_key = fingerprint_key(
505 &store.prefix,
506 &tenant,
507 &fingerprint_hex(&d.fingerprint_hash),
508 );
509 let _: () = store.client.del(&fp_key).await?;
510 if let Some(uid) = &d.user_id {
511 let user_idx = user_index_key(&store.prefix, &tenant, &uid.to_string());
512 let _: () = store.client.srem(&user_idx, &device_id).await?;
513 }
514 for binding in &d.bindings {
515 if let DeviceBinding::Refresh { family_id, .. } = binding {
516 let fam_idx = family_index_key(&store.prefix, &tenant, family_id);
517 let _: () = store.client.srem(&fam_idx, &device_id).await?;
518 }
519 }
520 }
521 Ok(())
522 }
523 }
524
525 fn sweep(
526 &self,
527 tenant_id: &TenantId,
528 now: DateTime<Utc>,
529 ) -> impl Future<Output = Result<SweepCounts, Self::Error>> + Send {
530 let store = self.clone();
531 let tenant = tenant_id.to_string().to_string();
532 async move {
533 let mut counts = SweepCounts::default();
534 let tenant_idx = tenant_index_key(&store.prefix, &tenant);
539 let members: Vec<String> = store.client.smembers(&tenant_idx).await?;
540 let cfg = store.sweep_config;
541
542 for member in members {
543 let Some(mut device) = store.get_device(&tenant, &member).await? else {
544 let _: () = store.client.srem(&tenant_idx, &member).await?;
546 continue;
547 };
548
549 let mut changed = false;
550
551 if device.trust_level == DeviceTrustLevel::Trusted
553 && now.signed_duration_since(device.last_seen_at) > cfg.trusted_idle
554 {
555 device.trust_level = DeviceTrustLevel::Seen;
556 counts.trusted_to_seen += 1;
557 changed = true;
558 }
559
560 if device.trust_level == DeviceTrustLevel::Seen
564 && now.signed_duration_since(device.last_seen_at) > cfg.seen_idle
565 {
566 device.trust_level = DeviceTrustLevel::Revoked;
567 device.revoked_at = Some(now);
568 counts.seen_to_revoked += 1;
569 changed = true;
570 }
571
572 let should_purge = device.trust_level == DeviceTrustLevel::Revoked
576 && device
577 .revoked_at
578 .map(|r| now.signed_duration_since(r) > cfg.revoked_grace)
579 .unwrap_or(false)
580 && !(counts.seen_to_revoked > 0 && device.revoked_at == Some(now));
583
584 if should_purge {
585 counts.revoked_purged += 1;
586 let row_key = device_key(&store.prefix, &tenant, &member);
587 let _: () = store.client.del(&row_key).await?;
588 let _: () = store.client.srem(&tenant_idx, &member).await?;
589 let fp_key = fingerprint_key(
590 &store.prefix,
591 &tenant,
592 &fingerprint_hex(&device.fingerprint_hash),
593 );
594 let _: () = store.client.del(&fp_key).await?;
595 if let Some(uid) = &device.user_id {
596 let user_idx = user_index_key(&store.prefix, &tenant, &uid.to_string());
597 let _: () = store.client.srem(&user_idx, &member).await?;
598 }
599 for binding in &device.bindings {
600 if let DeviceBinding::Refresh { family_id, .. } = binding {
601 let fam_idx = family_index_key(&store.prefix, &tenant, family_id);
602 let _: () = store.client.srem(&fam_idx, &member).await?;
603 }
604 }
605 } else if changed {
606 let payload = store.encode_row(&device)?;
607 let ttl = store.ttl_seconds_for(&device, now);
608 let row_key = device_key(&store.prefix, &tenant, &member);
609 let _: () = store
610 .client
611 .set(&row_key, payload, Some(Expiration::EX(ttl)), None, false)
612 .await?;
613 let fp_key = fingerprint_key(
614 &store.prefix,
615 &tenant,
616 &fingerprint_hex(&device.fingerprint_hash),
617 );
618 let _: () = store.client.expire(&fp_key, ttl, None).await?;
619 }
620 }
621
622 Ok(counts)
623 }
624 }
625}
626
627use crate::health::{HealthCheck, HealthStatus};
630
631impl HealthCheck for ValkeyDeviceStore {
632 fn check(
633 &self,
634 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = HealthStatus> + Send + '_>> {
635 Box::pin(async {
636 match tokio::time::timeout(Duration::from_secs(2), self.client.ping::<()>(None)).await {
637 Ok(Ok(_)) => HealthStatus::Healthy,
638 Ok(Err(e)) => HealthStatus::Unhealthy(format!("valkey PING failed: {e}")),
639 Err(_) => HealthStatus::Unhealthy("valkey PING timeout (2s)".into()),
640 }
641 })
642 }
643}
644
645