use crate::State;
use crate::agents::streaming::AgentEvent;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
#[non_exhaustive]
pub enum Directive {
#[serde(rename = "emit")]
Emit {
event: AgentEvent,
},
#[serde(rename = "spawn_agent")]
SpawnAgent {
name: String,
config: Value,
},
#[serde(rename = "stop_child")]
StopChild {
name: String,
},
#[serde(rename = "schedule")]
Schedule {
action: String,
#[serde(with = "humantime_serde")]
delay: Duration,
},
#[serde(rename = "run_instruction")]
RunInstruction {
instruction: String,
input: Value,
},
#[serde(rename = "cron")]
Cron {
expression: String,
action: String,
},
#[serde(rename = "stop")]
Stop {
reason: Option<String>,
},
#[serde(rename = "spawn_task")]
SpawnTask {
description: String,
input: Value,
},
#[serde(rename = "stop_task")]
StopTask {
task_id: String,
},
#[serde(rename = "custom")]
Custom {
#[serde(flatten)]
payload: Box<dyn DirectivePayload>,
},
}
#[typetag::serde(tag = "custom_type")]
pub trait DirectivePayload: std::fmt::Debug + Send + Sync + dyn_clone::DynClone {}
dyn_clone::clone_trait_object!(DirectivePayload);
#[derive(Debug, Clone)]
pub struct DirectiveResult<S: State> {
pub state: S,
pub directives: Vec<Directive>,
}
impl<S: State> DirectiveResult<S> {
#[must_use]
pub const fn state_only(state: S) -> Self {
Self {
state,
directives: Vec::new(),
}
}
#[must_use]
pub fn with_directive(state: S, directive: Directive) -> Self {
Self {
state,
directives: vec![directive],
}
}
#[must_use]
pub const fn with_directives(state: S, directives: Vec<Directive>) -> Self {
Self { state, directives }
}
}
impl<S: State> From<S> for DirectiveResult<S> {
fn from(state: S) -> Self {
Self::state_only(state)
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::redundant_clone
)]
mod tests {
use super::*;
use crate::agents::streaming::{AgentEvent, TerminationReason};
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
struct TestState {
count: u32,
}
impl State for TestState {
fn to_value(&self) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
Ok(serde_json::to_value(self)?)
}
fn from_value(value: Value) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
Ok(serde_json::from_value(value)?)
}
}
#[test]
fn test_directive_result_state_only() {
let state = TestState { count: 1 };
let result = DirectiveResult::state_only(state.clone());
assert_eq!(result.state, state);
assert!(result.directives.is_empty());
}
#[test]
fn test_directive_result_with_directive() {
let state = TestState { count: 1 };
let directive = Directive::Stop { reason: None };
let result = DirectiveResult::with_directive(state.clone(), directive.clone());
assert_eq!(result.state, state);
assert_eq!(result.directives.len(), 1);
}
#[test]
fn test_directive_result_from_state() {
let state = TestState { count: 1 };
let result: DirectiveResult<TestState> = state.clone().into();
assert_eq!(result.state, state);
assert!(result.directives.is_empty());
}
#[test]
fn test_directive_serde_emit() {
let directive = Directive::Emit {
event: AgentEvent::TurnComplete {
reason: TerminationReason::Complete,
},
};
let json = serde_json::to_string(&directive).expect("serialize");
let deserialized: Directive = serde_json::from_str(&json).expect("deserialize");
assert!(matches!(deserialized, Directive::Emit { .. }));
}
#[test]
fn test_directive_serde_spawn_agent() {
let directive = Directive::SpawnAgent {
name: "helper".to_string(),
config: serde_json::json!({"model": "gpt-4"}),
};
let json = serde_json::to_string(&directive).expect("serialize");
let deserialized: Directive = serde_json::from_str(&json).expect("deserialize");
assert!(matches!(deserialized, Directive::SpawnAgent { .. }));
}
#[test]
fn test_directive_serde_stop() {
let directive = Directive::Stop {
reason: Some("Task complete".to_string()),
};
let json = serde_json::to_string(&directive).expect("serialize");
let deserialized: Directive = serde_json::from_str(&json).expect("deserialize");
assert!(matches!(deserialized, Directive::Stop { .. }));
}
}