use crate::authn::ids::{DeviceId, TenantId, UserId};
use crate::device::types::{Device, DeviceTrustLevel, FingerprintHash};
use crate::health::{HealthCheck, HealthStatus};
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
pub trait DeviceStore: Send + Sync + Clone + 'static {
type Error: std::error::Error + Send + Sync + 'static;
fn load(
&self,
tenant_id: &TenantId,
id: &DeviceId,
) -> impl Future<Output = Result<Option<Device>, Self::Error>> + Send;
fn find_by_fingerprint(
&self,
tenant_id: &TenantId,
hash: &FingerprintHash,
) -> impl Future<Output = Result<Option<Device>, Self::Error>> + Send;
fn find_for_user(
&self,
tenant_id: &TenantId,
user_id: &UserId,
limit: usize,
) -> impl Future<Output = Result<Vec<Device>, Self::Error>> + Send;
fn find_active_for_user(
&self,
tenant_id: &TenantId,
user_id: &UserId,
limit: usize,
) -> impl Future<Output = Result<Vec<Device>, Self::Error>> + Send {
async move {
let mut all = self.find_for_user(tenant_id, user_id, limit).await?;
all.retain(|d| d.trust_level != DeviceTrustLevel::Revoked);
Ok(all)
}
}
fn find_by_refresh_family(
&self,
tenant_id: &TenantId,
family_id: &str,
) -> impl Future<Output = Result<Vec<Device>, Self::Error>> + Send;
fn save(&self, device: &Device) -> impl Future<Output = Result<(), Self::Error>> + Send;
fn record_sighting(
&self,
tenant_id: &TenantId,
id: &DeviceId,
now: DateTime<Utc>,
) -> impl Future<Output = Result<(), Self::Error>> + Send;
fn set_trust_level(
&self,
tenant_id: &TenantId,
id: &DeviceId,
level: DeviceTrustLevel,
now: DateTime<Utc>,
) -> impl Future<Output = Result<(), Self::Error>> + Send;
fn delete(
&self,
tenant_id: &TenantId,
id: &DeviceId,
) -> impl Future<Output = Result<(), Self::Error>> + Send;
fn sweep(
&self,
tenant_id: &TenantId,
now: DateTime<Utc>,
) -> impl Future<Output = Result<SweepCounts, Self::Error>> + Send;
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct SweepCounts {
pub trusted_to_seen: u64,
pub seen_to_revoked: u64,
pub revoked_purged: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SweepConfig {
pub trusted_idle: chrono::Duration,
pub seen_idle: chrono::Duration,
pub revoked_grace: chrono::Duration,
}
impl Default for SweepConfig {
fn default() -> Self {
Self {
trusted_idle: chrono::Duration::days(90),
seen_idle: chrono::Duration::days(30),
revoked_grace: chrono::Duration::days(7),
}
}
}
impl SweepConfig {
pub fn builder() -> SweepConfigBuilder {
SweepConfigBuilder::default()
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct SweepConfigBuilder {
trusted_idle: Option<chrono::Duration>,
seen_idle: Option<chrono::Duration>,
revoked_grace: Option<chrono::Duration>,
}
impl SweepConfigBuilder {
pub fn trusted_idle(mut self, d: chrono::Duration) -> Self {
self.trusted_idle = Some(d);
self
}
pub fn seen_idle(mut self, d: chrono::Duration) -> Self {
self.seen_idle = Some(d);
self
}
pub fn revoked_grace(mut self, d: chrono::Duration) -> Self {
self.revoked_grace = Some(d);
self
}
pub fn build(self) -> SweepConfig {
let d = SweepConfig::default();
SweepConfig {
trusted_idle: self.trusted_idle.unwrap_or(d.trusted_idle),
seen_idle: self.seen_idle.unwrap_or(d.seen_idle),
revoked_grace: self.revoked_grace.unwrap_or(d.revoked_grace),
}
}
}
type DeviceKey = (TenantId, DeviceId);
type FingerprintKey = (TenantId, FingerprintHash);
#[derive(Clone, Default)]
pub struct MemoryDeviceStore {
devices: Arc<DashMap<DeviceKey, Device>>,
fingerprint_index: Arc<DashMap<FingerprintKey, DeviceId>>,
write_count: Arc<AtomicU64>,
sweep_config: SweepConfig,
}
impl MemoryDeviceStore {
pub fn new() -> Self {
Self::default()
}
pub fn with_sweep_config(mut self, config: SweepConfig) -> Self {
self.sweep_config = config;
self
}
pub fn len(&self) -> usize {
self.devices.len()
}
pub fn is_empty(&self) -> bool {
self.devices.is_empty()
}
fn key(tenant_id: &TenantId, id: &DeviceId) -> DeviceKey {
(*tenant_id, *id)
}
}
#[derive(Debug, thiserror::Error)]
pub enum MemoryDeviceStoreError {
#[error("device not found: tenant={tenant_id} id={device_id}")]
NotFound {
tenant_id: String,
device_id: String,
},
}
impl DeviceStore for MemoryDeviceStore {
type Error = MemoryDeviceStoreError;
async fn load(
&self,
tenant_id: &TenantId,
id: &DeviceId,
) -> Result<Option<Device>, Self::Error> {
Ok(self
.devices
.get(&Self::key(tenant_id, id))
.map(|d| d.clone()))
}
async fn find_by_fingerprint(
&self,
tenant_id: &TenantId,
hash: &FingerprintHash,
) -> Result<Option<Device>, Self::Error> {
let key = (*tenant_id, *hash);
let Some(device_id) = self.fingerprint_index.get(&key).map(|v| *v) else {
return Ok(None);
};
let pk = (*tenant_id, device_id);
Ok(self.devices.get(&pk).map(|d| d.clone()))
}
async fn find_for_user(
&self,
tenant_id: &TenantId,
user_id: &UserId,
limit: usize,
) -> Result<Vec<Device>, Self::Error> {
let mut hits: Vec<Device> = self
.devices
.iter()
.filter_map(|entry| {
let device = entry.value();
let same_tenant = device.tenant_id == *tenant_id;
let owned_by_user = device.user_id.as_ref().is_some_and(|u| u == user_id);
(same_tenant && owned_by_user).then(|| device.clone())
})
.collect();
hits.sort_by_key(|d| std::cmp::Reverse(d.last_seen_at));
hits.truncate(limit);
Ok(hits)
}
async fn find_by_refresh_family(
&self,
tenant_id: &TenantId,
family_id: &str,
) -> Result<Vec<Device>, Self::Error> {
let mut hits: Vec<Device> = self
.devices
.iter()
.filter_map(|entry| {
let device = entry.value();
if device.tenant_id != *tenant_id {
return None;
}
let matched = device.bindings.iter().any(|b| {
matches!(b, crate::device::types::DeviceBinding::Refresh { family_id: fid, .. } if fid == family_id)
});
matched.then(|| device.clone())
})
.collect();
hits.sort_by_key(|d| std::cmp::Reverse(d.last_seen_at));
Ok(hits)
}
async fn save(&self, device: &Device) -> Result<(), Self::Error> {
let key = Self::key(&device.tenant_id, &device.id);
let fp_key = (device.tenant_id, device.fingerprint_hash);
self.devices.insert(key, device.clone());
self.fingerprint_index.insert(fp_key, device.id);
self.write_count.fetch_add(1, Ordering::Relaxed);
Ok(())
}
async fn record_sighting(
&self,
tenant_id: &TenantId,
id: &DeviceId,
now: DateTime<Utc>,
) -> Result<(), Self::Error> {
let key = Self::key(tenant_id, id);
if let Some(mut entry) = self.devices.get_mut(&key) {
entry.last_seen_at = now;
return Ok(());
}
Err(MemoryDeviceStoreError::NotFound {
tenant_id: tenant_id.to_string(),
device_id: id.to_string(),
})
}
async fn set_trust_level(
&self,
tenant_id: &TenantId,
id: &DeviceId,
level: DeviceTrustLevel,
now: DateTime<Utc>,
) -> Result<(), Self::Error> {
let key = Self::key(tenant_id, id);
if let Some(mut entry) = self.devices.get_mut(&key) {
entry.trust_level = level;
entry.revoked_at = matches!(level, DeviceTrustLevel::Revoked).then_some(now);
return Ok(());
}
Err(MemoryDeviceStoreError::NotFound {
tenant_id: tenant_id.to_string(),
device_id: id.to_string(),
})
}
async fn delete(&self, tenant_id: &TenantId, id: &DeviceId) -> Result<(), Self::Error> {
let key = Self::key(tenant_id, id);
if let Some((_, device)) = self.devices.remove(&key) {
let fp_key = (device.tenant_id, device.fingerprint_hash);
self.fingerprint_index.remove(&fp_key);
}
Ok(())
}
async fn sweep(
&self,
tenant_id: &TenantId,
now: DateTime<Utc>,
) -> Result<SweepCounts, Self::Error> {
let cfg = self.sweep_config;
let mut counts = SweepCounts::default();
enum Action {
DemoteToSeen,
DemoteToRevoked,
Purge,
}
let mut actions: Vec<(DeviceKey, Action)> = Vec::new();
for entry in self.devices.iter() {
let device = entry.value();
if device.tenant_id != *tenant_id {
continue;
}
let mut current_level = device.trust_level;
if current_level == DeviceTrustLevel::Trusted
&& now.signed_duration_since(device.last_seen_at) > cfg.trusted_idle
{
actions.push((*entry.key(), Action::DemoteToSeen));
counts.trusted_to_seen += 1;
current_level = DeviceTrustLevel::Seen;
}
if current_level == DeviceTrustLevel::Seen
&& now.signed_duration_since(device.last_seen_at) > cfg.seen_idle
{
actions.push((*entry.key(), Action::DemoteToRevoked));
counts.seen_to_revoked += 1;
current_level = DeviceTrustLevel::Revoked;
}
if current_level == DeviceTrustLevel::Revoked
&& let Some(revoked_at) = device.revoked_at
&& now.signed_duration_since(revoked_at) > cfg.revoked_grace
{
actions.push((*entry.key(), Action::Purge));
counts.revoked_purged += 1;
}
}
for (key, action) in actions {
match action {
Action::DemoteToSeen => {
if let Some(mut entry) = self.devices.get_mut(&key) {
entry.trust_level = DeviceTrustLevel::Seen;
}
}
Action::DemoteToRevoked => {
if let Some(mut entry) = self.devices.get_mut(&key) {
entry.trust_level = DeviceTrustLevel::Revoked;
entry.revoked_at = Some(now);
}
}
Action::Purge => {
if let Some((_, device)) = self.devices.remove(&key) {
let fp_key = (device.tenant_id, device.fingerprint_hash);
self.fingerprint_index.remove(&fp_key);
}
}
}
}
Ok(counts)
}
}
impl HealthCheck for MemoryDeviceStore {
fn check(&self) -> Pin<Box<dyn Future<Output = HealthStatus> + Send + '_>> {
Box::pin(async { HealthStatus::Healthy })
}
}
#[cfg(test)]
mod tests;