use paladin_core::platform::container::battalion::BattalionError;
use paladin_core::platform::container::battalion::chain_of_command::{
ChainOfCommand, DelegationStrategy,
};
use paladin_core::platform::container::paladin_error::PaladinError;
use paladin_ports::output::paladin_port::{PaladinPort, PaladinResult};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub struct DelegationResult {
pub selected_specialists: Vec<String>,
pub reasoning: String,
pub outputs: Vec<String>,
}
pub struct ChainOfCommandExecutionService {
paladin_port: Arc<dyn PaladinPort>,
round_robin_state: Arc<Mutex<HashMap<String, usize>>>,
}
impl ChainOfCommandExecutionService {
pub fn new(paladin_port: Arc<dyn PaladinPort>) -> Self {
Self {
paladin_port,
round_robin_state: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn validate(&self, chain: &ChainOfCommand) -> Result<(), BattalionError> {
chain.validate()?;
Ok(())
}
pub async fn execute(
&self,
chain: &ChainOfCommand,
input: &str,
) -> Result<DelegationResult, BattalionError> {
self.validate(chain).await?;
match chain.delegation_strategy() {
DelegationStrategy::Automatic => self.execute_automatic(chain, input).await,
DelegationStrategy::Broadcast => self.execute_broadcast(chain, input).await,
DelegationStrategy::RoundRobin => self.execute_round_robin(chain, input).await,
DelegationStrategy::Custom(logic) => self.execute_custom(chain, input, logic).await,
}
}
async fn execute_automatic(
&self,
chain: &ChainOfCommand,
input: &str,
) -> Result<DelegationResult, BattalionError> {
let specialist_descriptions: Vec<String> = chain
.specialists()
.iter()
.map(|p| {
format!(
"- {}: {}",
p.node.name,
p.node.system_prompt.lines().next().unwrap_or("")
)
})
.collect();
let commander_prompt = format!(
r#"You are a commander coordinating a team of specialists. Your task is to analyze the following request and select the appropriate specialist(s) to handle it.
Available Specialists:
{}
User Request:
{}
Instructions:
1. Analyze the request carefully
2. Select one or more specialists best suited for this task
3. Respond EXACTLY in this format:
SELECT: specialist_name1, specialist_name2
REASON: Brief explanation of your selection
Important: Use the exact specialist names shown above. Separate multiple specialists with commas."#,
specialist_descriptions.join("\n"),
input
);
let commander_result = self
.paladin_port
.execute(chain.commander(), &commander_prompt)
.await?;
let (selected_names, reasoning) =
self.parse_commander_response(&commander_result.output)?;
let selected_specialists: Vec<&paladin_core::platform::container::paladin::Paladin> =
selected_names
.iter()
.map(|name| {
chain
.specialists()
.iter()
.find(|s| s.node.name == *name)
.ok_or_else(|| {
BattalionError::ExecutionError(format!(
"Commander selected non-existent specialist: {}",
name
))
})
})
.collect::<Result<Vec<_>, _>>()?;
let mut outputs = Vec::new();
for specialist in &selected_specialists {
let result = self.paladin_port.execute(specialist, input).await?;
outputs.push(result.output);
}
Ok(DelegationResult {
selected_specialists: selected_names,
reasoning,
outputs,
})
}
fn parse_commander_response(
&self,
response: &str,
) -> Result<(Vec<String>, String), BattalionError> {
let mut selected = Vec::new();
let mut reasoning = String::new();
for line in response.lines() {
let line = line.trim();
if line.starts_with("SELECT:") {
let selection = line.strip_prefix("SELECT:").unwrap().trim();
selected = selection
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
} else if line.starts_with("REASON:") {
reasoning = line.strip_prefix("REASON:").unwrap().trim().to_string();
}
}
if selected.is_empty() {
return Err(BattalionError::ExecutionError(
"Commander did not select any specialists".to_string(),
));
}
if reasoning.is_empty() {
reasoning = "No reasoning provided".to_string();
}
Ok((selected, reasoning))
}
async fn execute_broadcast(
&self,
chain: &ChainOfCommand,
input: &str,
) -> Result<DelegationResult, BattalionError> {
use tokio::task::JoinSet;
let mut join_set = JoinSet::new();
for specialist in chain.specialists() {
let specialist_clone: paladin_core::platform::container::paladin::Paladin =
specialist.clone();
let input_clone = input.to_string();
let port_clone = Arc::clone(&self.paladin_port);
join_set.spawn(async move {
let result: Result<PaladinResult, PaladinError> =
port_clone.execute(&specialist_clone, &input_clone).await;
(specialist_clone.node.name.clone(), result)
});
}
let mut outputs = Vec::new();
let mut selected_specialists = Vec::new();
while let Some(result) = join_set.join_next().await {
match result {
Ok((name, Ok(paladin_result))) => {
selected_specialists.push(name);
outputs.push(paladin_result.output);
}
Ok((name, Err(e))) => {
return Err(BattalionError::PaladinError(format!(
"Specialist {} failed: {}",
name, e
)));
}
Err(join_error) => {
return Err(BattalionError::ExecutionError(format!(
"Task join error: {}",
join_error
)));
}
}
}
Ok(DelegationResult {
selected_specialists,
reasoning: "Broadcast to all specialists concurrently".to_string(),
outputs,
})
}
async fn execute_round_robin(
&self,
chain: &ChainOfCommand,
input: &str,
) -> Result<DelegationResult, BattalionError> {
let specialists = chain.specialists();
if specialists.is_empty() {
return Err(BattalionError::ValidationError(
"No specialists available for round-robin delegation".to_string(),
));
}
let chain_id = format!(
"{}:{}",
chain.commander().node.name,
specialists
.iter()
.map(|s| s.node.name.as_str())
.collect::<Vec<_>>()
.join(",")
);
let current_index = {
let mut state = self.round_robin_state.lock().unwrap();
let index = state.entry(chain_id.clone()).or_insert(0);
let current = *index;
*index = (current + 1) % specialists.len();
current
};
let selected_specialist = &specialists[current_index];
let result = self
.paladin_port
.execute(selected_specialist, input)
.await?;
Ok(DelegationResult {
selected_specialists: vec![selected_specialist.node.name.clone()],
reasoning: format!(
"Round-robin delegation selected specialist {} of {}",
current_index + 1,
specialists.len()
),
outputs: vec![result.output],
})
}
async fn execute_custom(
&self,
chain: &ChainOfCommand,
input: &str,
logic: &str,
) -> Result<DelegationResult, BattalionError> {
let specialists = chain.specialists();
if specialists.is_empty() {
return Err(BattalionError::ValidationError(
"No specialists available for custom delegation".to_string(),
));
}
let selected_specialist = &specialists[0];
let result = self
.paladin_port
.execute(selected_specialist, input)
.await?;
Ok(DelegationResult {
selected_specialists: vec![selected_specialist.node.name.clone()],
reasoning: format!("Custom delegation using custom logic: {}", logic),
outputs: vec![result.output],
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use paladin_core::platform::container::battalion::BattalionConfig;
use paladin_core::platform::container::paladin::{Paladin, PaladinData};
use paladin_ports::output::paladin_port::PaladinResult;
fn create_test_paladin(name: &str) -> Paladin {
let data = PaladinData {
system_prompt: format!("{} system prompt", name),
name: name.to_string(),
user_name: "test_user".to_string(),
..Default::default()
};
Paladin::new(data, Some(name.to_string()))
}
#[test]
fn test_service_construction() {
use async_trait::async_trait;
use paladin_ports::output::paladin_port::StopReason;
struct MockPort;
#[async_trait]
impl PaladinPort for MockPort {
async fn execute(
&self,
_paladin: &Paladin,
_input: &str,
) -> Result<PaladinResult, paladin_core::platform::container::paladin_error::PaladinError>
{
Ok(PaladinResult {
output: String::new(),
token_count: 0,
execution_time_ms: 0,
loop_count: 1,
stop_reason: StopReason::Completed,
..Default::default()
})
}
async fn execute_stream(
&self,
_paladin: &Paladin,
_input: &str,
) -> Result<
paladin_ports::output::paladin_port::PaladinStream,
paladin_core::platform::container::paladin_error::PaladinError,
> {
unimplemented!()
}
fn validate(
&self,
_paladin: &Paladin,
) -> Result<(), paladin_core::platform::container::paladin_error::PaladinError>
{
Ok(())
}
}
let port = Arc::new(MockPort);
let _service = ChainOfCommandExecutionService::new(port);
}
#[tokio::test]
async fn test_validate_valid_chain() {
use async_trait::async_trait;
use paladin_ports::output::paladin_port::StopReason;
struct MockPort;
#[async_trait]
impl PaladinPort for MockPort {
async fn execute(
&self,
_paladin: &Paladin,
_input: &str,
) -> Result<PaladinResult, paladin_core::platform::container::paladin_error::PaladinError>
{
Ok(PaladinResult {
output: String::new(),
token_count: 0,
execution_time_ms: 0,
loop_count: 1,
stop_reason: StopReason::Completed,
..Default::default()
})
}
async fn execute_stream(
&self,
_paladin: &Paladin,
_input: &str,
) -> Result<
paladin_ports::output::paladin_port::PaladinStream,
paladin_core::platform::container::paladin_error::PaladinError,
> {
unimplemented!()
}
fn validate(
&self,
_paladin: &Paladin,
) -> Result<(), paladin_core::platform::container::paladin_error::PaladinError>
{
Ok(())
}
}
let port = Arc::new(MockPort);
let service = ChainOfCommandExecutionService::new(port);
let commander = create_test_paladin("commander");
let specialist = create_test_paladin("specialist");
let config = BattalionConfig::default();
let chain = ChainOfCommand::new(commander, vec![specialist], config)
.expect("Should create valid chain");
let result = service.validate(&chain).await;
assert!(result.is_ok());
}
}