Skip to main content

ph_registry/
db.rs

1use crate::error::{RegistryError, Result};
2use rusqlite::{Connection, params};
3use std::sync::{Arc, Mutex};
4
5/// Thread-safe SQLite connection pool (single connection with Mutex for simplicity).
6#[derive(Clone)]
7pub struct Db {
8    /// pub(crate) for test helpers that need to inject raw SQL (e.g. simulate revoked tokens).
9    pub(crate) conn: Arc<Mutex<Connection>>,
10}
11
12impl Db {
13    pub fn open(path: &str) -> Result<Self> {
14        let conn = if path == ":memory:" {
15            Connection::open_in_memory()?
16        } else {
17            Connection::open(path)?
18        };
19        let db = Db { conn: Arc::new(Mutex::new(conn)) };
20        db.migrate()?;
21        Ok(db)
22    }
23
24    fn migrate(&self) -> Result<()> {
25        let conn = self.conn.lock().unwrap();
26        conn.execute_batch(r#"
27            CREATE TABLE IF NOT EXISTS users (
28                id            INTEGER PRIMARY KEY,
29                username      TEXT UNIQUE NOT NULL,
30                password_hash TEXT NOT NULL,
31                created_at    TEXT NOT NULL
32            );
33            CREATE TABLE IF NOT EXISTS tokens (
34                id         INTEGER PRIMARY KEY,
35                token      TEXT UNIQUE NOT NULL,
36                user_id    INTEGER REFERENCES users(id),
37                name       TEXT,
38                expires_at TEXT,
39                revoked    INTEGER DEFAULT 0
40            );
41            CREATE TABLE IF NOT EXISTS layer_meta (
42                id          INTEGER PRIMARY KEY,
43                namespace   TEXT NOT NULL,
44                name        TEXT NOT NULL,
45                version     TEXT NOT NULL,
46                description TEXT,
47                tags        TEXT,
48                pushed_by   INTEGER REFERENCES users(id),
49                pushed_at   TEXT NOT NULL,
50                UNIQUE(namespace, name, version)
51            );
52        "#)?;
53        Ok(())
54    }
55
56    // ── users ──────────────────────────────────────────────────────────────
57
58    pub fn create_user(&self, username: &str, password_hash: &str) -> Result<i64> {
59        let conn = self.conn.lock().unwrap();
60        let now = chrono::Utc::now().to_rfc3339();
61        conn.execute(
62            "INSERT INTO users (username, password_hash, created_at) VALUES (?1, ?2, ?3)",
63            params![username, password_hash, now],
64        )?;
65        Ok(conn.last_insert_rowid())
66    }
67
68    pub fn get_user_by_username(&self, username: &str) -> Result<Option<(i64, String)>> {
69        let conn = self.conn.lock().unwrap();
70        let mut stmt = conn.prepare(
71            "SELECT id, password_hash FROM users WHERE username = ?1"
72        )?;
73        let mut rows = stmt.query(params![username])?;
74        if let Some(row) = rows.next()? {
75            Ok(Some((row.get(0)?, row.get(1)?)))
76        } else {
77            Ok(None)
78        }
79    }
80
81    // ── tokens ─────────────────────────────────────────────────────────────
82
83    pub fn insert_token(&self, token: &str, user_id: Option<i64>, name: Option<&str>, expires_at: Option<&str>) -> Result<()> {
84        let conn = self.conn.lock().unwrap();
85        conn.execute(
86            "INSERT INTO tokens (token, user_id, name, expires_at) VALUES (?1, ?2, ?3, ?4)",
87            params![token, user_id, name, expires_at],
88        )?;
89        Ok(())
90    }
91
92    /// Returns (user_id, name) if token is valid (not revoked, not expired).
93    pub fn validate_token(&self, token: &str) -> Result<Option<(Option<i64>, Option<String>)>> {
94        let conn = self.conn.lock().unwrap();
95        let mut stmt = conn.prepare(
96            "SELECT user_id, name, expires_at, revoked FROM tokens WHERE token = ?1"
97        )?;
98        let mut rows = stmt.query(params![token])?;
99        if let Some(row) = rows.next()? {
100            let revoked: i32 = row.get(3)?;
101            if revoked != 0 {
102                return Ok(None);
103            }
104            let expires_at: Option<String> = row.get(2)?;
105            if let Some(ref exp) = expires_at {
106                let exp_time = chrono::DateTime::parse_from_rfc3339(exp)
107                    .map_err(|e| RegistryError::Internal(e.to_string()))?;
108                if exp_time < chrono::Utc::now() {
109                    return Ok(None); // expired
110                }
111            }
112            let user_id: Option<i64> = row.get(0)?;
113            let name: Option<String> = row.get(1)?;
114            Ok(Some((user_id, name)))
115        } else {
116            Ok(None)
117        }
118    }
119
120    // ── layer_meta ─────────────────────────────────────────────────────────
121
122    pub fn layer_exists(&self, namespace: &str, name: &str, version: &str) -> Result<bool> {
123        let conn = self.conn.lock().unwrap();
124        let count: i64 = conn.query_row(
125            "SELECT COUNT(*) FROM layer_meta WHERE namespace=?1 AND name=?2 AND version=?3",
126            params![namespace, name, version],
127            |row| row.get(0),
128        )?;
129        Ok(count > 0)
130    }
131
132    pub fn insert_layer(&self, namespace: &str, name: &str, version: &str,
133                         description: Option<&str>, tags: &[String], pushed_by: Option<i64>) -> Result<()> {
134        let conn = self.conn.lock().unwrap();
135        let now = chrono::Utc::now().to_rfc3339();
136        let tags_json = serde_json::to_string(tags).unwrap_or_else(|_| "[]".to_string());
137        conn.execute(
138            "INSERT INTO layer_meta (namespace, name, version, description, tags, pushed_by, pushed_at)
139             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
140            params![namespace, name, version, description, tags_json, pushed_by, now],
141        )?;
142        Ok(())
143    }
144
145    pub fn list_layers(&self) -> Result<Vec<LayerSummary>> {
146        let conn = self.conn.lock().unwrap();
147        // Order by pushed_at ASC so the last entry per (namespace, name) is the most recently pushed
148        let mut stmt = conn.prepare(
149            "SELECT namespace, name, version FROM layer_meta ORDER BY namespace, name, pushed_at ASC"
150        )?;
151        let rows = stmt.query_map([], |row| {
152            Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?, row.get::<_, String>(2)?))
153        })?;
154
155        let mut map: std::collections::BTreeMap<(String, String), Vec<String>> = Default::default();
156        for row in rows {
157            let (ns, nm, ver) = row?;
158            map.entry((ns, nm)).or_default().push(ver);
159        }
160
161        Ok(map.into_iter().map(|((namespace, name), versions)| {
162            // `versions` is in ascending pushed_at order; last() is the most recently pushed
163            let latest = versions.last().cloned().unwrap_or_default();
164            LayerSummary { namespace, name, latest, versions }
165        }).collect())
166    }
167
168    pub fn search_layers(&self, query: &str) -> Result<Vec<LayerSummary>> {
169        let pattern = format!("%{}%", query.to_lowercase());
170        let conn = self.conn.lock().unwrap();
171        let mut stmt = conn.prepare(
172            "SELECT namespace, name, version FROM layer_meta
173             WHERE LOWER(name) LIKE ?1 OR LOWER(namespace) LIKE ?1 OR LOWER(description) LIKE ?1
174             ORDER BY namespace, name, pushed_at ASC"
175        )?;
176        let rows = stmt.query_map(params![pattern], |row| {
177            Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?, row.get::<_, String>(2)?))
178        })?;
179
180        let mut map: std::collections::BTreeMap<(String, String), Vec<String>> = Default::default();
181        for row in rows {
182            let (ns, nm, ver) = row?;
183            map.entry((ns, nm)).or_default().push(ver);
184        }
185
186        Ok(map.into_iter().map(|((namespace, name), versions)| {
187            let latest = versions.last().cloned().unwrap_or_default();
188            LayerSummary { namespace, name, latest, versions }
189        }).collect())
190    }
191
192    pub fn get_versions(&self, namespace: &str, name: &str) -> Result<Vec<String>> {
193        let conn = self.conn.lock().unwrap();
194        let mut stmt = conn.prepare(
195            "SELECT version FROM layer_meta WHERE namespace=?1 AND name=?2 ORDER BY version"
196        )?;
197        let rows = stmt.query_map(params![namespace, name], |row| row.get(0))?;
198        Ok(rows.collect::<rusqlite::Result<Vec<String>>>()?)
199    }
200
201    pub fn get_stats(&self) -> Result<RegistryStats> {
202        let conn = self.conn.lock().unwrap();
203        let total_layers: i64 = conn.query_row(
204            "SELECT COUNT(DISTINCT namespace || '/' || name) FROM layer_meta", [], |r| r.get(0))?;
205        let total_versions: i64 = conn.query_row(
206            "SELECT COUNT(*) FROM layer_meta", [], |r| r.get(0))?;
207        let namespaces: i64 = conn.query_row(
208            "SELECT COUNT(DISTINCT namespace) FROM layer_meta", [], |r| r.get(0))?;
209        Ok(RegistryStats { total_layers, total_versions, namespaces })
210    }
211}
212
213#[derive(Debug, serde::Serialize)]
214pub struct RegistryStats {
215    pub total_layers: i64,
216    pub total_versions: i64,
217    pub namespaces: i64,
218}
219
220#[derive(Debug, serde::Serialize)]
221pub struct LayerSummary {
222    pub namespace: String,
223    pub name: String,
224    pub latest: String,
225    pub versions: Vec<String>,
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    fn test_db() -> Db { Db::open(":memory:").unwrap() }
233
234    #[test]
235    fn test_create_and_get_user() {
236        let db = test_db();
237        db.create_user("alice", "hashed_pw").unwrap();
238        let user = db.get_user_by_username("alice").unwrap();
239        assert!(user.is_some());
240        let (_, hash) = user.unwrap();
241        assert_eq!(hash, "hashed_pw");
242    }
243
244    #[test]
245    fn test_get_unknown_user_returns_none() {
246        let db = test_db();
247        assert!(db.get_user_by_username("nobody").unwrap().is_none());
248    }
249
250    #[test]
251    fn test_token_valid() {
252        let db = test_db();
253        db.insert_token("phrt_test", None, Some("ci"), None).unwrap();
254        let result = db.validate_token("phrt_test").unwrap();
255        assert!(result.is_some());
256    }
257
258    #[test]
259    fn test_token_unknown_returns_none() {
260        let db = test_db();
261        assert!(db.validate_token("phrt_unknown").unwrap().is_none());
262    }
263
264    #[test]
265    fn test_token_revoked_is_invalid() {
266        let db = test_db();
267        db.insert_token("phrt_revoke_me", None, Some("ci"), None).unwrap();
268        // Manually revoke by direct SQL update
269        let conn = db.conn.lock().unwrap();
270        conn.execute("UPDATE tokens SET revoked = 1 WHERE token = 'phrt_revoke_me'", []).unwrap();
271        drop(conn);
272        assert!(db.validate_token("phrt_revoke_me").unwrap().is_none());
273    }
274
275    #[test]
276    fn test_token_expired_is_invalid() {
277        let db = test_db();
278        // Insert a token that expired in the past
279        let past = "2000-01-01T00:00:00Z";
280        db.insert_token("phrt_old", None, Some("ci"), Some(past)).unwrap();
281        assert!(db.validate_token("phrt_old").unwrap().is_none());
282    }
283
284    #[test]
285    fn test_layer_exists_after_insert() {
286        let db = test_db();
287        assert!(!db.layer_exists("base", "expert", "v1.0").unwrap());
288        db.insert_layer("base", "expert", "v1.0", Some("desc"), &[], None).unwrap();
289        assert!(db.layer_exists("base", "expert", "v1.0").unwrap());
290    }
291
292    #[test]
293    fn test_list_layers() {
294        let db = test_db();
295        db.insert_layer("base", "expert", "v1.0", Some("desc"), &[], None).unwrap();
296        db.insert_layer("base", "expert", "v2.0", Some("desc"), &[], None).unwrap();
297        db.insert_layer("style", "concise", "v1.0", Some("desc"), &[], None).unwrap();
298        let layers = db.list_layers().unwrap();
299        assert_eq!(layers.len(), 2);
300        // BTreeMap sorts alphabetically: "base/expert" before "style/concise"
301        assert_eq!(layers[0].name, "expert");
302        assert_eq!(layers[0].versions.len(), 2);
303        assert!(layers[0].versions.contains(&"v1.0".to_string()));
304        assert!(layers[0].versions.contains(&"v2.0".to_string()));
305    }
306
307    #[test]
308    fn test_search_layers() {
309        let db = test_db();
310        db.insert_layer("base", "code-reviewer", "v1.0", Some("reviews code"), &[], None).unwrap();
311        db.insert_layer("style", "concise", "v1.0", Some("brief output"), &[], None).unwrap();
312        let results = db.search_layers("code").unwrap();
313        assert_eq!(results.len(), 1);
314        assert_eq!(results[0].name, "code-reviewer");
315    }
316
317    #[test]
318    fn test_get_versions() {
319        let db = test_db();
320        db.insert_layer("base", "expert", "v1.0", None, &[], None).unwrap();
321        db.insert_layer("base", "expert", "v2.0", None, &[], None).unwrap();
322        let versions = db.get_versions("base", "expert").unwrap();
323        assert_eq!(versions, vec!["v1.0", "v2.0"]);
324    }
325}