Skip to main content

nexo_pairing/
session_store.rs

1//! Companion session tokens — issued after a successful WS handshake.
2//!
3//! `PairingSessionStore` persists session tokens in SQLite so the daemon
4//! can validate companion requests after restart without forcing re-pairing.
5
6use std::path::Path;
7use std::time::Duration;
8
9use chrono::{DateTime, Utc};
10use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
11use sqlx::SqlitePool;
12use tracing;
13
14use crate::types::PairingError;
15
16pub struct PairingSessionStore {
17    pool: SqlitePool,
18}
19
20pub struct SessionRow {
21    pub profile: String,
22    pub device_label: Option<String>,
23    pub expires_at: DateTime<Utc>,
24}
25
26impl PairingSessionStore {
27    pub async fn open(path: &Path) -> Result<Self, PairingError> {
28        Self::connect(&path.to_string_lossy(), 4).await
29    }
30
31    pub async fn open_memory() -> Result<Self, PairingError> {
32        Self::connect(":memory:", 1).await
33    }
34
35    async fn connect(path: &str, max_conns: u32) -> Result<Self, PairingError> {
36        let opts = SqliteConnectOptions::new()
37            .filename(path)
38            .create_if_missing(true);
39        let pool = SqlitePoolOptions::new()
40            .max_connections(max_conns)
41            .connect_with(opts)
42            .await
43            .map_err(|e| PairingError::Storage(e.to_string()))?;
44        sqlx::query(
45            "CREATE TABLE IF NOT EXISTS pairing_sessions (\
46                token        TEXT PRIMARY KEY,\
47                profile      TEXT NOT NULL,\
48                device_label TEXT,\
49                issued_at    INTEGER NOT NULL,\
50                expires_at   INTEGER NOT NULL\
51            )",
52        )
53        .execute(&pool)
54        .await
55        .map_err(|e| PairingError::Storage(e.to_string()))?;
56        Ok(Self { pool })
57    }
58
59    pub async fn insert_session(
60        &self,
61        token: &str,
62        profile: &str,
63        device_label: Option<&str>,
64        ttl: Duration,
65    ) -> Result<(), PairingError> {
66        let now = Utc::now();
67        let issued_at = now.timestamp();
68        let ttl_dur = chrono::Duration::from_std(ttl)
69            .map_err(|_| PairingError::Invalid("session ttl out of range"))?;
70        let expires_at = (now + ttl_dur).timestamp();
71        // Best-effort GC — don't block the insert if cleanup fails.
72        if let Err(e) = self.expire_sessions().await {
73            tracing::warn!(error = %e, "pairing session GC failed");
74        }
75        sqlx::query(
76            "INSERT OR REPLACE INTO pairing_sessions \
77             (token, profile, device_label, issued_at, expires_at) \
78             VALUES (?1, ?2, ?3, ?4, ?5)",
79        )
80        .bind(token)
81        .bind(profile)
82        .bind(device_label)
83        .bind(issued_at)
84        .bind(expires_at)
85        .execute(&self.pool)
86        .await
87        .map_err(|e| PairingError::Storage(e.to_string()))?;
88        Ok(())
89    }
90
91    pub async fn lookup_session(&self, token: &str) -> Result<Option<SessionRow>, PairingError> {
92        let now = Utc::now().timestamp();
93        let row: Option<(String, Option<String>, i64)> = sqlx::query_as(
94            "SELECT profile, device_label, expires_at \
95             FROM pairing_sessions WHERE token = ?1 AND expires_at >= ?2",
96        )
97        .bind(token)
98        .bind(now)
99        .fetch_optional(&self.pool)
100        .await
101        .map_err(|e| PairingError::Storage(e.to_string()))?;
102        row.map(|(profile, device_label, expires_at_ts)| {
103            let expires_at = DateTime::<Utc>::from_timestamp(expires_at_ts, 0).ok_or(
104                PairingError::Storage(format!("corrupt expires_at timestamp: {expires_at_ts}")),
105            )?;
106            Ok(SessionRow {
107                profile,
108                device_label,
109                expires_at,
110            })
111        })
112        .transpose()
113    }
114
115    pub async fn expire_sessions(&self) -> Result<usize, PairingError> {
116        let now = Utc::now().timestamp();
117        let result = sqlx::query("DELETE FROM pairing_sessions WHERE expires_at < ?1")
118            .bind(now)
119            .execute(&self.pool)
120            .await
121            .map_err(|e| PairingError::Storage(e.to_string()))?;
122        Ok(result.rows_affected() as usize)
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use std::time::Duration;
130
131    #[tokio::test]
132    async fn insert_and_lookup() {
133        let store = PairingSessionStore::open_memory().await.unwrap();
134        store
135            .insert_session(
136                "tok1",
137                "companion-v1",
138                Some("phone"),
139                Duration::from_secs(3600),
140            )
141            .await
142            .unwrap();
143        let row = store.lookup_session("tok1").await.unwrap().unwrap();
144        assert_eq!(row.profile, "companion-v1");
145        assert_eq!(row.device_label.as_deref(), Some("phone"));
146    }
147
148    #[tokio::test]
149    async fn unknown_token_returns_none() {
150        let store = PairingSessionStore::open_memory().await.unwrap();
151        assert!(store.lookup_session("nosuchtoken").await.unwrap().is_none());
152    }
153
154    #[tokio::test]
155    async fn expired_row_not_returned() {
156        let store = PairingSessionStore::open_memory().await.unwrap();
157        let issued_at = Utc::now().timestamp() - 10;
158        let expires_at = Utc::now().timestamp() - 1;
159        sqlx::query(
160            "INSERT INTO pairing_sessions \
161             (token, profile, device_label, issued_at, expires_at) \
162             VALUES ('expired_tok', 'companion-v1', NULL, ?1, ?2)",
163        )
164        .bind(issued_at)
165        .bind(expires_at)
166        .execute(&store.pool)
167        .await
168        .unwrap();
169        assert!(store.lookup_session("expired_tok").await.unwrap().is_none());
170    }
171
172    #[tokio::test]
173    async fn expire_sessions_clears_old_rows() {
174        let store = PairingSessionStore::open_memory().await.unwrap();
175        let issued_at = Utc::now().timestamp() - 10;
176        let expires_at = Utc::now().timestamp() - 1;
177        sqlx::query(
178            "INSERT INTO pairing_sessions \
179             (token, profile, device_label, issued_at, expires_at) \
180             VALUES ('old_tok', 'p', NULL, ?1, ?2)",
181        )
182        .bind(issued_at)
183        .bind(expires_at)
184        .execute(&store.pool)
185        .await
186        .unwrap();
187        let deleted = store.expire_sessions().await.unwrap();
188        assert_eq!(deleted, 1);
189        let count: i64 =
190            sqlx::query_scalar("SELECT COUNT(*) FROM pairing_sessions WHERE token = 'old_tok'")
191                .fetch_one(&store.pool)
192                .await
193                .unwrap();
194        assert_eq!(count, 0);
195    }
196}