Skip to main content

pacha/recipe/
mod.rs

1//! Recipe registry types and operations.
2//!
3//! Provides training recipes for reproducible ML workflows.
4
5mod hyperparams;
6mod version;
7
8pub use hyperparams::{HyperparamValue, Hyperparameters};
9pub use version::RecipeVersion;
10
11use crate::data::DatasetReference;
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use uuid::Uuid;
16
17/// Unique identifier for a registered recipe.
18#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
19pub struct RecipeId(Uuid);
20
21impl RecipeId {
22    /// Create a new random recipe ID.
23    #[must_use]
24    pub fn new() -> Self {
25        Self(Uuid::new_v4())
26    }
27
28    /// Create from a UUID.
29    #[must_use]
30    pub fn from_uuid(uuid: Uuid) -> Self {
31        Self(uuid)
32    }
33
34    /// Get the underlying UUID.
35    #[must_use]
36    pub fn as_uuid(&self) -> &Uuid {
37        &self.0
38    }
39}
40
41impl Default for RecipeId {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl std::fmt::Display for RecipeId {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "{}", self.0)
50    }
51}
52
53impl std::str::FromStr for RecipeId {
54    type Err = uuid::Error;
55
56    fn from_str(s: &str) -> Result<Self, Self::Err> {
57        Ok(Self(Uuid::parse_str(s)?))
58    }
59}
60
61/// Reference to a recipe (name + version).
62#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
63pub struct RecipeReference {
64    /// Recipe name.
65    pub name: String,
66    /// Recipe version.
67    pub version: RecipeVersion,
68}
69
70impl RecipeReference {
71    /// Create a new recipe reference.
72    #[must_use]
73    pub fn new(name: impl Into<String>, version: RecipeVersion) -> Self {
74        Self { name: name.into(), version }
75    }
76}
77
78impl std::fmt::Display for RecipeReference {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        write!(f, "{}:{}", self.name, self.version)
81    }
82}
83
84/// Training recipe for reproducible ML workflows.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct TrainingRecipe {
87    /// Unique identifier.
88    pub id: RecipeId,
89    /// Recipe name.
90    pub name: String,
91    /// Recipe version.
92    pub version: RecipeVersion,
93    /// Description.
94    pub description: String,
95
96    /// Model architecture specification.
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub architecture: Option<String>,
99
100    /// Training hyperparameters.
101    pub hyperparameters: Hyperparameters,
102
103    /// Optimizer configuration.
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub optimizer: Option<OptimizerSpec>,
106
107    /// Learning rate scheduler.
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub scheduler: Option<SchedulerSpec>,
110
111    /// Loss function.
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub loss: Option<LossSpec>,
114
115    /// Training data reference.
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub train_data: Option<DatasetReference>,
118
119    /// Validation data reference.
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub validation_data: Option<DatasetReference>,
122
123    /// Preprocessing steps.
124    #[serde(default)]
125    pub preprocessing: Vec<String>,
126
127    /// Data augmentation steps.
128    #[serde(default)]
129    pub augmentation: Vec<String>,
130
131    /// Environment dependencies.
132    pub dependencies: Dependencies,
133
134    /// Hardware requirements.
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub hardware: Option<HardwareSpec>,
137
138    /// Random seed for reproducibility.
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub random_seed: Option<u64>,
141
142    /// Whether the recipe produces deterministic results.
143    #[serde(default)]
144    pub deterministic: bool,
145
146    /// Registration timestamp.
147    pub created_at: DateTime<Utc>,
148
149    /// Additional metadata.
150    #[serde(default)]
151    pub extra: HashMap<String, serde_json::Value>,
152}
153
154impl TrainingRecipe {
155    /// Create a new recipe builder.
156    #[must_use]
157    pub fn builder() -> TrainingRecipeBuilder {
158        TrainingRecipeBuilder::new()
159    }
160
161    /// Create a reference to this recipe.
162    #[must_use]
163    pub fn reference(&self) -> RecipeReference {
164        RecipeReference::new(&self.name, self.version.clone())
165    }
166}
167
168/// Optimizer specification.
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct OptimizerSpec {
171    /// Optimizer type (e.g., "adam", "sgd").
172    pub optimizer_type: String,
173    /// Optimizer-specific parameters.
174    #[serde(default)]
175    pub params: HashMap<String, HyperparamValue>,
176}
177
178impl OptimizerSpec {
179    /// Create a new optimizer spec.
180    #[must_use]
181    pub fn new(optimizer_type: impl Into<String>) -> Self {
182        Self { optimizer_type: optimizer_type.into(), params: HashMap::new() }
183    }
184
185    /// Add a parameter.
186    #[must_use]
187    pub fn with_param(mut self, name: impl Into<String>, value: HyperparamValue) -> Self {
188        self.params.insert(name.into(), value);
189        self
190    }
191}
192
193/// Learning rate scheduler specification.
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct SchedulerSpec {
196    /// Scheduler type (e.g., "cosine", "step").
197    pub scheduler_type: String,
198    /// Scheduler-specific parameters.
199    #[serde(default)]
200    pub params: HashMap<String, HyperparamValue>,
201}
202
203impl SchedulerSpec {
204    /// Create a new scheduler spec.
205    #[must_use]
206    pub fn new(scheduler_type: impl Into<String>) -> Self {
207        Self { scheduler_type: scheduler_type.into(), params: HashMap::new() }
208    }
209}
210
211/// Loss function specification.
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct LossSpec {
214    /// Loss function type (e.g., "`cross_entropy`", "mse").
215    pub loss_type: String,
216    /// Loss-specific parameters.
217    #[serde(default)]
218    pub params: HashMap<String, HyperparamValue>,
219}
220
221impl LossSpec {
222    /// Create a new loss spec.
223    #[must_use]
224    pub fn new(loss_type: impl Into<String>) -> Self {
225        Self { loss_type: loss_type.into(), params: HashMap::new() }
226    }
227}
228
229/// Environment dependencies.
230#[derive(Debug, Clone, Default, Serialize, Deserialize)]
231pub struct Dependencies {
232    /// Rust toolchain version.
233    #[serde(skip_serializing_if = "Option::is_none")]
234    pub rust_version: Option<String>,
235    /// Cargo.lock hash for exact reproducibility.
236    #[serde(skip_serializing_if = "Option::is_none")]
237    pub cargo_lock_hash: Option<String>,
238    /// System dependencies.
239    #[serde(default)]
240    pub system_deps: Vec<String>,
241    /// Environment variables (non-sensitive).
242    #[serde(default)]
243    pub env_vars: HashMap<String, String>,
244}
245
246/// Hardware requirements.
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct HardwareSpec {
249    /// Minimum CPU cores.
250    #[serde(skip_serializing_if = "Option::is_none")]
251    pub min_cpu_cores: Option<usize>,
252    /// Minimum RAM in GB.
253    #[serde(skip_serializing_if = "Option::is_none")]
254    pub min_ram_gb: Option<usize>,
255    /// GPU requirements.
256    #[serde(skip_serializing_if = "Option::is_none")]
257    pub gpu: Option<GpuRequirement>,
258    /// Estimated training time in seconds.
259    #[serde(skip_serializing_if = "Option::is_none")]
260    pub estimated_duration_secs: Option<u64>,
261}
262
263/// GPU requirement.
264#[derive(Debug, Clone, Serialize, Deserialize)]
265pub struct GpuRequirement {
266    /// Minimum GPU count.
267    pub count: usize,
268    /// Minimum VRAM per GPU in GB.
269    #[serde(skip_serializing_if = "Option::is_none")]
270    pub min_vram_gb: Option<usize>,
271    /// Required compute capability.
272    #[serde(skip_serializing_if = "Option::is_none")]
273    pub compute_capability: Option<String>,
274}
275
276/// Builder for creating training recipes.
277#[derive(Debug)]
278pub struct TrainingRecipeBuilder {
279    name: String,
280    version: RecipeVersion,
281    description: String,
282    architecture: Option<String>,
283    hyperparameters: Hyperparameters,
284    optimizer: Option<OptimizerSpec>,
285    scheduler: Option<SchedulerSpec>,
286    loss: Option<LossSpec>,
287    train_data: Option<DatasetReference>,
288    validation_data: Option<DatasetReference>,
289    preprocessing: Vec<String>,
290    augmentation: Vec<String>,
291    dependencies: Dependencies,
292    hardware: Option<HardwareSpec>,
293    random_seed: Option<u64>,
294    deterministic: bool,
295}
296
297impl TrainingRecipeBuilder {
298    /// Create a new builder.
299    #[must_use]
300    pub fn new() -> Self {
301        Self {
302            name: String::new(),
303            version: RecipeVersion::initial(),
304            description: String::new(),
305            architecture: None,
306            hyperparameters: Hyperparameters::default(),
307            optimizer: None,
308            scheduler: None,
309            loss: None,
310            train_data: None,
311            validation_data: None,
312            preprocessing: Vec::new(),
313            augmentation: Vec::new(),
314            dependencies: Dependencies::default(),
315            hardware: None,
316            random_seed: None,
317            deterministic: false,
318        }
319    }
320
321    /// Set the name.
322    #[must_use]
323    pub fn name(mut self, name: impl Into<String>) -> Self {
324        self.name = name.into();
325        self
326    }
327
328    /// Set the version.
329    #[must_use]
330    pub fn version(mut self, version: RecipeVersion) -> Self {
331        self.version = version;
332        self
333    }
334
335    /// Set the description.
336    #[must_use]
337    pub fn description(mut self, description: impl Into<String>) -> Self {
338        self.description = description.into();
339        self
340    }
341
342    /// Set the architecture.
343    #[must_use]
344    pub fn architecture(mut self, architecture: impl Into<String>) -> Self {
345        self.architecture = Some(architecture.into());
346        self
347    }
348
349    /// Set hyperparameters.
350    #[must_use]
351    pub fn hyperparameters(mut self, hyperparameters: Hyperparameters) -> Self {
352        self.hyperparameters = hyperparameters;
353        self
354    }
355
356    /// Set optimizer.
357    #[must_use]
358    pub fn optimizer(mut self, optimizer: OptimizerSpec) -> Self {
359        self.optimizer = Some(optimizer);
360        self
361    }
362
363    /// Set scheduler.
364    #[must_use]
365    pub fn scheduler(mut self, scheduler: SchedulerSpec) -> Self {
366        self.scheduler = Some(scheduler);
367        self
368    }
369
370    /// Set loss.
371    #[must_use]
372    pub fn loss(mut self, loss: LossSpec) -> Self {
373        self.loss = Some(loss);
374        self
375    }
376
377    /// Set training data.
378    #[must_use]
379    pub fn train_data(mut self, data: DatasetReference) -> Self {
380        self.train_data = Some(data);
381        self
382    }
383
384    /// Set validation data.
385    #[must_use]
386    pub fn validation_data(mut self, data: DatasetReference) -> Self {
387        self.validation_data = Some(data);
388        self
389    }
390
391    /// Set random seed.
392    #[must_use]
393    pub fn random_seed(mut self, seed: u64) -> Self {
394        self.random_seed = Some(seed);
395        self
396    }
397
398    /// Set deterministic flag.
399    #[must_use]
400    pub fn deterministic(mut self, deterministic: bool) -> Self {
401        self.deterministic = deterministic;
402        self
403    }
404
405    /// Build the recipe.
406    #[must_use]
407    pub fn build(self) -> TrainingRecipe {
408        TrainingRecipe {
409            id: RecipeId::new(),
410            name: self.name,
411            version: self.version,
412            description: self.description,
413            architecture: self.architecture,
414            hyperparameters: self.hyperparameters,
415            optimizer: self.optimizer,
416            scheduler: self.scheduler,
417            loss: self.loss,
418            train_data: self.train_data,
419            validation_data: self.validation_data,
420            preprocessing: self.preprocessing,
421            augmentation: self.augmentation,
422            dependencies: self.dependencies,
423            hardware: self.hardware,
424            random_seed: self.random_seed,
425            deterministic: self.deterministic,
426            created_at: Utc::now(),
427            extra: HashMap::new(),
428        }
429    }
430}
431
432impl Default for TrainingRecipeBuilder {
433    fn default() -> Self {
434        Self::new()
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_recipe_id_generation() {
444        let id1 = RecipeId::new();
445        let id2 = RecipeId::new();
446        assert_ne!(id1, id2);
447    }
448
449    #[test]
450    fn test_recipe_reference_display() {
451        let reference = RecipeReference::new("bert-finetune", RecipeVersion::new(1, 2, 3));
452        assert_eq!(reference.to_string(), "bert-finetune:1.2.3");
453    }
454
455    #[test]
456    fn test_recipe_builder() {
457        let hyperparams = Hyperparameters {
458            learning_rate: 2e-5,
459            batch_size: 32,
460            epochs: 3,
461            ..Default::default()
462        };
463
464        let recipe = TrainingRecipe::builder()
465            .name("bert-finetune")
466            .version(RecipeVersion::new(1, 0, 0))
467            .description("Fine-tune BERT for sentiment analysis")
468            .hyperparameters(hyperparams)
469            .optimizer(OptimizerSpec::new("adam"))
470            .loss(LossSpec::new("cross_entropy"))
471            .random_seed(42)
472            .deterministic(true)
473            .build();
474
475        assert_eq!(recipe.name, "bert-finetune");
476        assert_eq!(recipe.hyperparameters.learning_rate, 2e-5);
477        assert_eq!(recipe.hyperparameters.batch_size, 32);
478        assert_eq!(recipe.random_seed, Some(42));
479        assert!(recipe.deterministic);
480    }
481
482    #[test]
483    fn test_optimizer_spec() {
484        let optimizer = OptimizerSpec::new("adam")
485            .with_param("beta1", HyperparamValue::Float(0.9))
486            .with_param("beta2", HyperparamValue::Float(0.999));
487
488        assert_eq!(optimizer.optimizer_type, "adam");
489        assert_eq!(optimizer.params.len(), 2);
490    }
491
492    #[test]
493    fn test_recipe_serialization() {
494        let recipe = TrainingRecipe::builder().name("test-recipe").description("Test").build();
495
496        let json = serde_json::to_string(&recipe).unwrap();
497        let deserialized: TrainingRecipe = serde_json::from_str(&json).unwrap();
498
499        assert_eq!(recipe.name, deserialized.name);
500    }
501}