1use base64::engine::general_purpose::URL_SAFE_NO_PAD;
2use base64::Engine;
3use chrono::{DateTime, Utc};
4use rand::RngCore;
5use sqlx::{Row, SqlitePool};
6use uuid::Uuid;
7
8use crate::error::{GuardError, GuardResult};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct ApiKeyRecord {
13 pub id: Uuid,
15 pub key_hash: String,
17 pub workspace_id: Uuid,
19 pub label: Option<String>,
21 pub created_at: DateTime<Utc>,
23 pub revoked: bool,
25 pub last_used_at: Option<DateTime<Utc>>,
27}
28
29#[derive(Clone)]
31pub struct ApiKeyManager {
32 pool: SqlitePool,
33}
34
35impl ApiKeyManager {
36 pub fn new(pool: SqlitePool) -> Self {
38 Self { pool }
39 }
40
41 pub async fn create_key(
43 &self,
44 workspace_id: Uuid,
45 label: &str,
46 ) -> GuardResult<(String, ApiKeyRecord)> {
47 let mut bytes = [0_u8; 32];
48 rand::thread_rng().fill_bytes(&mut bytes);
49 let raw_key = URL_SAFE_NO_PAD.encode(bytes);
50 let key_hash = blake3::hash(raw_key.as_bytes()).to_hex().to_string();
51 let id = Uuid::new_v4();
52 let created_at = Utc::now();
53
54 sqlx::query(
55 "INSERT INTO api_keys (id, key_hash, workspace_id, label, created_at, revoked, last_used_at)
56 VALUES (?1, ?2, ?3, ?4, ?5, 0, NULL)",
57 )
58 .bind(id.to_string())
59 .bind(&key_hash)
60 .bind(workspace_id.to_string())
61 .bind(if label.is_empty() { None::<String> } else { Some(label.to_owned()) })
62 .bind(created_at.timestamp_millis())
63 .execute(&self.pool)
64 .await?;
65
66 Ok((
67 raw_key,
68 ApiKeyRecord {
69 id,
70 key_hash,
71 workspace_id,
72 label: (!label.is_empty()).then(|| label.to_owned()),
73 created_at,
74 revoked: false,
75 last_used_at: None,
76 },
77 ))
78 }
79
80 pub async fn validate_key(&self, raw_key: &str) -> GuardResult<ApiKeyRecord> {
82 let candidate_hash = blake3::hash(raw_key.as_bytes());
83 let candidate_hex = candidate_hash.to_hex().to_string();
84 let row = sqlx::query(
85 "SELECT id, key_hash, workspace_id, label, created_at, revoked, last_used_at
86 FROM api_keys WHERE key_hash = ?1 LIMIT 1",
87 )
88 .bind(&candidate_hex)
89 .fetch_optional(&self.pool)
90 .await?
91 .ok_or(GuardError::InvalidToken)?;
92
93 if row.try_get::<i64, _>("revoked")? != 0 {
94 return Err(GuardError::InvalidToken);
95 }
96
97 let stored_hex: String = row.try_get("key_hash")?;
98 let stored_hash = blake3::Hash::from_hex(stored_hex.as_str()).map_err(|error| {
99 GuardError::ConfigError(format!("invalid stored key hash: {error}"))
100 })?;
101 if stored_hash != candidate_hash {
102 return Err(GuardError::InvalidToken);
103 }
104
105 let record = row_to_api_key_record(&row)?;
106 let pool = self.pool.clone();
107 let id = record.id;
108 tokio::spawn(async move {
109 let _ = sqlx::query("UPDATE api_keys SET last_used_at = ?1 WHERE id = ?2")
110 .bind(Utc::now().timestamp_millis())
111 .bind(id.to_string())
112 .execute(&pool)
113 .await;
114 });
115
116 Ok(record)
117 }
118
119 pub async fn revoke_key(&self, id: Uuid) -> GuardResult<()> {
121 sqlx::query("UPDATE api_keys SET revoked = 1 WHERE id = ?1")
122 .bind(id.to_string())
123 .execute(&self.pool)
124 .await?;
125 Ok(())
126 }
127
128 pub async fn list_keys(&self, workspace_id: Uuid) -> GuardResult<Vec<ApiKeyRecord>> {
130 let rows = sqlx::query(
131 "SELECT id, key_hash, workspace_id, label, created_at, revoked, last_used_at
132 FROM api_keys WHERE workspace_id = ?1 ORDER BY created_at DESC",
133 )
134 .bind(workspace_id.to_string())
135 .fetch_all(&self.pool)
136 .await?;
137
138 rows.iter().map(row_to_api_key_record).collect()
139 }
140}
141
142fn row_to_api_key_record(row: &sqlx::sqlite::SqliteRow) -> GuardResult<ApiKeyRecord> {
143 Ok(ApiKeyRecord {
144 id: Uuid::parse_str(&row.try_get::<String, _>("id")?)?,
145 key_hash: row.try_get("key_hash")?,
146 workspace_id: Uuid::parse_str(&row.try_get::<String, _>("workspace_id")?)?,
147 label: row.try_get("label")?,
148 created_at: from_ms(row.try_get("created_at")?)?,
149 revoked: row.try_get::<i64, _>("revoked")? != 0,
150 last_used_at: row
151 .try_get::<Option<i64>, _>("last_used_at")?
152 .map(from_ms)
153 .transpose()?,
154 })
155}
156
157fn from_ms(value: i64) -> GuardResult<DateTime<Utc>> {
158 DateTime::from_timestamp_millis(value)
159 .ok_or_else(|| GuardError::ConfigError(format!("invalid timestamp millis: {value}")))
160}