use anyhow::{Context, Result};
use serde::Deserialize;
use std::collections::HashMap;
use std::fs;
use std::path::Path;
#[derive(Debug, Deserialize, Clone)]
pub struct OxidizeRecipe {
pub version: u32,
pub embedding_model: Option<String>,
pub projections: HashMap<String, ProjectionConfig>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct ProjectionConfig {
pub layers: Vec<usize>,
pub epochs: usize,
pub batch_size: usize,
}
impl OxidizeRecipe {
pub fn load() -> Result<Self> {
Self::load_from_path(".patina/oxidize.yaml")
}
pub fn load_from_path<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read recipe: {}", path.display()))?;
let recipe: OxidizeRecipe = serde_yaml::from_str(&content)
.with_context(|| format!("Failed to parse recipe YAML: {}", path.display()))?;
recipe.validate()?;
Ok(recipe)
}
pub fn get_model_name(&self) -> Result<String> {
use patina::embeddings::models::Config;
if let Some(ref model) = self.embedding_model {
return Ok(model.clone());
}
let config =
Config::load().context("No embedding_model in recipe and config.toml not found")?;
Ok(config.embeddings.model)
}
pub fn get_input_dim(&self) -> Result<usize> {
use patina::embeddings::models::ModelRegistry;
let model_name = self.get_model_name()?;
let registry = ModelRegistry::load()?;
let model_def = registry.get_model(&model_name)?;
Ok(model_def.dimensions)
}
fn validate(&self) -> Result<()> {
if self.version != 1 && self.version != 2 {
anyhow::bail!(
"Unsupported recipe version: {} (expected 1 or 2)",
self.version
);
}
if self.version == 1 && self.embedding_model.is_none() {
anyhow::bail!("Recipe v1 requires 'embedding_model' field");
}
if self.projections.is_empty() {
anyhow::bail!("Recipe must define at least one projection");
}
for (name, config) in &self.projections {
config
.validate(name, self.version)
.with_context(|| format!("Invalid projection config: {}", name))?;
}
Ok(())
}
}
impl ProjectionConfig {
fn validate(&self, name: &str, version: u32) -> Result<()> {
let expected = if version == 1 {
"[input, hidden, output]"
} else {
"[hidden, output]"
};
if version == 1 && self.layers.len() != 3 {
anyhow::bail!(
"Projection '{}': v1 layers must have 3 dimensions {}, got {}",
name,
expected,
self.layers.len()
);
}
if version == 2 && self.layers.len() != 2 && self.layers.len() != 3 {
anyhow::bail!(
"Projection '{}': v2 layers must have 2 or 3 dimensions, got {}",
name,
self.layers.len()
);
}
if self.layers.contains(&0) {
anyhow::bail!("Projection '{}': layer dimensions must be > 0", name);
}
if self.epochs == 0 {
anyhow::bail!("Projection '{}': epochs must be > 0", name);
}
if self.batch_size == 0 {
anyhow::bail!("Projection '{}': batch_size must be > 0", name);
}
Ok(())
}
pub fn input_dim(&self, recipe: &OxidizeRecipe) -> Result<usize> {
if self.layers.len() == 3 {
Ok(self.layers[0])
} else {
recipe.get_input_dim()
}
}
pub fn hidden_dim(&self) -> usize {
if self.layers.len() == 3 {
self.layers[1]
} else {
self.layers[0] }
}
pub fn output_dim(&self) -> usize {
if self.layers.len() == 3 {
self.layers[2]
} else {
self.layers[1] }
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_parse_valid_recipe() {
let yaml = r#"
version: 1
embedding_model: e5-base-v2
projections:
semantic:
layers: [768, 1024, 256]
epochs: 10
batch_size: 32
"#;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml.as_bytes()).unwrap();
temp_file.flush().unwrap();
let recipe = OxidizeRecipe::load_from_path(temp_file.path()).unwrap();
assert_eq!(recipe.version, 1);
assert_eq!(recipe.embedding_model.as_deref(), Some("e5-base-v2"));
assert_eq!(recipe.projections.len(), 1);
let semantic = recipe.projections.get("semantic").unwrap();
assert_eq!(semantic.layers, vec![768, 1024, 256]);
assert_eq!(semantic.epochs, 10);
assert_eq!(semantic.batch_size, 32);
assert_eq!(semantic.input_dim(&recipe).unwrap(), 768);
assert_eq!(semantic.hidden_dim(), 1024);
assert_eq!(semantic.output_dim(), 256);
}
#[test]
fn test_invalid_version() {
let yaml = r#"
version: 99
embedding_model: e5-base-v2
projections:
semantic:
layers: [768, 1024, 256]
epochs: 10
batch_size: 32
"#;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml.as_bytes()).unwrap();
temp_file.flush().unwrap();
let result = OxidizeRecipe::load_from_path(temp_file.path());
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Unsupported recipe version"));
}
#[test]
fn test_invalid_layers() {
let yaml = r#"
version: 1
embedding_model: e5-base-v2
projections:
semantic:
layers: [768, 256]
epochs: 10
batch_size: 32
"#;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml.as_bytes()).unwrap();
temp_file.flush().unwrap();
let result = OxidizeRecipe::load_from_path(temp_file.path());
assert!(result.is_err(), "Should reject invalid layer count");
}
#[test]
fn test_zero_epochs() {
let yaml = r#"
version: 1
embedding_model: e5-base-v2
projections:
semantic:
layers: [768, 1024, 256]
epochs: 0
batch_size: 32
"#;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml.as_bytes()).unwrap();
temp_file.flush().unwrap();
let result = OxidizeRecipe::load_from_path(temp_file.path());
assert!(result.is_err(), "Should reject zero epochs");
}
}