use std::sync::Arc;
use crate::lens::Lens;
use crate::llm::LlmConfig;
use crate::tool::Tool;
#[derive(Clone)]
pub struct AgentDefinition {
pub name: String,
pub kind: AgentKind,
}
#[derive(Clone)]
pub enum AgentKind {
Llm(Box<LlmAgentConfig>),
Sequential(Vec<AgentDefinition>),
Parallel(Vec<AgentDefinition>),
Loop {
agent: Box<AgentDefinition>,
max_iterations: usize,
},
}
#[derive(Clone)]
pub struct LlmAgentConfig {
pub system_prompt: String,
pub tools: Vec<Arc<dyn Tool>>,
pub lens: Lens,
pub llm_config: LlmConfig,
pub experience_extractor: Option<Arc<dyn ExperienceExtractor>>,
pub refresh_every_n_tool_calls: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct ExtractionContext {
pub agent_id: String,
pub collective_id: pulsedb::CollectiveId,
pub task_description: String,
}
#[async_trait::async_trait]
pub trait ExperienceExtractor: Send + Sync {
async fn extract(
&self,
conversation: &[crate::llm::Message],
outcome: &AgentOutcome,
context: &ExtractionContext,
) -> Vec<pulsedb::NewExperience>;
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum AgentOutcome {
Complete { response: String },
Error { error: String },
MaxIterationsReached,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AgentKindTag {
Llm,
Sequential,
Parallel,
Loop,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_outcome_debug_clone() {
let outcome = AgentOutcome::Complete {
response: "Done!".into(),
};
let cloned = outcome.clone();
assert!(matches!(cloned, AgentOutcome::Complete { response } if response == "Done!"));
let debug = format!("{:?}", outcome);
assert!(debug.contains("Complete"));
}
#[test]
fn test_agent_outcome_variants() {
let complete = AgentOutcome::Complete {
response: "result".into(),
};
assert!(matches!(complete, AgentOutcome::Complete { .. }));
let error = AgentOutcome::Error {
error: "timeout".into(),
};
assert!(matches!(error, AgentOutcome::Error { .. }));
let max = AgentOutcome::MaxIterationsReached;
assert!(matches!(max, AgentOutcome::MaxIterationsReached));
}
#[test]
fn test_agent_kind_tag() {
assert_ne!(AgentKindTag::Llm, AgentKindTag::Sequential);
assert_eq!(AgentKindTag::Loop, AgentKindTag::Loop);
let tag = AgentKindTag::Parallel;
let copied = tag;
assert_eq!(tag, copied);
}
#[test]
fn test_sequential_workflow() {
let workflow = AgentDefinition {
name: "pipeline".into(),
kind: AgentKind::Sequential(vec![
AgentDefinition {
name: "step1".into(),
kind: AgentKind::Sequential(vec![]), },
AgentDefinition {
name: "step2".into(),
kind: AgentKind::Sequential(vec![]),
},
]),
};
assert_eq!(workflow.name, "pipeline");
match workflow.kind {
AgentKind::Sequential(children) => assert_eq!(children.len(), 2),
_ => panic!("Expected Sequential"),
}
}
#[test]
fn test_nested_workflow() {
let workflow = AgentDefinition {
name: "complex".into(),
kind: AgentKind::Sequential(vec![
AgentDefinition {
name: "explore".into(),
kind: AgentKind::Parallel(vec![
AgentDefinition {
name: "explorer_a".into(),
kind: AgentKind::Sequential(vec![]),
},
AgentDefinition {
name: "explorer_b".into(),
kind: AgentKind::Sequential(vec![]),
},
]),
},
AgentDefinition {
name: "refine".into(),
kind: AgentKind::Loop {
agent: Box::new(AgentDefinition {
name: "refiner".into(),
kind: AgentKind::Sequential(vec![]),
}),
max_iterations: 5,
},
},
]),
};
assert_eq!(workflow.name, "complex");
}
#[test]
fn test_loop_workflow() {
let looped = AgentDefinition {
name: "iterator".into(),
kind: AgentKind::Loop {
agent: Box::new(AgentDefinition {
name: "worker".into(),
kind: AgentKind::Sequential(vec![]),
}),
max_iterations: 10,
},
};
match looped.kind {
AgentKind::Loop { max_iterations, .. } => assert_eq!(max_iterations, 10),
_ => panic!("Expected Loop"),
}
}
}