1mod hyperparams;
6mod version;
7
8pub use hyperparams::{HyperparamValue, Hyperparameters};
9pub use version::RecipeVersion;
10
11use crate::data::DatasetReference;
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use uuid::Uuid;
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
19pub struct RecipeId(Uuid);
20
21impl RecipeId {
22 #[must_use]
24 pub fn new() -> Self {
25 Self(Uuid::new_v4())
26 }
27
28 #[must_use]
30 pub fn from_uuid(uuid: Uuid) -> Self {
31 Self(uuid)
32 }
33
34 #[must_use]
36 pub fn as_uuid(&self) -> &Uuid {
37 &self.0
38 }
39}
40
41impl Default for RecipeId {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl std::fmt::Display for RecipeId {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(f, "{}", self.0)
50 }
51}
52
53impl std::str::FromStr for RecipeId {
54 type Err = uuid::Error;
55
56 fn from_str(s: &str) -> Result<Self, Self::Err> {
57 Ok(Self(Uuid::parse_str(s)?))
58 }
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
63pub struct RecipeReference {
64 pub name: String,
66 pub version: RecipeVersion,
68}
69
70impl RecipeReference {
71 #[must_use]
73 pub fn new(name: impl Into<String>, version: RecipeVersion) -> Self {
74 Self { name: name.into(), version }
75 }
76}
77
78impl std::fmt::Display for RecipeReference {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}:{}", self.name, self.version)
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct TrainingRecipe {
87 pub id: RecipeId,
89 pub name: String,
91 pub version: RecipeVersion,
93 pub description: String,
95
96 #[serde(skip_serializing_if = "Option::is_none")]
98 pub architecture: Option<String>,
99
100 pub hyperparameters: Hyperparameters,
102
103 #[serde(skip_serializing_if = "Option::is_none")]
105 pub optimizer: Option<OptimizerSpec>,
106
107 #[serde(skip_serializing_if = "Option::is_none")]
109 pub scheduler: Option<SchedulerSpec>,
110
111 #[serde(skip_serializing_if = "Option::is_none")]
113 pub loss: Option<LossSpec>,
114
115 #[serde(skip_serializing_if = "Option::is_none")]
117 pub train_data: Option<DatasetReference>,
118
119 #[serde(skip_serializing_if = "Option::is_none")]
121 pub validation_data: Option<DatasetReference>,
122
123 #[serde(default)]
125 pub preprocessing: Vec<String>,
126
127 #[serde(default)]
129 pub augmentation: Vec<String>,
130
131 pub dependencies: Dependencies,
133
134 #[serde(skip_serializing_if = "Option::is_none")]
136 pub hardware: Option<HardwareSpec>,
137
138 #[serde(skip_serializing_if = "Option::is_none")]
140 pub random_seed: Option<u64>,
141
142 #[serde(default)]
144 pub deterministic: bool,
145
146 pub created_at: DateTime<Utc>,
148
149 #[serde(default)]
151 pub extra: HashMap<String, serde_json::Value>,
152}
153
154impl TrainingRecipe {
155 #[must_use]
157 pub fn builder() -> TrainingRecipeBuilder {
158 TrainingRecipeBuilder::new()
159 }
160
161 #[must_use]
163 pub fn reference(&self) -> RecipeReference {
164 RecipeReference::new(&self.name, self.version.clone())
165 }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct OptimizerSpec {
171 pub optimizer_type: String,
173 #[serde(default)]
175 pub params: HashMap<String, HyperparamValue>,
176}
177
178impl OptimizerSpec {
179 #[must_use]
181 pub fn new(optimizer_type: impl Into<String>) -> Self {
182 Self { optimizer_type: optimizer_type.into(), params: HashMap::new() }
183 }
184
185 #[must_use]
187 pub fn with_param(mut self, name: impl Into<String>, value: HyperparamValue) -> Self {
188 self.params.insert(name.into(), value);
189 self
190 }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct SchedulerSpec {
196 pub scheduler_type: String,
198 #[serde(default)]
200 pub params: HashMap<String, HyperparamValue>,
201}
202
203impl SchedulerSpec {
204 #[must_use]
206 pub fn new(scheduler_type: impl Into<String>) -> Self {
207 Self { scheduler_type: scheduler_type.into(), params: HashMap::new() }
208 }
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct LossSpec {
214 pub loss_type: String,
216 #[serde(default)]
218 pub params: HashMap<String, HyperparamValue>,
219}
220
221impl LossSpec {
222 #[must_use]
224 pub fn new(loss_type: impl Into<String>) -> Self {
225 Self { loss_type: loss_type.into(), params: HashMap::new() }
226 }
227}
228
229#[derive(Debug, Clone, Default, Serialize, Deserialize)]
231pub struct Dependencies {
232 #[serde(skip_serializing_if = "Option::is_none")]
234 pub rust_version: Option<String>,
235 #[serde(skip_serializing_if = "Option::is_none")]
237 pub cargo_lock_hash: Option<String>,
238 #[serde(default)]
240 pub system_deps: Vec<String>,
241 #[serde(default)]
243 pub env_vars: HashMap<String, String>,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct HardwareSpec {
249 #[serde(skip_serializing_if = "Option::is_none")]
251 pub min_cpu_cores: Option<usize>,
252 #[serde(skip_serializing_if = "Option::is_none")]
254 pub min_ram_gb: Option<usize>,
255 #[serde(skip_serializing_if = "Option::is_none")]
257 pub gpu: Option<GpuRequirement>,
258 #[serde(skip_serializing_if = "Option::is_none")]
260 pub estimated_duration_secs: Option<u64>,
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
265pub struct GpuRequirement {
266 pub count: usize,
268 #[serde(skip_serializing_if = "Option::is_none")]
270 pub min_vram_gb: Option<usize>,
271 #[serde(skip_serializing_if = "Option::is_none")]
273 pub compute_capability: Option<String>,
274}
275
276#[derive(Debug)]
278pub struct TrainingRecipeBuilder {
279 name: String,
280 version: RecipeVersion,
281 description: String,
282 architecture: Option<String>,
283 hyperparameters: Hyperparameters,
284 optimizer: Option<OptimizerSpec>,
285 scheduler: Option<SchedulerSpec>,
286 loss: Option<LossSpec>,
287 train_data: Option<DatasetReference>,
288 validation_data: Option<DatasetReference>,
289 preprocessing: Vec<String>,
290 augmentation: Vec<String>,
291 dependencies: Dependencies,
292 hardware: Option<HardwareSpec>,
293 random_seed: Option<u64>,
294 deterministic: bool,
295}
296
297impl TrainingRecipeBuilder {
298 #[must_use]
300 pub fn new() -> Self {
301 Self {
302 name: String::new(),
303 version: RecipeVersion::initial(),
304 description: String::new(),
305 architecture: None,
306 hyperparameters: Hyperparameters::default(),
307 optimizer: None,
308 scheduler: None,
309 loss: None,
310 train_data: None,
311 validation_data: None,
312 preprocessing: Vec::new(),
313 augmentation: Vec::new(),
314 dependencies: Dependencies::default(),
315 hardware: None,
316 random_seed: None,
317 deterministic: false,
318 }
319 }
320
321 #[must_use]
323 pub fn name(mut self, name: impl Into<String>) -> Self {
324 self.name = name.into();
325 self
326 }
327
328 #[must_use]
330 pub fn version(mut self, version: RecipeVersion) -> Self {
331 self.version = version;
332 self
333 }
334
335 #[must_use]
337 pub fn description(mut self, description: impl Into<String>) -> Self {
338 self.description = description.into();
339 self
340 }
341
342 #[must_use]
344 pub fn architecture(mut self, architecture: impl Into<String>) -> Self {
345 self.architecture = Some(architecture.into());
346 self
347 }
348
349 #[must_use]
351 pub fn hyperparameters(mut self, hyperparameters: Hyperparameters) -> Self {
352 self.hyperparameters = hyperparameters;
353 self
354 }
355
356 #[must_use]
358 pub fn optimizer(mut self, optimizer: OptimizerSpec) -> Self {
359 self.optimizer = Some(optimizer);
360 self
361 }
362
363 #[must_use]
365 pub fn scheduler(mut self, scheduler: SchedulerSpec) -> Self {
366 self.scheduler = Some(scheduler);
367 self
368 }
369
370 #[must_use]
372 pub fn loss(mut self, loss: LossSpec) -> Self {
373 self.loss = Some(loss);
374 self
375 }
376
377 #[must_use]
379 pub fn train_data(mut self, data: DatasetReference) -> Self {
380 self.train_data = Some(data);
381 self
382 }
383
384 #[must_use]
386 pub fn validation_data(mut self, data: DatasetReference) -> Self {
387 self.validation_data = Some(data);
388 self
389 }
390
391 #[must_use]
393 pub fn random_seed(mut self, seed: u64) -> Self {
394 self.random_seed = Some(seed);
395 self
396 }
397
398 #[must_use]
400 pub fn deterministic(mut self, deterministic: bool) -> Self {
401 self.deterministic = deterministic;
402 self
403 }
404
405 #[must_use]
407 pub fn build(self) -> TrainingRecipe {
408 TrainingRecipe {
409 id: RecipeId::new(),
410 name: self.name,
411 version: self.version,
412 description: self.description,
413 architecture: self.architecture,
414 hyperparameters: self.hyperparameters,
415 optimizer: self.optimizer,
416 scheduler: self.scheduler,
417 loss: self.loss,
418 train_data: self.train_data,
419 validation_data: self.validation_data,
420 preprocessing: self.preprocessing,
421 augmentation: self.augmentation,
422 dependencies: self.dependencies,
423 hardware: self.hardware,
424 random_seed: self.random_seed,
425 deterministic: self.deterministic,
426 created_at: Utc::now(),
427 extra: HashMap::new(),
428 }
429 }
430}
431
432impl Default for TrainingRecipeBuilder {
433 fn default() -> Self {
434 Self::new()
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_recipe_id_generation() {
444 let id1 = RecipeId::new();
445 let id2 = RecipeId::new();
446 assert_ne!(id1, id2);
447 }
448
449 #[test]
450 fn test_recipe_reference_display() {
451 let reference = RecipeReference::new("bert-finetune", RecipeVersion::new(1, 2, 3));
452 assert_eq!(reference.to_string(), "bert-finetune:1.2.3");
453 }
454
455 #[test]
456 fn test_recipe_builder() {
457 let hyperparams = Hyperparameters {
458 learning_rate: 2e-5,
459 batch_size: 32,
460 epochs: 3,
461 ..Default::default()
462 };
463
464 let recipe = TrainingRecipe::builder()
465 .name("bert-finetune")
466 .version(RecipeVersion::new(1, 0, 0))
467 .description("Fine-tune BERT for sentiment analysis")
468 .hyperparameters(hyperparams)
469 .optimizer(OptimizerSpec::new("adam"))
470 .loss(LossSpec::new("cross_entropy"))
471 .random_seed(42)
472 .deterministic(true)
473 .build();
474
475 assert_eq!(recipe.name, "bert-finetune");
476 assert_eq!(recipe.hyperparameters.learning_rate, 2e-5);
477 assert_eq!(recipe.hyperparameters.batch_size, 32);
478 assert_eq!(recipe.random_seed, Some(42));
479 assert!(recipe.deterministic);
480 }
481
482 #[test]
483 fn test_optimizer_spec() {
484 let optimizer = OptimizerSpec::new("adam")
485 .with_param("beta1", HyperparamValue::Float(0.9))
486 .with_param("beta2", HyperparamValue::Float(0.999));
487
488 assert_eq!(optimizer.optimizer_type, "adam");
489 assert_eq!(optimizer.params.len(), 2);
490 }
491
492 #[test]
493 fn test_recipe_serialization() {
494 let recipe = TrainingRecipe::builder().name("test-recipe").description("Test").build();
495
496 let json = serde_json::to_string(&recipe).unwrap();
497 let deserialized: TrainingRecipe = serde_json::from_str(&json).unwrap();
498
499 assert_eq!(recipe.name, deserialized.name);
500 }
501}