Skip to main content

axess_core/device/storage/
postgres.rs

1//! PostgreSQL-backed [`DeviceStore`] using sqlx.
2//!
3//! # Schema
4//!
5//! ```sql
6//! CREATE TABLE IF NOT EXISTS devices (
7//!     tenant_id        TEXT   NOT NULL,
8//!     id               TEXT   NOT NULL,
9//!     user_id          TEXT,
10//!     trust_level      TEXT   NOT NULL,
11//!     fingerprint_hash BYTEA  NOT NULL,
12//!     first_seen_at    BIGINT NOT NULL,
13//!     last_seen_at     BIGINT NOT NULL,
14//!     revoked_at       BIGINT,
15//!     bindings         TEXT   NOT NULL,
16//!     PRIMARY KEY (tenant_id, id)
17//! );
18//!
19//! CREATE INDEX IF NOT EXISTS idx_devices_fingerprint
20//!     ON devices (tenant_id, fingerprint_hash);
21//! CREATE INDEX IF NOT EXISTS idx_devices_user
22//!     ON devices (tenant_id, user_id, last_seen_at DESC);
23//!
24//! CREATE TABLE IF NOT EXISTS device_bindings_refresh (
25//!     tenant_id TEXT NOT NULL,
26//!     device_id TEXT NOT NULL,
27//!     family_id TEXT NOT NULL,
28//!     PRIMARY KEY (tenant_id, device_id, family_id),
29//!     FOREIGN KEY (tenant_id, device_id)
30//!         REFERENCES devices (tenant_id, id) ON DELETE CASCADE
31//! );
32//!
33//! CREATE INDEX IF NOT EXISTS idx_device_bindings_refresh_family
34//!     ON device_bindings_refresh (tenant_id, family_id);
35//! ```
36//!
37//! Mirrors [`super::sqlite::SqliteDeviceStore`] with the SQL-dialect
38//! changes Postgres requires (`$1` binds, `BIGINT`, `BYTEA`).
39
40use std::future::Future;
41use std::sync::Arc;
42use std::time::Duration;
43
44use chrono::{DateTime, TimeZone, Utc};
45use sqlx::PgPool;
46
47use axess_clock::{Clock, SystemClock};
48
49use crate::authn::ids::{DeviceId, TenantId, UserId};
50use crate::device::storage::sql_common::{BindingsCodec, SqlDeviceStoreError, trust_level_codec};
51use crate::device::store::{DeviceStore, SweepConfig, SweepCounts};
52use crate::device::types::{Device, DeviceBinding, DeviceTrustLevel, FingerprintHash};
53use crate::session::crypto::SessionCrypto;
54
55/// PostgreSQL-backed [`DeviceStore`].
56///
57/// Wrap an existing [`PgPool`] and call
58/// [`init_schema`](Self::init_schema) once at startup.
59///
60/// # Encryption
61///
62/// Same shape as
63/// [`SqliteDeviceStore`](super::sqlite::SqliteDeviceStore): the
64/// optional [`SessionCrypto`] envelope is applied to the bindings
65/// blob only.
66#[derive(Clone)]
67pub struct PostgresDeviceStore {
68    pool: PgPool,
69    codec: BindingsCodec,
70    clock: Arc<dyn Clock>,
71    sweep_config: SweepConfig,
72}
73
74impl PostgresDeviceStore {
75    /// Create an encrypted store (recommended for production).
76    pub fn new(pool: PgPool, crypto: SessionCrypto) -> Self {
77        Self {
78            pool,
79            codec: BindingsCodec::encrypted(crypto),
80            clock: Arc::new(SystemClock),
81            sweep_config: SweepConfig::default(),
82        }
83    }
84
85    /// Create a plaintext store (development/testing only).
86    pub fn plaintext(pool: PgPool) -> Self {
87        tracing::warn!(
88            "PostgresDeviceStore created without encryption; \
89             do not use in production"
90        );
91        Self {
92            pool,
93            codec: BindingsCodec::plaintext(),
94            clock: Arc::new(SystemClock),
95            sweep_config: SweepConfig::default(),
96        }
97    }
98
99    /// Inject a [`Clock`] for deterministic-simulation testing.
100    pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
101        self.clock = clock;
102        self
103    }
104
105    /// Override the [`SweepConfig`] driving the retention ladder.
106    pub fn with_sweep_config(mut self, config: SweepConfig) -> Self {
107        self.sweep_config = config;
108        self
109    }
110
111    /// Create tables + indexes. Idempotent.
112    pub async fn init_schema(&self) -> Result<(), sqlx::Error> {
113        sqlx::query(
114            r#"
115            CREATE TABLE IF NOT EXISTS devices (
116                tenant_id        TEXT   NOT NULL,
117                id               TEXT   NOT NULL,
118                user_id          TEXT,
119                trust_level      TEXT   NOT NULL,
120                fingerprint_hash BYTEA  NOT NULL,
121                first_seen_at    BIGINT NOT NULL,
122                last_seen_at     BIGINT NOT NULL,
123                revoked_at       BIGINT,
124                bindings         TEXT   NOT NULL,
125                PRIMARY KEY (tenant_id, id)
126            )
127            "#,
128        )
129        .execute(&self.pool)
130        .await?;
131
132        sqlx::query(
133            "CREATE INDEX IF NOT EXISTS idx_devices_fingerprint \
134             ON devices (tenant_id, fingerprint_hash)",
135        )
136        .execute(&self.pool)
137        .await?;
138        sqlx::query(
139            "CREATE INDEX IF NOT EXISTS idx_devices_user \
140             ON devices (tenant_id, user_id, last_seen_at DESC)",
141        )
142        .execute(&self.pool)
143        .await?;
144
145        sqlx::query(
146            r#"
147            CREATE TABLE IF NOT EXISTS device_bindings_refresh (
148                tenant_id TEXT NOT NULL,
149                device_id TEXT NOT NULL,
150                family_id TEXT NOT NULL,
151                PRIMARY KEY (tenant_id, device_id, family_id),
152                FOREIGN KEY (tenant_id, device_id)
153                    REFERENCES devices (tenant_id, id) ON DELETE CASCADE
154            )
155            "#,
156        )
157        .execute(&self.pool)
158        .await?;
159
160        sqlx::query(
161            "CREATE INDEX IF NOT EXISTS idx_device_bindings_refresh_family \
162             ON device_bindings_refresh (tenant_id, family_id)",
163        )
164        .execute(&self.pool)
165        .await?;
166
167        Ok(())
168    }
169
170    fn decode_row(&self, row: DeviceRow) -> Result<Device, SqlDeviceStoreError> {
171        let DeviceRow {
172            tenant_id,
173            id,
174            user_id,
175            trust_level,
176            fingerprint_hash,
177            first_seen_at,
178            last_seen_at,
179            revoked_at,
180            bindings,
181        } = row;
182
183        let tenant = TenantId::try_new(&tenant_id)
184            .map_err(|e| SqlDeviceStoreError::MalformedRow(format!("tenant_id: {e}")))?;
185        let device_id = DeviceId::try_new(&id)
186            .map_err(|e| SqlDeviceStoreError::MalformedRow(format!("device id: {e}")))?;
187        let user = match user_id {
188            Some(u) => Some(
189                UserId::try_new(&u)
190                    .map_err(|e| SqlDeviceStoreError::MalformedRow(format!("user_id: {e}")))?,
191            ),
192            None => None,
193        };
194
195        let trust = trust_level_codec::from_str(&trust_level)
196            .ok_or(SqlDeviceStoreError::UnknownTrustLevel(trust_level))?;
197
198        let fp_bytes: [u8; 32] = fingerprint_hash
199            .try_into()
200            .map_err(|_| SqlDeviceStoreError::MalformedRow("fingerprint_hash length".into()))?;
201
202        let first = unix_to_utc(first_seen_at)?;
203        let last = unix_to_utc(last_seen_at)?;
204        let revoked = match revoked_at {
205            Some(t) => Some(unix_to_utc(t)?),
206            None => None,
207        };
208
209        let bindings = self.codec.decode(&bindings)?;
210
211        Ok(Device {
212            id: device_id,
213            tenant_id: tenant,
214            user_id: user,
215            trust_level: trust,
216            fingerprint_hash: FingerprintHash::from_bytes(fp_bytes),
217            first_seen_at: first,
218            last_seen_at: last,
219            revoked_at: revoked,
220            bindings,
221        })
222    }
223}
224
225#[derive(sqlx::FromRow)]
226struct DeviceRow {
227    tenant_id: String,
228    id: String,
229    user_id: Option<String>,
230    trust_level: String,
231    fingerprint_hash: Vec<u8>,
232    first_seen_at: i64,
233    last_seen_at: i64,
234    revoked_at: Option<i64>,
235    bindings: String,
236}
237
238fn unix_to_utc(secs: i64) -> Result<DateTime<Utc>, SqlDeviceStoreError> {
239    Utc.timestamp_opt(secs, 0).single().ok_or_else(|| {
240        SqlDeviceStoreError::MalformedRow(format!("unrepresentable Unix timestamp: {secs}"))
241    })
242}
243
244fn utc_to_unix(dt: DateTime<Utc>) -> i64 {
245    dt.timestamp()
246}
247
248fn refresh_family_ids(bindings: &[DeviceBinding]) -> Vec<String> {
249    bindings
250        .iter()
251        .filter_map(|b| match b {
252            DeviceBinding::Refresh { family_id, .. } => Some(family_id.clone()),
253            _ => None,
254        })
255        .collect()
256}
257
258impl DeviceStore for PostgresDeviceStore {
259    type Error = SqlDeviceStoreError;
260
261    fn load(
262        &self,
263        tenant_id: &TenantId,
264        id: &DeviceId,
265    ) -> impl Future<Output = Result<Option<Device>, Self::Error>> + Send {
266        let pool = self.pool.clone();
267        let store = self.clone();
268        let tenant = tenant_id.to_string().to_string();
269        let device_id = id.to_string().to_string();
270        async move {
271            let row: Option<DeviceRow> = sqlx::query_as(
272                "SELECT tenant_id, id, user_id, trust_level, fingerprint_hash, \
273                        first_seen_at, last_seen_at, revoked_at, bindings \
274                 FROM devices WHERE tenant_id = $1 AND id = $2",
275            )
276            .bind(&tenant)
277            .bind(&device_id)
278            .fetch_optional(&pool)
279            .await?;
280
281            match row {
282                Some(r) => Ok(Some(store.decode_row(r)?)),
283                None => Ok(None),
284            }
285        }
286    }
287
288    fn find_by_fingerprint(
289        &self,
290        tenant_id: &TenantId,
291        hash: &FingerprintHash,
292    ) -> impl Future<Output = Result<Option<Device>, Self::Error>> + Send {
293        let pool = self.pool.clone();
294        let store = self.clone();
295        let tenant = tenant_id.to_string().to_string();
296        let bytes = hash.as_bytes().to_vec();
297        async move {
298            let row: Option<DeviceRow> = sqlx::query_as(
299                "SELECT tenant_id, id, user_id, trust_level, fingerprint_hash, \
300                        first_seen_at, last_seen_at, revoked_at, bindings \
301                 FROM devices WHERE tenant_id = $1 AND fingerprint_hash = $2 \
302                 ORDER BY last_seen_at DESC LIMIT 1",
303            )
304            .bind(&tenant)
305            .bind(&bytes)
306            .fetch_optional(&pool)
307            .await?;
308
309            match row {
310                Some(r) => Ok(Some(store.decode_row(r)?)),
311                None => Ok(None),
312            }
313        }
314    }
315
316    fn find_for_user(
317        &self,
318        tenant_id: &TenantId,
319        user_id: &UserId,
320        limit: usize,
321    ) -> impl Future<Output = Result<Vec<Device>, Self::Error>> + Send {
322        let pool = self.pool.clone();
323        let store = self.clone();
324        let tenant = tenant_id.to_string().to_string();
325        let uid = user_id.to_string().to_string();
326        let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
327        async move {
328            let rows: Vec<DeviceRow> = sqlx::query_as(
329                "SELECT tenant_id, id, user_id, trust_level, fingerprint_hash, \
330                        first_seen_at, last_seen_at, revoked_at, bindings \
331                 FROM devices WHERE tenant_id = $1 AND user_id = $2 \
332                 ORDER BY last_seen_at DESC LIMIT $3",
333            )
334            .bind(&tenant)
335            .bind(&uid)
336            .bind(limit_i64)
337            .fetch_all(&pool)
338            .await?;
339
340            let mut out = Vec::with_capacity(rows.len());
341            for r in rows {
342                out.push(store.decode_row(r)?);
343            }
344            Ok(out)
345        }
346    }
347
348    fn find_by_refresh_family(
349        &self,
350        tenant_id: &TenantId,
351        family_id: &str,
352    ) -> impl Future<Output = Result<Vec<Device>, Self::Error>> + Send {
353        let pool = self.pool.clone();
354        let store = self.clone();
355        let tenant = tenant_id.to_string().to_string();
356        let family = family_id.to_string();
357        async move {
358            let rows: Vec<DeviceRow> = sqlx::query_as(
359                "SELECT d.tenant_id, d.id, d.user_id, d.trust_level, d.fingerprint_hash, \
360                        d.first_seen_at, d.last_seen_at, d.revoked_at, d.bindings \
361                 FROM devices d \
362                 INNER JOIN device_bindings_refresh r \
363                   ON d.tenant_id = r.tenant_id AND d.id = r.device_id \
364                 WHERE r.tenant_id = $1 AND r.family_id = $2 \
365                 ORDER BY d.last_seen_at DESC",
366            )
367            .bind(&tenant)
368            .bind(&family)
369            .fetch_all(&pool)
370            .await?;
371
372            let mut out = Vec::with_capacity(rows.len());
373            for r in rows {
374                out.push(store.decode_row(r)?);
375            }
376            Ok(out)
377        }
378    }
379
380    fn save(&self, device: &Device) -> impl Future<Output = Result<(), Self::Error>> + Send {
381        let pool = self.pool.clone();
382        let codec = self.codec.clone();
383        let device = device.clone();
384        async move {
385            let bindings_blob = codec.encode(&device.bindings)?;
386            let trust = trust_level_codec::to_str(device.trust_level);
387            let fp = device.fingerprint_hash.as_bytes().to_vec();
388            let user_id_col = device.user_id.as_ref().map(|u| u.to_string().to_string());
389            let first = utc_to_unix(device.first_seen_at);
390            let last = utc_to_unix(device.last_seen_at);
391            let revoked = device.revoked_at.map(utc_to_unix);
392            let family_ids = refresh_family_ids(&device.bindings);
393            let tenant = device.tenant_id.to_string().to_string();
394            let id = device.id.to_string().to_string();
395
396            let mut tx = pool.begin().await?;
397
398            sqlx::query(
399                r#"
400                INSERT INTO devices
401                    (tenant_id, id, user_id, trust_level, fingerprint_hash,
402                     first_seen_at, last_seen_at, revoked_at, bindings)
403                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
404                ON CONFLICT (tenant_id, id) DO UPDATE SET
405                    user_id          = EXCLUDED.user_id,
406                    trust_level      = EXCLUDED.trust_level,
407                    fingerprint_hash = EXCLUDED.fingerprint_hash,
408                    first_seen_at    = EXCLUDED.first_seen_at,
409                    last_seen_at     = EXCLUDED.last_seen_at,
410                    revoked_at       = EXCLUDED.revoked_at,
411                    bindings         = EXCLUDED.bindings
412                "#,
413            )
414            .bind(&tenant)
415            .bind(&id)
416            .bind(user_id_col.as_deref())
417            .bind(trust)
418            .bind(&fp)
419            .bind(first)
420            .bind(last)
421            .bind(revoked)
422            .bind(&bindings_blob)
423            .execute(&mut *tx)
424            .await?;
425
426            sqlx::query(
427                "DELETE FROM device_bindings_refresh \
428                 WHERE tenant_id = $1 AND device_id = $2",
429            )
430            .bind(&tenant)
431            .bind(&id)
432            .execute(&mut *tx)
433            .await?;
434
435            for family_id in &family_ids {
436                sqlx::query(
437                    "INSERT INTO device_bindings_refresh \
438                     (tenant_id, device_id, family_id) VALUES ($1, $2, $3)",
439                )
440                .bind(&tenant)
441                .bind(&id)
442                .bind(family_id)
443                .execute(&mut *tx)
444                .await?;
445            }
446
447            tx.commit().await?;
448            Ok(())
449        }
450    }
451
452    fn record_sighting(
453        &self,
454        tenant_id: &TenantId,
455        id: &DeviceId,
456        now: DateTime<Utc>,
457    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
458        let pool = self.pool.clone();
459        let tenant = tenant_id.to_string().to_string();
460        let device_id = id.to_string().to_string();
461        let ts = utc_to_unix(now);
462        async move {
463            sqlx::query(
464                "UPDATE devices SET last_seen_at = $3 \
465                 WHERE tenant_id = $1 AND id = $2",
466            )
467            .bind(&tenant)
468            .bind(&device_id)
469            .bind(ts)
470            .execute(&pool)
471            .await?;
472            Ok(())
473        }
474    }
475
476    fn set_trust_level(
477        &self,
478        tenant_id: &TenantId,
479        id: &DeviceId,
480        level: DeviceTrustLevel,
481        now: DateTime<Utc>,
482    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
483        let pool = self.pool.clone();
484        let tenant = tenant_id.to_string().to_string();
485        let device_id = id.to_string().to_string();
486        let trust = trust_level_codec::to_str(level);
487        let ts = utc_to_unix(now);
488        let revoked_at = match level {
489            DeviceTrustLevel::Revoked => Some(ts),
490            _ => None,
491        };
492        async move {
493            sqlx::query(
494                "UPDATE devices SET trust_level = $3, revoked_at = $4 \
495                 WHERE tenant_id = $1 AND id = $2",
496            )
497            .bind(&tenant)
498            .bind(&device_id)
499            .bind(trust)
500            .bind(revoked_at)
501            .execute(&pool)
502            .await?;
503            Ok(())
504        }
505    }
506
507    fn delete(
508        &self,
509        tenant_id: &TenantId,
510        id: &DeviceId,
511    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
512        let pool = self.pool.clone();
513        let tenant = tenant_id.to_string().to_string();
514        let device_id = id.to_string().to_string();
515        async move {
516            sqlx::query("DELETE FROM devices WHERE tenant_id = $1 AND id = $2")
517                .bind(&tenant)
518                .bind(&device_id)
519                .execute(&pool)
520                .await?;
521            Ok(())
522        }
523    }
524
525    fn sweep(
526        &self,
527        tenant_id: &TenantId,
528        now: DateTime<Utc>,
529    ) -> impl Future<Output = Result<SweepCounts, Self::Error>> + Send {
530        let pool = self.pool.clone();
531        let cfg = self.sweep_config;
532        let tenant = tenant_id.to_string().to_string();
533        let now_secs = utc_to_unix(now);
534        async move {
535            let trusted_cutoff = now_secs - cfg.trusted_idle.num_seconds();
536            let seen_cutoff = now_secs - cfg.seen_idle.num_seconds();
537            let grace_cutoff = now_secs - cfg.revoked_grace.num_seconds();
538
539            let trusted_demoted = sqlx::query(
540                "UPDATE devices SET trust_level = 'Seen' \
541                 WHERE tenant_id = $1 \
542                   AND trust_level = 'Trusted' \
543                   AND last_seen_at < $2",
544            )
545            .bind(&tenant)
546            .bind(trusted_cutoff)
547            .execute(&pool)
548            .await?
549            .rows_affected();
550
551            let seen_demoted = sqlx::query(
552                "UPDATE devices SET trust_level = 'Revoked', revoked_at = $3 \
553                 WHERE tenant_id = $1 \
554                   AND trust_level = 'Seen' \
555                   AND last_seen_at < $2",
556            )
557            .bind(&tenant)
558            .bind(seen_cutoff)
559            .bind(now_secs)
560            .execute(&pool)
561            .await?
562            .rows_affected();
563
564            let purged = sqlx::query(
565                "DELETE FROM devices \
566                 WHERE tenant_id = $1 \
567                   AND trust_level = 'Revoked' \
568                   AND revoked_at IS NOT NULL \
569                   AND revoked_at < $2",
570            )
571            .bind(&tenant)
572            .bind(grace_cutoff)
573            .execute(&pool)
574            .await?
575            .rows_affected();
576
577            Ok(SweepCounts {
578                trusted_to_seen: trusted_demoted,
579                seen_to_revoked: seen_demoted,
580                revoked_purged: purged,
581            })
582        }
583    }
584}
585
586// ── HealthCheck ──────────────────────────────────────────────────────
587
588use crate::health::{HealthCheck, HealthStatus};
589
590impl HealthCheck for PostgresDeviceStore {
591    fn check(
592        &self,
593    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = HealthStatus> + Send + '_>> {
594        Box::pin(async {
595            match tokio::time::timeout(
596                Duration::from_secs(2),
597                sqlx::query_scalar::<_, i32>("SELECT 1").fetch_one(&self.pool),
598            )
599            .await
600            {
601                Ok(Ok(_)) => HealthStatus::Healthy,
602                Ok(Err(e)) => HealthStatus::Unhealthy(format!("postgres SELECT 1 failed: {e}")),
603                Err(_) => HealthStatus::Unhealthy("postgres SELECT 1 timeout (2s)".into()),
604            }
605        })
606    }
607}
608
609// Integration tests live in `axess-core/tests/postgres_device_store.rs`
610//; they require a running Postgres at $POSTGRES_URL and run with
611// `cargo test -- --ignored`. Mirrors the pattern used by
612// `tests/postgres_store.rs` for the session-store equivalent.