use chrono::Utc;
use log::{debug, info, warn};
use petgraph::algo::toposort;
use petgraph::graph::NodeIndex;
use petgraph::visit::EdgeRef;
use regex::Regex;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::time::{Duration, timeout};
use uuid::Uuid;
use paladin_core::platform::container::battalion::campaign::{Campaign, EdgeCondition};
use paladin_core::platform::container::battalion::{BattalionError, BattalionResult};
use paladin_ports::output::paladin_port::{PaladinPort, PaladinResult};
pub struct CampaignExecutionService {
paladin_port: Arc<dyn PaladinPort>,
}
impl CampaignExecutionService {
pub fn new(paladin_port: Arc<dyn PaladinPort>) -> Self {
info!("Creating CampaignExecutionService");
Self { paladin_port }
}
pub async fn execute(
&self,
campaign: &Campaign,
initial_input: &str,
) -> Result<BattalionResult, BattalionError> {
campaign.validate()?;
let battalion_id = Uuid::new_v4();
let started_at = Utc::now();
info!(
"Starting Campaign execution: {} (ID: {}) with {} Paladins",
campaign.config().name,
battalion_id,
campaign.paladin_count()
);
let timeout_duration = Duration::from_secs(campaign.config().timeout_seconds);
match timeout(
timeout_duration,
self.execute_internal(campaign, initial_input, battalion_id),
)
.await
{
Ok(result) => {
let duration_ms = Utc::now()
.signed_duration_since(started_at)
.num_milliseconds() as u64;
info!("Campaign {} completed in {}ms", battalion_id, duration_ms);
result
}
Err(_) => {
warn!(
"Campaign {} timed out after {} seconds",
battalion_id,
campaign.config().timeout_seconds
);
Err(BattalionError::Timeout(campaign.config().timeout_seconds))
}
}
}
async fn execute_internal(
&self,
campaign: &Campaign,
initial_input: &str,
battalion_id: Uuid,
) -> Result<BattalionResult, BattalionError> {
let started_at = Utc::now();
let entry_points = campaign.entry_points();
debug!("Campaign entry points: {} nodes", entry_points.len());
let sorted_nodes = toposort(campaign.graph(), None).map_err(|cycle| {
BattalionError::InvalidGraph(format!(
"Cycle detected in campaign graph at node {:?}",
cycle.node_id()
))
})?;
debug!("Topological sort completed: {} nodes", sorted_nodes.len());
let mut node_outputs: HashMap<Uuid, String> = HashMap::new();
let mut all_results: Vec<PaladinResult> = Vec::new();
let mut ready_nodes: HashSet<Uuid> = entry_points.clone();
let mut executed_nodes: HashSet<Uuid> = HashSet::new();
for node_index in sorted_nodes {
let node_id = campaign.graph()[node_index];
if !ready_nodes.contains(&node_id) {
continue;
}
let paladin = campaign.get_paladin(&node_id).ok_or_else(|| {
BattalionError::InvalidGraph(format!(
"Node {:?} not found in paladins map",
node_id
))
})?;
let input = if entry_points.contains(&node_id) {
initial_input.to_string()
} else {
self.aggregate_inputs_for_node(campaign, node_id, &node_outputs)?
};
debug!("Executing Paladin: {} ({})", paladin.node.name, node_id);
let result = self
.paladin_port
.execute(paladin, &input)
.await
.map_err(|e| BattalionError::PaladinError(e.to_string()))?;
debug!(
"Paladin {} completed: {} tokens, {} loops",
paladin.node.name, result.token_count, result.loop_count
);
node_outputs.insert(node_id, result.output.clone());
all_results.push(result.clone());
executed_nodes.insert(node_id);
let edges = campaign.graph().edges(node_index);
for edge in edges {
let edge_data = edge.weight();
let target_id = campaign.graph()[edge.target()];
if self.evaluate_edge_condition(&edge_data.condition, &result.output)? {
debug!("Edge condition satisfied: {} ā {}", node_id, target_id);
if let Some(transform_name) = &edge_data.transform {
debug!("Applying edge transform: {}", transform_name);
}
if self.are_dependencies_satisfied(campaign, edge.target(), &executed_nodes) {
ready_nodes.insert(target_id);
}
} else {
debug!("Edge condition NOT satisfied: {} ā {}", node_id, target_id);
}
}
}
let final_output = self.compute_final_output(&all_results);
let result = BattalionResult::new(
battalion_id,
campaign.config().name.clone(),
started_at,
final_output,
all_results,
);
Ok(result)
}
fn aggregate_inputs_for_node(
&self,
campaign: &Campaign,
node_id: Uuid,
node_outputs: &HashMap<Uuid, String>,
) -> Result<String, BattalionError> {
let node_index = campaign.node_indices().get(&node_id).ok_or_else(|| {
BattalionError::InvalidGraph(format!("Node {:?} not in indices map", node_id))
})?;
let mut inputs = Vec::new();
let incoming_edges = campaign
.graph()
.edges_directed(*node_index, petgraph::Direction::Incoming);
for edge in incoming_edges {
let source_id = campaign.graph()[edge.source()];
if let Some(output) = node_outputs.get(&source_id) {
inputs.push(output.clone());
}
}
if inputs.is_empty() {
Ok(String::new())
} else if inputs.len() == 1 {
Ok(inputs[0].clone())
} else {
Ok(inputs.join("\n\n---\n\n"))
}
}
fn evaluate_edge_condition(
&self,
condition: &EdgeCondition,
output: &str,
) -> Result<bool, BattalionError> {
match condition {
EdgeCondition::Always => Ok(true),
EdgeCondition::Contains(substring) => Ok(output.contains(substring)),
EdgeCondition::Regex(pattern) => {
let regex = Regex::new(pattern).map_err(|e| {
BattalionError::InvalidGraph(format!("Invalid regex pattern: {}", e))
})?;
Ok(regex.is_match(output))
}
EdgeCondition::Custom(_) => {
warn!("Custom edge condition not yet implemented, defaulting to true");
Ok(true)
}
}
}
fn are_dependencies_satisfied(
&self,
campaign: &Campaign,
target_index: NodeIndex,
executed_nodes: &HashSet<Uuid>,
) -> bool {
let incoming = campaign
.graph()
.edges_directed(target_index, petgraph::Direction::Incoming);
for edge in incoming {
let source_id = campaign.graph()[edge.source()];
if !executed_nodes.contains(&source_id) {
return false;
}
}
true
}
fn compute_final_output(&self, results: &[PaladinResult]) -> String {
if results.is_empty() {
return String::new();
}
results.last().unwrap().output.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use paladin_core::platform::container::paladin::Paladin;
#[test]
fn test_service_creation() {
use async_trait::async_trait;
use paladin_core::platform::container::paladin_error::PaladinError;
use paladin_ports::output::paladin_port::StopReason;
struct MockPort;
#[async_trait]
impl PaladinPort for MockPort {
async fn execute(
&self,
_paladin: &Paladin,
_input: &str,
) -> Result<PaladinResult, PaladinError> {
Ok(PaladinResult {
output: "test".to_string(),
token_count: 0,
execution_time_ms: 0,
loop_count: 1,
stop_reason: StopReason::Completed,
..Default::default()
})
}
async fn execute_stream(
&self,
_paladin: &Paladin,
_input: &str,
) -> Result<
tokio::sync::mpsc::Receiver<
Result<paladin_ports::output::paladin_port::PaladinStreamChunk, PaladinError>,
>,
PaladinError,
> {
unimplemented!()
}
fn validate(&self, _paladin: &Paladin) -> Result<(), PaladinError> {
Ok(())
}
}
let port = Arc::new(MockPort);
let _service = CampaignExecutionService::new(port);
}
}