use std::collections::HashMap;
use std::sync::Arc;
use error_stack::ResultExt as _;
use serde::{Deserialize, Serialize};
use super::Flow;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
pub struct WorkflowOverrides {
pub steps: HashMap<String, StepOverride>,
}
impl WorkflowOverrides {
pub fn new() -> Self {
Self {
steps: HashMap::new(),
}
}
pub fn is_empty(&self) -> bool {
self.steps.is_empty()
}
pub fn add_step_override(&mut self, step_id: String, override_spec: StepOverride) {
self.steps.insert(step_id, override_spec);
}
}
impl Default for WorkflowOverrides {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
pub struct StepOverride {
#[serde(rename = "$type", default = "default_override_type")]
pub override_type: OverrideType,
pub value: serde_json::Value,
}
impl StepOverride {
pub fn merge_patch(value: serde_json::Value) -> Self {
Self {
override_type: OverrideType::MergePatch,
value,
}
}
pub fn with_type(override_type: OverrideType, value: serde_json::Value) -> Self {
Self {
override_type,
value,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum OverrideType {
MergePatch,
#[allow(dead_code)]
JsonPatch,
}
fn default_override_type() -> OverrideType {
OverrideType::MergePatch
}
#[derive(Debug, thiserror::Error)]
pub enum OverrideError {
#[error("Step '{step_id}' not found in workflow")]
StepNotFound { step_id: String },
#[error("Invalid override value for step '{step_id}': {reason}")]
InvalidOverrideValue { step_id: String, reason: String },
#[error("Unsupported override type: {override_type:?}")]
UnsupportedOverrideType { override_type: OverrideType },
#[error("JSON merge patch failed for step '{step_id}': {reason}")]
MergePatchFailed { step_id: String, reason: String },
}
pub type OverrideResult<T> = error_stack::Result<T, OverrideError>;
pub trait OverrideProcessor {
fn apply_overrides(
&self,
flow: Arc<Flow>,
overrides: &WorkflowOverrides,
) -> OverrideResult<Arc<Flow>>;
}
pub struct DefaultOverrideProcessor;
impl OverrideProcessor for DefaultOverrideProcessor {
fn apply_overrides(
&self,
flow: Arc<Flow>,
overrides: &WorkflowOverrides,
) -> OverrideResult<Arc<Flow>> {
if overrides.is_empty() {
return Ok(flow);
}
log::debug!(
"Applying {} step overrides to workflow",
overrides.steps.len()
);
self.validate_override_targets(&flow, overrides)?;
let mut cloned_flow = flow.slow_clone();
for step in &mut cloned_flow.steps {
if let Some(step_override) = overrides.steps.get(&step.id) {
log::debug!(
"Applying override to step '{}' with type '{:?}'",
step.id,
step_override.override_type
);
self.apply_step_override(step, step_override)
.change_context(OverrideError::InvalidOverrideValue {
step_id: step.id.clone(),
reason: "Failed to apply step override".to_string(),
})?;
}
}
Ok(Arc::new(cloned_flow))
}
}
impl DefaultOverrideProcessor {
pub fn new() -> Self {
Self
}
fn validate_override_targets(
&self,
flow: &Flow,
overrides: &WorkflowOverrides,
) -> OverrideResult<()> {
let step_ids: std::collections::HashSet<&String> =
flow.steps().iter().map(|step| &step.id).collect();
for step_id in overrides.steps.keys() {
if !step_ids.contains(&step_id) {
return Err(error_stack::report!(OverrideError::StepNotFound {
step_id: step_id.clone(),
}));
}
}
Ok(())
}
fn apply_step_override(
&self,
step: &mut super::Step,
step_override: &StepOverride,
) -> OverrideResult<()> {
match step_override.override_type {
OverrideType::MergePatch => self.apply_merge_patch(step, &step_override.value),
OverrideType::JsonPatch => Err(error_stack::report!(
OverrideError::UnsupportedOverrideType {
override_type: step_override.override_type.clone(),
}
)),
}
}
fn apply_merge_patch(
&self,
step: &mut super::Step,
patch: &serde_json::Value,
) -> OverrideResult<()> {
let step_id = step.id.clone();
let mut step_json =
serde_json::to_value(&*step).change_context(OverrideError::MergePatchFailed {
step_id: step_id.clone(),
reason: "Failed to serialize step to JSON".to_string(),
})?;
json_patch::merge(&mut step_json, patch);
*step =
serde_json::from_value(step_json).change_context(OverrideError::MergePatchFailed {
step_id,
reason: "Failed to deserialize modified step from JSON".to_string(),
})?;
Ok(())
}
}
impl Default for DefaultOverrideProcessor {
fn default() -> Self {
Self::new()
}
}
pub fn apply_overrides(
flow: Arc<Flow>,
overrides: &WorkflowOverrides,
) -> OverrideResult<Arc<Flow>> {
DefaultOverrideProcessor::new().apply_overrides(flow, overrides)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ValueExpr;
use serde_json::json;
fn create_test_flow() -> Flow {
Flow {
name: Some("test_flow".to_string()),
description: None,
version: None,
schemas: super::super::FlowSchema::default(),
steps: vec![super::super::Step {
id: "step1".to_string(),
component: super::super::Component::from_string("/test/component"),
on_error: None,
input: ValueExpr::null(),
must_execute: None,
metadata: std::collections::HashMap::new(),
}],
output: ValueExpr::null(),
test: None,
examples: None,
metadata: std::collections::HashMap::new(),
}
}
#[test]
fn test_workflow_overrides_creation() {
let overrides = WorkflowOverrides::new();
assert!(overrides.is_empty());
let mut overrides = WorkflowOverrides::new();
overrides.add_step_override(
"step1".to_string(),
StepOverride::merge_patch(json!({"input": {"temperature": 0.8}})),
);
assert!(!overrides.is_empty());
assert_eq!(overrides.steps.len(), 1);
}
#[test]
fn test_step_override_creation() {
let override_spec = StepOverride::merge_patch(json!({"input": {"temperature": 0.8}}));
assert!(matches!(
override_spec.override_type,
OverrideType::MergePatch
));
assert_eq!(override_spec.value, json!({"input": {"temperature": 0.8}}));
let override_spec = StepOverride::with_type(
OverrideType::MergePatch,
json!({"component": "/different/component"}),
);
assert!(matches!(
override_spec.override_type,
OverrideType::MergePatch
));
assert_eq!(
override_spec.value,
json!({"component": "/different/component"})
);
}
#[test]
fn test_apply_empty_overrides() {
let flow = Arc::new(create_test_flow());
let overrides = WorkflowOverrides::new();
let original_step_count = flow.steps().len();
let result = apply_overrides(flow, &overrides).unwrap();
assert_eq!(result.steps().len(), original_step_count);
}
#[test]
fn test_apply_merge_patch_override() {
let flow = Arc::new(create_test_flow());
let mut overrides = WorkflowOverrides::new();
overrides.add_step_override(
"step1".to_string(),
StepOverride::merge_patch(json!({
"input": {"temperature": 0.8},
"component": "/new/component"
})),
);
let result = apply_overrides(flow, &overrides).unwrap();
let step = &result.steps()[0];
assert_eq!(step.component.to_string(), "/new/component");
}
#[test]
fn test_validate_override_targets_missing_step() {
let flow = Arc::new(create_test_flow());
let mut overrides = WorkflowOverrides::new();
overrides.add_step_override(
"nonexistent_step".to_string(),
StepOverride::merge_patch(json!({"input": {"temperature": 0.8}})),
);
let result = apply_overrides(flow, &overrides);
assert!(result.is_err());
let error = result.unwrap_err();
assert!(
error
.to_string()
.contains("Step 'nonexistent_step' not found in workflow")
);
}
#[test]
fn test_serde_override_type_default() {
let json_str = r#"{"value": {"temperature": 0.8}}"#;
let step_override: StepOverride = serde_json::from_str(json_str).unwrap();
assert!(matches!(
step_override.override_type,
OverrideType::MergePatch
));
assert_eq!(step_override.value, json!({"temperature": 0.8}));
}
#[test]
fn test_serde_override_type_explicit() {
let json_str = r#"{"$type": "merge_patch", "value": {"temperature": 0.8}}"#;
let step_override: StepOverride = serde_json::from_str(json_str).unwrap();
assert!(matches!(
step_override.override_type,
OverrideType::MergePatch
));
assert_eq!(step_override.value, json!({"temperature": 0.8}));
}
}