1use crate::data::{Dataset, DatasetVersion};
4use crate::error::{PachaError, Result};
5use crate::experiment::{ExperimentRun, RunId};
6use crate::model::{Model, ModelId, ModelStage, ModelVersion};
7use crate::recipe::{RecipeReference, RecipeVersion, TrainingRecipe};
8use crate::storage::ContentAddress;
9use rusqlite::{params, Connection};
10use std::path::Path;
11
12pub struct RegistryDb {
14 conn: Connection,
15}
16
17impl RegistryDb {
18 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
24 let conn = Connection::open(path)?;
25 let db = Self { conn };
26 db.init_schema()?;
27 Ok(db)
28 }
29
30 fn init_schema(&self) -> Result<()> {
32 self.conn.execute_batch(
33 r"
34 -- Models
35 CREATE TABLE IF NOT EXISTS models (
36 id TEXT PRIMARY KEY,
37 name TEXT NOT NULL,
38 version TEXT NOT NULL,
39 content_hash TEXT NOT NULL,
40 content_size INTEGER NOT NULL,
41 card_json TEXT NOT NULL,
42 stage TEXT DEFAULT 'development',
43 created_at TEXT NOT NULL,
44 updated_at TEXT NOT NULL,
45 UNIQUE(name, version)
46 );
47
48 CREATE INDEX IF NOT EXISTS idx_models_name ON models(name);
49 CREATE INDEX IF NOT EXISTS idx_models_stage ON models(stage);
50
51 -- Datasets
52 CREATE TABLE IF NOT EXISTS datasets (
53 id TEXT PRIMARY KEY,
54 name TEXT NOT NULL,
55 version TEXT NOT NULL,
56 content_hash TEXT NOT NULL,
57 content_size INTEGER NOT NULL,
58 datasheet_json TEXT NOT NULL,
59 created_at TEXT NOT NULL,
60 UNIQUE(name, version)
61 );
62
63 CREATE INDEX IF NOT EXISTS idx_datasets_name ON datasets(name);
64
65 -- Recipes
66 CREATE TABLE IF NOT EXISTS recipes (
67 id TEXT PRIMARY KEY,
68 name TEXT NOT NULL,
69 version TEXT NOT NULL,
70 recipe_json TEXT NOT NULL,
71 created_at TEXT NOT NULL,
72 UNIQUE(name, version)
73 );
74
75 CREATE INDEX IF NOT EXISTS idx_recipes_name ON recipes(name);
76
77 -- Experiment Runs
78 CREATE TABLE IF NOT EXISTS runs (
79 id TEXT PRIMARY KEY,
80 recipe_name TEXT,
81 recipe_version TEXT,
82 hyperparameters_json TEXT NOT NULL,
83 status TEXT NOT NULL,
84 started_at TEXT NOT NULL,
85 finished_at TEXT,
86 run_json TEXT NOT NULL
87 );
88
89 CREATE INDEX IF NOT EXISTS idx_runs_recipe ON runs(recipe_name, recipe_version);
90 CREATE INDEX IF NOT EXISTS idx_runs_status ON runs(status);
91
92 -- Lineage edges
93 CREATE TABLE IF NOT EXISTS lineage (
94 id INTEGER PRIMARY KEY AUTOINCREMENT,
95 from_id TEXT NOT NULL,
96 to_id TEXT NOT NULL,
97 edge_type TEXT NOT NULL,
98 metadata_json TEXT
99 );
100
101 CREATE INDEX IF NOT EXISTS idx_lineage_from ON lineage(from_id);
102 CREATE INDEX IF NOT EXISTS idx_lineage_to ON lineage(to_id);
103 ",
104 )?;
105 Ok(())
106 }
107
108 pub fn insert_model(&self, model: &Model) -> Result<()> {
112 let card_json = serde_json::to_string(&model.card)?;
113 self.conn.execute(
114 r"INSERT INTO models (id, name, version, content_hash, content_size, card_json, stage, created_at, updated_at)
115 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
116 params![
117 model.id.to_string(),
118 model.name,
119 model.version.to_string(),
120 model.content_address.hash_hex(),
121 model.content_address.size(),
122 card_json,
123 model.stage.to_string(),
124 model.created_at.to_rfc3339(),
125 model.updated_at.to_rfc3339(),
126 ],
127 )?;
128 Ok(())
129 }
130
131 pub fn model_exists(&self, name: &str, version: &ModelVersion) -> Result<bool> {
133 let count: i64 = self.conn.query_row(
134 "SELECT COUNT(*) FROM models WHERE name = ?1 AND version = ?2",
135 params![name, version.to_string()],
136 |row| row.get(0),
137 )?;
138 Ok(count > 0)
139 }
140
141 pub fn get_model(&self, name: &str, version: &ModelVersion) -> Result<Model> {
143 let row = self.conn.query_row(
144 r"SELECT id, name, version, content_hash, content_size, card_json, stage, created_at, updated_at
145 FROM models WHERE name = ?1 AND version = ?2",
146 params![name, version.to_string()],
147 |row| {
148 Ok((
149 row.get::<_, String>(0)?,
150 row.get::<_, String>(1)?,
151 row.get::<_, String>(2)?,
152 row.get::<_, String>(3)?,
153 row.get::<_, i64>(4)?,
154 row.get::<_, String>(5)?,
155 row.get::<_, String>(6)?,
156 row.get::<_, String>(7)?,
157 row.get::<_, String>(8)?,
158 ))
159 },
160 ).map_err(|e| match e {
161 rusqlite::Error::QueryReturnedNoRows => PachaError::NotFound {
162 kind: "model".to_string(),
163 name: name.to_string(),
164 version: version.to_string(),
165 },
166 e => PachaError::Database(e),
167 })?;
168
169 Self::row_to_model(row)
170 }
171
172 pub fn get_model_by_id(&self, id: &ModelId) -> Result<Model> {
174 let row = self.conn.query_row(
175 r"SELECT id, name, version, content_hash, content_size, card_json, stage, created_at, updated_at
176 FROM models WHERE id = ?1",
177 params![id.to_string()],
178 |row| {
179 Ok((
180 row.get::<_, String>(0)?,
181 row.get::<_, String>(1)?,
182 row.get::<_, String>(2)?,
183 row.get::<_, String>(3)?,
184 row.get::<_, i64>(4)?,
185 row.get::<_, String>(5)?,
186 row.get::<_, String>(6)?,
187 row.get::<_, String>(7)?,
188 row.get::<_, String>(8)?,
189 ))
190 },
191 ).map_err(|e| match e {
192 rusqlite::Error::QueryReturnedNoRows => PachaError::NotFound {
193 kind: "model".to_string(),
194 name: id.to_string(),
195 version: "n/a".to_string(),
196 },
197 e => PachaError::Database(e),
198 })?;
199
200 Self::row_to_model(row)
201 }
202
203 fn row_to_model(
204 row: (String, String, String, String, i64, String, String, String, String),
205 ) -> Result<Model> {
206 let (
207 id_str,
208 name,
209 version_str,
210 hash_hex,
211 size,
212 card_json,
213 stage_str,
214 created_str,
215 updated_str,
216 ) = row;
217
218 let hash_bytes = hex_decode(&hash_hex)?;
220 let mut hash = [0u8; 32];
221 hash.copy_from_slice(&hash_bytes);
222
223 let size_u64 = u64::try_from(size).unwrap_or(0);
225
226 Ok(Model {
227 id: id_str
228 .parse()
229 .map_err(|_| PachaError::Validation("invalid model id".to_string()))?,
230 name,
231 version: version_str.parse()?,
232 content_address: ContentAddress::new(hash, size_u64, crate::storage::Compression::None),
233 card: serde_json::from_str(&card_json)?,
234 stage: stage_str.parse()?,
235 created_at: chrono::DateTime::parse_from_rfc3339(&created_str)
236 .map_err(|_| PachaError::Validation("invalid timestamp".to_string()))?
237 .with_timezone(&chrono::Utc),
238 updated_at: chrono::DateTime::parse_from_rfc3339(&updated_str)
239 .map_err(|_| PachaError::Validation("invalid timestamp".to_string()))?
240 .with_timezone(&chrono::Utc),
241 })
242 }
243
244 pub fn list_model_versions(&self, name: &str) -> Result<Vec<ModelVersion>> {
246 let mut stmt =
247 self.conn.prepare("SELECT version FROM models WHERE name = ?1 ORDER BY version")?;
248 let rows = stmt.query_map(params![name], |row| row.get::<_, String>(0))?;
249
250 let mut versions = Vec::new();
251 for row in rows {
252 let version_str = row?;
253 versions.push(version_str.parse()?);
254 }
255 Ok(versions)
256 }
257
258 pub fn list_model_names(&self) -> Result<Vec<String>> {
260 contract_pre_ols_fit!();
261 let mut stmt = self.conn.prepare("SELECT DISTINCT name FROM models ORDER BY name")?;
262 let rows = stmt.query_map([], |row| row.get::<_, String>(0))?;
263
264 let mut names = Vec::new();
265 for row in rows {
266 names.push(row?);
267 }
268 Ok(names)
269 }
270
271 pub fn update_model_stage(&self, id: &ModelId, stage: ModelStage) -> Result<()> {
273 let updated_at = chrono::Utc::now().to_rfc3339();
274 self.conn.execute(
275 "UPDATE models SET stage = ?1, updated_at = ?2 WHERE id = ?3",
276 params![stage.to_string(), updated_at, id.to_string()],
277 )?;
278 Ok(())
279 }
280
281 pub fn count_models(&self) -> Result<usize> {
283 let count: i64 =
284 self.conn.query_row("SELECT COUNT(*) FROM models", [], |row| row.get(0))?;
285 Ok(usize::try_from(count).unwrap_or(0))
286 }
287
288 pub fn insert_dataset(&self, dataset: &Dataset) -> Result<()> {
292 let datasheet_json = serde_json::to_string(&dataset.datasheet)?;
293 self.conn.execute(
294 r"INSERT INTO datasets (id, name, version, content_hash, content_size, datasheet_json, created_at)
295 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
296 params![
297 dataset.id.to_string(),
298 dataset.name,
299 dataset.version.to_string(),
300 dataset.content_address.hash_hex(),
301 dataset.content_address.size(),
302 datasheet_json,
303 dataset.created_at.to_rfc3339(),
304 ],
305 )?;
306 Ok(())
307 }
308
309 pub fn dataset_exists(&self, name: &str, version: &DatasetVersion) -> Result<bool> {
311 let count: i64 = self.conn.query_row(
312 "SELECT COUNT(*) FROM datasets WHERE name = ?1 AND version = ?2",
313 params![name, version.to_string()],
314 |row| row.get(0),
315 )?;
316 Ok(count > 0)
317 }
318
319 pub fn get_dataset(&self, name: &str, version: &DatasetVersion) -> Result<Dataset> {
321 let row = self
322 .conn
323 .query_row(
324 r"SELECT id, name, version, content_hash, content_size, datasheet_json, created_at
325 FROM datasets WHERE name = ?1 AND version = ?2",
326 params![name, version.to_string()],
327 |row| {
328 Ok((
329 row.get::<_, String>(0)?,
330 row.get::<_, String>(1)?,
331 row.get::<_, String>(2)?,
332 row.get::<_, String>(3)?,
333 row.get::<_, i64>(4)?,
334 row.get::<_, String>(5)?,
335 row.get::<_, String>(6)?,
336 ))
337 },
338 )
339 .map_err(|e| match e {
340 rusqlite::Error::QueryReturnedNoRows => PachaError::NotFound {
341 kind: "dataset".to_string(),
342 name: name.to_string(),
343 version: version.to_string(),
344 },
345 e => PachaError::Database(e),
346 })?;
347
348 let (id_str, name, version_str, hash_hex, size, datasheet_json, created_str) = row;
349
350 let hash_bytes = hex_decode(&hash_hex)?;
351 let mut hash = [0u8; 32];
352 hash.copy_from_slice(&hash_bytes);
353
354 let size_u64 = u64::try_from(size).unwrap_or(0);
356
357 Ok(Dataset {
358 id: id_str
359 .parse()
360 .map_err(|_| PachaError::Validation("invalid dataset id".to_string()))?,
361 name,
362 version: version_str.parse()?,
363 content_address: ContentAddress::new(hash, size_u64, crate::storage::Compression::None),
364 datasheet: serde_json::from_str(&datasheet_json)?,
365 created_at: chrono::DateTime::parse_from_rfc3339(&created_str)
366 .map_err(|_| PachaError::Validation("invalid timestamp".to_string()))?
367 .with_timezone(&chrono::Utc),
368 })
369 }
370
371 pub fn list_dataset_names(&self) -> Result<Vec<String>> {
373 contract_pre_name_resolution!();
374 let mut stmt = self.conn.prepare("SELECT DISTINCT name FROM datasets ORDER BY name")?;
375 let rows = stmt.query_map([], |row| row.get::<_, String>(0))?;
376
377 let mut names = Vec::new();
378 for row in rows {
379 names.push(row?);
380 }
381 Ok(names)
382 }
383
384 pub fn list_dataset_versions(&self, name: &str) -> Result<Vec<DatasetVersion>> {
386 let mut stmt =
387 self.conn.prepare("SELECT version FROM datasets WHERE name = ?1 ORDER BY version")?;
388 let rows = stmt.query_map(params![name], |row| row.get::<_, String>(0))?;
389
390 let mut versions = Vec::new();
391 for row in rows {
392 let version_str = row?;
393 versions.push(version_str.parse()?);
394 }
395 Ok(versions)
396 }
397
398 pub fn count_datasets(&self) -> Result<usize> {
400 let count: i64 =
401 self.conn.query_row("SELECT COUNT(*) FROM datasets", [], |row| row.get(0))?;
402 Ok(usize::try_from(count).unwrap_or(0))
403 }
404
405 pub fn insert_recipe(&self, recipe: &TrainingRecipe) -> Result<()> {
409 let recipe_json = serde_json::to_string(recipe)?;
410 self.conn.execute(
411 r"INSERT INTO recipes (id, name, version, recipe_json, created_at)
412 VALUES (?1, ?2, ?3, ?4, ?5)",
413 params![
414 recipe.id.to_string(),
415 recipe.name,
416 recipe.version.to_string(),
417 recipe_json,
418 recipe.created_at.to_rfc3339(),
419 ],
420 )?;
421 Ok(())
422 }
423
424 pub fn recipe_exists(&self, name: &str, version: &RecipeVersion) -> Result<bool> {
426 let count: i64 = self.conn.query_row(
427 "SELECT COUNT(*) FROM recipes WHERE name = ?1 AND version = ?2",
428 params![name, version.to_string()],
429 |row| row.get(0),
430 )?;
431 Ok(count > 0)
432 }
433
434 pub fn get_recipe(&self, name: &str, version: &RecipeVersion) -> Result<TrainingRecipe> {
436 let recipe_json: String = self
437 .conn
438 .query_row(
439 "SELECT recipe_json FROM recipes WHERE name = ?1 AND version = ?2",
440 params![name, version.to_string()],
441 |row| row.get(0),
442 )
443 .map_err(|e| match e {
444 rusqlite::Error::QueryReturnedNoRows => PachaError::NotFound {
445 kind: "recipe".to_string(),
446 name: name.to_string(),
447 version: version.to_string(),
448 },
449 e => PachaError::Database(e),
450 })?;
451
452 Ok(serde_json::from_str(&recipe_json)?)
453 }
454
455 pub fn list_recipe_names(&self) -> Result<Vec<String>> {
457 contract_pre_expand_recipe!();
458 let mut stmt = self.conn.prepare("SELECT DISTINCT name FROM recipes ORDER BY name")?;
459 let rows = stmt.query_map([], |row| row.get::<_, String>(0))?;
460
461 let mut names = Vec::new();
462 for row in rows {
463 names.push(row?);
464 }
465 Ok(names)
466 }
467
468 pub fn list_recipe_versions(&self, name: &str) -> Result<Vec<RecipeVersion>> {
470 let mut stmt =
471 self.conn.prepare("SELECT version FROM recipes WHERE name = ?1 ORDER BY version")?;
472 let rows = stmt.query_map(params![name], |row| row.get::<_, String>(0))?;
473
474 let mut versions = Vec::new();
475 for row in rows {
476 let version_str = row?;
477 versions.push(version_str.parse()?);
478 }
479 Ok(versions)
480 }
481
482 pub fn count_recipes(&self) -> Result<usize> {
484 let count: i64 =
485 self.conn.query_row("SELECT COUNT(*) FROM recipes", [], |row| row.get(0))?;
486 Ok(usize::try_from(count).unwrap_or(0))
487 }
488
489 pub fn insert_run(&self, run: &ExperimentRun) -> Result<()> {
493 contract_pre_configuration!();
494 let hyperparams_json = serde_json::to_string(&run.hyperparameters)?;
495 let run_json = serde_json::to_string(run)?;
496 let (recipe_name, recipe_version) = run
497 .recipe
498 .as_ref()
499 .map_or((None, None), |r| (Some(r.name.clone()), Some(r.version.to_string())));
500
501 self.conn.execute(
502 r"INSERT INTO runs (id, recipe_name, recipe_version, hyperparameters_json, status, started_at, finished_at, run_json)
503 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
504 params![
505 run.run_id.to_string(),
506 recipe_name,
507 recipe_version,
508 hyperparams_json,
509 run.status.to_string(),
510 run.started_at.to_rfc3339(),
511 run.finished_at.map(|t| t.to_rfc3339()),
512 run_json,
513 ],
514 )?;
515 Ok(())
516 }
517
518 pub fn update_run(&self, run: &ExperimentRun) -> Result<()> {
520 let run_json = serde_json::to_string(run)?;
521 self.conn.execute(
522 r"UPDATE runs SET status = ?1, finished_at = ?2, run_json = ?3 WHERE id = ?4",
523 params![
524 run.status.to_string(),
525 run.finished_at.map(|t| t.to_rfc3339()),
526 run_json,
527 run.run_id.to_string(),
528 ],
529 )?;
530 Ok(())
531 }
532
533 pub fn get_run(&self, run_id: &RunId) -> Result<ExperimentRun> {
535 let run_json: String = self
536 .conn
537 .query_row(
538 "SELECT run_json FROM runs WHERE id = ?1",
539 params![run_id.to_string()],
540 |row| row.get(0),
541 )
542 .map_err(|e| match e {
543 rusqlite::Error::QueryReturnedNoRows => PachaError::NotFound {
544 kind: "run".to_string(),
545 name: run_id.to_string(),
546 version: "n/a".to_string(),
547 },
548 e => PachaError::Database(e),
549 })?;
550
551 Ok(serde_json::from_str(&run_json)?)
552 }
553
554 pub fn list_runs_for_recipe(&self, recipe_ref: &RecipeReference) -> Result<Vec<ExperimentRun>> {
556 contract_pre_expand_recipe!(recipe_ref);
557 let mut stmt = self.conn.prepare(
558 "SELECT run_json FROM runs WHERE recipe_name = ?1 AND recipe_version = ?2 ORDER BY started_at DESC"
559 )?;
560
561 let rows = stmt
562 .query_map(params![recipe_ref.name, recipe_ref.version.to_string()], |row| {
563 row.get::<_, String>(0)
564 })?;
565
566 let mut runs = Vec::new();
567 for row in rows {
568 let run_json = row?;
569 runs.push(serde_json::from_str(&run_json)?);
570 }
571 Ok(runs)
572 }
573}
574
575fn hex_decode(s: &str) -> Result<Vec<u8>> {
577 let mut bytes = Vec::with_capacity(s.len() / 2);
578 let chars: Vec<char> = s.chars().collect();
579
580 for chunk in chars.chunks(2) {
581 if chunk.len() != 2 {
582 return Err(PachaError::Validation("invalid hex string".to_string()));
583 }
584 let high = hex_char_to_nibble(chunk[0])?;
585 let low = hex_char_to_nibble(chunk[1])?;
586 bytes.push((high << 4) | low);
587 }
588
589 Ok(bytes)
590}
591
592fn hex_char_to_nibble(c: char) -> Result<u8> {
593 match c {
594 '0'..='9' => Ok(c as u8 - b'0'),
595 'a'..='f' => Ok(c as u8 - b'a' + 10),
596 'A'..='F' => Ok(c as u8 - b'A' + 10),
597 _ => Err(PachaError::Validation(format!("invalid hex char: {c}"))),
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604 use crate::data::{DatasetId, Datasheet};
605 use crate::model::ModelCard;
606 use tempfile::TempDir;
607
608 fn setup() -> (TempDir, RegistryDb) {
609 let dir = TempDir::new().unwrap();
610 let db = RegistryDb::open(dir.path().join("test.db")).unwrap();
611 (dir, db)
612 }
613
614 #[test]
615 fn test_db_open() {
616 let (_dir, _db) = setup();
617 }
618
619 #[test]
620 fn test_hex_decode() {
621 assert_eq!(hex_decode("00").unwrap(), vec![0]);
622 assert_eq!(hex_decode("ff").unwrap(), vec![255]);
623 assert_eq!(hex_decode("0123").unwrap(), vec![1, 35]);
624 assert_eq!(hex_decode("deadbeef").unwrap(), vec![0xde, 0xad, 0xbe, 0xef]);
625 }
626
627 #[test]
628 fn test_model_crud() {
629 let (_dir, db) = setup();
630
631 let model = Model {
632 id: ModelId::new(),
633 name: "test".to_string(),
634 version: ModelVersion::new(1, 0, 0),
635 content_address: ContentAddress::from_bytes(b"test"),
636 card: ModelCard::new("Test model"),
637 stage: ModelStage::Development,
638 created_at: chrono::Utc::now(),
639 updated_at: chrono::Utc::now(),
640 };
641
642 db.insert_model(&model).unwrap();
643 assert!(db.model_exists("test", &ModelVersion::new(1, 0, 0)).unwrap());
644
645 let retrieved = db.get_model("test", &ModelVersion::new(1, 0, 0)).unwrap();
646 assert_eq!(retrieved.id, model.id);
647 assert_eq!(retrieved.name, model.name);
648 }
649
650 #[test]
651 fn test_dataset_crud() {
652 let (_dir, db) = setup();
653
654 let dataset = Dataset {
655 id: DatasetId::new(),
656 name: "test-data".to_string(),
657 version: DatasetVersion::new(1, 0, 0),
658 content_address: ContentAddress::from_bytes(b"data"),
659 datasheet: Datasheet::new("Test dataset"),
660 created_at: chrono::Utc::now(),
661 };
662
663 db.insert_dataset(&dataset).unwrap();
664 assert!(db.dataset_exists("test-data", &DatasetVersion::new(1, 0, 0)).unwrap());
665
666 let retrieved = db.get_dataset("test-data", &DatasetVersion::new(1, 0, 0)).unwrap();
667 assert_eq!(retrieved.id, dataset.id);
668 }
669
670 #[test]
671 fn test_recipe_crud() {
672 let (_dir, db) = setup();
673
674 let recipe = TrainingRecipe::builder()
675 .name("test-recipe")
676 .version(RecipeVersion::new(1, 0, 0))
677 .description("Test")
678 .build();
679
680 db.insert_recipe(&recipe).unwrap();
681 assert!(db.recipe_exists("test-recipe", &RecipeVersion::new(1, 0, 0)).unwrap());
682
683 let retrieved = db.get_recipe("test-recipe", &RecipeVersion::new(1, 0, 0)).unwrap();
684 assert_eq!(retrieved.id, recipe.id);
685 }
686}