1use crate::error::{RegistryError, Result};
2use rusqlite::{Connection, params};
3use std::sync::{Arc, Mutex};
4
5#[derive(Clone)]
7pub struct Db {
8 pub(crate) conn: Arc<Mutex<Connection>>,
10}
11
12impl Db {
13 pub fn open(path: &str) -> Result<Self> {
14 let conn = if path == ":memory:" {
15 Connection::open_in_memory()?
16 } else {
17 Connection::open(path)?
18 };
19 let db = Db { conn: Arc::new(Mutex::new(conn)) };
20 db.migrate()?;
21 Ok(db)
22 }
23
24 fn migrate(&self) -> Result<()> {
25 let conn = self.conn.lock().unwrap();
26 conn.execute_batch(r#"
27 CREATE TABLE IF NOT EXISTS users (
28 id INTEGER PRIMARY KEY,
29 username TEXT UNIQUE NOT NULL,
30 password_hash TEXT NOT NULL,
31 created_at TEXT NOT NULL
32 );
33 CREATE TABLE IF NOT EXISTS tokens (
34 id INTEGER PRIMARY KEY,
35 token TEXT UNIQUE NOT NULL,
36 user_id INTEGER REFERENCES users(id),
37 name TEXT,
38 expires_at TEXT,
39 revoked INTEGER DEFAULT 0
40 );
41 CREATE TABLE IF NOT EXISTS layer_meta (
42 id INTEGER PRIMARY KEY,
43 namespace TEXT NOT NULL,
44 name TEXT NOT NULL,
45 version TEXT NOT NULL,
46 description TEXT,
47 tags TEXT,
48 pushed_by INTEGER REFERENCES users(id),
49 pushed_at TEXT NOT NULL,
50 UNIQUE(namespace, name, version)
51 );
52 "#)?;
53 Ok(())
54 }
55
56 pub fn create_user(&self, username: &str, password_hash: &str) -> Result<i64> {
59 let conn = self.conn.lock().unwrap();
60 let now = chrono::Utc::now().to_rfc3339();
61 conn.execute(
62 "INSERT INTO users (username, password_hash, created_at) VALUES (?1, ?2, ?3)",
63 params![username, password_hash, now],
64 )?;
65 Ok(conn.last_insert_rowid())
66 }
67
68 pub fn get_user_by_username(&self, username: &str) -> Result<Option<(i64, String)>> {
69 let conn = self.conn.lock().unwrap();
70 let mut stmt = conn.prepare(
71 "SELECT id, password_hash FROM users WHERE username = ?1"
72 )?;
73 let mut rows = stmt.query(params![username])?;
74 if let Some(row) = rows.next()? {
75 Ok(Some((row.get(0)?, row.get(1)?)))
76 } else {
77 Ok(None)
78 }
79 }
80
81 pub fn insert_token(&self, token: &str, user_id: Option<i64>, name: Option<&str>, expires_at: Option<&str>) -> Result<()> {
84 let conn = self.conn.lock().unwrap();
85 conn.execute(
86 "INSERT INTO tokens (token, user_id, name, expires_at) VALUES (?1, ?2, ?3, ?4)",
87 params![token, user_id, name, expires_at],
88 )?;
89 Ok(())
90 }
91
92 pub fn validate_token(&self, token: &str) -> Result<Option<(Option<i64>, Option<String>)>> {
94 let conn = self.conn.lock().unwrap();
95 let mut stmt = conn.prepare(
96 "SELECT user_id, name, expires_at, revoked FROM tokens WHERE token = ?1"
97 )?;
98 let mut rows = stmt.query(params![token])?;
99 if let Some(row) = rows.next()? {
100 let revoked: i32 = row.get(3)?;
101 if revoked != 0 {
102 return Ok(None);
103 }
104 let expires_at: Option<String> = row.get(2)?;
105 if let Some(ref exp) = expires_at {
106 let exp_time = chrono::DateTime::parse_from_rfc3339(exp)
107 .map_err(|e| RegistryError::Internal(e.to_string()))?;
108 if exp_time < chrono::Utc::now() {
109 return Ok(None); }
111 }
112 let user_id: Option<i64> = row.get(0)?;
113 let name: Option<String> = row.get(1)?;
114 Ok(Some((user_id, name)))
115 } else {
116 Ok(None)
117 }
118 }
119
120 pub fn layer_exists(&self, namespace: &str, name: &str, version: &str) -> Result<bool> {
123 let conn = self.conn.lock().unwrap();
124 let count: i64 = conn.query_row(
125 "SELECT COUNT(*) FROM layer_meta WHERE namespace=?1 AND name=?2 AND version=?3",
126 params![namespace, name, version],
127 |row| row.get(0),
128 )?;
129 Ok(count > 0)
130 }
131
132 pub fn insert_layer(&self, namespace: &str, name: &str, version: &str,
133 description: Option<&str>, tags: &[String], pushed_by: Option<i64>) -> Result<()> {
134 let conn = self.conn.lock().unwrap();
135 let now = chrono::Utc::now().to_rfc3339();
136 let tags_json = serde_json::to_string(tags).unwrap_or_else(|_| "[]".to_string());
137 conn.execute(
138 "INSERT INTO layer_meta (namespace, name, version, description, tags, pushed_by, pushed_at)
139 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
140 params![namespace, name, version, description, tags_json, pushed_by, now],
141 )?;
142 Ok(())
143 }
144
145 pub fn list_layers(&self) -> Result<Vec<LayerSummary>> {
146 let conn = self.conn.lock().unwrap();
147 let mut stmt = conn.prepare(
149 "SELECT namespace, name, version FROM layer_meta ORDER BY namespace, name, pushed_at ASC"
150 )?;
151 let rows = stmt.query_map([], |row| {
152 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?, row.get::<_, String>(2)?))
153 })?;
154
155 let mut map: std::collections::BTreeMap<(String, String), Vec<String>> = Default::default();
156 for row in rows {
157 let (ns, nm, ver) = row?;
158 map.entry((ns, nm)).or_default().push(ver);
159 }
160
161 Ok(map.into_iter().map(|((namespace, name), versions)| {
162 let latest = versions.last().cloned().unwrap_or_default();
164 LayerSummary { namespace, name, latest, versions }
165 }).collect())
166 }
167
168 pub fn search_layers(&self, query: &str) -> Result<Vec<LayerSummary>> {
169 let pattern = format!("%{}%", query.to_lowercase());
170 let conn = self.conn.lock().unwrap();
171 let mut stmt = conn.prepare(
172 "SELECT namespace, name, version FROM layer_meta
173 WHERE LOWER(name) LIKE ?1 OR LOWER(namespace) LIKE ?1 OR LOWER(description) LIKE ?1
174 ORDER BY namespace, name, pushed_at ASC"
175 )?;
176 let rows = stmt.query_map(params![pattern], |row| {
177 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?, row.get::<_, String>(2)?))
178 })?;
179
180 let mut map: std::collections::BTreeMap<(String, String), Vec<String>> = Default::default();
181 for row in rows {
182 let (ns, nm, ver) = row?;
183 map.entry((ns, nm)).or_default().push(ver);
184 }
185
186 Ok(map.into_iter().map(|((namespace, name), versions)| {
187 let latest = versions.last().cloned().unwrap_or_default();
188 LayerSummary { namespace, name, latest, versions }
189 }).collect())
190 }
191
192 pub fn get_versions(&self, namespace: &str, name: &str) -> Result<Vec<String>> {
193 let conn = self.conn.lock().unwrap();
194 let mut stmt = conn.prepare(
195 "SELECT version FROM layer_meta WHERE namespace=?1 AND name=?2 ORDER BY version"
196 )?;
197 let rows = stmt.query_map(params![namespace, name], |row| row.get(0))?;
198 Ok(rows.collect::<rusqlite::Result<Vec<String>>>()?)
199 }
200
201 pub fn get_stats(&self) -> Result<RegistryStats> {
202 let conn = self.conn.lock().unwrap();
203 let total_layers: i64 = conn.query_row(
204 "SELECT COUNT(DISTINCT namespace || '/' || name) FROM layer_meta", [], |r| r.get(0))?;
205 let total_versions: i64 = conn.query_row(
206 "SELECT COUNT(*) FROM layer_meta", [], |r| r.get(0))?;
207 let namespaces: i64 = conn.query_row(
208 "SELECT COUNT(DISTINCT namespace) FROM layer_meta", [], |r| r.get(0))?;
209 Ok(RegistryStats { total_layers, total_versions, namespaces })
210 }
211}
212
213#[derive(Debug, serde::Serialize)]
214pub struct RegistryStats {
215 pub total_layers: i64,
216 pub total_versions: i64,
217 pub namespaces: i64,
218}
219
220#[derive(Debug, serde::Serialize)]
221pub struct LayerSummary {
222 pub namespace: String,
223 pub name: String,
224 pub latest: String,
225 pub versions: Vec<String>,
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 fn test_db() -> Db { Db::open(":memory:").unwrap() }
233
234 #[test]
235 fn test_create_and_get_user() {
236 let db = test_db();
237 db.create_user("alice", "hashed_pw").unwrap();
238 let user = db.get_user_by_username("alice").unwrap();
239 assert!(user.is_some());
240 let (_, hash) = user.unwrap();
241 assert_eq!(hash, "hashed_pw");
242 }
243
244 #[test]
245 fn test_get_unknown_user_returns_none() {
246 let db = test_db();
247 assert!(db.get_user_by_username("nobody").unwrap().is_none());
248 }
249
250 #[test]
251 fn test_token_valid() {
252 let db = test_db();
253 db.insert_token("phrt_test", None, Some("ci"), None).unwrap();
254 let result = db.validate_token("phrt_test").unwrap();
255 assert!(result.is_some());
256 }
257
258 #[test]
259 fn test_token_unknown_returns_none() {
260 let db = test_db();
261 assert!(db.validate_token("phrt_unknown").unwrap().is_none());
262 }
263
264 #[test]
265 fn test_token_revoked_is_invalid() {
266 let db = test_db();
267 db.insert_token("phrt_revoke_me", None, Some("ci"), None).unwrap();
268 let conn = db.conn.lock().unwrap();
270 conn.execute("UPDATE tokens SET revoked = 1 WHERE token = 'phrt_revoke_me'", []).unwrap();
271 drop(conn);
272 assert!(db.validate_token("phrt_revoke_me").unwrap().is_none());
273 }
274
275 #[test]
276 fn test_token_expired_is_invalid() {
277 let db = test_db();
278 let past = "2000-01-01T00:00:00Z";
280 db.insert_token("phrt_old", None, Some("ci"), Some(past)).unwrap();
281 assert!(db.validate_token("phrt_old").unwrap().is_none());
282 }
283
284 #[test]
285 fn test_layer_exists_after_insert() {
286 let db = test_db();
287 assert!(!db.layer_exists("base", "expert", "v1.0").unwrap());
288 db.insert_layer("base", "expert", "v1.0", Some("desc"), &[], None).unwrap();
289 assert!(db.layer_exists("base", "expert", "v1.0").unwrap());
290 }
291
292 #[test]
293 fn test_list_layers() {
294 let db = test_db();
295 db.insert_layer("base", "expert", "v1.0", Some("desc"), &[], None).unwrap();
296 db.insert_layer("base", "expert", "v2.0", Some("desc"), &[], None).unwrap();
297 db.insert_layer("style", "concise", "v1.0", Some("desc"), &[], None).unwrap();
298 let layers = db.list_layers().unwrap();
299 assert_eq!(layers.len(), 2);
300 assert_eq!(layers[0].name, "expert");
302 assert_eq!(layers[0].versions.len(), 2);
303 assert!(layers[0].versions.contains(&"v1.0".to_string()));
304 assert!(layers[0].versions.contains(&"v2.0".to_string()));
305 }
306
307 #[test]
308 fn test_search_layers() {
309 let db = test_db();
310 db.insert_layer("base", "code-reviewer", "v1.0", Some("reviews code"), &[], None).unwrap();
311 db.insert_layer("style", "concise", "v1.0", Some("brief output"), &[], None).unwrap();
312 let results = db.search_layers("code").unwrap();
313 assert_eq!(results.len(), 1);
314 assert_eq!(results[0].name, "code-reviewer");
315 }
316
317 #[test]
318 fn test_get_versions() {
319 let db = test_db();
320 db.insert_layer("base", "expert", "v1.0", None, &[], None).unwrap();
321 db.insert_layer("base", "expert", "v2.0", None, &[], None).unwrap();
322 let versions = db.get_versions("base", "expert").unwrap();
323 assert_eq!(versions, vec!["v1.0", "v2.0"]);
324 }
325}