mod hyperparams;
mod version;
pub use hyperparams::{HyperparamValue, Hyperparameters};
pub use version::RecipeVersion;
use crate::data::DatasetReference;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct RecipeId(Uuid);
impl RecipeId {
#[must_use]
pub fn new() -> Self {
Self(Uuid::new_v4())
}
#[must_use]
pub fn from_uuid(uuid: Uuid) -> Self {
Self(uuid)
}
#[must_use]
pub fn as_uuid(&self) -> &Uuid {
&self.0
}
}
impl Default for RecipeId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for RecipeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::str::FromStr for RecipeId {
type Err = uuid::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(Uuid::parse_str(s)?))
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RecipeReference {
pub name: String,
pub version: RecipeVersion,
}
impl RecipeReference {
#[must_use]
pub fn new(name: impl Into<String>, version: RecipeVersion) -> Self {
Self { name: name.into(), version }
}
}
impl std::fmt::Display for RecipeReference {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.name, self.version)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingRecipe {
pub id: RecipeId,
pub name: String,
pub version: RecipeVersion,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub architecture: Option<String>,
pub hyperparameters: Hyperparameters,
#[serde(skip_serializing_if = "Option::is_none")]
pub optimizer: Option<OptimizerSpec>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scheduler: Option<SchedulerSpec>,
#[serde(skip_serializing_if = "Option::is_none")]
pub loss: Option<LossSpec>,
#[serde(skip_serializing_if = "Option::is_none")]
pub train_data: Option<DatasetReference>,
#[serde(skip_serializing_if = "Option::is_none")]
pub validation_data: Option<DatasetReference>,
#[serde(default)]
pub preprocessing: Vec<String>,
#[serde(default)]
pub augmentation: Vec<String>,
pub dependencies: Dependencies,
#[serde(skip_serializing_if = "Option::is_none")]
pub hardware: Option<HardwareSpec>,
#[serde(skip_serializing_if = "Option::is_none")]
pub random_seed: Option<u64>,
#[serde(default)]
pub deterministic: bool,
pub created_at: DateTime<Utc>,
#[serde(default)]
pub extra: HashMap<String, serde_json::Value>,
}
impl TrainingRecipe {
#[must_use]
pub fn builder() -> TrainingRecipeBuilder {
TrainingRecipeBuilder::new()
}
#[must_use]
pub fn reference(&self) -> RecipeReference {
RecipeReference::new(&self.name, self.version.clone())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizerSpec {
pub optimizer_type: String,
#[serde(default)]
pub params: HashMap<String, HyperparamValue>,
}
impl OptimizerSpec {
#[must_use]
pub fn new(optimizer_type: impl Into<String>) -> Self {
Self { optimizer_type: optimizer_type.into(), params: HashMap::new() }
}
#[must_use]
pub fn with_param(mut self, name: impl Into<String>, value: HyperparamValue) -> Self {
self.params.insert(name.into(), value);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerSpec {
pub scheduler_type: String,
#[serde(default)]
pub params: HashMap<String, HyperparamValue>,
}
impl SchedulerSpec {
#[must_use]
pub fn new(scheduler_type: impl Into<String>) -> Self {
Self { scheduler_type: scheduler_type.into(), params: HashMap::new() }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LossSpec {
pub loss_type: String,
#[serde(default)]
pub params: HashMap<String, HyperparamValue>,
}
impl LossSpec {
#[must_use]
pub fn new(loss_type: impl Into<String>) -> Self {
Self { loss_type: loss_type.into(), params: HashMap::new() }
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Dependencies {
#[serde(skip_serializing_if = "Option::is_none")]
pub rust_version: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cargo_lock_hash: Option<String>,
#[serde(default)]
pub system_deps: Vec<String>,
#[serde(default)]
pub env_vars: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HardwareSpec {
#[serde(skip_serializing_if = "Option::is_none")]
pub min_cpu_cores: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_ram_gb: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub gpu: Option<GpuRequirement>,
#[serde(skip_serializing_if = "Option::is_none")]
pub estimated_duration_secs: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuRequirement {
pub count: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_vram_gb: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub compute_capability: Option<String>,
}
#[derive(Debug)]
pub struct TrainingRecipeBuilder {
name: String,
version: RecipeVersion,
description: String,
architecture: Option<String>,
hyperparameters: Hyperparameters,
optimizer: Option<OptimizerSpec>,
scheduler: Option<SchedulerSpec>,
loss: Option<LossSpec>,
train_data: Option<DatasetReference>,
validation_data: Option<DatasetReference>,
preprocessing: Vec<String>,
augmentation: Vec<String>,
dependencies: Dependencies,
hardware: Option<HardwareSpec>,
random_seed: Option<u64>,
deterministic: bool,
}
impl TrainingRecipeBuilder {
#[must_use]
pub fn new() -> Self {
Self {
name: String::new(),
version: RecipeVersion::initial(),
description: String::new(),
architecture: None,
hyperparameters: Hyperparameters::default(),
optimizer: None,
scheduler: None,
loss: None,
train_data: None,
validation_data: None,
preprocessing: Vec::new(),
augmentation: Vec::new(),
dependencies: Dependencies::default(),
hardware: None,
random_seed: None,
deterministic: false,
}
}
#[must_use]
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
#[must_use]
pub fn version(mut self, version: RecipeVersion) -> Self {
self.version = version;
self
}
#[must_use]
pub fn description(mut self, description: impl Into<String>) -> Self {
self.description = description.into();
self
}
#[must_use]
pub fn architecture(mut self, architecture: impl Into<String>) -> Self {
self.architecture = Some(architecture.into());
self
}
#[must_use]
pub fn hyperparameters(mut self, hyperparameters: Hyperparameters) -> Self {
self.hyperparameters = hyperparameters;
self
}
#[must_use]
pub fn optimizer(mut self, optimizer: OptimizerSpec) -> Self {
self.optimizer = Some(optimizer);
self
}
#[must_use]
pub fn scheduler(mut self, scheduler: SchedulerSpec) -> Self {
self.scheduler = Some(scheduler);
self
}
#[must_use]
pub fn loss(mut self, loss: LossSpec) -> Self {
self.loss = Some(loss);
self
}
#[must_use]
pub fn train_data(mut self, data: DatasetReference) -> Self {
self.train_data = Some(data);
self
}
#[must_use]
pub fn validation_data(mut self, data: DatasetReference) -> Self {
self.validation_data = Some(data);
self
}
#[must_use]
pub fn random_seed(mut self, seed: u64) -> Self {
self.random_seed = Some(seed);
self
}
#[must_use]
pub fn deterministic(mut self, deterministic: bool) -> Self {
self.deterministic = deterministic;
self
}
#[must_use]
pub fn build(self) -> TrainingRecipe {
TrainingRecipe {
id: RecipeId::new(),
name: self.name,
version: self.version,
description: self.description,
architecture: self.architecture,
hyperparameters: self.hyperparameters,
optimizer: self.optimizer,
scheduler: self.scheduler,
loss: self.loss,
train_data: self.train_data,
validation_data: self.validation_data,
preprocessing: self.preprocessing,
augmentation: self.augmentation,
dependencies: self.dependencies,
hardware: self.hardware,
random_seed: self.random_seed,
deterministic: self.deterministic,
created_at: Utc::now(),
extra: HashMap::new(),
}
}
}
impl Default for TrainingRecipeBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_recipe_id_generation() {
let id1 = RecipeId::new();
let id2 = RecipeId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_recipe_reference_display() {
let reference = RecipeReference::new("bert-finetune", RecipeVersion::new(1, 2, 3));
assert_eq!(reference.to_string(), "bert-finetune:1.2.3");
}
#[test]
fn test_recipe_builder() {
let hyperparams = Hyperparameters {
learning_rate: 2e-5,
batch_size: 32,
epochs: 3,
..Default::default()
};
let recipe = TrainingRecipe::builder()
.name("bert-finetune")
.version(RecipeVersion::new(1, 0, 0))
.description("Fine-tune BERT for sentiment analysis")
.hyperparameters(hyperparams)
.optimizer(OptimizerSpec::new("adam"))
.loss(LossSpec::new("cross_entropy"))
.random_seed(42)
.deterministic(true)
.build();
assert_eq!(recipe.name, "bert-finetune");
assert_eq!(recipe.hyperparameters.learning_rate, 2e-5);
assert_eq!(recipe.hyperparameters.batch_size, 32);
assert_eq!(recipe.random_seed, Some(42));
assert!(recipe.deterministic);
}
#[test]
fn test_optimizer_spec() {
let optimizer = OptimizerSpec::new("adam")
.with_param("beta1", HyperparamValue::Float(0.9))
.with_param("beta2", HyperparamValue::Float(0.999));
assert_eq!(optimizer.optimizer_type, "adam");
assert_eq!(optimizer.params.len(), 2);
}
#[test]
fn test_recipe_serialization() {
let recipe = TrainingRecipe::builder().name("test-recipe").description("Test").build();
let json = serde_json::to_string(&recipe).unwrap();
let deserialized: TrainingRecipe = serde_json::from_str(&json).unwrap();
assert_eq!(recipe.name, deserialized.name);
}
}