1use std::future::Future;
57use std::sync::Arc;
58use std::time::Duration;
59
60use chrono::{DateTime, TimeZone, Utc};
61use sqlx::SqlitePool;
62
63use axess_clock::{Clock, SystemClock};
64
65use crate::authn::ids::{DeviceId, TenantId, UserId};
66use crate::device::storage::sql_common::{BindingsCodec, SqlDeviceStoreError, trust_level_codec};
67use crate::device::store::{DeviceStore, SweepConfig, SweepCounts};
68use crate::device::types::{Device, DeviceBinding, DeviceTrustLevel, FingerprintHash};
69use crate::session::crypto::SessionCrypto;
70
71#[derive(Clone)]
85pub struct SqliteDeviceStore {
86 pool: SqlitePool,
87 codec: BindingsCodec,
88 clock: Arc<dyn Clock>,
89 sweep_config: SweepConfig,
90}
91
92impl SqliteDeviceStore {
93 pub fn new(pool: SqlitePool, crypto: SessionCrypto) -> Self {
95 Self {
96 pool,
97 codec: BindingsCodec::encrypted(crypto),
98 clock: Arc::new(SystemClock),
99 sweep_config: SweepConfig::default(),
100 }
101 }
102
103 pub fn plaintext(pool: SqlitePool) -> Self {
106 tracing::warn!(
107 "SqliteDeviceStore created without encryption; \
108 do not use in production"
109 );
110 Self {
111 pool,
112 codec: BindingsCodec::plaintext(),
113 clock: Arc::new(SystemClock),
114 sweep_config: SweepConfig::default(),
115 }
116 }
117
118 pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
122 self.clock = clock;
123 self
124 }
125
126 pub fn with_sweep_config(mut self, config: SweepConfig) -> Self {
128 self.sweep_config = config;
129 self
130 }
131
132 #[doc(hidden)]
135 pub fn pool_for_test(&self) -> &SqlitePool {
136 &self.pool
137 }
138
139 pub async fn init_schema(&self) -> Result<(), sqlx::Error> {
142 sqlx::query(
143 r#"
144 CREATE TABLE IF NOT EXISTS devices (
145 tenant_id TEXT NOT NULL,
146 id TEXT NOT NULL,
147 user_id TEXT,
148 trust_level TEXT NOT NULL,
149 fingerprint_hash BLOB NOT NULL,
150 first_seen_at INTEGER NOT NULL,
151 last_seen_at INTEGER NOT NULL,
152 revoked_at INTEGER,
153 bindings TEXT NOT NULL,
154 PRIMARY KEY (tenant_id, id)
155 )
156 "#,
157 )
158 .execute(&self.pool)
159 .await?;
160
161 sqlx::query(
162 "CREATE INDEX IF NOT EXISTS idx_devices_fingerprint \
163 ON devices (tenant_id, fingerprint_hash)",
164 )
165 .execute(&self.pool)
166 .await?;
167 sqlx::query(
168 "CREATE INDEX IF NOT EXISTS idx_devices_user \
169 ON devices (tenant_id, user_id, last_seen_at DESC)",
170 )
171 .execute(&self.pool)
172 .await?;
173
174 sqlx::query(
175 r#"
176 CREATE TABLE IF NOT EXISTS device_bindings_refresh (
177 tenant_id TEXT NOT NULL,
178 device_id TEXT NOT NULL,
179 family_id TEXT NOT NULL,
180 PRIMARY KEY (tenant_id, device_id, family_id),
181 FOREIGN KEY (tenant_id, device_id)
182 REFERENCES devices (tenant_id, id) ON DELETE CASCADE
183 )
184 "#,
185 )
186 .execute(&self.pool)
187 .await?;
188
189 sqlx::query(
190 "CREATE INDEX IF NOT EXISTS idx_device_bindings_refresh_family \
191 ON device_bindings_refresh (tenant_id, family_id)",
192 )
193 .execute(&self.pool)
194 .await?;
195
196 Ok(())
197 }
198
199 fn decode_row(&self, row: DeviceRow) -> Result<Device, SqlDeviceStoreError> {
202 let DeviceRow {
203 tenant_id,
204 id,
205 user_id,
206 trust_level,
207 fingerprint_hash,
208 first_seen_at,
209 last_seen_at,
210 revoked_at,
211 bindings,
212 } = row;
213
214 let tenant = TenantId::try_new(&tenant_id)
215 .map_err(|e| SqlDeviceStoreError::MalformedRow(format!("tenant_id: {e}")))?;
216 let device_id = DeviceId::try_new(&id)
217 .map_err(|e| SqlDeviceStoreError::MalformedRow(format!("device id: {e}")))?;
218 let user = match user_id {
219 Some(u) => Some(
220 UserId::try_new(&u)
221 .map_err(|e| SqlDeviceStoreError::MalformedRow(format!("user_id: {e}")))?,
222 ),
223 None => None,
224 };
225
226 let trust = trust_level_codec::from_str(&trust_level)
227 .ok_or(SqlDeviceStoreError::UnknownTrustLevel(trust_level))?;
228
229 let fp_bytes: [u8; 32] = fingerprint_hash
230 .try_into()
231 .map_err(|_| SqlDeviceStoreError::MalformedRow("fingerprint_hash length".into()))?;
232
233 let first = unix_to_utc(first_seen_at)?;
234 let last = unix_to_utc(last_seen_at)?;
235 let revoked = match revoked_at {
236 Some(t) => Some(unix_to_utc(t)?),
237 None => None,
238 };
239
240 let bindings = self.codec.decode(&bindings)?;
241
242 Ok(Device {
243 id: device_id,
244 tenant_id: tenant,
245 user_id: user,
246 trust_level: trust,
247 fingerprint_hash: FingerprintHash::from_bytes(fp_bytes),
248 first_seen_at: first,
249 last_seen_at: last,
250 revoked_at: revoked,
251 bindings,
252 })
253 }
254}
255
256#[derive(sqlx::FromRow)]
259struct DeviceRow {
260 tenant_id: String,
261 id: String,
262 user_id: Option<String>,
263 trust_level: String,
264 fingerprint_hash: Vec<u8>,
265 first_seen_at: i64,
266 last_seen_at: i64,
267 revoked_at: Option<i64>,
268 bindings: String,
269}
270
271fn unix_to_utc(secs: i64) -> Result<DateTime<Utc>, SqlDeviceStoreError> {
272 Utc.timestamp_opt(secs, 0).single().ok_or_else(|| {
273 SqlDeviceStoreError::MalformedRow(format!("unrepresentable Unix timestamp: {secs}"))
274 })
275}
276
277fn utc_to_unix(dt: DateTime<Utc>) -> i64 {
278 dt.timestamp()
279}
280
281fn refresh_family_ids(bindings: &[DeviceBinding]) -> Vec<String> {
282 bindings
283 .iter()
284 .filter_map(|b| match b {
285 DeviceBinding::Refresh { family_id, .. } => Some(family_id.clone()),
286 _ => None,
287 })
288 .collect()
289}
290
291impl DeviceStore for SqliteDeviceStore {
292 type Error = SqlDeviceStoreError;
293
294 fn load(
295 &self,
296 tenant_id: &TenantId,
297 id: &DeviceId,
298 ) -> impl Future<Output = Result<Option<Device>, Self::Error>> + Send {
299 let pool = self.pool.clone();
300 let store = self.clone();
301 let tenant = tenant_id.to_string().to_string();
302 let device_id = id.to_string().to_string();
303 async move {
304 let row: Option<DeviceRow> = sqlx::query_as(
305 "SELECT tenant_id, id, user_id, trust_level, fingerprint_hash, \
306 first_seen_at, last_seen_at, revoked_at, bindings \
307 FROM devices WHERE tenant_id = ?1 AND id = ?2",
308 )
309 .bind(&tenant)
310 .bind(&device_id)
311 .fetch_optional(&pool)
312 .await?;
313
314 match row {
315 Some(r) => Ok(Some(store.decode_row(r)?)),
316 None => Ok(None),
317 }
318 }
319 }
320
321 fn find_by_fingerprint(
322 &self,
323 tenant_id: &TenantId,
324 hash: &FingerprintHash,
325 ) -> impl Future<Output = Result<Option<Device>, Self::Error>> + Send {
326 let pool = self.pool.clone();
327 let store = self.clone();
328 let tenant = tenant_id.to_string().to_string();
329 let bytes = hash.as_bytes().to_vec();
330 async move {
331 let row: Option<DeviceRow> = sqlx::query_as(
332 "SELECT tenant_id, id, user_id, trust_level, fingerprint_hash, \
333 first_seen_at, last_seen_at, revoked_at, bindings \
334 FROM devices WHERE tenant_id = ?1 AND fingerprint_hash = ?2 \
335 ORDER BY last_seen_at DESC LIMIT 1",
336 )
337 .bind(&tenant)
338 .bind(&bytes)
339 .fetch_optional(&pool)
340 .await?;
341
342 match row {
343 Some(r) => Ok(Some(store.decode_row(r)?)),
344 None => Ok(None),
345 }
346 }
347 }
348
349 fn find_for_user(
350 &self,
351 tenant_id: &TenantId,
352 user_id: &UserId,
353 limit: usize,
354 ) -> impl Future<Output = Result<Vec<Device>, Self::Error>> + Send {
355 let pool = self.pool.clone();
356 let store = self.clone();
357 let tenant = tenant_id.to_string().to_string();
358 let uid = user_id.to_string().to_string();
359 let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
360 async move {
361 let rows: Vec<DeviceRow> = sqlx::query_as(
362 "SELECT tenant_id, id, user_id, trust_level, fingerprint_hash, \
363 first_seen_at, last_seen_at, revoked_at, bindings \
364 FROM devices WHERE tenant_id = ?1 AND user_id = ?2 \
365 ORDER BY last_seen_at DESC LIMIT ?3",
366 )
367 .bind(&tenant)
368 .bind(&uid)
369 .bind(limit_i64)
370 .fetch_all(&pool)
371 .await?;
372
373 let mut out = Vec::with_capacity(rows.len());
374 for r in rows {
375 out.push(store.decode_row(r)?);
376 }
377 Ok(out)
378 }
379 }
380
381 fn find_by_refresh_family(
382 &self,
383 tenant_id: &TenantId,
384 family_id: &str,
385 ) -> impl Future<Output = Result<Vec<Device>, Self::Error>> + Send {
386 let pool = self.pool.clone();
387 let store = self.clone();
388 let tenant = tenant_id.to_string().to_string();
389 let family = family_id.to_string();
390 async move {
391 let rows: Vec<DeviceRow> = sqlx::query_as(
392 "SELECT d.tenant_id, d.id, d.user_id, d.trust_level, d.fingerprint_hash, \
393 d.first_seen_at, d.last_seen_at, d.revoked_at, d.bindings \
394 FROM devices d \
395 INNER JOIN device_bindings_refresh r \
396 ON d.tenant_id = r.tenant_id AND d.id = r.device_id \
397 WHERE r.tenant_id = ?1 AND r.family_id = ?2 \
398 ORDER BY d.last_seen_at DESC",
399 )
400 .bind(&tenant)
401 .bind(&family)
402 .fetch_all(&pool)
403 .await?;
404
405 let mut out = Vec::with_capacity(rows.len());
406 for r in rows {
407 out.push(store.decode_row(r)?);
408 }
409 Ok(out)
410 }
411 }
412
413 fn save(&self, device: &Device) -> impl Future<Output = Result<(), Self::Error>> + Send {
414 let pool = self.pool.clone();
415 let codec = self.codec.clone();
416 let device = device.clone();
417 async move {
418 let bindings_blob = codec.encode(&device.bindings)?;
419 let trust = trust_level_codec::to_str(device.trust_level);
420 let fp = device.fingerprint_hash.as_bytes().to_vec();
421 let user_id_col = device.user_id.as_ref().map(|u| u.to_string().to_string());
422 let first = utc_to_unix(device.first_seen_at);
423 let last = utc_to_unix(device.last_seen_at);
424 let revoked = device.revoked_at.map(utc_to_unix);
425 let family_ids = refresh_family_ids(&device.bindings);
426 let tenant = device.tenant_id.to_string().to_string();
427 let id = device.id.to_string().to_string();
428
429 let mut tx = pool.begin().await?;
434
435 sqlx::query(
436 r#"
437 INSERT INTO devices
438 (tenant_id, id, user_id, trust_level, fingerprint_hash,
439 first_seen_at, last_seen_at, revoked_at, bindings)
440 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
441 ON CONFLICT(tenant_id, id) DO UPDATE SET
442 user_id = excluded.user_id,
443 trust_level = excluded.trust_level,
444 fingerprint_hash = excluded.fingerprint_hash,
445 first_seen_at = excluded.first_seen_at,
446 last_seen_at = excluded.last_seen_at,
447 revoked_at = excluded.revoked_at,
448 bindings = excluded.bindings
449 "#,
450 )
451 .bind(&tenant)
452 .bind(&id)
453 .bind(user_id_col.as_deref())
454 .bind(trust)
455 .bind(&fp)
456 .bind(first)
457 .bind(last)
458 .bind(revoked)
459 .bind(&bindings_blob)
460 .execute(&mut *tx)
461 .await?;
462
463 sqlx::query(
465 "DELETE FROM device_bindings_refresh \
466 WHERE tenant_id = ?1 AND device_id = ?2",
467 )
468 .bind(&tenant)
469 .bind(&id)
470 .execute(&mut *tx)
471 .await?;
472
473 for family_id in &family_ids {
474 sqlx::query(
475 "INSERT INTO device_bindings_refresh \
476 (tenant_id, device_id, family_id) VALUES (?1, ?2, ?3)",
477 )
478 .bind(&tenant)
479 .bind(&id)
480 .bind(family_id)
481 .execute(&mut *tx)
482 .await?;
483 }
484
485 tx.commit().await?;
486 Ok(())
487 }
488 }
489
490 fn record_sighting(
491 &self,
492 tenant_id: &TenantId,
493 id: &DeviceId,
494 now: DateTime<Utc>,
495 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
496 let pool = self.pool.clone();
497 let tenant = tenant_id.to_string().to_string();
498 let device_id = id.to_string().to_string();
499 let ts = utc_to_unix(now);
500 async move {
501 sqlx::query(
502 "UPDATE devices SET last_seen_at = ?3 \
503 WHERE tenant_id = ?1 AND id = ?2",
504 )
505 .bind(&tenant)
506 .bind(&device_id)
507 .bind(ts)
508 .execute(&pool)
509 .await?;
510 Ok(())
511 }
512 }
513
514 fn set_trust_level(
515 &self,
516 tenant_id: &TenantId,
517 id: &DeviceId,
518 level: DeviceTrustLevel,
519 now: DateTime<Utc>,
520 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
521 let pool = self.pool.clone();
522 let tenant = tenant_id.to_string().to_string();
523 let device_id = id.to_string().to_string();
524 let trust = trust_level_codec::to_str(level);
525 let ts = utc_to_unix(now);
526 let revoked_at = match level {
529 DeviceTrustLevel::Revoked => Some(ts),
530 _ => None,
531 };
532 async move {
533 sqlx::query(
534 "UPDATE devices SET trust_level = ?3, revoked_at = ?4 \
535 WHERE tenant_id = ?1 AND id = ?2",
536 )
537 .bind(&tenant)
538 .bind(&device_id)
539 .bind(trust)
540 .bind(revoked_at)
541 .execute(&pool)
542 .await?;
543 Ok(())
544 }
545 }
546
547 fn delete(
548 &self,
549 tenant_id: &TenantId,
550 id: &DeviceId,
551 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
552 let pool = self.pool.clone();
553 let tenant = tenant_id.to_string().to_string();
554 let device_id = id.to_string().to_string();
555 async move {
556 sqlx::query("DELETE FROM devices WHERE tenant_id = ?1 AND id = ?2")
559 .bind(&tenant)
560 .bind(&device_id)
561 .execute(&pool)
562 .await?;
563 Ok(())
564 }
565 }
566
567 fn sweep(
568 &self,
569 tenant_id: &TenantId,
570 now: DateTime<Utc>,
571 ) -> impl Future<Output = Result<SweepCounts, Self::Error>> + Send {
572 let pool = self.pool.clone();
573 let cfg = self.sweep_config;
574 let tenant = tenant_id.to_string().to_string();
575 let now_secs = utc_to_unix(now);
576 async move {
577 let trusted_cutoff = now_secs - cfg.trusted_idle.num_seconds();
587 let seen_cutoff = now_secs - cfg.seen_idle.num_seconds();
588 let grace_cutoff = now_secs - cfg.revoked_grace.num_seconds();
589
590 let trusted_demoted = sqlx::query(
592 "UPDATE devices SET trust_level = 'Seen' \
593 WHERE tenant_id = ?1 \
594 AND trust_level = 'Trusted' \
595 AND last_seen_at < ?2",
596 )
597 .bind(&tenant)
598 .bind(trusted_cutoff)
599 .execute(&pool)
600 .await?
601 .rows_affected();
602
603 let seen_demoted = sqlx::query(
607 "UPDATE devices SET trust_level = 'Revoked', revoked_at = ?3 \
608 WHERE tenant_id = ?1 \
609 AND trust_level = 'Seen' \
610 AND last_seen_at < ?2",
611 )
612 .bind(&tenant)
613 .bind(seen_cutoff)
614 .bind(now_secs)
615 .execute(&pool)
616 .await?
617 .rows_affected();
618
619 let purged = sqlx::query(
627 "DELETE FROM devices \
628 WHERE tenant_id = ?1 \
629 AND trust_level = 'Revoked' \
630 AND revoked_at IS NOT NULL \
631 AND revoked_at < ?2",
632 )
633 .bind(&tenant)
634 .bind(grace_cutoff)
635 .execute(&pool)
636 .await?
637 .rows_affected();
638
639 Ok(SweepCounts {
640 trusted_to_seen: trusted_demoted,
641 seen_to_revoked: seen_demoted,
642 revoked_purged: purged,
643 })
644 }
645 }
646}
647
648use crate::health::{HealthCheck, HealthStatus};
651
652impl HealthCheck for SqliteDeviceStore {
653 fn check(
654 &self,
655 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = HealthStatus> + Send + '_>> {
656 Box::pin(async {
657 match tokio::time::timeout(
658 Duration::from_secs(2),
659 sqlx::query_scalar::<_, i32>("SELECT 1").fetch_one(&self.pool),
660 )
661 .await
662 {
663 Ok(Ok(_)) => HealthStatus::Healthy,
664 Ok(Err(e)) => HealthStatus::Unhealthy(format!("sqlite SELECT 1 failed: {e}")),
665 Err(_) => HealthStatus::Unhealthy("sqlite SELECT 1 timeout (2s)".into()),
666 }
667 })
668 }
669}
670#[cfg(test)]
671mod tests;