use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineConfig {
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub registry: Option<String>,
#[serde(deserialize_with = "deserialize_stages")]
pub stages: Vec<StageConfig>,
}
impl PipelineConfig {
pub fn from_yaml(yaml: &str) -> Result<Self, serde_yaml::Error> {
serde_yaml::from_str(yaml)
}
pub fn registry_url(&self) -> &str {
self.registry.as_deref().unwrap_or("https://api.xybrid.dev")
}
pub fn stages(&self) -> &[StageConfig] {
&self.stages
}
pub fn stage_count(&self) -> usize {
self.stages.len()
}
pub fn stage_names(&self) -> Vec<String> {
self.stages.iter().map(|s| s.stage_id()).collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum StageConfig {
Simple(String),
Object(StageObjectConfig),
}
impl StageConfig {
pub fn model_id(&self) -> String {
match self {
StageConfig::Simple(s) => {
s.split('@').next().unwrap_or(s).to_string()
}
StageConfig::Object(obj) => obj
.model
.clone()
.unwrap_or_else(|| obj.id.clone().unwrap_or_else(|| "unknown".to_string())),
}
}
pub fn stage_id(&self) -> String {
match self {
StageConfig::Simple(_) => self.model_id(),
StageConfig::Object(obj) => obj.id.clone().unwrap_or_else(|| self.model_id()),
}
}
pub fn version(&self) -> Option<String> {
match self {
StageConfig::Simple(s) => {
s.split('@').nth(1).map(|v| v.to_string())
}
StageConfig::Object(obj) => obj.version.clone(),
}
}
pub fn target(&self) -> Option<&str> {
match self {
StageConfig::Simple(_) => None, StageConfig::Object(obj) => obj.target.as_deref(),
}
}
pub fn provider(&self) -> Option<&str> {
match self {
StageConfig::Simple(_) => None,
StageConfig::Object(obj) => obj.provider.as_deref(),
}
}
pub fn options(&self) -> HashMap<String, serde_json::Value> {
match self {
StageConfig::Simple(_) => HashMap::new(),
StageConfig::Object(obj) => obj.options.clone(),
}
}
pub fn is_cloud_stage(&self) -> bool {
matches!(self.target(), Some("cloud") | Some("integration")) || self.provider().is_some()
}
pub fn is_device_stage(&self) -> bool {
matches!(self.target(), Some("device"))
}
pub fn execution_provider(&self) -> Option<&str> {
match self {
StageConfig::Simple(_) => None,
StageConfig::Object(obj) => obj.execution_provider.as_deref(),
}
}
pub fn to_object(&self) -> StageObjectConfig {
match self {
StageConfig::Simple(s) => {
let (model, version) = if s.contains('@') {
let parts: Vec<&str> = s.split('@').collect();
(
Some(parts[0].to_string()),
parts.get(1).map(|v| v.to_string()),
)
} else {
(Some(s.clone()), None)
};
StageObjectConfig {
id: None,
model,
version,
target: None,
provider: None,
execution_provider: None,
options: HashMap::new(),
}
}
StageConfig::Object(obj) => obj.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct StageObjectConfig {
#[serde(default)]
pub id: Option<String>,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub version: Option<String>,
#[serde(default)]
pub target: Option<String>,
#[serde(default)]
pub provider: Option<String>,
#[serde(default)]
pub execution_provider: Option<String>,
#[serde(default, flatten)]
pub options: HashMap<String, serde_json::Value>,
}
fn deserialize_stages<'de, D>(deserializer: D) -> Result<Vec<StageConfig>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let raw: Vec<serde_json::Value> = Vec::deserialize(deserializer)?;
let mut stages = Vec::new();
for (i, item) in raw.into_iter().enumerate() {
match item {
serde_json::Value::String(s) => {
stages.push(StageConfig::Simple(s));
}
serde_json::Value::Object(_) => {
let config: StageObjectConfig = serde_json::from_value(item)
.map_err(|e| D::Error::custom(format!("Invalid stage {}: {}", i, e)))?;
stages.push(StageConfig::Object(config));
}
_ => {
return Err(D::Error::custom(format!(
"Stage {} must be a string or object",
i
)));
}
}
}
if stages.is_empty() {
return Err(D::Error::custom("Pipeline must have at least one stage"));
}
Ok(stages)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ExecutionTarget {
Device,
Cloud,
#[default]
Auto,
}
impl ExecutionTarget {
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"device" | "local" | "edge" => ExecutionTarget::Device,
"cloud" | "server" | "integration" => ExecutionTarget::Cloud,
_ => ExecutionTarget::Auto,
}
}
}
impl std::fmt::Display for ExecutionTarget {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExecutionTarget::Device => write!(f, "device"),
ExecutionTarget::Cloud => write!(f, "cloud"),
ExecutionTarget::Auto => write!(f, "auto"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_minimal_pipeline() {
let yaml = r#"
stages:
- kokoro-82m
"#;
let config = PipelineConfig::from_yaml(yaml).unwrap();
assert_eq!(config.stage_count(), 1);
assert_eq!(config.stages[0].model_id(), "kokoro-82m");
assert_eq!(config.registry_url(), "https://api.xybrid.dev");
}
#[test]
fn test_pipeline_with_version() {
let yaml = r#"
stages:
- whisper-tiny@1.0
"#;
let config = PipelineConfig::from_yaml(yaml).unwrap();
assert_eq!(config.stages[0].model_id(), "whisper-tiny");
assert_eq!(config.stages[0].version(), Some("1.0".to_string()));
}
#[test]
fn test_mixed_stages() {
let yaml = r#"
name: voice-assistant
stages:
- whisper-tiny
- model: gpt-4o-mini
target: cloud
provider: openai
system_prompt: "Be concise."
- kokoro-82m
"#;
let config = PipelineConfig::from_yaml(yaml).unwrap();
assert_eq!(config.name, Some("voice-assistant".to_string()));
assert_eq!(config.stage_count(), 3);
assert_eq!(config.stages[0].model_id(), "whisper-tiny");
assert!(!config.stages[0].is_cloud_stage());
assert_eq!(config.stages[1].model_id(), "gpt-4o-mini");
assert!(config.stages[1].is_cloud_stage());
assert_eq!(config.stages[1].provider(), Some("openai"));
assert_eq!(config.stages[2].model_id(), "kokoro-82m");
}
#[test]
fn test_custom_registry() {
let yaml = r#"
registry: "http://localhost:8080"
stages:
- test-model
"#;
let config = PipelineConfig::from_yaml(yaml).unwrap();
assert_eq!(config.registry_url(), "http://localhost:8080");
}
#[test]
fn test_stage_with_id() {
let yaml = r#"
stages:
- id: asr
model: whisper-tiny
"#;
let config = PipelineConfig::from_yaml(yaml).unwrap();
assert_eq!(config.stages[0].stage_id(), "asr");
assert_eq!(config.stages[0].model_id(), "whisper-tiny");
}
#[test]
fn test_empty_stages_fails() {
let yaml = r#"
stages: []
"#;
let result = PipelineConfig::from_yaml(yaml);
assert!(result.is_err());
}
#[test]
fn test_to_object() {
let simple = StageConfig::Simple("model@1.0".to_string());
let obj = simple.to_object();
assert_eq!(obj.model, Some("model".to_string()));
assert_eq!(obj.version, Some("1.0".to_string()));
}
#[test]
fn test_execution_provider_override() {
let yaml = r#"
stages:
- model: mobilenet-v2
execution_provider: coreml-ane
"#;
let config = PipelineConfig::from_yaml(yaml).unwrap();
assert_eq!(config.stages[0].execution_provider(), Some("coreml-ane"));
}
#[test]
fn test_execution_provider_default_none() {
let yaml = r#"
stages:
- mobilenet-v2
"#;
let config = PipelineConfig::from_yaml(yaml).unwrap();
assert_eq!(config.stages[0].execution_provider(), None);
}
}