use super::{
ConditionEvaluator, ConditionResult, IntegrationProvider, PipelineConfig, ResolutionContext,
StageConfig, StageOutputContext, TargetResolver,
};
use crate::context::{DeviceMetrics, StageDescriptor};
use crate::device::capabilities::HardwareCapabilities;
use crate::ir::{Envelope, EnvelopeKind};
use crate::orchestrator::routing_engine::LocalAvailability;
use crate::orchestrator::{Orchestrator, OrchestratorError};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug, Clone)]
pub enum PipelineRunnerError {
ValidationFailed(String),
ConditionFailed(String),
ResolutionFailed(String),
ExecutionFailed(String),
InputConversionFailed(String),
OutputConversionFailed(String),
OrchestratorError(String),
}
impl std::fmt::Display for PipelineRunnerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PipelineRunnerError::ValidationFailed(msg) => {
write!(f, "Pipeline validation failed: {}", msg)
}
PipelineRunnerError::ConditionFailed(msg) => {
write!(f, "Condition evaluation failed: {}", msg)
}
PipelineRunnerError::ResolutionFailed(msg) => {
write!(f, "Target resolution failed: {}", msg)
}
PipelineRunnerError::ExecutionFailed(msg) => {
write!(f, "Stage execution failed: {}", msg)
}
PipelineRunnerError::InputConversionFailed(msg) => {
write!(f, "Input conversion failed: {}", msg)
}
PipelineRunnerError::OutputConversionFailed(msg) => {
write!(f, "Output conversion failed: {}", msg)
}
PipelineRunnerError::OrchestratorError(msg) => {
write!(f, "Orchestrator error: {}", msg)
}
}
}
}
impl std::error::Error for PipelineRunnerError {}
impl From<OrchestratorError> for PipelineRunnerError {
fn from(err: OrchestratorError) -> Self {
PipelineRunnerError::OrchestratorError(err.to_string())
}
}
#[derive(Debug, Clone)]
pub struct StageResult {
pub stage_id: String,
pub executed: bool,
pub skip_reason: Option<String>,
pub target: Option<String>,
pub output: Option<Value>,
pub latency_ms: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineResult {
#[serde(default)]
pub name: Option<String>,
pub success: bool,
#[serde(default)]
pub error: Option<String>,
#[serde(default)]
pub stages: HashMap<String, StageResultSummary>,
pub output_type: OutputResultType,
pub output: OutputResult,
pub total_latency_ms: u32,
pub stages_executed: usize,
pub stages_skipped: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StageResultSummary {
pub executed: bool,
#[serde(default)]
pub skip_reason: Option<String>,
#[serde(default)]
pub target: Option<String>,
pub latency_ms: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum OutputResultType {
Text,
Audio,
Embedding,
Image,
Json,
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
#[derive(Default)]
pub enum OutputResult {
Text(String),
Audio { bytes: Vec<u8>, sample_rate: u32 },
Embedding(Vec<f32>),
Image { bytes: Vec<u8>, format: String },
Json(Value),
#[default]
None,
}
#[derive(Debug, Clone, Default)]
pub struct RunnerConfig {
pub metrics: DeviceMetrics,
pub capabilities: HardwareCapabilities,
pub local_models: HashMap<String, bool>,
pub server_models: HashMap<String, bool>,
pub integrations: HashMap<IntegrationProvider, bool>,
}
pub struct PipelineRunner {
orchestrator: Orchestrator,
config: RunnerConfig,
output_context: StageOutputContext,
}
impl PipelineRunner {
pub fn new() -> Self {
Self {
orchestrator: Orchestrator::new(),
config: RunnerConfig::default(),
output_context: StageOutputContext::new(),
}
}
pub fn with_config(config: RunnerConfig) -> Self {
Self {
orchestrator: Orchestrator::new(),
config,
output_context: StageOutputContext::new(),
}
}
pub fn with_orchestrator(orchestrator: Orchestrator, config: RunnerConfig) -> Self {
Self {
orchestrator,
config,
output_context: StageOutputContext::new(),
}
}
pub fn run_yaml(
&mut self,
yaml: &str,
input: Envelope,
) -> Result<PipelineResult, PipelineRunnerError> {
let pipeline =
PipelineConfig::from_yaml(yaml).map_err(PipelineRunnerError::ValidationFailed)?;
self.run(&pipeline, input)
}
pub fn run(
&mut self,
pipeline: &PipelineConfig,
input: Envelope,
) -> Result<PipelineResult, PipelineRunnerError> {
let start_time = Instant::now();
pipeline
.validate()
.map_err(PipelineRunnerError::ValidationFailed)?;
self.output_context = StageOutputContext::new();
let mut stage_results: Vec<StageResult> = Vec::new();
let mut current_input = input;
let mut stages_executed = 0;
let mut stages_skipped = 0;
for stage_config in &pipeline.stages {
let stage_result = self.execute_stage(stage_config, ¤t_input)?;
if stage_result.executed {
stages_executed += 1;
if let Some(ref output) = stage_result.output {
current_input = self.value_to_envelope(output);
self.output_context
.add_output(&stage_config.id, output.clone());
}
} else {
stages_skipped += 1;
}
stage_results.push(stage_result);
}
let total_latency_ms = start_time.elapsed().as_millis() as u32;
let (output_type, output) = self.extract_final_output(¤t_input);
let mut stages = HashMap::new();
for result in &stage_results {
stages.insert(
result.stage_id.clone(),
StageResultSummary {
executed: result.executed,
skip_reason: result.skip_reason.clone(),
target: result.target.clone(),
latency_ms: result.latency_ms,
},
);
}
Ok(PipelineResult {
name: pipeline.name.clone(),
success: true,
error: None,
stages,
output_type,
output,
total_latency_ms,
stages_executed,
stages_skipped,
})
}
fn execute_stage(
&mut self,
stage_config: &StageConfig,
input: &Envelope,
) -> Result<StageResult, PipelineRunnerError> {
let stage_start = Instant::now();
if let Some(ref condition) = stage_config.when {
match ConditionEvaluator::evaluate(condition, &self.output_context) {
ConditionResult::True => {
}
ConditionResult::False => {
return Ok(StageResult {
stage_id: stage_config.id.clone(),
executed: false,
skip_reason: Some(format!("Condition '{}' evaluated to false", condition)),
target: None,
output: None,
latency_ms: 0,
});
}
ConditionResult::Error(err) => {
return Err(PipelineRunnerError::ConditionFailed(format!(
"Stage '{}': {}",
stage_config.id, err
)));
}
}
}
let resolution_context = self.build_resolution_context(stage_config);
let resolved = TargetResolver::resolve(stage_config, &resolution_context)
.map_err(|e| PipelineRunnerError::ResolutionFailed(format!("{:?}", e)))?;
let stage_descriptor = self.stage_config_to_descriptor(stage_config);
let availability = LocalAvailability::new(
self.config
.local_models
.get(&stage_config.model)
.copied()
.unwrap_or(true),
);
let exec_result = self.orchestrator.execute_stage(
&stage_descriptor,
input,
&self.config.metrics,
&availability,
)?;
let output_value = self.envelope_to_value(&exec_result.output);
let latency_ms = stage_start.elapsed().as_millis() as u32;
Ok(StageResult {
stage_id: stage_config.id.clone(),
executed: true,
skip_reason: None,
target: Some(resolved.target.to_string()),
output: Some(output_value),
latency_ms,
})
}
fn build_resolution_context(&self, stage_config: &StageConfig) -> ResolutionContext {
ResolutionContext {
metrics: self.config.metrics.clone(),
local_available: self
.config
.local_models
.get(&stage_config.model)
.copied()
.unwrap_or(false),
server_available: self
.config
.server_models
.get(&stage_config.model)
.copied()
.unwrap_or(false),
integration_available: self.config.integrations.clone(),
capabilities: self.config.capabilities.clone(),
}
}
fn stage_config_to_descriptor(&self, stage_config: &StageConfig) -> StageDescriptor {
StageDescriptor {
name: stage_config.model_identifier(),
bundle_path: None, target: Some(stage_config.target.clone()),
provider: stage_config.provider,
model: Some(stage_config.model.clone()),
options: Some(stage_config.options.clone()),
}
}
fn envelope_to_value(&self, envelope: &Envelope) -> Value {
match &envelope.kind {
EnvelopeKind::Text(text) => {
serde_json::json!({
"type": "text",
"output": text
})
}
EnvelopeKind::Audio(bytes) => {
serde_json::json!({
"type": "audio",
"bytes_len": bytes.len()
})
}
EnvelopeKind::Embedding(values) => {
serde_json::json!({
"type": "embedding",
"dimensions": values.len(),
"output": values
})
}
}
}
fn value_to_envelope(&self, value: &Value) -> Envelope {
if let Some(obj) = value.as_object() {
if let Some(type_str) = obj.get("type").and_then(|v| v.as_str()) {
match type_str {
"text" => {
if let Some(text) = obj.get("output").and_then(|v| v.as_str()) {
return Envelope::new(EnvelopeKind::Text(text.to_string()));
}
}
"embedding" => {
if let Some(values) = obj.get("output").and_then(|v| v.as_array()) {
let floats: Vec<f32> = values
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
return Envelope::new(EnvelopeKind::Embedding(floats));
}
}
"audio" => {
return Envelope::new(EnvelopeKind::Audio(Vec::new()));
}
_ => {}
}
}
}
Envelope::new(EnvelopeKind::Text(value.to_string()))
}
fn extract_final_output(&self, envelope: &Envelope) -> (OutputResultType, OutputResult) {
match &envelope.kind {
EnvelopeKind::Text(text) => (OutputResultType::Text, OutputResult::Text(text.clone())),
EnvelopeKind::Audio(bytes) => (
OutputResultType::Audio,
OutputResult::Audio {
bytes: bytes.clone(),
sample_rate: 16000, },
),
EnvelopeKind::Embedding(values) => (
OutputResultType::Embedding,
OutputResult::Embedding(values.clone()),
),
}
}
pub fn output_context(&self) -> &StageOutputContext {
&self.output_context
}
pub fn orchestrator_mut(&mut self) -> &mut Orchestrator {
&mut self.orchestrator
}
pub fn config(&self) -> &RunnerConfig {
&self.config
}
pub fn set_config(&mut self, config: RunnerConfig) {
self.config = config;
}
pub fn set_metrics(&mut self, metrics: DeviceMetrics) {
self.config.metrics = metrics;
}
pub fn register_local_model(&mut self, model_id: &str, available: bool) {
self.config
.local_models
.insert(model_id.to_string(), available);
}
pub fn register_server_model(&mut self, model_id: &str, available: bool) {
self.config
.server_models
.insert(model_id.to_string(), available);
}
pub fn register_integration(&mut self, provider: IntegrationProvider, available: bool) {
self.config.integrations.insert(provider, available);
}
}
impl Default for PipelineRunner {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::orchestrator::LocalAuthority;
use crate::runtime_adapter::RuntimeAdapter;
use crate::testing::mocks::MockRuntimeAdapter;
use std::sync::Arc;
fn text_envelope(value: &str) -> Envelope {
Envelope::new(EnvelopeKind::Text(value.to_string()))
}
fn audio_envelope(bytes: &[u8]) -> Envelope {
Envelope::new(EnvelopeKind::Audio(bytes.to_vec()))
}
fn preload_mock_local_adapter(runner: &mut PipelineRunner) {
let mut fresh = Orchestrator::with_authority(Box::new(LocalAuthority::new()));
let mut adapter = MockRuntimeAdapter::with_text_output("mock output");
adapter.load_model("/mock/model.onnx").unwrap();
fresh.executor_mut().register_adapter(Arc::new(adapter));
*runner.orchestrator_mut() = fresh;
}
#[test]
fn test_pipeline_runner_new() {
let runner = PipelineRunner::new();
assert_eq!(runner.config().metrics.capabilities.battery_level, 100);
}
#[test]
fn test_pipeline_runner_with_config() {
let mut metrics = DeviceMetrics::default();
metrics.capabilities.battery_level = 80;
let config = RunnerConfig {
metrics,
..Default::default()
};
let runner = PipelineRunner::with_config(config);
assert_eq!(runner.config().metrics.capabilities.battery_level, 80);
}
#[test]
fn test_run_simple_pipeline() {
let yaml = r#"
name: "Test Pipeline"
version: "1.0"
input:
type: text
stages:
- id: process
model: test-model
target: device
"#;
let mut runner = PipelineRunner::new();
runner.register_local_model("test-model", true);
preload_mock_local_adapter(&mut runner);
let input = text_envelope("Hello, world!");
let result = runner.run_yaml(yaml, input);
assert!(result.is_ok());
let pipeline_result = result.unwrap();
assert!(pipeline_result.success);
assert_eq!(pipeline_result.stages_executed, 1);
assert_eq!(pipeline_result.stages_skipped, 0);
}
#[test]
fn test_run_pipeline_with_condition_skip() {
let yaml = r#"
name: "Conditional Pipeline"
version: "1.0"
input:
type: text
stages:
- id: first
model: model-a
target: device
- id: second
model: model-b
target: device
when: "first.output == 'trigger'"
"#;
let mut runner = PipelineRunner::new();
runner.register_local_model("model-a", true);
runner.register_local_model("model-b", true);
preload_mock_local_adapter(&mut runner);
let input = text_envelope("Hello");
let result = runner.run_yaml(yaml, input);
assert!(result.is_ok());
let pipeline_result = result.unwrap();
assert!(pipeline_result.success);
assert_eq!(pipeline_result.stages_executed, 1);
assert_eq!(pipeline_result.stages_skipped, 1);
}
#[test]
fn test_stage_result_tracking() {
let yaml = r#"
name: "Multi-Stage Pipeline"
version: "1.0"
input:
type: audio
sample_rate: 16000
channels: 1
format: float32
stages:
- id: asr
model: wav2vec2
target: device
- id: process
model: processor
target: device
"#;
let mut runner = PipelineRunner::new();
runner.register_local_model("wav2vec2", true);
runner.register_local_model("processor", true);
preload_mock_local_adapter(&mut runner);
let input = audio_envelope(&[0u8; 32000]);
let result = runner.run_yaml(yaml, input);
assert!(result.is_ok());
let pipeline_result = result.unwrap();
assert!(pipeline_result.stages.contains_key("asr"));
assert!(pipeline_result.stages.contains_key("process"));
}
#[test]
fn test_invalid_pipeline_yaml() {
let yaml = r#"
name: "Invalid"
stages: []
"#;
let mut runner = PipelineRunner::new();
let input = text_envelope("test");
let result = runner.run_yaml(yaml, input);
assert!(result.is_err());
}
#[test]
fn test_output_result_types() {
let envelope = text_envelope("Hello");
let runner = PipelineRunner::new();
let (output_type, _) = runner.extract_final_output(&envelope);
assert_eq!(output_type, OutputResultType::Text);
let envelope = Envelope::new(EnvelopeKind::Embedding(vec![0.1, 0.2, 0.3]));
let (output_type, _) = runner.extract_final_output(&envelope);
assert_eq!(output_type, OutputResultType::Embedding);
}
#[test]
fn test_model_registration() {
let mut runner = PipelineRunner::new();
runner.register_local_model("wav2vec2", true);
runner.register_server_model("whisper-large", true);
runner.register_integration(IntegrationProvider::OpenAI, true);
assert!(runner
.config()
.local_models
.get("wav2vec2")
.copied()
.unwrap_or(false));
assert!(runner
.config()
.server_models
.get("whisper-large")
.copied()
.unwrap_or(false));
assert!(runner
.config()
.integrations
.get(&IntegrationProvider::OpenAI)
.copied()
.unwrap_or(false));
}
#[test]
fn test_pipeline_result_serialization() {
let result = PipelineResult {
name: Some("Test".to_string()),
success: true,
error: None,
stages: HashMap::new(),
output_type: OutputResultType::Text,
output: OutputResult::Text("Hello".to_string()),
total_latency_ms: 100,
stages_executed: 1,
stages_skipped: 0,
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("\"success\":true"));
assert!(json.contains("\"output_type\":\"text\""));
}
}