Skip to main content

pacha/registry/
mod.rs

1//! Registry implementation with `SQLite` storage.
2
3mod database;
4
5pub use database::RegistryDb;
6
7use crate::data::{Dataset, DatasetId, Datasheet};
8use crate::error::{PachaError, Result};
9use crate::experiment::{ExperimentRun, RunId};
10use crate::lineage::LineageGraph;
11use crate::model::{Model, ModelCard, ModelId, ModelStage, ModelVersion};
12use crate::recipe::{RecipeId, RecipeReference, TrainingRecipe};
13use crate::storage::ObjectStore;
14use chrono::Utc;
15use std::fs;
16use std::path::{Path, PathBuf};
17
18/// Configuration for the Pacha registry.
19#[derive(Debug, Clone)]
20pub struct RegistryConfig {
21    /// Base path for the registry.
22    pub base_path: PathBuf,
23}
24
25impl RegistryConfig {
26    /// Create a new config with the given base path.
27    #[must_use]
28    pub fn new<P: AsRef<Path>>(base_path: P) -> Self {
29        Self { base_path: base_path.as_ref().to_path_buf() }
30    }
31
32    /// Get the database path.
33    #[must_use]
34    pub fn db_path(&self) -> PathBuf {
35        self.base_path.join("registry.db")
36    }
37
38    /// Get the objects path.
39    #[must_use]
40    pub fn objects_path(&self) -> PathBuf {
41        self.base_path.join("objects")
42    }
43
44    /// Get the config file path.
45    #[must_use]
46    pub fn config_path(&self) -> PathBuf {
47        self.base_path.join("config.toml")
48    }
49}
50
51impl Default for RegistryConfig {
52    fn default() -> Self {
53        let home = dirs_path();
54        Self::new(home.join(".pacha"))
55    }
56}
57
58fn dirs_path() -> PathBuf {
59    std::env::var("HOME").map_or_else(|_| PathBuf::from("."), PathBuf::from)
60}
61
62/// The main Pacha registry.
63pub struct Registry {
64    config: RegistryConfig,
65    db: RegistryDb,
66    objects: ObjectStore,
67}
68
69impl Registry {
70    /// Create or open a registry at the default location (~/.pacha).
71    ///
72    /// # Errors
73    ///
74    /// Returns an error if initialization fails.
75    pub fn open_default() -> Result<Self> {
76        Self::open(RegistryConfig::default())
77    }
78
79    /// Create or open a registry with the given configuration.
80    ///
81    /// # Errors
82    ///
83    /// Returns an error if initialization fails.
84    pub fn open(config: RegistryConfig) -> Result<Self> {
85        // Create base directory
86        fs::create_dir_all(&config.base_path)?;
87
88        // Initialize database
89        let db = RegistryDb::open(config.db_path())?;
90
91        // Initialize object store
92        let objects = ObjectStore::new(config.objects_path())?;
93
94        Ok(Self { config, db, objects })
95    }
96
97    /// Get the registry configuration.
98    #[must_use]
99    pub fn config(&self) -> &RegistryConfig {
100        &self.config
101    }
102
103    // ==================== Model Registry ====================
104
105    /// Register a new model.
106    ///
107    /// # Errors
108    ///
109    /// Returns an error if registration fails.
110    pub fn register_model(
111        &self,
112        name: &str,
113        version: &ModelVersion,
114        artifact: &[u8],
115        card: ModelCard,
116    ) -> Result<ModelId> {
117        // Check if already exists
118        if self.db.model_exists(name, version)? {
119            return Err(PachaError::AlreadyExists {
120                kind: "model".to_string(),
121                name: name.to_string(),
122                version: version.to_string(),
123            });
124        }
125
126        // Store artifact
127        let content_address = self.objects.put(artifact)?;
128
129        // Create model
130        let model = Model {
131            id: ModelId::new(),
132            name: name.to_string(),
133            version: version.clone(),
134            content_address,
135            card,
136            stage: ModelStage::Development,
137            created_at: Utc::now(),
138            updated_at: Utc::now(),
139        };
140
141        // Save to database
142        self.db.insert_model(&model)?;
143
144        Ok(model.id)
145    }
146
147    /// Get a model by name and version.
148    ///
149    /// # Errors
150    ///
151    /// Returns an error if the model is not found.
152    pub fn get_model(&self, name: &str, version: &ModelVersion) -> Result<Model> {
153        // Contract: configuration-v1.yaml precondition (pv codegen)
154        contract_pre_configuration!(name.as_bytes());
155        self.db.get_model(name, version)
156    }
157
158    /// Get a model by ID.
159    ///
160    /// # Errors
161    ///
162    /// Returns an error if the model is not found.
163    pub fn get_model_by_id(&self, id: &ModelId) -> Result<Model> {
164        self.db.get_model_by_id(id)
165    }
166
167    /// List all versions of a model.
168    ///
169    /// # Errors
170    ///
171    /// Returns an error if the query fails.
172    pub fn list_model_versions(&self, name: &str) -> Result<Vec<ModelVersion>> {
173        contract_pre_ols_fit!();
174        let result = self.db.list_model_versions(name);
175        if let Ok(ref val) = result {
176            contract_post_configuration!(val);
177        }
178        result
179    }
180
181    /// List all model names.
182    ///
183    /// # Errors
184    ///
185    /// Returns an error if the query fails.
186    pub fn list_models(&self) -> Result<Vec<String>> {
187        contract_pre_ols_fit!();
188        let result = self.db.list_model_names();
189        if let Ok(ref val) = result {
190            contract_post_configuration!(val);
191        }
192        result
193    }
194
195    /// Transition a model to a new stage.
196    ///
197    /// # Errors
198    ///
199    /// Returns an error if the transition is invalid.
200    pub fn transition_model_stage(
201        &self,
202        name: &str,
203        version: &ModelVersion,
204        target_stage: ModelStage,
205    ) -> Result<()> {
206        let model = self.get_model(name, version)?;
207        let _new_stage = model.stage.transition_to(target_stage)?;
208        self.db.update_model_stage(&model.id, target_stage)
209    }
210
211    /// Get the artifact data for a model.
212    ///
213    /// # Errors
214    ///
215    /// Returns an error if the artifact cannot be retrieved.
216    pub fn get_model_artifact(&self, name: &str, version: &ModelVersion) -> Result<Vec<u8>> {
217        let model = self.get_model(name, version)?;
218        self.objects.get(&model.content_address)
219    }
220
221    /// Get model lineage graph.
222    ///
223    /// # Errors
224    ///
225    /// Returns an error if the query fails.
226    pub fn get_model_lineage(&self, _model_id: &ModelId) -> Result<LineageGraph> {
227        // Returns empty graph until lineage data is populated
228        Ok(LineageGraph::new())
229    }
230
231    // ==================== Dataset Registry ====================
232
233    /// Register a new dataset.
234    ///
235    /// # Errors
236    ///
237    /// Returns an error if registration fails.
238    pub fn register_dataset(
239        &self,
240        name: &str,
241        version: &crate::data::DatasetVersion,
242        data: &[u8],
243        datasheet: Datasheet,
244    ) -> Result<DatasetId> {
245        // Check if already exists
246        if self.db.dataset_exists(name, version)? {
247            return Err(PachaError::AlreadyExists {
248                kind: "dataset".to_string(),
249                name: name.to_string(),
250                version: version.to_string(),
251            });
252        }
253
254        // Store data
255        let content_address = self.objects.put(data)?;
256
257        // Create dataset
258        let dataset = Dataset {
259            id: DatasetId::new(),
260            name: name.to_string(),
261            version: version.clone(),
262            content_address,
263            datasheet,
264            created_at: Utc::now(),
265        };
266
267        // Save to database
268        self.db.insert_dataset(&dataset)?;
269
270        Ok(dataset.id)
271    }
272
273    /// Get a dataset by name and version.
274    ///
275    /// # Errors
276    ///
277    /// Returns an error if the dataset is not found.
278    pub fn get_dataset(
279        &self,
280        name: &str,
281        version: &crate::data::DatasetVersion,
282    ) -> Result<Dataset> {
283        self.db.get_dataset(name, version)
284    }
285
286    /// List all dataset names.
287    ///
288    /// # Errors
289    ///
290    /// Returns an error if the query fails.
291    pub fn list_datasets(&self) -> Result<Vec<String>> {
292        contract_pre_configuration!();
293        let result = self.db.list_dataset_names();
294        if let Ok(ref val) = result {
295            contract_post_configuration!(val);
296        }
297        result
298    }
299
300    /// List all versions of a dataset.
301    ///
302    /// # Errors
303    ///
304    /// Returns an error if the query fails.
305    pub fn list_dataset_versions(&self, name: &str) -> Result<Vec<crate::data::DatasetVersion>> {
306        contract_pre_configuration!(name);
307        let result = self.db.list_dataset_versions(name);
308        if let Ok(ref val) = result {
309            contract_post_configuration!(val);
310        }
311        result
312    }
313
314    /// Get the data for a dataset.
315    ///
316    /// # Errors
317    ///
318    /// Returns an error if the data cannot be retrieved.
319    pub fn get_dataset_data(
320        &self,
321        name: &str,
322        version: &crate::data::DatasetVersion,
323    ) -> Result<Vec<u8>> {
324        let dataset = self.get_dataset(name, version)?;
325        self.objects.get(&dataset.content_address)
326    }
327
328    // ==================== Recipe Registry ====================
329
330    /// Register a new recipe.
331    ///
332    /// # Errors
333    ///
334    /// Returns an error if registration fails.
335    pub fn register_recipe(&self, recipe: &TrainingRecipe) -> Result<RecipeId> {
336        // Check if already exists
337        if self.db.recipe_exists(&recipe.name, &recipe.version)? {
338            return Err(PachaError::AlreadyExists {
339                kind: "recipe".to_string(),
340                name: recipe.name.clone(),
341                version: recipe.version.to_string(),
342            });
343        }
344
345        let id = recipe.id.clone();
346        self.db.insert_recipe(recipe)?;
347        Ok(id)
348    }
349
350    /// Get a recipe by name and version.
351    ///
352    /// # Errors
353    ///
354    /// Returns an error if the recipe is not found.
355    pub fn get_recipe(
356        &self,
357        name: &str,
358        version: &crate::recipe::RecipeVersion,
359    ) -> Result<TrainingRecipe> {
360        self.db.get_recipe(name, version)
361    }
362
363    /// List all recipe names.
364    ///
365    /// # Errors
366    ///
367    /// Returns an error if the query fails.
368    pub fn list_recipes(&self) -> Result<Vec<String>> {
369        contract_pre_configuration!();
370        let result = self.db.list_recipe_names();
371        if let Ok(ref val) = result {
372            contract_post_configuration!(val);
373        }
374        result
375    }
376
377    /// List all versions of a recipe.
378    ///
379    /// # Errors
380    ///
381    /// Returns an error if the query fails.
382    pub fn list_recipe_versions(&self, name: &str) -> Result<Vec<crate::recipe::RecipeVersion>> {
383        contract_pre_expand_recipe!(name);
384        self.db.list_recipe_versions(name)
385    }
386
387    // ==================== Experiment Tracking ====================
388
389    /// Start a new experiment run.
390    ///
391    /// # Errors
392    ///
393    /// Returns an error if starting fails.
394    pub fn start_run(&self, mut run: ExperimentRun) -> Result<RunId> {
395        contract_pre_configuration!();
396        run.start();
397        let id = run.run_id.clone();
398        self.db.insert_run(&run)?;
399        Ok(id)
400    }
401
402    /// Update an experiment run.
403    ///
404    /// # Errors
405    ///
406    /// Returns an error if the update fails.
407    pub fn update_run(&self, run: &ExperimentRun) -> Result<()> {
408        contract_pre_configuration!();
409        self.db.update_run(run)
410    }
411
412    /// Get an experiment run by ID.
413    ///
414    /// # Errors
415    ///
416    /// Returns an error if the run is not found.
417    pub fn get_run(&self, run_id: &RunId) -> Result<ExperimentRun> {
418        contract_pre_configuration!();
419        self.db.get_run(run_id)
420    }
421
422    /// List runs for a recipe.
423    ///
424    /// # Errors
425    ///
426    /// Returns an error if the query fails.
427    pub fn list_runs(&self, recipe_ref: &RecipeReference) -> Result<Vec<ExperimentRun>> {
428        contract_pre_configuration!();
429        self.db.list_runs_for_recipe(recipe_ref)
430    }
431
432    // ==================== Utility ====================
433
434    /// Get storage statistics.
435    ///
436    /// # Errors
437    ///
438    /// Returns an error if querying fails.
439    pub fn storage_stats(&self) -> Result<StorageStats> {
440        let total_size = self.objects.total_size()?;
441        let object_count = self.objects.list()?.len();
442        let model_count = self.db.count_models()?;
443        let dataset_count = self.db.count_datasets()?;
444        let recipe_count = self.db.count_recipes()?;
445
446        Ok(StorageStats {
447            total_size_bytes: total_size,
448            object_count,
449            model_count,
450            dataset_count,
451            recipe_count,
452        })
453    }
454}
455
456/// Storage statistics.
457#[derive(Debug, Clone)]
458pub struct StorageStats {
459    /// Total size of all objects in bytes.
460    pub total_size_bytes: u64,
461    /// Number of content-addressed objects.
462    pub object_count: usize,
463    /// Number of registered models.
464    pub model_count: usize,
465    /// Number of registered datasets.
466    pub dataset_count: usize,
467    /// Number of registered recipes.
468    pub recipe_count: usize,
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use crate::data::DatasetVersion;
475    use crate::recipe::{Hyperparameters, RecipeVersion};
476    use tempfile::TempDir;
477
478    fn setup() -> (TempDir, Registry) {
479        let dir = TempDir::new().unwrap();
480        let config = RegistryConfig::new(dir.path());
481        let registry = Registry::open(config).unwrap();
482        (dir, registry)
483    }
484
485    #[test]
486    fn test_registry_open() {
487        let (_dir, registry) = setup();
488        assert!(registry.config.base_path.exists());
489    }
490
491    #[test]
492    fn test_register_and_get_model() {
493        let (_dir, registry) = setup();
494
495        let name = "test-model";
496        let version = ModelVersion::new(1, 0, 0);
497        let artifact = b"model data";
498        let card = ModelCard::new("Test model");
499
500        let id = registry.register_model(name, &version, artifact, card.clone()).unwrap();
501
502        let model = registry.get_model(name, &version).unwrap();
503        assert_eq!(model.id, id);
504        assert_eq!(model.name, name);
505        assert_eq!(model.version, version);
506        assert_eq!(model.card.description, card.description);
507    }
508
509    #[test]
510    fn test_register_duplicate_model_fails() {
511        let (_dir, registry) = setup();
512
513        let name = "test-model";
514        let version = ModelVersion::new(1, 0, 0);
515        let artifact = b"model data";
516        let card = ModelCard::new("Test model");
517
518        registry.register_model(name, &version, artifact, card.clone()).unwrap();
519
520        let result = registry.register_model(name, &version, artifact, card);
521        assert!(matches!(result, Err(PachaError::AlreadyExists { .. })));
522    }
523
524    #[test]
525    fn test_model_artifact_roundtrip() {
526        let (_dir, registry) = setup();
527
528        let name = "test-model";
529        let version = ModelVersion::new(1, 0, 0);
530        let artifact = b"model binary data here";
531        let card = ModelCard::new("Test");
532
533        registry.register_model(name, &version, artifact, card).unwrap();
534
535        let retrieved = registry.get_model_artifact(name, &version).unwrap();
536        assert_eq!(retrieved, artifact);
537    }
538
539    #[test]
540    fn test_model_stage_transition() {
541        let (_dir, registry) = setup();
542
543        let name = "test-model";
544        let version = ModelVersion::new(1, 0, 0);
545        registry.register_model(name, &version, b"data", ModelCard::new("Test")).unwrap();
546
547        // Development -> Staging is valid
548        registry.transition_model_stage(name, &version, ModelStage::Staging).unwrap();
549
550        let model = registry.get_model(name, &version).unwrap();
551        assert_eq!(model.stage, ModelStage::Staging);
552    }
553
554    #[test]
555    fn test_register_and_get_dataset() {
556        let (_dir, registry) = setup();
557
558        let name = "test-dataset";
559        let version = DatasetVersion::new(1, 0, 0);
560        let data = b"csv,data,here";
561        let datasheet = Datasheet::new("Test dataset");
562
563        let id = registry.register_dataset(name, &version, data, datasheet.clone()).unwrap();
564
565        let dataset = registry.get_dataset(name, &version).unwrap();
566        assert_eq!(dataset.id, id);
567        assert_eq!(dataset.datasheet.purpose, datasheet.purpose);
568    }
569
570    #[test]
571    fn test_dataset_data_roundtrip() {
572        let (_dir, registry) = setup();
573
574        let name = "test-dataset";
575        let version = DatasetVersion::new(1, 0, 0);
576        let data = b"raw dataset bytes";
577        let datasheet = Datasheet::new("Test");
578
579        registry.register_dataset(name, &version, data, datasheet).unwrap();
580
581        let retrieved = registry.get_dataset_data(name, &version).unwrap();
582        assert_eq!(retrieved, data);
583    }
584
585    #[test]
586    fn test_register_and_get_recipe() {
587        let (_dir, registry) = setup();
588
589        let recipe = TrainingRecipe::builder()
590            .name("test-recipe")
591            .version(RecipeVersion::new(1, 0, 0))
592            .description("Test recipe")
593            .hyperparameters(Hyperparameters::default())
594            .build();
595
596        let id = registry.register_recipe(&recipe).unwrap();
597
598        let retrieved = registry.get_recipe("test-recipe", &RecipeVersion::new(1, 0, 0)).unwrap();
599        assert_eq!(retrieved.id, id);
600        assert_eq!(retrieved.description, "Test recipe");
601    }
602
603    #[test]
604    fn test_experiment_run() {
605        let (_dir, registry) = setup();
606
607        let mut run = ExperimentRun::new(Hyperparameters::default());
608        run.log_metric("loss", 0.5, 100);
609
610        let run_id = registry.start_run(run).unwrap();
611
612        let retrieved = registry.get_run(&run_id).unwrap();
613        assert_eq!(retrieved.run_id, run_id);
614        assert_eq!(retrieved.metrics.len(), 1);
615    }
616
617    #[test]
618    fn test_storage_stats() {
619        let (_dir, registry) = setup();
620
621        registry
622            .register_model("model1", &ModelVersion::new(1, 0, 0), b"data1", ModelCard::new("M1"))
623            .unwrap();
624
625        registry
626            .register_dataset(
627                "dataset1",
628                &DatasetVersion::new(1, 0, 0),
629                b"data2",
630                Datasheet::new("D1"),
631            )
632            .unwrap();
633
634        let stats = registry.storage_stats().unwrap();
635        assert_eq!(stats.model_count, 1);
636        assert_eq!(stats.dataset_count, 1);
637        assert_eq!(stats.object_count, 2);
638    }
639
640    #[test]
641    fn test_list_operations() {
642        let (_dir, registry) = setup();
643
644        registry
645            .register_model("model-a", &ModelVersion::new(1, 0, 0), b"data", ModelCard::new("A"))
646            .unwrap();
647        registry
648            .register_model(
649                "model-a",
650                &ModelVersion::new(1, 1, 0),
651                b"data2",
652                ModelCard::new("A v1.1"),
653            )
654            .unwrap();
655        registry
656            .register_model("model-b", &ModelVersion::new(1, 0, 0), b"data3", ModelCard::new("B"))
657            .unwrap();
658
659        let models = registry.list_models().unwrap();
660        assert_eq!(models.len(), 2);
661        assert!(models.contains(&"model-a".to_string()));
662        assert!(models.contains(&"model-b".to_string()));
663
664        let versions = registry.list_model_versions("model-a").unwrap();
665        assert_eq!(versions.len(), 2);
666    }
667}