nexo_pairing/
session_store.rs1use std::path::Path;
7use std::time::Duration;
8
9use chrono::{DateTime, Utc};
10use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
11use sqlx::SqlitePool;
12use tracing;
13
14use crate::types::PairingError;
15
16pub struct PairingSessionStore {
17 pool: SqlitePool,
18}
19
20pub struct SessionRow {
21 pub profile: String,
22 pub device_label: Option<String>,
23 pub expires_at: DateTime<Utc>,
24}
25
26impl PairingSessionStore {
27 pub async fn open(path: &Path) -> Result<Self, PairingError> {
28 Self::connect(&path.to_string_lossy(), 4).await
29 }
30
31 pub async fn open_memory() -> Result<Self, PairingError> {
32 Self::connect(":memory:", 1).await
33 }
34
35 async fn connect(path: &str, max_conns: u32) -> Result<Self, PairingError> {
36 let opts = SqliteConnectOptions::new()
37 .filename(path)
38 .create_if_missing(true);
39 let pool = SqlitePoolOptions::new()
40 .max_connections(max_conns)
41 .connect_with(opts)
42 .await
43 .map_err(|e| PairingError::Storage(e.to_string()))?;
44 sqlx::query(
45 "CREATE TABLE IF NOT EXISTS pairing_sessions (\
46 token TEXT PRIMARY KEY,\
47 profile TEXT NOT NULL,\
48 device_label TEXT,\
49 issued_at INTEGER NOT NULL,\
50 expires_at INTEGER NOT NULL\
51 )",
52 )
53 .execute(&pool)
54 .await
55 .map_err(|e| PairingError::Storage(e.to_string()))?;
56 Ok(Self { pool })
57 }
58
59 pub async fn insert_session(
60 &self,
61 token: &str,
62 profile: &str,
63 device_label: Option<&str>,
64 ttl: Duration,
65 ) -> Result<(), PairingError> {
66 let now = Utc::now();
67 let issued_at = now.timestamp();
68 let ttl_dur = chrono::Duration::from_std(ttl)
69 .map_err(|_| PairingError::Invalid("session ttl out of range"))?;
70 let expires_at = (now + ttl_dur).timestamp();
71 if let Err(e) = self.expire_sessions().await {
73 tracing::warn!(error = %e, "pairing session GC failed");
74 }
75 sqlx::query(
76 "INSERT OR REPLACE INTO pairing_sessions \
77 (token, profile, device_label, issued_at, expires_at) \
78 VALUES (?1, ?2, ?3, ?4, ?5)",
79 )
80 .bind(token)
81 .bind(profile)
82 .bind(device_label)
83 .bind(issued_at)
84 .bind(expires_at)
85 .execute(&self.pool)
86 .await
87 .map_err(|e| PairingError::Storage(e.to_string()))?;
88 Ok(())
89 }
90
91 pub async fn lookup_session(&self, token: &str) -> Result<Option<SessionRow>, PairingError> {
92 let now = Utc::now().timestamp();
93 let row: Option<(String, Option<String>, i64)> = sqlx::query_as(
94 "SELECT profile, device_label, expires_at \
95 FROM pairing_sessions WHERE token = ?1 AND expires_at >= ?2",
96 )
97 .bind(token)
98 .bind(now)
99 .fetch_optional(&self.pool)
100 .await
101 .map_err(|e| PairingError::Storage(e.to_string()))?;
102 row.map(|(profile, device_label, expires_at_ts)| {
103 let expires_at = DateTime::<Utc>::from_timestamp(expires_at_ts, 0).ok_or(
104 PairingError::Storage(format!("corrupt expires_at timestamp: {expires_at_ts}")),
105 )?;
106 Ok(SessionRow {
107 profile,
108 device_label,
109 expires_at,
110 })
111 })
112 .transpose()
113 }
114
115 pub async fn expire_sessions(&self) -> Result<usize, PairingError> {
116 let now = Utc::now().timestamp();
117 let result = sqlx::query("DELETE FROM pairing_sessions WHERE expires_at < ?1")
118 .bind(now)
119 .execute(&self.pool)
120 .await
121 .map_err(|e| PairingError::Storage(e.to_string()))?;
122 Ok(result.rows_affected() as usize)
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use std::time::Duration;
130
131 #[tokio::test]
132 async fn insert_and_lookup() {
133 let store = PairingSessionStore::open_memory().await.unwrap();
134 store
135 .insert_session(
136 "tok1",
137 "companion-v1",
138 Some("phone"),
139 Duration::from_secs(3600),
140 )
141 .await
142 .unwrap();
143 let row = store.lookup_session("tok1").await.unwrap().unwrap();
144 assert_eq!(row.profile, "companion-v1");
145 assert_eq!(row.device_label.as_deref(), Some("phone"));
146 }
147
148 #[tokio::test]
149 async fn unknown_token_returns_none() {
150 let store = PairingSessionStore::open_memory().await.unwrap();
151 assert!(store.lookup_session("nosuchtoken").await.unwrap().is_none());
152 }
153
154 #[tokio::test]
155 async fn expired_row_not_returned() {
156 let store = PairingSessionStore::open_memory().await.unwrap();
157 let issued_at = Utc::now().timestamp() - 10;
158 let expires_at = Utc::now().timestamp() - 1;
159 sqlx::query(
160 "INSERT INTO pairing_sessions \
161 (token, profile, device_label, issued_at, expires_at) \
162 VALUES ('expired_tok', 'companion-v1', NULL, ?1, ?2)",
163 )
164 .bind(issued_at)
165 .bind(expires_at)
166 .execute(&store.pool)
167 .await
168 .unwrap();
169 assert!(store.lookup_session("expired_tok").await.unwrap().is_none());
170 }
171
172 #[tokio::test]
173 async fn expire_sessions_clears_old_rows() {
174 let store = PairingSessionStore::open_memory().await.unwrap();
175 let issued_at = Utc::now().timestamp() - 10;
176 let expires_at = Utc::now().timestamp() - 1;
177 sqlx::query(
178 "INSERT INTO pairing_sessions \
179 (token, profile, device_label, issued_at, expires_at) \
180 VALUES ('old_tok', 'p', NULL, ?1, ?2)",
181 )
182 .bind(issued_at)
183 .bind(expires_at)
184 .execute(&store.pool)
185 .await
186 .unwrap();
187 let deleted = store.expire_sessions().await.unwrap();
188 assert_eq!(deleted, 1);
189 let count: i64 =
190 sqlx::query_scalar("SELECT COUNT(*) FROM pairing_sessions WHERE token = 'old_tok'")
191 .fetch_one(&store.pool)
192 .await
193 .unwrap();
194 assert_eq!(count, 0);
195 }
196}