use crate::cook::environment::{EnvProfile, SecretValue};
use crate::cook::execution::variable_capture::CaptureConfig;
use crate::cook::execution::{MapPhase, MapReduceConfig, ReducePhase, SetupPhase};
use crate::cook::workflow::{WorkflowErrorPolicy, WorkflowStep};
use anyhow::Context as _;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MapReduceWorkflowConfig {
pub name: String,
#[serde(default = "default_mode")]
pub mode: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub env: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub secrets: Option<HashMap<String, SecretValue>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub env_files: Option<Vec<PathBuf>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub profiles: Option<HashMap<String, EnvProfile>>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deserialize_setup_phase_option"
)]
pub setup: Option<SetupPhaseConfig>,
pub map: MapPhaseYaml,
#[serde(skip_serializing_if = "Option::is_none")]
pub reduce: Option<ReducePhaseYaml>,
#[serde(default, skip_serializing_if = "is_default_error_policy")]
pub error_policy: WorkflowErrorPolicy,
#[serde(skip_serializing_if = "Option::is_none")]
pub on_item_failure: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub continue_on_failure: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_failures: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub failure_threshold: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_collection: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub merge: Option<MergeWorkflow>,
}
#[derive(Debug, Clone, Serialize)]
pub struct MergeWorkflow {
pub commands: Vec<WorkflowStep>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout: Option<u64>,
}
impl<'de> Deserialize<'de> for MergeWorkflow {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum MergeValue {
Commands(Vec<WorkflowStep>),
Config {
commands: Vec<WorkflowStep>,
#[serde(default)]
timeout: Option<u64>,
},
}
let value = MergeValue::deserialize(deserializer)?;
match value {
MergeValue::Commands(commands) => Ok(MergeWorkflow {
commands,
timeout: None, }),
MergeValue::Config { commands, timeout } => Ok(MergeWorkflow { commands, timeout }),
}
}
}
fn is_default_error_policy(policy: &WorkflowErrorPolicy) -> bool {
matches!(policy, WorkflowErrorPolicy { .. } if false) }
fn default_mode() -> String {
"mapreduce".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SetupPhaseConfig {
pub commands: Vec<WorkflowStep>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deserialize_optional_u64_or_string"
)]
pub timeout: Option<String>,
#[serde(
default,
skip_serializing_if = "HashMap::is_empty",
deserialize_with = "deserialize_capture_outputs"
)]
pub capture_outputs: HashMap<String, CaptureConfig>,
}
fn deserialize_capture_outputs<'de, D>(
deserializer: D,
) -> Result<HashMap<String, CaptureConfig>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Deserialize;
#[derive(Deserialize)]
#[serde(untagged)]
enum CaptureValue {
LegacyIndex(usize),
Config(CaptureConfig),
}
let raw_map: HashMap<String, CaptureValue> = HashMap::deserialize(deserializer)?;
let mut result = HashMap::new();
for (key, value) in raw_map {
let config = match value {
CaptureValue::LegacyIndex(idx) => CaptureConfig::Simple(idx),
CaptureValue::Config(cfg) => cfg,
};
result.insert(key, config);
}
Ok(result)
}
fn deserialize_setup_phase_option<'de, D>(
deserializer: D,
) -> Result<Option<SetupPhaseConfig>, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum SetupValue {
Commands(Vec<WorkflowStep>),
Config(SetupPhaseConfig),
}
let value = Option::<SetupValue>::deserialize(deserializer)?;
match value {
None => Ok(None),
Some(SetupValue::Commands(commands)) => {
Ok(Some(SetupPhaseConfig {
commands,
timeout: None,
capture_outputs: HashMap::new(),
}))
}
Some(SetupValue::Config(config)) => Ok(Some(config)),
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MapPhaseYaml {
pub input: String,
#[serde(default)]
pub json_path: String,
pub agent_template: AgentTemplate,
#[serde(
default = "default_max_parallel_string",
deserialize_with = "deserialize_usize_or_string"
)]
pub max_parallel: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub filter: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort_by: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_items: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub offset: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub distinct: Option<String>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deserialize_optional_u64_or_string"
)]
pub agent_timeout_secs: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout_config: Option<crate::cook::execution::mapreduce::timeout::TimeoutConfig>,
}
fn default_max_parallel_string() -> String {
"10".to_string()
}
#[derive(Debug, Clone, Serialize)]
pub struct AgentTemplate {
pub commands: Vec<WorkflowStep>,
}
impl<'de> Deserialize<'de> for AgentTemplate {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum AgentTemplateValue {
Commands(Vec<WorkflowStep>),
Nested { commands: Vec<WorkflowStep> },
}
let value = AgentTemplateValue::deserialize(deserializer)?;
match value {
AgentTemplateValue::Commands(commands) => {
Ok(AgentTemplate { commands })
}
AgentTemplateValue::Nested { commands } => {
tracing::warn!("Using deprecated nested 'commands' syntax in agent_template. Consider using the simplified array format directly under 'agent_template'.");
Ok(AgentTemplate { commands })
}
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ReducePhaseYaml {
pub commands: Vec<WorkflowStep>,
}
impl<'de> Deserialize<'de> for ReducePhaseYaml {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum ReduceValue {
Commands(Vec<WorkflowStep>),
Nested { commands: Vec<WorkflowStep> },
}
let value = ReduceValue::deserialize(deserializer)?;
match value {
ReduceValue::Commands(commands) => {
Ok(ReducePhaseYaml { commands })
}
ReduceValue::Nested { commands } => {
tracing::warn!("Using deprecated nested 'commands' syntax in reduce. Consider using the simplified array format directly under 'reduce'.");
Ok(ReducePhaseYaml { commands })
}
}
}
}
fn deserialize_usize_or_string<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum UsizeOrString {
Number(usize),
String(String),
}
let value = UsizeOrString::deserialize(deserializer)?;
match value {
UsizeOrString::Number(n) => Ok(n.to_string()),
UsizeOrString::String(s) => Ok(s),
}
}
fn deserialize_optional_u64_or_string<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum U64OrString {
Number(u64),
String(String),
}
let value = Option::<U64OrString>::deserialize(deserializer)?;
match value {
None => Ok(None),
Some(U64OrString::Number(n)) => Ok(Some(n.to_string())),
Some(U64OrString::String(s)) => Ok(Some(s)),
}
}
impl MapReduceWorkflowConfig {
fn resolve_env_or_parse<T>(&self, value: &str) -> Result<T, anyhow::Error>
where
T: std::str::FromStr,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
if value.starts_with('$') {
let var_name = if let Some(stripped) = value.strip_prefix("${") {
stripped.strip_suffix('}').unwrap_or(stripped)
} else if let Some(stripped) = value.strip_prefix('$') {
stripped
} else {
value
};
if let Some(ref env) = self.env {
if let Some(env_value) = env.get(var_name) {
return env_value.parse::<T>().map_err(|e| {
anyhow::anyhow!(
"Failed to parse environment variable '{}' value '{}': {}",
var_name,
env_value,
e
)
});
}
}
if let Ok(env_value) = std::env::var(var_name) {
return env_value.parse::<T>().map_err(|e| {
anyhow::anyhow!(
"Failed to parse environment variable '{}' value '{}': {}",
var_name,
env_value,
e
)
});
}
return Err(anyhow::anyhow!(
"Environment variable '{}' not found in workflow env or system environment",
var_name
));
}
value
.parse::<T>()
.map_err(|e| anyhow::anyhow!("Failed to parse numeric value '{}': {}", value, e))
}
pub fn get_error_policy(&self) -> WorkflowErrorPolicy {
use crate::cook::workflow::{ErrorCollectionStrategy, ItemFailureAction};
let mut policy = self.error_policy.clone();
if let Some(ref action_str) = self.on_item_failure {
policy.on_item_failure = match action_str.as_str() {
"dlq" => ItemFailureAction::Dlq,
"retry" => ItemFailureAction::Retry,
"skip" => ItemFailureAction::Skip,
"stop" => ItemFailureAction::Stop,
custom => ItemFailureAction::Custom(custom.to_string()),
};
}
if let Some(continue_on_failure) = self.continue_on_failure {
policy.continue_on_failure = continue_on_failure;
}
if let Some(max_failures) = self.max_failures {
policy.max_failures = Some(max_failures);
}
if let Some(failure_threshold) = self.failure_threshold {
policy.failure_threshold = Some(failure_threshold);
}
if let Some(ref collection_str) = self.error_collection {
policy.error_collection = match collection_str.as_str() {
"aggregate" => ErrorCollectionStrategy::Aggregate,
"immediate" => ErrorCollectionStrategy::Immediate,
_ if collection_str.starts_with("batched:") => {
if let Some(size_str) = collection_str.strip_prefix("batched:") {
if let Ok(size) = size_str.parse::<usize>() {
ErrorCollectionStrategy::Batched { size }
} else {
ErrorCollectionStrategy::Aggregate
}
} else {
ErrorCollectionStrategy::Aggregate
}
}
_ => ErrorCollectionStrategy::Aggregate,
};
}
policy
}
pub fn to_setup_phase(&self) -> Result<Option<SetupPhase>, anyhow::Error> {
if let Some(ref s) = self.setup {
let timeout = if let Some(ref timeout_str) = s.timeout {
Some(
self.resolve_env_or_parse::<u64>(timeout_str)
.context("Failed to resolve setup timeout")?,
)
} else {
None
};
Ok(Some(SetupPhase {
commands: s.commands.clone(),
timeout,
capture_outputs: s.capture_outputs.clone(),
}))
} else {
Ok(None)
}
}
pub fn to_map_phase(&self) -> Result<MapPhase, anyhow::Error> {
let max_parallel = self
.resolve_env_or_parse::<usize>(&self.map.max_parallel)
.context("Failed to resolve max_parallel")?;
let agent_timeout_secs = if let Some(ref timeout_str) = self.map.agent_timeout_secs {
Some(
self.resolve_env_or_parse::<u64>(timeout_str)
.context("Failed to resolve agent_timeout_secs")?,
)
} else {
None
};
Ok(MapPhase {
config: MapReduceConfig {
input: self.map.input.clone(),
json_path: self.map.json_path.clone(),
max_parallel,
agent_timeout_secs,
continue_on_failure: false,
batch_size: None,
enable_checkpoints: true,
max_items: self.map.max_items,
offset: self.map.offset,
},
json_path: Some(self.map.json_path.clone()).filter(|s| !s.is_empty()),
agent_template: self.map.agent_template.commands.clone(),
filter: self.map.filter.clone(),
sort_by: self.map.sort_by.clone(),
max_items: self.map.max_items,
distinct: self.map.distinct.clone(),
timeout_config: self.map.timeout_config.clone(),
workflow_env: self.env.clone().unwrap_or_default(),
})
}
pub fn to_reduce_phase(&self) -> Option<ReducePhase> {
self.reduce.as_ref().map(|r| ReducePhase {
commands: r.commands.clone(),
timeout_secs: None,
})
}
pub fn is_mapreduce(&self) -> bool {
self.mode.to_lowercase() == "mapreduce"
}
}
pub fn parse_mapreduce_workflow(
content: &str,
) -> Result<MapReduceWorkflowConfig, serde_yaml::Error> {
serde_yaml::from_str(content)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_basic_mapreduce_workflow() {
let yaml = r#"
name: parallel-debt-elimination
mode: mapreduce
map:
input: items.json
json_path: "$.debt_items[*]"
agent_template:
commands:
- claude: "/fix-issue ${item.description}"
- shell: "cargo test"
max_parallel: 10
reduce:
commands:
- claude: "/summarize-fixes ${map.results}"
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert_eq!(config.name, "parallel-debt-elimination");
assert_eq!(config.mode, "mapreduce");
assert_eq!(config.map.max_parallel, "10");
assert_eq!(config.map.agent_template.commands.len(), 2);
}
#[test]
fn test_simplified_agent_template_syntax() {
let yaml = r#"
name: test-simplified
mode: mapreduce
map:
input: items.json
json_path: "$.items[*]"
# New simplified syntax - direct array of commands
agent_template:
- claude: "/process '${item}'"
- shell: "validate ${item}"
max_parallel: 5
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert_eq!(config.name, "test-simplified");
assert_eq!(config.map.agent_template.commands.len(), 2);
let first_step = &config.map.agent_template.commands[0];
assert!(first_step.claude.is_some());
assert!(first_step.claude.as_ref().unwrap().contains("/process"));
}
#[test]
fn test_nested_agent_template_syntax() {
let yaml = r#"
name: test-nested
mode: mapreduce
map:
input: items.json
json_path: "$.items[*]"
# Old nested syntax with 'commands' key
agent_template:
commands:
- claude: "/process '${item}'"
- shell: "validate ${item}"
max_parallel: 5
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert_eq!(config.name, "test-nested");
assert_eq!(config.map.agent_template.commands.len(), 2);
}
#[test]
fn test_simplified_reduce_syntax() {
let yaml = r#"
name: test-reduce-simplified
mode: mapreduce
map:
input: items.json
agent_template:
- shell: "echo processing"
# New simplified reduce syntax
reduce:
- claude: "/summarize ${map.results}"
- shell: "generate-report"
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert!(config.reduce.is_some());
assert_eq!(config.reduce.as_ref().unwrap().commands.len(), 2);
}
#[test]
fn test_nested_reduce_syntax() {
let yaml = r#"
name: test-reduce-nested
mode: mapreduce
map:
input: items.json
agent_template:
- shell: "echo processing"
# Old nested reduce syntax
reduce:
commands:
- claude: "/summarize ${map.results}"
- shell: "generate-report"
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert!(config.reduce.is_some());
assert_eq!(config.reduce.as_ref().unwrap().commands.len(), 2);
}
#[test]
fn test_mixed_simplified_and_nested_syntax() {
let yaml = r#"
name: test-mixed
mode: mapreduce
# Setup uses simple list format (already supported)
setup:
- shell: "prepare-data"
- claude: "/analyze-requirements"
map:
input: items.json
# Using new simplified syntax for agent_template
agent_template:
- claude: "/process ${item}"
- shell: "test ${item}"
# Using old nested syntax for reduce
reduce:
commands:
- claude: "/summarize ${map.results}"
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert!(config.setup.is_some());
assert_eq!(config.setup.as_ref().unwrap().commands.len(), 2);
assert_eq!(config.map.agent_template.commands.len(), 2);
assert!(config.reduce.is_some());
assert_eq!(config.reduce.as_ref().unwrap().commands.len(), 1);
}
#[test]
fn test_mapreduce_with_env_variables() {
let yaml = r#"
name: test-env-vars
mode: mapreduce
env:
PROJECT_NAME: "TestProject"
CONFIG_PATH: ".test/config.json"
OUTPUT_DIR: ".test/output"
map:
input: items.json
json_path: "$.items[*]"
agent_template:
- shell: "echo Processing $PROJECT_NAME"
- claude: "/process --config $CONFIG_PATH"
reduce:
- shell: "echo Saving to $OUTPUT_DIR"
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert!(config.env.is_some());
let env = config.env.unwrap();
assert_eq!(env.get("PROJECT_NAME"), Some(&"TestProject".to_string()));
assert_eq!(
env.get("CONFIG_PATH"),
Some(&".test/config.json".to_string())
);
assert_eq!(env.get("OUTPUT_DIR"), Some(&".test/output".to_string()));
}
#[test]
fn test_mapreduce_backward_compatibility_without_env() {
let yaml = r#"
name: test-no-env
mode: mapreduce
map:
input: items.json
agent_template:
- shell: "echo test"
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert!(config.env.is_none());
assert!(config.secrets.is_none());
assert!(config.env_files.is_none());
assert!(config.profiles.is_none());
}
}
#[cfg(test)]
mod merge_workflow_tests {
use super::*;
#[test]
fn test_deserialize_simplified_syntax() {
let yaml = r#"
name: test
mode: mapreduce
map:
input: items.json
agent_template:
- shell: "echo test"
merge:
- shell: "git fetch origin"
- claude: "/merge-master ${merge.source_branch}"
- shell: "cargo test"
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert!(config.merge.is_some());
let merge = config.merge.unwrap();
assert_eq!(merge.commands.len(), 3);
assert_eq!(merge.timeout, None);
assert!(merge.commands[0].shell.is_some());
assert_eq!(
merge.commands[0].shell.as_ref().unwrap(),
"git fetch origin"
);
assert!(merge.commands[1].claude.is_some());
assert!(merge.commands[1]
.claude
.as_ref()
.unwrap()
.contains("${merge.source_branch}"));
}
#[test]
fn test_deserialize_full_syntax() {
let yaml = r#"
name: test
mode: mapreduce
map:
input: items.json
agent_template:
- shell: "echo test"
merge:
commands:
- shell: "git fetch origin"
- claude: "/merge-master"
- shell: "git push"
timeout: 900
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert!(config.merge.is_some());
let merge = config.merge.unwrap();
assert_eq!(merge.commands.len(), 3);
assert_eq!(merge.timeout, Some(900));
assert!(merge.commands[0].shell.is_some());
assert!(merge.commands[1].claude.is_some());
assert!(merge.commands[2].shell.is_some());
}
#[test]
fn test_default_timeout() {
let yaml = r#"
name: test
mode: mapreduce
map:
input: items.json
agent_template:
- shell: "echo test"
merge:
commands:
- shell: "git merge"
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert!(config.merge.is_some());
let merge = config.merge.unwrap();
assert_eq!(merge.timeout, None); }
#[test]
fn test_empty_merge_workflow() {
let yaml = r#"
name: test
mode: mapreduce
map:
input: items.json
agent_template:
- shell: "echo test"
merge: []
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert!(config.merge.is_some());
let merge = config.merge.unwrap();
assert_eq!(merge.commands.len(), 0);
}
#[test]
fn test_no_merge_workflow() {
let yaml = r#"
name: test
mode: mapreduce
map:
input: items.json
agent_template:
- shell: "echo test"
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert!(config.merge.is_none());
}
#[test]
fn test_merge_with_variable_interpolation() {
let yaml = r#"
name: test
mode: mapreduce
map:
input: items.json
agent_template:
- shell: "echo test"
merge:
- shell: "echo Merging ${merge.worktree}"
- shell: "git checkout ${merge.target_branch}"
- shell: "git merge ${merge.source_branch}"
- claude: "/log-merge ${merge.session_id}"
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert!(config.merge.is_some());
let merge = config.merge.unwrap();
assert_eq!(merge.commands.len(), 4);
assert!(merge.commands[0]
.shell
.as_ref()
.unwrap()
.contains("${merge.worktree}"));
assert!(merge.commands[1]
.shell
.as_ref()
.unwrap()
.contains("${merge.target_branch}"));
assert!(merge.commands[2]
.shell
.as_ref()
.unwrap()
.contains("${merge.source_branch}"));
assert!(merge.commands[3]
.claude
.as_ref()
.unwrap()
.contains("${merge.session_id}"));
}
#[test]
fn test_invalid_merge_syntax_handled_gracefully() {
let yaml = r#"
name: test
mode: mapreduce
map:
input: items.json
agent_template:
- shell: "echo test"
merge:
invalid_key: "should not parse"
"#;
let result = parse_mapreduce_workflow(yaml);
assert!(result.is_err());
}
#[test]
fn test_merge_workflow_with_on_failure() {
let yaml = r#"
name: test
mode: mapreduce
map:
input: items.json
agent_template:
- shell: "echo test"
merge:
- shell: "cargo test"
on_failure:
claude: "/fix-test-failures"
- claude: "/merge-worktree"
"#;
let config = parse_mapreduce_workflow(yaml).unwrap();
assert!(config.merge.is_some());
let merge = config.merge.unwrap();
assert_eq!(merge.commands.len(), 2);
assert!(merge.commands[0].on_failure.is_some());
}
}