use crate::platform::container::battalion::{BattalionConfig, BattalionError};
use crate::platform::container::paladin::Paladin;
use petgraph::graph::{DiGraph, NodeIndex};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum EdgeCondition {
Always,
Contains(String),
Regex(String),
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CampaignEdge {
pub source: Uuid,
pub target: Uuid,
pub condition: EdgeCondition,
pub transform: Option<String>,
}
impl CampaignEdge {
pub fn new(source: Uuid, target: Uuid, condition: EdgeCondition) -> Self {
Self {
source,
target,
condition,
transform: None,
}
}
pub fn with_transform(mut self, transform: String) -> Self {
self.transform = Some(transform);
self
}
}
#[derive(Debug, Clone)]
pub struct Campaign {
config: BattalionConfig,
graph: DiGraph<Uuid, CampaignEdge>,
paladins: HashMap<Uuid, Paladin>,
node_indices: HashMap<Uuid, NodeIndex>,
entry_points: HashSet<Uuid>,
}
impl Campaign {
pub fn new(config: BattalionConfig) -> Self {
Self {
config,
graph: DiGraph::new(),
paladins: HashMap::new(),
node_indices: HashMap::new(),
entry_points: HashSet::new(),
}
}
pub fn add_paladin(&mut self, paladin: Paladin) -> Uuid {
let uuid = paladin.uuid;
let node_index = self.graph.add_node(uuid);
self.paladins.insert(uuid, paladin);
self.node_indices.insert(uuid, node_index);
uuid
}
pub fn add_edge(&mut self, edge: CampaignEdge) -> Result<(), BattalionError> {
let source_idx = self.node_indices.get(&edge.source).ok_or_else(|| {
BattalionError::InvalidGraph(format!(
"Source Paladin {} not found in campaign",
edge.source
))
})?;
let target_idx = self.node_indices.get(&edge.target).ok_or_else(|| {
BattalionError::InvalidGraph(format!(
"Target Paladin {} not found in campaign",
edge.target
))
})?;
self.graph.add_edge(*source_idx, *target_idx, edge);
Ok(())
}
pub fn set_entry_point(&mut self, paladin_id: Uuid) -> Result<(), BattalionError> {
if !self.paladins.contains_key(&paladin_id) {
return Err(BattalionError::InvalidGraph(format!(
"Paladin {} not found in campaign, cannot set as entry point",
paladin_id
)));
}
self.entry_points.insert(paladin_id);
Ok(())
}
pub fn entry_points(&self) -> HashSet<Uuid> {
if !self.entry_points.is_empty() {
return self.entry_points.clone();
}
let mut entries = HashSet::new();
for (uuid, &node_idx) in &self.node_indices {
let has_incoming = self
.graph
.edges_directed(node_idx, petgraph::Direction::Incoming)
.next()
.is_some();
if !has_incoming {
entries.insert(*uuid);
}
}
entries
}
pub fn validate(&self) -> Result<(), BattalionError> {
if self.paladins.is_empty() {
return Err(BattalionError::InvalidGraph(
"Campaign must have at least one Paladin".to_string(),
));
}
if petgraph::algo::toposort(&self.graph, None).is_err() {
return Err(BattalionError::InvalidGraph(
"Campaign graph contains a cycle, must be a DAG".to_string(),
));
}
Ok(())
}
pub fn has_paladin(&self, paladin_id: &Uuid) -> bool {
self.paladins.contains_key(paladin_id)
}
pub fn paladin_count(&self) -> usize {
self.paladins.len()
}
pub fn edge_count(&self) -> usize {
self.graph.edge_count()
}
pub fn config(&self) -> &BattalionConfig {
&self.config
}
pub fn get_paladin(&self, paladin_id: &Uuid) -> Option<&Paladin> {
self.paladins.get(paladin_id)
}
pub fn paladins(&self) -> &HashMap<Uuid, Paladin> {
&self.paladins
}
pub fn graph(&self) -> &DiGraph<Uuid, CampaignEdge> {
&self.graph
}
pub fn node_indices(&self) -> &HashMap<Uuid, NodeIndex> {
&self.node_indices
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::base::entity::node::Node;
use crate::platform::container::paladin::{MaxLoops, PaladinData, PaladinStatus};
fn create_test_paladin(name: &str) -> Paladin {
let data = PaladinData {
system_prompt: format!("You are {}", name),
name: name.to_string(),
user_name: "User".to_string(),
model: "gpt-4".to_string(),
temperature: 0.7,
max_loops: MaxLoops::Fixed(1),
stop_words: vec![],
status: PaladinStatus::Idle,
vision_enabled: false,
..Default::default()
};
Node::new(data, Some(name.to_string()))
}
#[test]
fn test_campaign_creation() {
let config = BattalionConfig::new("test");
let campaign = Campaign::new(config);
assert_eq!(campaign.paladin_count(), 0);
assert_eq!(campaign.edge_count(), 0);
}
#[test]
fn test_add_paladin_returns_uuid() {
let config = BattalionConfig::new("test");
let mut campaign = Campaign::new(config);
let paladin = create_test_paladin("Test");
let uuid = campaign.add_paladin(paladin);
assert!(campaign.has_paladin(&uuid));
assert_eq!(campaign.paladin_count(), 1);
}
#[test]
fn test_add_edge_success() {
let config = BattalionConfig::new("test");
let mut campaign = Campaign::new(config);
let id1 = campaign.add_paladin(create_test_paladin("P1"));
let id2 = campaign.add_paladin(create_test_paladin("P2"));
let edge = CampaignEdge::new(id1, id2, EdgeCondition::Always);
let result = campaign.add_edge(edge);
assert!(result.is_ok());
assert_eq!(campaign.edge_count(), 1);
}
#[test]
fn test_validate_empty_campaign_fails() {
let config = BattalionConfig::new("test");
let campaign = Campaign::new(config);
let result = campaign.validate();
assert!(result.is_err());
}
#[test]
fn test_validate_single_paladin_succeeds() {
let config = BattalionConfig::new("test");
let mut campaign = Campaign::new(config);
campaign.add_paladin(create_test_paladin("P1"));
let result = campaign.validate();
assert!(result.is_ok());
}
}