1use crate::data::{Dataset, DatasetVersion};
4use crate::error::{PachaError, Result};
5use crate::experiment::{ExperimentRun, RunId};
6use crate::model::{Model, ModelId, ModelStage, ModelVersion};
7use crate::recipe::{RecipeReference, RecipeVersion, TrainingRecipe};
8use crate::storage::ContentAddress;
9use rusqlite::{params, Connection};
10use std::path::Path;
11
12pub struct RegistryDb {
14 conn: Connection,
15}
16
17impl RegistryDb {
18 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
24 let conn = Connection::open(path)?;
25 let db = Self { conn };
26 db.init_schema()?;
27 Ok(db)
28 }
29
30 pub fn vacuum_into<P: AsRef<Path>>(&self, target: P) -> Result<()> {
47 let raw = target.as_ref().to_str().ok_or_else(|| {
51 PachaError::Validation("snapshot path is not valid UTF-8".to_string())
52 })?;
53 let escaped = raw.replace('\'', "''");
54 let sql = format!("VACUUM INTO '{escaped}'");
55 self.conn.execute_batch(&sql)?;
56 Ok(())
57 }
58
59 fn init_schema(&self) -> Result<()> {
61 self.conn.execute_batch(
62 r"
63 -- Models
64 CREATE TABLE IF NOT EXISTS models (
65 id TEXT PRIMARY KEY,
66 name TEXT NOT NULL,
67 version TEXT NOT NULL,
68 content_hash TEXT NOT NULL,
69 content_size INTEGER NOT NULL,
70 card_json TEXT NOT NULL,
71 stage TEXT DEFAULT 'development',
72 created_at TEXT NOT NULL,
73 updated_at TEXT NOT NULL,
74 UNIQUE(name, version)
75 );
76
77 CREATE INDEX IF NOT EXISTS idx_models_name ON models(name);
78 CREATE INDEX IF NOT EXISTS idx_models_stage ON models(stage);
79
80 -- Datasets
81 CREATE TABLE IF NOT EXISTS datasets (
82 id TEXT PRIMARY KEY,
83 name TEXT NOT NULL,
84 version TEXT NOT NULL,
85 content_hash TEXT NOT NULL,
86 content_size INTEGER NOT NULL,
87 datasheet_json TEXT NOT NULL,
88 created_at TEXT NOT NULL,
89 UNIQUE(name, version)
90 );
91
92 CREATE INDEX IF NOT EXISTS idx_datasets_name ON datasets(name);
93
94 -- Recipes
95 CREATE TABLE IF NOT EXISTS recipes (
96 id TEXT PRIMARY KEY,
97 name TEXT NOT NULL,
98 version TEXT NOT NULL,
99 recipe_json TEXT NOT NULL,
100 created_at TEXT NOT NULL,
101 UNIQUE(name, version)
102 );
103
104 CREATE INDEX IF NOT EXISTS idx_recipes_name ON recipes(name);
105
106 -- Experiment Runs
107 CREATE TABLE IF NOT EXISTS runs (
108 id TEXT PRIMARY KEY,
109 recipe_name TEXT,
110 recipe_version TEXT,
111 hyperparameters_json TEXT NOT NULL,
112 status TEXT NOT NULL,
113 started_at TEXT NOT NULL,
114 finished_at TEXT,
115 run_json TEXT NOT NULL
116 );
117
118 CREATE INDEX IF NOT EXISTS idx_runs_recipe ON runs(recipe_name, recipe_version);
119 CREATE INDEX IF NOT EXISTS idx_runs_status ON runs(status);
120
121 -- Lineage edges
122 CREATE TABLE IF NOT EXISTS lineage (
123 id INTEGER PRIMARY KEY AUTOINCREMENT,
124 from_id TEXT NOT NULL,
125 to_id TEXT NOT NULL,
126 edge_type TEXT NOT NULL,
127 metadata_json TEXT
128 );
129
130 CREATE INDEX IF NOT EXISTS idx_lineage_from ON lineage(from_id);
131 CREATE INDEX IF NOT EXISTS idx_lineage_to ON lineage(to_id);
132 ",
133 )?;
134 Ok(())
135 }
136
137 pub fn insert_model(&self, model: &Model) -> Result<()> {
141 let card_json = serde_json::to_string(&model.card)?;
142 self.conn.execute(
143 r"INSERT INTO models (id, name, version, content_hash, content_size, card_json, stage, created_at, updated_at)
144 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
145 params![
146 model.id.to_string(),
147 model.name,
148 model.version.to_string(),
149 model.content_address.hash_hex(),
150 model.content_address.size(),
151 card_json,
152 model.stage.to_string(),
153 model.created_at.to_rfc3339(),
154 model.updated_at.to_rfc3339(),
155 ],
156 )?;
157 Ok(())
158 }
159
160 pub fn model_exists(&self, name: &str, version: &ModelVersion) -> Result<bool> {
162 let count: i64 = self.conn.query_row(
163 "SELECT COUNT(*) FROM models WHERE name = ?1 AND version = ?2",
164 params![name, version.to_string()],
165 |row| row.get(0),
166 )?;
167 Ok(count > 0)
168 }
169
170 pub fn get_model(&self, name: &str, version: &ModelVersion) -> Result<Model> {
172 let row = self.conn.query_row(
173 r"SELECT id, name, version, content_hash, content_size, card_json, stage, created_at, updated_at
174 FROM models WHERE name = ?1 AND version = ?2",
175 params![name, version.to_string()],
176 |row| {
177 Ok((
178 row.get::<_, String>(0)?,
179 row.get::<_, String>(1)?,
180 row.get::<_, String>(2)?,
181 row.get::<_, String>(3)?,
182 row.get::<_, i64>(4)?,
183 row.get::<_, String>(5)?,
184 row.get::<_, String>(6)?,
185 row.get::<_, String>(7)?,
186 row.get::<_, String>(8)?,
187 ))
188 },
189 ).map_err(|e| match e {
190 rusqlite::Error::QueryReturnedNoRows => PachaError::NotFound {
191 kind: "model".to_string(),
192 name: name.to_string(),
193 version: version.to_string(),
194 },
195 e => PachaError::Database(e),
196 })?;
197
198 Self::row_to_model(row)
199 }
200
201 pub fn get_model_by_id(&self, id: &ModelId) -> Result<Model> {
203 let row = self.conn.query_row(
204 r"SELECT id, name, version, content_hash, content_size, card_json, stage, created_at, updated_at
205 FROM models WHERE id = ?1",
206 params![id.to_string()],
207 |row| {
208 Ok((
209 row.get::<_, String>(0)?,
210 row.get::<_, String>(1)?,
211 row.get::<_, String>(2)?,
212 row.get::<_, String>(3)?,
213 row.get::<_, i64>(4)?,
214 row.get::<_, String>(5)?,
215 row.get::<_, String>(6)?,
216 row.get::<_, String>(7)?,
217 row.get::<_, String>(8)?,
218 ))
219 },
220 ).map_err(|e| match e {
221 rusqlite::Error::QueryReturnedNoRows => PachaError::NotFound {
222 kind: "model".to_string(),
223 name: id.to_string(),
224 version: "n/a".to_string(),
225 },
226 e => PachaError::Database(e),
227 })?;
228
229 Self::row_to_model(row)
230 }
231
232 fn row_to_model(
233 row: (String, String, String, String, i64, String, String, String, String),
234 ) -> Result<Model> {
235 let (
236 id_str,
237 name,
238 version_str,
239 hash_hex,
240 size,
241 card_json,
242 stage_str,
243 created_str,
244 updated_str,
245 ) = row;
246
247 let hash_bytes = hex_decode(&hash_hex)?;
249 let mut hash = [0u8; 32];
250 hash.copy_from_slice(&hash_bytes);
251
252 let size_u64 = u64::try_from(size).unwrap_or(0);
254
255 Ok(Model {
256 id: id_str
257 .parse()
258 .map_err(|_| PachaError::Validation("invalid model id".to_string()))?,
259 name,
260 version: version_str.parse()?,
261 content_address: ContentAddress::new(hash, size_u64, crate::storage::Compression::None),
262 card: serde_json::from_str(&card_json)?,
263 stage: stage_str.parse()?,
264 created_at: chrono::DateTime::parse_from_rfc3339(&created_str)
265 .map_err(|_| PachaError::Validation("invalid timestamp".to_string()))?
266 .with_timezone(&chrono::Utc),
267 updated_at: chrono::DateTime::parse_from_rfc3339(&updated_str)
268 .map_err(|_| PachaError::Validation("invalid timestamp".to_string()))?
269 .with_timezone(&chrono::Utc),
270 })
271 }
272
273 pub fn list_model_versions(&self, name: &str) -> Result<Vec<ModelVersion>> {
275 let mut stmt =
276 self.conn.prepare("SELECT version FROM models WHERE name = ?1 ORDER BY version")?;
277 let rows = stmt.query_map(params![name], |row| row.get::<_, String>(0))?;
278
279 let mut versions = Vec::new();
280 for row in rows {
281 let version_str = row?;
282 versions.push(version_str.parse()?);
283 }
284 Ok(versions)
285 }
286
287 pub fn list_model_names(&self) -> Result<Vec<String>> {
289 contract_pre_ols_fit!();
290 let mut stmt = self.conn.prepare("SELECT DISTINCT name FROM models ORDER BY name")?;
291 let rows = stmt.query_map([], |row| row.get::<_, String>(0))?;
292
293 let mut names = Vec::new();
294 for row in rows {
295 names.push(row?);
296 }
297 Ok(names)
298 }
299
300 pub fn update_model_stage(&self, id: &ModelId, stage: ModelStage) -> Result<()> {
302 let updated_at = chrono::Utc::now().to_rfc3339();
303 self.conn.execute(
304 "UPDATE models SET stage = ?1, updated_at = ?2 WHERE id = ?3",
305 params![stage.to_string(), updated_at, id.to_string()],
306 )?;
307 Ok(())
308 }
309
310 pub fn count_models(&self) -> Result<usize> {
312 let count: i64 =
313 self.conn.query_row("SELECT COUNT(*) FROM models", [], |row| row.get(0))?;
314 Ok(usize::try_from(count).unwrap_or(0))
315 }
316
317 pub fn insert_dataset(&self, dataset: &Dataset) -> Result<()> {
321 let datasheet_json = serde_json::to_string(&dataset.datasheet)?;
322 self.conn.execute(
323 r"INSERT INTO datasets (id, name, version, content_hash, content_size, datasheet_json, created_at)
324 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
325 params![
326 dataset.id.to_string(),
327 dataset.name,
328 dataset.version.to_string(),
329 dataset.content_address.hash_hex(),
330 dataset.content_address.size(),
331 datasheet_json,
332 dataset.created_at.to_rfc3339(),
333 ],
334 )?;
335 Ok(())
336 }
337
338 pub fn dataset_exists(&self, name: &str, version: &DatasetVersion) -> Result<bool> {
340 let count: i64 = self.conn.query_row(
341 "SELECT COUNT(*) FROM datasets WHERE name = ?1 AND version = ?2",
342 params![name, version.to_string()],
343 |row| row.get(0),
344 )?;
345 Ok(count > 0)
346 }
347
348 pub fn get_dataset(&self, name: &str, version: &DatasetVersion) -> Result<Dataset> {
350 let row = self
351 .conn
352 .query_row(
353 r"SELECT id, name, version, content_hash, content_size, datasheet_json, created_at
354 FROM datasets WHERE name = ?1 AND version = ?2",
355 params![name, version.to_string()],
356 |row| {
357 Ok((
358 row.get::<_, String>(0)?,
359 row.get::<_, String>(1)?,
360 row.get::<_, String>(2)?,
361 row.get::<_, String>(3)?,
362 row.get::<_, i64>(4)?,
363 row.get::<_, String>(5)?,
364 row.get::<_, String>(6)?,
365 ))
366 },
367 )
368 .map_err(|e| match e {
369 rusqlite::Error::QueryReturnedNoRows => PachaError::NotFound {
370 kind: "dataset".to_string(),
371 name: name.to_string(),
372 version: version.to_string(),
373 },
374 e => PachaError::Database(e),
375 })?;
376
377 let (id_str, name, version_str, hash_hex, size, datasheet_json, created_str) = row;
378
379 let hash_bytes = hex_decode(&hash_hex)?;
380 let mut hash = [0u8; 32];
381 hash.copy_from_slice(&hash_bytes);
382
383 let size_u64 = u64::try_from(size).unwrap_or(0);
385
386 Ok(Dataset {
387 id: id_str
388 .parse()
389 .map_err(|_| PachaError::Validation("invalid dataset id".to_string()))?,
390 name,
391 version: version_str.parse()?,
392 content_address: ContentAddress::new(hash, size_u64, crate::storage::Compression::None),
393 datasheet: serde_json::from_str(&datasheet_json)?,
394 created_at: chrono::DateTime::parse_from_rfc3339(&created_str)
395 .map_err(|_| PachaError::Validation("invalid timestamp".to_string()))?
396 .with_timezone(&chrono::Utc),
397 })
398 }
399
400 pub fn list_dataset_names(&self) -> Result<Vec<String>> {
402 contract_pre_name_resolution!();
403 let mut stmt = self.conn.prepare("SELECT DISTINCT name FROM datasets ORDER BY name")?;
404 let rows = stmt.query_map([], |row| row.get::<_, String>(0))?;
405
406 let mut names = Vec::new();
407 for row in rows {
408 names.push(row?);
409 }
410 Ok(names)
411 }
412
413 pub fn list_dataset_versions(&self, name: &str) -> Result<Vec<DatasetVersion>> {
415 let mut stmt =
416 self.conn.prepare("SELECT version FROM datasets WHERE name = ?1 ORDER BY version")?;
417 let rows = stmt.query_map(params![name], |row| row.get::<_, String>(0))?;
418
419 let mut versions = Vec::new();
420 for row in rows {
421 let version_str = row?;
422 versions.push(version_str.parse()?);
423 }
424 Ok(versions)
425 }
426
427 pub fn count_datasets(&self) -> Result<usize> {
429 let count: i64 =
430 self.conn.query_row("SELECT COUNT(*) FROM datasets", [], |row| row.get(0))?;
431 Ok(usize::try_from(count).unwrap_or(0))
432 }
433
434 pub fn insert_recipe(&self, recipe: &TrainingRecipe) -> Result<()> {
438 let recipe_json = serde_json::to_string(recipe)?;
439 self.conn.execute(
440 r"INSERT INTO recipes (id, name, version, recipe_json, created_at)
441 VALUES (?1, ?2, ?3, ?4, ?5)",
442 params![
443 recipe.id.to_string(),
444 recipe.name,
445 recipe.version.to_string(),
446 recipe_json,
447 recipe.created_at.to_rfc3339(),
448 ],
449 )?;
450 Ok(())
451 }
452
453 pub fn recipe_exists(&self, name: &str, version: &RecipeVersion) -> Result<bool> {
455 let count: i64 = self.conn.query_row(
456 "SELECT COUNT(*) FROM recipes WHERE name = ?1 AND version = ?2",
457 params![name, version.to_string()],
458 |row| row.get(0),
459 )?;
460 Ok(count > 0)
461 }
462
463 pub fn get_recipe(&self, name: &str, version: &RecipeVersion) -> Result<TrainingRecipe> {
465 let recipe_json: String = self
466 .conn
467 .query_row(
468 "SELECT recipe_json FROM recipes WHERE name = ?1 AND version = ?2",
469 params![name, version.to_string()],
470 |row| row.get(0),
471 )
472 .map_err(|e| match e {
473 rusqlite::Error::QueryReturnedNoRows => PachaError::NotFound {
474 kind: "recipe".to_string(),
475 name: name.to_string(),
476 version: version.to_string(),
477 },
478 e => PachaError::Database(e),
479 })?;
480
481 Ok(serde_json::from_str(&recipe_json)?)
482 }
483
484 pub fn list_recipe_names(&self) -> Result<Vec<String>> {
486 contract_pre_expand_recipe!();
487 let mut stmt = self.conn.prepare("SELECT DISTINCT name FROM recipes ORDER BY name")?;
488 let rows = stmt.query_map([], |row| row.get::<_, String>(0))?;
489
490 let mut names = Vec::new();
491 for row in rows {
492 names.push(row?);
493 }
494 Ok(names)
495 }
496
497 pub fn list_recipe_versions(&self, name: &str) -> Result<Vec<RecipeVersion>> {
499 let mut stmt =
500 self.conn.prepare("SELECT version FROM recipes WHERE name = ?1 ORDER BY version")?;
501 let rows = stmt.query_map(params![name], |row| row.get::<_, String>(0))?;
502
503 let mut versions = Vec::new();
504 for row in rows {
505 let version_str = row?;
506 versions.push(version_str.parse()?);
507 }
508 Ok(versions)
509 }
510
511 pub fn count_recipes(&self) -> Result<usize> {
513 let count: i64 =
514 self.conn.query_row("SELECT COUNT(*) FROM recipes", [], |row| row.get(0))?;
515 Ok(usize::try_from(count).unwrap_or(0))
516 }
517
518 pub fn insert_run(&self, run: &ExperimentRun) -> Result<()> {
522 contract_pre_configuration!();
523 let hyperparams_json = serde_json::to_string(&run.hyperparameters)?;
524 let run_json = serde_json::to_string(run)?;
525 let (recipe_name, recipe_version) = run
526 .recipe
527 .as_ref()
528 .map_or((None, None), |r| (Some(r.name.clone()), Some(r.version.to_string())));
529
530 self.conn.execute(
531 r"INSERT INTO runs (id, recipe_name, recipe_version, hyperparameters_json, status, started_at, finished_at, run_json)
532 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
533 params![
534 run.run_id.to_string(),
535 recipe_name,
536 recipe_version,
537 hyperparams_json,
538 run.status.to_string(),
539 run.started_at.to_rfc3339(),
540 run.finished_at.map(|t| t.to_rfc3339()),
541 run_json,
542 ],
543 )?;
544 Ok(())
545 }
546
547 pub fn update_run(&self, run: &ExperimentRun) -> Result<()> {
549 let run_json = serde_json::to_string(run)?;
550 self.conn.execute(
551 r"UPDATE runs SET status = ?1, finished_at = ?2, run_json = ?3 WHERE id = ?4",
552 params![
553 run.status.to_string(),
554 run.finished_at.map(|t| t.to_rfc3339()),
555 run_json,
556 run.run_id.to_string(),
557 ],
558 )?;
559 Ok(())
560 }
561
562 pub fn get_run(&self, run_id: &RunId) -> Result<ExperimentRun> {
564 let run_json: String = self
565 .conn
566 .query_row(
567 "SELECT run_json FROM runs WHERE id = ?1",
568 params![run_id.to_string()],
569 |row| row.get(0),
570 )
571 .map_err(|e| match e {
572 rusqlite::Error::QueryReturnedNoRows => PachaError::NotFound {
573 kind: "run".to_string(),
574 name: run_id.to_string(),
575 version: "n/a".to_string(),
576 },
577 e => PachaError::Database(e),
578 })?;
579
580 Ok(serde_json::from_str(&run_json)?)
581 }
582
583 pub fn list_runs_for_recipe(&self, recipe_ref: &RecipeReference) -> Result<Vec<ExperimentRun>> {
585 contract_pre_expand_recipe!(recipe_ref);
586 let mut stmt = self.conn.prepare(
587 "SELECT run_json FROM runs WHERE recipe_name = ?1 AND recipe_version = ?2 ORDER BY started_at DESC"
588 )?;
589
590 let rows = stmt
591 .query_map(params![recipe_ref.name, recipe_ref.version.to_string()], |row| {
592 row.get::<_, String>(0)
593 })?;
594
595 let mut runs = Vec::new();
596 for row in rows {
597 let run_json = row?;
598 runs.push(serde_json::from_str(&run_json)?);
599 }
600 Ok(runs)
601 }
602}
603
604fn hex_decode(s: &str) -> Result<Vec<u8>> {
606 let mut bytes = Vec::with_capacity(s.len() / 2);
607 let chars: Vec<char> = s.chars().collect();
608
609 for chunk in chars.chunks(2) {
610 if chunk.len() != 2 {
611 return Err(PachaError::Validation("invalid hex string".to_string()));
612 }
613 let high = hex_char_to_nibble(chunk[0])?;
614 let low = hex_char_to_nibble(chunk[1])?;
615 bytes.push((high << 4) | low);
616 }
617
618 Ok(bytes)
619}
620
621fn hex_char_to_nibble(c: char) -> Result<u8> {
622 match c {
623 '0'..='9' => Ok(c as u8 - b'0'),
624 'a'..='f' => Ok(c as u8 - b'a' + 10),
625 'A'..='F' => Ok(c as u8 - b'A' + 10),
626 _ => Err(PachaError::Validation(format!("invalid hex char: {c}"))),
627 }
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633 use crate::data::{DatasetId, Datasheet};
634 use crate::model::ModelCard;
635 use tempfile::TempDir;
636
637 fn setup() -> (TempDir, RegistryDb) {
638 let dir = TempDir::new().unwrap();
639 let db = RegistryDb::open(dir.path().join("test.db")).unwrap();
640 (dir, db)
641 }
642
643 #[test]
644 fn test_db_open() {
645 let (_dir, _db) = setup();
646 }
647
648 #[test]
649 fn test_hex_decode() {
650 assert_eq!(hex_decode("00").unwrap(), vec![0]);
651 assert_eq!(hex_decode("ff").unwrap(), vec![255]);
652 assert_eq!(hex_decode("0123").unwrap(), vec![1, 35]);
653 assert_eq!(hex_decode("deadbeef").unwrap(), vec![0xde, 0xad, 0xbe, 0xef]);
654 }
655
656 #[test]
657 fn test_model_crud() {
658 let (_dir, db) = setup();
659
660 let model = Model {
661 id: ModelId::new(),
662 name: "test".to_string(),
663 version: ModelVersion::new(1, 0, 0),
664 content_address: ContentAddress::from_bytes(b"test"),
665 card: ModelCard::new("Test model"),
666 stage: ModelStage::Development,
667 created_at: chrono::Utc::now(),
668 updated_at: chrono::Utc::now(),
669 };
670
671 db.insert_model(&model).unwrap();
672 assert!(db.model_exists("test", &ModelVersion::new(1, 0, 0)).unwrap());
673
674 let retrieved = db.get_model("test", &ModelVersion::new(1, 0, 0)).unwrap();
675 assert_eq!(retrieved.id, model.id);
676 assert_eq!(retrieved.name, model.name);
677 }
678
679 #[test]
680 fn test_dataset_crud() {
681 let (_dir, db) = setup();
682
683 let dataset = Dataset {
684 id: DatasetId::new(),
685 name: "test-data".to_string(),
686 version: DatasetVersion::new(1, 0, 0),
687 content_address: ContentAddress::from_bytes(b"data"),
688 datasheet: Datasheet::new("Test dataset"),
689 created_at: chrono::Utc::now(),
690 };
691
692 db.insert_dataset(&dataset).unwrap();
693 assert!(db.dataset_exists("test-data", &DatasetVersion::new(1, 0, 0)).unwrap());
694
695 let retrieved = db.get_dataset("test-data", &DatasetVersion::new(1, 0, 0)).unwrap();
696 assert_eq!(retrieved.id, dataset.id);
697 }
698
699 #[test]
700 fn test_recipe_crud() {
701 let (_dir, db) = setup();
702
703 let recipe = TrainingRecipe::builder()
704 .name("test-recipe")
705 .version(RecipeVersion::new(1, 0, 0))
706 .description("Test")
707 .build();
708
709 db.insert_recipe(&recipe).unwrap();
710 assert!(db.recipe_exists("test-recipe", &RecipeVersion::new(1, 0, 0)).unwrap());
711
712 let retrieved = db.get_recipe("test-recipe", &RecipeVersion::new(1, 0, 0)).unwrap();
713 assert_eq!(retrieved.id, recipe.id);
714 }
715}