1use std::future::Future;
41use std::sync::Arc;
42use std::time::Duration;
43
44use chrono::{DateTime, TimeZone, Utc};
45use sqlx::PgPool;
46
47use axess_clock::{Clock, SystemClock};
48
49use crate::authn::ids::{DeviceId, TenantId, UserId};
50use crate::device::storage::sql_common::{BindingsCodec, SqlDeviceStoreError, trust_level_codec};
51use crate::device::store::{DeviceStore, SweepConfig, SweepCounts};
52use crate::device::types::{Device, DeviceBinding, DeviceTrustLevel, FingerprintHash};
53use crate::session::crypto::SessionCrypto;
54
55#[derive(Clone)]
67pub struct PostgresDeviceStore {
68 pool: PgPool,
69 codec: BindingsCodec,
70 clock: Arc<dyn Clock>,
71 sweep_config: SweepConfig,
72}
73
74impl PostgresDeviceStore {
75 pub fn new(pool: PgPool, crypto: SessionCrypto) -> Self {
77 Self {
78 pool,
79 codec: BindingsCodec::encrypted(crypto),
80 clock: Arc::new(SystemClock),
81 sweep_config: SweepConfig::default(),
82 }
83 }
84
85 pub fn plaintext(pool: PgPool) -> Self {
87 tracing::warn!(
88 "PostgresDeviceStore created without encryption; \
89 do not use in production"
90 );
91 Self {
92 pool,
93 codec: BindingsCodec::plaintext(),
94 clock: Arc::new(SystemClock),
95 sweep_config: SweepConfig::default(),
96 }
97 }
98
99 pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
101 self.clock = clock;
102 self
103 }
104
105 pub fn with_sweep_config(mut self, config: SweepConfig) -> Self {
107 self.sweep_config = config;
108 self
109 }
110
111 pub async fn init_schema(&self) -> Result<(), sqlx::Error> {
113 sqlx::query(
114 r#"
115 CREATE TABLE IF NOT EXISTS devices (
116 tenant_id TEXT NOT NULL,
117 id TEXT NOT NULL,
118 user_id TEXT,
119 trust_level TEXT NOT NULL,
120 fingerprint_hash BYTEA NOT NULL,
121 first_seen_at BIGINT NOT NULL,
122 last_seen_at BIGINT NOT NULL,
123 revoked_at BIGINT,
124 bindings TEXT NOT NULL,
125 PRIMARY KEY (tenant_id, id)
126 )
127 "#,
128 )
129 .execute(&self.pool)
130 .await?;
131
132 sqlx::query(
133 "CREATE INDEX IF NOT EXISTS idx_devices_fingerprint \
134 ON devices (tenant_id, fingerprint_hash)",
135 )
136 .execute(&self.pool)
137 .await?;
138 sqlx::query(
139 "CREATE INDEX IF NOT EXISTS idx_devices_user \
140 ON devices (tenant_id, user_id, last_seen_at DESC)",
141 )
142 .execute(&self.pool)
143 .await?;
144
145 sqlx::query(
146 r#"
147 CREATE TABLE IF NOT EXISTS device_bindings_refresh (
148 tenant_id TEXT NOT NULL,
149 device_id TEXT NOT NULL,
150 family_id TEXT NOT NULL,
151 PRIMARY KEY (tenant_id, device_id, family_id),
152 FOREIGN KEY (tenant_id, device_id)
153 REFERENCES devices (tenant_id, id) ON DELETE CASCADE
154 )
155 "#,
156 )
157 .execute(&self.pool)
158 .await?;
159
160 sqlx::query(
161 "CREATE INDEX IF NOT EXISTS idx_device_bindings_refresh_family \
162 ON device_bindings_refresh (tenant_id, family_id)",
163 )
164 .execute(&self.pool)
165 .await?;
166
167 Ok(())
168 }
169
170 fn decode_row(&self, row: DeviceRow) -> Result<Device, SqlDeviceStoreError> {
171 let DeviceRow {
172 tenant_id,
173 id,
174 user_id,
175 trust_level,
176 fingerprint_hash,
177 first_seen_at,
178 last_seen_at,
179 revoked_at,
180 bindings,
181 } = row;
182
183 let tenant = TenantId::try_new(&tenant_id)
184 .map_err(|e| SqlDeviceStoreError::MalformedRow(format!("tenant_id: {e}")))?;
185 let device_id = DeviceId::try_new(&id)
186 .map_err(|e| SqlDeviceStoreError::MalformedRow(format!("device id: {e}")))?;
187 let user = match user_id {
188 Some(u) => Some(
189 UserId::try_new(&u)
190 .map_err(|e| SqlDeviceStoreError::MalformedRow(format!("user_id: {e}")))?,
191 ),
192 None => None,
193 };
194
195 let trust = trust_level_codec::from_str(&trust_level)
196 .ok_or(SqlDeviceStoreError::UnknownTrustLevel(trust_level))?;
197
198 let fp_bytes: [u8; 32] = fingerprint_hash
199 .try_into()
200 .map_err(|_| SqlDeviceStoreError::MalformedRow("fingerprint_hash length".into()))?;
201
202 let first = unix_to_utc(first_seen_at)?;
203 let last = unix_to_utc(last_seen_at)?;
204 let revoked = match revoked_at {
205 Some(t) => Some(unix_to_utc(t)?),
206 None => None,
207 };
208
209 let bindings = self.codec.decode(&bindings)?;
210
211 Ok(Device {
212 id: device_id,
213 tenant_id: tenant,
214 user_id: user,
215 trust_level: trust,
216 fingerprint_hash: FingerprintHash::from_bytes(fp_bytes),
217 first_seen_at: first,
218 last_seen_at: last,
219 revoked_at: revoked,
220 bindings,
221 })
222 }
223}
224
225#[derive(sqlx::FromRow)]
226struct DeviceRow {
227 tenant_id: String,
228 id: String,
229 user_id: Option<String>,
230 trust_level: String,
231 fingerprint_hash: Vec<u8>,
232 first_seen_at: i64,
233 last_seen_at: i64,
234 revoked_at: Option<i64>,
235 bindings: String,
236}
237
238fn unix_to_utc(secs: i64) -> Result<DateTime<Utc>, SqlDeviceStoreError> {
239 Utc.timestamp_opt(secs, 0).single().ok_or_else(|| {
240 SqlDeviceStoreError::MalformedRow(format!("unrepresentable Unix timestamp: {secs}"))
241 })
242}
243
244fn utc_to_unix(dt: DateTime<Utc>) -> i64 {
245 dt.timestamp()
246}
247
248fn refresh_family_ids(bindings: &[DeviceBinding]) -> Vec<String> {
249 bindings
250 .iter()
251 .filter_map(|b| match b {
252 DeviceBinding::Refresh { family_id, .. } => Some(family_id.clone()),
253 _ => None,
254 })
255 .collect()
256}
257
258impl DeviceStore for PostgresDeviceStore {
259 type Error = SqlDeviceStoreError;
260
261 fn load(
262 &self,
263 tenant_id: &TenantId,
264 id: &DeviceId,
265 ) -> impl Future<Output = Result<Option<Device>, Self::Error>> + Send {
266 let pool = self.pool.clone();
267 let store = self.clone();
268 let tenant = tenant_id.to_string().to_string();
269 let device_id = id.to_string().to_string();
270 async move {
271 let row: Option<DeviceRow> = sqlx::query_as(
272 "SELECT tenant_id, id, user_id, trust_level, fingerprint_hash, \
273 first_seen_at, last_seen_at, revoked_at, bindings \
274 FROM devices WHERE tenant_id = $1 AND id = $2",
275 )
276 .bind(&tenant)
277 .bind(&device_id)
278 .fetch_optional(&pool)
279 .await?;
280
281 match row {
282 Some(r) => Ok(Some(store.decode_row(r)?)),
283 None => Ok(None),
284 }
285 }
286 }
287
288 fn find_by_fingerprint(
289 &self,
290 tenant_id: &TenantId,
291 hash: &FingerprintHash,
292 ) -> impl Future<Output = Result<Option<Device>, Self::Error>> + Send {
293 let pool = self.pool.clone();
294 let store = self.clone();
295 let tenant = tenant_id.to_string().to_string();
296 let bytes = hash.as_bytes().to_vec();
297 async move {
298 let row: Option<DeviceRow> = sqlx::query_as(
299 "SELECT tenant_id, id, user_id, trust_level, fingerprint_hash, \
300 first_seen_at, last_seen_at, revoked_at, bindings \
301 FROM devices WHERE tenant_id = $1 AND fingerprint_hash = $2 \
302 ORDER BY last_seen_at DESC LIMIT 1",
303 )
304 .bind(&tenant)
305 .bind(&bytes)
306 .fetch_optional(&pool)
307 .await?;
308
309 match row {
310 Some(r) => Ok(Some(store.decode_row(r)?)),
311 None => Ok(None),
312 }
313 }
314 }
315
316 fn find_for_user(
317 &self,
318 tenant_id: &TenantId,
319 user_id: &UserId,
320 limit: usize,
321 ) -> impl Future<Output = Result<Vec<Device>, Self::Error>> + Send {
322 let pool = self.pool.clone();
323 let store = self.clone();
324 let tenant = tenant_id.to_string().to_string();
325 let uid = user_id.to_string().to_string();
326 let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
327 async move {
328 let rows: Vec<DeviceRow> = sqlx::query_as(
329 "SELECT tenant_id, id, user_id, trust_level, fingerprint_hash, \
330 first_seen_at, last_seen_at, revoked_at, bindings \
331 FROM devices WHERE tenant_id = $1 AND user_id = $2 \
332 ORDER BY last_seen_at DESC LIMIT $3",
333 )
334 .bind(&tenant)
335 .bind(&uid)
336 .bind(limit_i64)
337 .fetch_all(&pool)
338 .await?;
339
340 let mut out = Vec::with_capacity(rows.len());
341 for r in rows {
342 out.push(store.decode_row(r)?);
343 }
344 Ok(out)
345 }
346 }
347
348 fn find_by_refresh_family(
349 &self,
350 tenant_id: &TenantId,
351 family_id: &str,
352 ) -> impl Future<Output = Result<Vec<Device>, Self::Error>> + Send {
353 let pool = self.pool.clone();
354 let store = self.clone();
355 let tenant = tenant_id.to_string().to_string();
356 let family = family_id.to_string();
357 async move {
358 let rows: Vec<DeviceRow> = sqlx::query_as(
359 "SELECT d.tenant_id, d.id, d.user_id, d.trust_level, d.fingerprint_hash, \
360 d.first_seen_at, d.last_seen_at, d.revoked_at, d.bindings \
361 FROM devices d \
362 INNER JOIN device_bindings_refresh r \
363 ON d.tenant_id = r.tenant_id AND d.id = r.device_id \
364 WHERE r.tenant_id = $1 AND r.family_id = $2 \
365 ORDER BY d.last_seen_at DESC",
366 )
367 .bind(&tenant)
368 .bind(&family)
369 .fetch_all(&pool)
370 .await?;
371
372 let mut out = Vec::with_capacity(rows.len());
373 for r in rows {
374 out.push(store.decode_row(r)?);
375 }
376 Ok(out)
377 }
378 }
379
380 fn save(&self, device: &Device) -> impl Future<Output = Result<(), Self::Error>> + Send {
381 let pool = self.pool.clone();
382 let codec = self.codec.clone();
383 let device = device.clone();
384 async move {
385 let bindings_blob = codec.encode(&device.bindings)?;
386 let trust = trust_level_codec::to_str(device.trust_level);
387 let fp = device.fingerprint_hash.as_bytes().to_vec();
388 let user_id_col = device.user_id.as_ref().map(|u| u.to_string().to_string());
389 let first = utc_to_unix(device.first_seen_at);
390 let last = utc_to_unix(device.last_seen_at);
391 let revoked = device.revoked_at.map(utc_to_unix);
392 let family_ids = refresh_family_ids(&device.bindings);
393 let tenant = device.tenant_id.to_string().to_string();
394 let id = device.id.to_string().to_string();
395
396 let mut tx = pool.begin().await?;
397
398 sqlx::query(
399 r#"
400 INSERT INTO devices
401 (tenant_id, id, user_id, trust_level, fingerprint_hash,
402 first_seen_at, last_seen_at, revoked_at, bindings)
403 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
404 ON CONFLICT (tenant_id, id) DO UPDATE SET
405 user_id = EXCLUDED.user_id,
406 trust_level = EXCLUDED.trust_level,
407 fingerprint_hash = EXCLUDED.fingerprint_hash,
408 first_seen_at = EXCLUDED.first_seen_at,
409 last_seen_at = EXCLUDED.last_seen_at,
410 revoked_at = EXCLUDED.revoked_at,
411 bindings = EXCLUDED.bindings
412 "#,
413 )
414 .bind(&tenant)
415 .bind(&id)
416 .bind(user_id_col.as_deref())
417 .bind(trust)
418 .bind(&fp)
419 .bind(first)
420 .bind(last)
421 .bind(revoked)
422 .bind(&bindings_blob)
423 .execute(&mut *tx)
424 .await?;
425
426 sqlx::query(
427 "DELETE FROM device_bindings_refresh \
428 WHERE tenant_id = $1 AND device_id = $2",
429 )
430 .bind(&tenant)
431 .bind(&id)
432 .execute(&mut *tx)
433 .await?;
434
435 for family_id in &family_ids {
436 sqlx::query(
437 "INSERT INTO device_bindings_refresh \
438 (tenant_id, device_id, family_id) VALUES ($1, $2, $3)",
439 )
440 .bind(&tenant)
441 .bind(&id)
442 .bind(family_id)
443 .execute(&mut *tx)
444 .await?;
445 }
446
447 tx.commit().await?;
448 Ok(())
449 }
450 }
451
452 fn record_sighting(
453 &self,
454 tenant_id: &TenantId,
455 id: &DeviceId,
456 now: DateTime<Utc>,
457 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
458 let pool = self.pool.clone();
459 let tenant = tenant_id.to_string().to_string();
460 let device_id = id.to_string().to_string();
461 let ts = utc_to_unix(now);
462 async move {
463 sqlx::query(
464 "UPDATE devices SET last_seen_at = $3 \
465 WHERE tenant_id = $1 AND id = $2",
466 )
467 .bind(&tenant)
468 .bind(&device_id)
469 .bind(ts)
470 .execute(&pool)
471 .await?;
472 Ok(())
473 }
474 }
475
476 fn set_trust_level(
477 &self,
478 tenant_id: &TenantId,
479 id: &DeviceId,
480 level: DeviceTrustLevel,
481 now: DateTime<Utc>,
482 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
483 let pool = self.pool.clone();
484 let tenant = tenant_id.to_string().to_string();
485 let device_id = id.to_string().to_string();
486 let trust = trust_level_codec::to_str(level);
487 let ts = utc_to_unix(now);
488 let revoked_at = match level {
489 DeviceTrustLevel::Revoked => Some(ts),
490 _ => None,
491 };
492 async move {
493 sqlx::query(
494 "UPDATE devices SET trust_level = $3, revoked_at = $4 \
495 WHERE tenant_id = $1 AND id = $2",
496 )
497 .bind(&tenant)
498 .bind(&device_id)
499 .bind(trust)
500 .bind(revoked_at)
501 .execute(&pool)
502 .await?;
503 Ok(())
504 }
505 }
506
507 fn delete(
508 &self,
509 tenant_id: &TenantId,
510 id: &DeviceId,
511 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
512 let pool = self.pool.clone();
513 let tenant = tenant_id.to_string().to_string();
514 let device_id = id.to_string().to_string();
515 async move {
516 sqlx::query("DELETE FROM devices WHERE tenant_id = $1 AND id = $2")
517 .bind(&tenant)
518 .bind(&device_id)
519 .execute(&pool)
520 .await?;
521 Ok(())
522 }
523 }
524
525 fn sweep(
526 &self,
527 tenant_id: &TenantId,
528 now: DateTime<Utc>,
529 ) -> impl Future<Output = Result<SweepCounts, Self::Error>> + Send {
530 let pool = self.pool.clone();
531 let cfg = self.sweep_config;
532 let tenant = tenant_id.to_string().to_string();
533 let now_secs = utc_to_unix(now);
534 async move {
535 let trusted_cutoff = now_secs - cfg.trusted_idle.num_seconds();
536 let seen_cutoff = now_secs - cfg.seen_idle.num_seconds();
537 let grace_cutoff = now_secs - cfg.revoked_grace.num_seconds();
538
539 let trusted_demoted = sqlx::query(
540 "UPDATE devices SET trust_level = 'Seen' \
541 WHERE tenant_id = $1 \
542 AND trust_level = 'Trusted' \
543 AND last_seen_at < $2",
544 )
545 .bind(&tenant)
546 .bind(trusted_cutoff)
547 .execute(&pool)
548 .await?
549 .rows_affected();
550
551 let seen_demoted = sqlx::query(
552 "UPDATE devices SET trust_level = 'Revoked', revoked_at = $3 \
553 WHERE tenant_id = $1 \
554 AND trust_level = 'Seen' \
555 AND last_seen_at < $2",
556 )
557 .bind(&tenant)
558 .bind(seen_cutoff)
559 .bind(now_secs)
560 .execute(&pool)
561 .await?
562 .rows_affected();
563
564 let purged = sqlx::query(
565 "DELETE FROM devices \
566 WHERE tenant_id = $1 \
567 AND trust_level = 'Revoked' \
568 AND revoked_at IS NOT NULL \
569 AND revoked_at < $2",
570 )
571 .bind(&tenant)
572 .bind(grace_cutoff)
573 .execute(&pool)
574 .await?
575 .rows_affected();
576
577 Ok(SweepCounts {
578 trusted_to_seen: trusted_demoted,
579 seen_to_revoked: seen_demoted,
580 revoked_purged: purged,
581 })
582 }
583 }
584}
585
586use crate::health::{HealthCheck, HealthStatus};
589
590impl HealthCheck for PostgresDeviceStore {
591 fn check(
592 &self,
593 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = HealthStatus> + Send + '_>> {
594 Box::pin(async {
595 match tokio::time::timeout(
596 Duration::from_secs(2),
597 sqlx::query_scalar::<_, i32>("SELECT 1").fetch_one(&self.pool),
598 )
599 .await
600 {
601 Ok(Ok(_)) => HealthStatus::Healthy,
602 Ok(Err(e)) => HealthStatus::Unhealthy(format!("postgres SELECT 1 failed: {e}")),
603 Err(_) => HealthStatus::Unhealthy("postgres SELECT 1 timeout (2s)".into()),
604 }
605 })
606 }
607}
608
609