1use super::{AppState, client_key_from_request};
4use axum::{
5 extract::{ConnectInfo, State},
6 http::{HeaderMap, StatusCode, header},
7 response::{IntoResponse, Json},
8};
9use chrono::{DateTime, Utc};
10use parking_lot::Mutex;
11use rusqlite::Connection;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::net::SocketAddr;
15use std::path::{Path, PathBuf};
16use tracing::{debug, error, info, warn};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct DeviceInfo {
21 pub id: String,
22 pub name: Option<String>,
23 pub device_type: Option<String>,
24 pub paired_at: DateTime<Utc>,
25 pub last_seen: DateTime<Utc>,
26 pub ip_address: Option<String>,
27}
28
29#[derive(Debug)]
31pub struct DeviceRegistry {
32 cache: Mutex<HashMap<String, DeviceInfo>>,
33 db_path: PathBuf,
34}
35
36impl DeviceRegistry {
37 pub fn new(workspace_dir: &Path) -> anyhow::Result<Self> {
42 use anyhow::Context;
43
44 let db_path = workspace_dir.join("devices.db");
45 let conn = Connection::open(&db_path)
46 .with_context(|| format!("open device registry DB at {}", db_path.display()))?;
47 conn.execute_batch(
48 "CREATE TABLE IF NOT EXISTS devices (
49 token_hash TEXT PRIMARY KEY,
50 id TEXT NOT NULL,
51 name TEXT,
52 device_type TEXT,
53 paired_at TEXT NOT NULL,
54 last_seen TEXT NOT NULL,
55 ip_address TEXT
56 )",
57 )
58 .context("create devices table")?;
59
60 let mut cache = HashMap::new();
61 let mut stmt = conn
62 .prepare("SELECT token_hash, id, name, device_type, paired_at, last_seen, ip_address FROM devices")
63 .context("prepare device select")?;
64 let rows = stmt
65 .query_map([], |row| {
66 let token_hash: String = row.get(0)?;
67 let id: String = row.get(1)?;
68 let name: Option<String> = row.get(2)?;
69 let device_type: Option<String> = row.get(3)?;
70 let paired_at_str: String = row.get(4)?;
71 let last_seen_str: String = row.get(5)?;
72 let ip_address: Option<String> = row.get(6)?;
73 let paired_at = DateTime::parse_from_rfc3339(&paired_at_str)
74 .map(|dt| dt.with_timezone(&Utc))
75 .unwrap_or_else(|_| Utc::now());
76 let last_seen = DateTime::parse_from_rfc3339(&last_seen_str)
77 .map(|dt| dt.with_timezone(&Utc))
78 .unwrap_or_else(|_| Utc::now());
79 Ok((
80 token_hash,
81 DeviceInfo {
82 id,
83 name,
84 device_type,
85 paired_at,
86 last_seen,
87 ip_address,
88 },
89 ))
90 })
91 .context("query devices")?;
92 for (hash, info) in rows.flatten() {
93 cache.insert(hash, info);
94 }
95
96 Ok(Self {
97 cache: Mutex::new(cache),
98 db_path,
99 })
100 }
101
102 fn open_db(&self) -> anyhow::Result<Connection> {
103 use anyhow::Context;
104 Connection::open(&self.db_path)
105 .with_context(|| format!("open device registry DB at {}", self.db_path.display()))
106 }
107
108 pub fn register(&self, token_hash: String, info: DeviceInfo) -> anyhow::Result<()> {
109 use anyhow::Context;
110 let conn = self.open_db()?;
111 let device_id = info.id.clone();
112 conn.execute(
113 "INSERT OR REPLACE INTO devices (token_hash, id, name, device_type, paired_at, last_seen, ip_address) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
114 rusqlite::params![
115 token_hash,
116 info.id,
117 info.name,
118 info.device_type,
119 info.paired_at.to_rfc3339(),
120 info.last_seen.to_rfc3339(),
121 info.ip_address,
122 ],
123 )
124 .context("insert device row")?;
125 let hash_prefix: String = token_hash.chars().take(8).collect();
126 self.cache.lock().insert(token_hash, info);
127 info!(device_id = %device_id, token_hash_prefix = %hash_prefix, "device registered in SQLite");
128 Ok(())
129 }
130
131 pub fn list(&self) -> Vec<DeviceInfo> {
132 let conn = match self.open_db() {
135 Ok(c) => c,
136 Err(e) => {
137 warn!(error = %e, "device registry list: open_db failed — returning empty list");
138 return Vec::new();
139 }
140 };
141 let mut stmt = match conn.prepare(
142 "SELECT token_hash, id, name, device_type, paired_at, last_seen, ip_address FROM devices",
143 ) {
144 Ok(s) => s,
145 Err(e) => {
146 warn!(error = %e, "device registry list: prepare failed — returning empty list");
147 return Vec::new();
148 }
149 };
150 let rows = match stmt.query_map([], |row| {
151 let id: String = row.get(1)?;
152 let name: Option<String> = row.get(2)?;
153 let device_type: Option<String> = row.get(3)?;
154 let paired_at_str: String = row.get(4)?;
155 let last_seen_str: String = row.get(5)?;
156 let ip_address: Option<String> = row.get(6)?;
157 let paired_at = DateTime::parse_from_rfc3339(&paired_at_str)
158 .map(|dt| dt.with_timezone(&Utc))
159 .unwrap_or_else(|_| Utc::now());
160 let last_seen = DateTime::parse_from_rfc3339(&last_seen_str)
161 .map(|dt| dt.with_timezone(&Utc))
162 .unwrap_or_else(|_| Utc::now());
163 Ok(DeviceInfo {
164 id,
165 name,
166 device_type,
167 paired_at,
168 last_seen,
169 ip_address,
170 })
171 }) {
172 Ok(r) => r,
173 Err(e) => {
174 warn!(error = %e, "device registry list: query_map failed — returning empty list");
175 return Vec::new();
176 }
177 };
178 rows.filter_map(|r| r.ok()).collect()
179 }
180
181 pub fn revoke(&self, device_id: &str) -> bool {
182 let conn = match self.open_db() {
183 Ok(c) => c,
184 Err(e) => {
185 warn!(error = %e, "device registry revoke: open_db failed");
186 return false;
187 }
188 };
189 let deleted = conn
190 .execute(
191 "DELETE FROM devices WHERE id = ?1",
192 rusqlite::params![device_id],
193 )
194 .unwrap_or(0);
195 if deleted > 0 {
196 let mut cache = self.cache.lock();
197 let key = cache
198 .iter()
199 .find(|(_, v)| v.id == device_id)
200 .map(|(k, _)| k.clone());
201 if let Some(key) = key {
202 cache.remove(&key);
203 }
204 true
205 } else {
206 false
207 }
208 }
209
210 pub fn update_last_seen(&self, token_hash: &str) {
211 let now = Utc::now();
212 if let Ok(conn) = self.open_db() {
213 let _ = conn.execute(
214 "UPDATE devices SET last_seen = ?1 WHERE token_hash = ?2",
215 rusqlite::params![now.to_rfc3339(), token_hash],
216 );
217 }
218 if let Some(device) = self.cache.lock().get_mut(token_hash) {
219 device.last_seen = now;
220 }
221 }
222
223 pub fn device_count(&self) -> usize {
224 self.cache.lock().len()
225 }
226}
227
228#[derive(Debug)]
230pub struct PairingStore {
231 pending: Mutex<Vec<PendingPairing>>,
232 max_pending: usize,
233}
234
235#[derive(Debug, Clone, Serialize)]
236struct PendingPairing {
237 code: String,
238 created_at: DateTime<Utc>,
239 expires_at: DateTime<Utc>,
240 client_ip: Option<String>,
241 attempts: u32,
242}
243
244impl PairingStore {
245 pub fn new(max_pending: usize) -> Self {
246 Self {
247 pending: Mutex::new(Vec::new()),
248 max_pending,
249 }
250 }
251
252 pub fn pending_count(&self) -> usize {
253 let mut pending = self.pending.lock();
254 pending.retain(|p| p.expires_at > Utc::now());
255 pending.len()
256 }
257}
258
259fn extract_bearer(headers: &HeaderMap) -> Option<&str> {
260 headers
261 .get(header::AUTHORIZATION)
262 .and_then(|v| v.to_str().ok())
263 .and_then(|auth| auth.strip_prefix("Bearer "))
264}
265
266fn require_auth(state: &AppState, headers: &HeaderMap) -> Result<(), (StatusCode, &'static str)> {
267 if state.pairing.require_pairing() {
268 let token = extract_bearer(headers).unwrap_or("");
269 if !state.pairing.is_authenticated(token) {
270 return Err((StatusCode::UNAUTHORIZED, "Unauthorized"));
271 }
272 }
273 Ok(())
274}
275
276pub async fn initiate_pairing(
278 State(state): State<AppState>,
279 headers: HeaderMap,
280) -> impl IntoResponse {
281 if let Err(e) = require_auth(&state, &headers) {
282 warn!("initiate_pairing: unauthorized request");
283 return e.into_response();
284 }
285
286 info!("initiate_pairing: generating new pairing code");
287 match state.pairing.generate_new_pairing_code() {
288 Some(code) => {
289 let code_prefix: String = code.chars().take(2).collect();
290 info!(code_prefix = %code_prefix, "initiate_pairing: code generated");
291 Json(serde_json::json!({
292 "pairing_code": code,
293 "message": "New pairing code generated"
294 }))
295 .into_response()
296 }
297 None => {
298 warn!("initiate_pairing: pairing disabled or unavailable");
299 (
300 StatusCode::SERVICE_UNAVAILABLE,
301 "Pairing is disabled or not available",
302 )
303 .into_response()
304 }
305 }
306}
307
308pub async fn submit_pairing_enhanced(
310 State(state): State<AppState>,
311 ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
312 headers: HeaderMap,
313 Json(body): Json<serde_json::Value>,
314) -> impl IntoResponse {
315 let code = body["code"].as_str().unwrap_or("");
316 let device_name = body["device_name"].as_str().map(String::from);
317 let device_type = body["device_type"].as_str().map(String::from);
318
319 let client_id =
323 client_key_from_request(Some(peer_addr), &headers, state.trust_forwarded_headers);
324
325 info!(
326 client_id = %client_id,
327 code_len = code.len(),
328 device_name = ?device_name,
329 device_type = ?device_type,
330 "submit_pairing_enhanced: received pair request"
331 );
332
333 match state.pairing.try_pair(code, &client_id).await {
334 Ok(Some(token)) => {
335 let token_hash = {
337 use sha2::{Digest, Sha256};
338 let hash = Sha256::digest(token.as_bytes());
339 hex::encode(hash)
340 };
341 let hash_prefix: String = token_hash.chars().take(8).collect();
342 info!(
343 client_id = %client_id,
344 token_hash_prefix = %hash_prefix,
345 "submit_pairing_enhanced: pairing succeeded, registering device"
346 );
347 if let Some(ref registry) = state.device_registry {
348 if let Err(e) = registry.register(
349 token_hash,
350 DeviceInfo {
351 id: uuid::Uuid::new_v4().to_string(),
352 name: device_name,
353 device_type,
354 paired_at: Utc::now(),
355 last_seen: Utc::now(),
356 ip_address: Some(client_id.clone()),
357 },
358 ) {
359 error!(
360 client_id = %client_id,
361 error = %e,
362 "submit_pairing_enhanced: device registry insert failed"
363 );
364 return (
365 StatusCode::INTERNAL_SERVER_ERROR,
366 "Pairing succeeded but device registration failed",
367 )
368 .into_response();
369 }
370 } else {
371 debug!("submit_pairing_enhanced: no device_registry configured; skipping persist");
372 }
373 Json(serde_json::json!({
374 "token": token,
375 "message": "Pairing successful"
376 }))
377 .into_response()
378 }
379 Ok(None) => {
380 warn!(client_id = %client_id, "submit_pairing_enhanced: invalid or expired code");
381 (StatusCode::BAD_REQUEST, "Invalid or expired pairing code").into_response()
382 }
383 Err(lockout_secs) => {
384 warn!(
385 client_id = %client_id,
386 lockout_secs,
387 "submit_pairing_enhanced: client locked out"
388 );
389 (
390 StatusCode::TOO_MANY_REQUESTS,
391 format!("Too many attempts. Locked out for {lockout_secs}s"),
392 )
393 .into_response()
394 }
395 }
396}
397
398pub async fn list_devices(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse {
400 if let Err(e) = require_auth(&state, &headers) {
401 return e.into_response();
402 }
403
404 let devices = state
405 .device_registry
406 .as_ref()
407 .map(|r| r.list())
408 .unwrap_or_default();
409
410 let count = devices.len();
411 Json(serde_json::json!({
412 "devices": devices,
413 "count": count
414 }))
415 .into_response()
416}
417
418pub async fn revoke_device(
420 State(state): State<AppState>,
421 headers: HeaderMap,
422 axum::extract::Path(device_id): axum::extract::Path<String>,
423) -> impl IntoResponse {
424 if let Err(e) = require_auth(&state, &headers) {
425 return e.into_response();
426 }
427
428 let revoked = state
429 .device_registry
430 .as_ref()
431 .map(|r| r.revoke(&device_id))
432 .unwrap_or(false);
433
434 if revoked {
435 Json(serde_json::json!({
436 "message": "Device revoked",
437 "device_id": device_id
438 }))
439 .into_response()
440 } else {
441 (StatusCode::NOT_FOUND, "Device not found").into_response()
442 }
443}
444
445pub async fn rotate_token(
447 State(state): State<AppState>,
448 headers: HeaderMap,
449 axum::extract::Path(device_id): axum::extract::Path<String>,
450) -> impl IntoResponse {
451 if let Err(e) = require_auth(&state, &headers) {
452 return e.into_response();
453 }
454
455 match state.pairing.generate_new_pairing_code() {
457 Some(code) => Json(serde_json::json!({
458 "device_id": device_id,
459 "pairing_code": code,
460 "message": "Use this code to re-pair the device"
461 }))
462 .into_response(),
463 None => (
464 StatusCode::SERVICE_UNAVAILABLE,
465 "Cannot generate new pairing code",
466 )
467 .into_response(),
468 }
469}