use serde::{Deserialize, Serialize};
fn is_false(value: &bool) -> bool {
!*value
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StrategyMap {
pub goal: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub elements: Vec<StrategyInstruction>,
#[serde(default)]
pub steps: Vec<StrategyStep>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct StrategyStep {
pub step_id: String,
pub description: String,
pub assigned_agent: String,
pub intent_template: String,
pub expected_output: String,
#[serde(default, skip_serializing_if = "is_false")]
pub requires_validation: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_key: Option<String>,
}
impl StrategyMap {
pub fn new(goal: String) -> Self {
Self {
goal,
elements: Vec::new(),
steps: Vec::new(),
}
}
pub fn migrate_legacy_steps(&mut self) {
if self.elements.is_empty() && !self.steps.is_empty() {
tracing::warn!(
"Loading legacy StrategyMap format with {} steps. Consider migrating to 'elements' format.",
self.steps.len()
);
self.elements = self
.steps
.iter()
.map(|step| StrategyInstruction::Step(step.clone()))
.collect();
}
self.steps = self
.elements
.iter()
.filter_map(|instruction| match instruction {
StrategyInstruction::Step(step) => Some(step.clone()),
_ => None,
})
.collect();
}
pub fn add_step(&mut self, step: StrategyStep) {
self.add_instruction(StrategyInstruction::Step(step));
}
pub fn add_instruction(&mut self, instruction: StrategyInstruction) {
if let StrategyInstruction::Step(step) = &instruction {
self.steps.push(step.clone());
}
self.elements.push(instruction);
}
pub fn len(&self) -> usize {
self.elements.len()
}
pub fn is_empty(&self) -> bool {
self.elements.is_empty()
}
pub fn get_instruction(&self, index: usize) -> Option<&StrategyInstruction> {
self.elements.get(index)
}
pub fn get_instruction_mut(&mut self, index: usize) -> Option<&mut StrategyInstruction> {
self.elements.get_mut(index)
}
pub fn get_step(&self, index: usize) -> Option<&StrategyStep> {
match self.elements.get(index) {
Some(StrategyInstruction::Step(step)) => Some(step),
_ => None,
}
}
pub fn get_step_mut(&mut self, index: usize) -> Option<&mut StrategyStep> {
match self.elements.get_mut(index) {
Some(StrategyInstruction::Step(step)) => Some(step),
_ => None,
}
}
pub fn steps(&self) -> Vec<&StrategyStep> {
self.elements
.iter()
.filter_map(|instruction| match instruction {
StrategyInstruction::Step(step) => Some(step),
_ => None,
})
.collect()
}
pub fn validate(&self) -> Result<(), &'static str> {
for instruction in &self.elements {
if let StrategyInstruction::Loop(loop_block) = instruction {
loop_block.validate()?;
}
}
Ok(())
}
}
impl StrategyStep {
pub fn new(
step_id: String,
description: String,
assigned_agent: String,
intent_template: String,
expected_output: String,
) -> Self {
Self {
step_id,
description,
assigned_agent,
intent_template,
expected_output,
requires_validation: false,
output_key: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "type")] pub enum StrategyInstruction {
#[serde(rename = "step")]
Step(StrategyStep),
#[serde(rename = "loop")]
Loop(LoopBlock),
#[serde(rename = "terminate")]
Terminate(TerminateInstruction),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum LoopType {
While,
ForEach,
UntilConvergence,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct LoopAggregation {
#[serde(rename = "mode")]
pub mode: AggregationMode,
pub output_key: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum AggregationMode {
LastSuccess,
CollectAll,
FirstSuccess,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct LoopBlock {
pub loop_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub loop_type: Option<LoopType>,
pub max_iterations: usize,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub condition_template: Option<String>,
pub body: Vec<StrategyInstruction>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub aggregation: Option<LoopAggregation>,
}
impl LoopBlock {
pub fn validate(&self) -> Result<(), &'static str> {
for instruction in &self.body {
if matches!(instruction, StrategyInstruction::Loop(_)) {
return Err(
"Nested loops are not supported. Loop body cannot contain other Loop instructions.",
);
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct TerminateInstruction {
pub terminate_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub condition_template: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub final_output_template: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RedesignStrategy {
Retry,
TacticalRedesign,
FullRegenerate,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_strategy_map_creation() {
let mut strategy = StrategyMap::new("Complete the task".to_string());
assert_eq!(strategy.goal, "Complete the task");
assert!(strategy.is_empty());
let step = StrategyStep::new(
"step_1".to_string(),
"Do something".to_string(),
"AgentA".to_string(),
"Do {task}".to_string(),
"Result".to_string(),
);
strategy.add_step(step);
assert_eq!(strategy.len(), 1);
assert!(!strategy.is_empty());
}
#[test]
fn test_add_step_retains_legacy_steps_vector() {
let mut strategy = StrategyMap::new("Ensure steps tracked".to_string());
strategy.add_step(StrategyStep::new(
"step_1".to_string(),
"Do something".to_string(),
"AgentA".to_string(),
"Perform task".to_string(),
"Result".to_string(),
));
assert_eq!(strategy.elements.len(), 1);
assert_eq!(strategy.steps.len(), 1);
}
#[test]
fn test_strategy_step_access() {
let mut strategy = StrategyMap::new("Goal".to_string());
let step = StrategyStep::new(
"s1".to_string(),
"Description".to_string(),
"Agent".to_string(),
"Intent".to_string(),
"Output".to_string(),
);
strategy.add_step(step);
assert!(strategy.get_step(0).is_some());
assert!(strategy.get_step(1).is_none());
if let Some(step_mut) = strategy.get_step_mut(0) {
step_mut.description = "Modified".to_string();
}
assert_eq!(strategy.get_step(0).unwrap().description, "Modified");
}
#[test]
fn test_strategy_step_with_output_key() {
let mut step = StrategyStep::new(
"step_1".to_string(),
"Create world concept".to_string(),
"WorldConceptAgent".to_string(),
"Create a concept for {{ user_request }}".to_string(),
"World concept data".to_string(),
);
assert!(step.output_key.is_none());
step.output_key = Some("world_concept".to_string());
assert_eq!(step.output_key, Some("world_concept".to_string()));
let json = serde_json::to_string(&step).unwrap();
assert!(json.contains("world_concept"));
}
#[test]
fn test_strategy_step_output_key_serialization() {
let mut step = StrategyStep::new(
"step_1".to_string(),
"Test step".to_string(),
"TestAgent".to_string(),
"Do something".to_string(),
"Result".to_string(),
);
step.output_key = Some("test_output".to_string());
let json = serde_json::to_string(&step).unwrap();
let deserialized: StrategyStep = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.output_key, Some("test_output".to_string()));
assert_eq!(deserialized.step_id, "step_1");
}
#[test]
fn test_strategy_step_without_output_key_serialization() {
let step = StrategyStep::new(
"step_1".to_string(),
"Test step".to_string(),
"TestAgent".to_string(),
"Do something".to_string(),
"Result".to_string(),
);
let json = serde_json::to_string(&step).unwrap();
assert!(!json.contains("output_key"));
}
#[test]
fn test_strategy_instruction_step_serialization() {
let step = StrategyStep::new(
"step_1".to_string(),
"Test step".to_string(),
"TestAgent".to_string(),
"Do something".to_string(),
"Result".to_string(),
);
let instruction = StrategyInstruction::Step(step);
let json = serde_json::to_string_pretty(&instruction).unwrap();
assert!(json.contains(r#""type": "step""#));
assert!(json.contains("step_1"));
let deserialized: StrategyInstruction = serde_json::from_str(&json).unwrap();
match deserialized {
StrategyInstruction::Step(s) => {
assert_eq!(s.step_id, "step_1");
}
_ => panic!("Expected Step variant"),
}
}
#[test]
fn test_loop_block_serialization() {
let loop_block = LoopBlock {
loop_id: "test_loop".to_string(),
description: Some("Test loop".to_string()),
loop_type: Some(LoopType::UntilConvergence),
max_iterations: 5,
condition_template: Some("{{ approved == false }}".to_string()),
body: vec![StrategyInstruction::Step(StrategyStep::new(
"inner_step".to_string(),
"Inner step".to_string(),
"TestAgent".to_string(),
"Do something".to_string(),
"Result".to_string(),
))],
aggregation: Some(LoopAggregation {
mode: AggregationMode::LastSuccess,
output_key: "final_result".to_string(),
}),
};
let instruction = StrategyInstruction::Loop(loop_block);
let json = serde_json::to_string_pretty(&instruction).unwrap();
assert!(json.contains(r#""type": "loop""#));
assert!(json.contains("test_loop"));
assert!(json.contains("until_convergence"));
assert!(json.contains("approved == false"));
let deserialized: StrategyInstruction = serde_json::from_str(&json).unwrap();
match deserialized {
StrategyInstruction::Loop(l) => {
assert_eq!(l.loop_id, "test_loop");
assert_eq!(l.max_iterations, 5);
assert_eq!(l.body.len(), 1);
}
_ => panic!("Expected Loop variant"),
}
}
#[test]
fn test_terminate_instruction_serialization() {
let terminate = TerminateInstruction {
terminate_id: "early_exit".to_string(),
description: Some("Exit early on approval".to_string()),
condition_template: Some("{{ approved == true }}".to_string()),
final_output_template: Some("Completed successfully".to_string()),
};
let instruction = StrategyInstruction::Terminate(terminate);
let json = serde_json::to_string_pretty(&instruction).unwrap();
assert!(json.contains(r#""type": "terminate""#));
assert!(json.contains("early_exit"));
assert!(json.contains("approved == true"));
let deserialized: StrategyInstruction = serde_json::from_str(&json).unwrap();
match deserialized {
StrategyInstruction::Terminate(t) => {
assert_eq!(t.terminate_id, "early_exit");
assert!(t.condition_template.is_some());
}
_ => panic!("Expected Terminate variant"),
}
}
#[test]
fn test_strategy_map_with_elements_serialization() {
let mut strategy = StrategyMap::new("Test goal".to_string());
strategy.add_step(StrategyStep::new(
"step_1".to_string(),
"First step".to_string(),
"Agent1".to_string(),
"Do task".to_string(),
"Result".to_string(),
));
let json = serde_json::to_string_pretty(&strategy).unwrap();
assert!(json.contains("Test goal"));
assert!(json.contains("elements"));
assert!(json.contains("step_1"));
let deserialized: StrategyMap = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.goal, "Test goal");
assert_eq!(deserialized.elements.len(), 1);
}
#[test]
fn test_legacy_steps_format_deserialization() {
let legacy_json = r#"{
"goal": "Legacy workflow",
"steps": [
{
"step_id": "step_1",
"description": "Do something",
"assigned_agent": "TestAgent",
"intent_template": "Execute task",
"expected_output": "Result"
}
]
}"#;
let mut strategy: StrategyMap = serde_json::from_str(legacy_json).unwrap();
assert_eq!(strategy.goal, "Legacy workflow");
assert_eq!(strategy.steps.len(), 1);
assert_eq!(strategy.elements.len(), 0);
strategy.migrate_legacy_steps();
assert_eq!(strategy.elements.len(), 1);
assert_eq!(strategy.steps.len(), 1);
match &strategy.elements[0] {
StrategyInstruction::Step(s) => {
assert_eq!(s.step_id, "step_1");
}
_ => panic!("Expected Step variant"),
}
}
#[test]
fn test_round_trip_with_mixed_instructions() {
let mut strategy = StrategyMap::new("Complex workflow".to_string());
strategy.add_step(StrategyStep::new(
"step_1".to_string(),
"Prepare".to_string(),
"Agent1".to_string(),
"Setup".to_string(),
"Config".to_string(),
));
strategy.add_instruction(StrategyInstruction::Loop(LoopBlock {
loop_id: "refine_loop".to_string(),
description: Some("Refine output".to_string()),
loop_type: Some(LoopType::UntilConvergence),
max_iterations: 3,
condition_template: Some("{{ needs_refinement }}".to_string()),
body: vec![StrategyInstruction::Step(StrategyStep::new(
"refine_step".to_string(),
"Refine".to_string(),
"Agent2".to_string(),
"Improve".to_string(),
"Better result".to_string(),
))],
aggregation: None,
}));
strategy.add_instruction(StrategyInstruction::Terminate(TerminateInstruction {
terminate_id: "success_exit".to_string(),
description: Some("Exit on success".to_string()),
condition_template: Some("{{ success }}".to_string()),
final_output_template: None,
}));
let json = serde_json::to_string_pretty(&strategy).unwrap();
let deserialized: StrategyMap = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.goal, "Complex workflow");
assert_eq!(deserialized.elements.len(), 3);
assert!(matches!(
&deserialized.elements[0],
StrategyInstruction::Step(_)
));
assert!(matches!(
&deserialized.elements[1],
StrategyInstruction::Loop(_)
));
assert!(matches!(
&deserialized.elements[2],
StrategyInstruction::Terminate(_)
));
}
#[test]
fn test_loop_block_validation_success() {
let loop_block = LoopBlock {
loop_id: "valid_loop".to_string(),
description: Some("Valid loop".to_string()),
loop_type: Some(LoopType::While),
max_iterations: 5,
condition_template: Some("{{ continue }}".to_string()),
body: vec![
StrategyInstruction::Step(StrategyStep::new(
"step_1".to_string(),
"Step".to_string(),
"Agent".to_string(),
"Do work".to_string(),
"Result".to_string(),
)),
StrategyInstruction::Terminate(TerminateInstruction {
terminate_id: "exit".to_string(),
description: Some("Exit".to_string()),
condition_template: Some("{{ done }}".to_string()),
final_output_template: None,
}),
],
aggregation: None,
};
assert!(loop_block.validate().is_ok());
}
#[test]
fn test_loop_block_validation_nested_loop_fails() {
let nested_loop_block = LoopBlock {
loop_id: "nested_loop".to_string(),
description: Some("Nested loop".to_string()),
loop_type: Some(LoopType::While),
max_iterations: 3,
condition_template: Some("{{ inner }}".to_string()),
body: vec![StrategyInstruction::Step(StrategyStep::new(
"inner_step".to_string(),
"Inner".to_string(),
"Agent".to_string(),
"Work".to_string(),
"Result".to_string(),
))],
aggregation: None,
};
let outer_loop_block = LoopBlock {
loop_id: "outer_loop".to_string(),
description: Some("Outer loop".to_string()),
loop_type: Some(LoopType::UntilConvergence),
max_iterations: 5,
condition_template: Some("{{ outer }}".to_string()),
body: vec![
StrategyInstruction::Step(StrategyStep::new(
"outer_step".to_string(),
"Outer".to_string(),
"Agent".to_string(),
"Work".to_string(),
"Result".to_string(),
)),
StrategyInstruction::Loop(nested_loop_block),
],
aggregation: None,
};
let result = outer_loop_block.validate();
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
"Nested loops are not supported. Loop body cannot contain other Loop instructions."
);
}
#[test]
fn test_strategy_map_validation_success() {
let mut strategy = StrategyMap::new("Valid strategy".to_string());
strategy.add_step(StrategyStep::new(
"step_1".to_string(),
"Step".to_string(),
"Agent".to_string(),
"Work".to_string(),
"Result".to_string(),
));
strategy.add_instruction(StrategyInstruction::Loop(LoopBlock {
loop_id: "loop_1".to_string(),
description: Some("Loop".to_string()),
loop_type: Some(LoopType::While),
max_iterations: 3,
condition_template: Some("{{ continue }}".to_string()),
body: vec![StrategyInstruction::Step(StrategyStep::new(
"loop_step".to_string(),
"Loop step".to_string(),
"Agent".to_string(),
"Loop work".to_string(),
"Loop result".to_string(),
))],
aggregation: None,
}));
assert!(strategy.validate().is_ok());
}
#[test]
fn test_strategy_map_validation_nested_loop_fails() {
let mut strategy = StrategyMap::new("Invalid strategy".to_string());
let nested_loop = LoopBlock {
loop_id: "nested".to_string(),
description: Some("Nested".to_string()),
loop_type: Some(LoopType::While),
max_iterations: 2,
condition_template: None,
body: vec![StrategyInstruction::Step(StrategyStep::new(
"inner".to_string(),
"Inner".to_string(),
"Agent".to_string(),
"Work".to_string(),
"Result".to_string(),
))],
aggregation: None,
};
strategy.add_instruction(StrategyInstruction::Loop(LoopBlock {
loop_id: "outer".to_string(),
description: Some("Outer".to_string()),
loop_type: Some(LoopType::UntilConvergence),
max_iterations: 3,
condition_template: None,
body: vec![StrategyInstruction::Loop(nested_loop)],
aggregation: None,
}));
let result = strategy.validate();
assert!(result.is_err());
}
}