use serde::{Deserialize, Serialize};
use crate::platform::container::battalion::{BattalionConfig, BattalionError};
use crate::platform::container::paladin::Paladin;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub enum AggregationStrategy {
#[default]
CollectAll,
FirstSuccess,
Majority,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Phalanx {
paladins: Vec<Paladin>,
config: BattalionConfig,
aggregation_strategy: AggregationStrategy,
max_concurrency: Option<usize>,
}
impl Phalanx {
pub fn new(paladins: Vec<Paladin>, config: BattalionConfig) -> Result<Self, BattalionError> {
let phalanx = Self {
paladins,
config,
aggregation_strategy: AggregationStrategy::default(),
max_concurrency: None,
};
phalanx.validate()?;
Ok(phalanx)
}
pub fn with_aggregation(mut self, strategy: AggregationStrategy) -> Self {
self.aggregation_strategy = strategy;
self
}
pub fn with_max_concurrency(mut self, max: usize) -> Self {
self.max_concurrency = Some(max);
self
}
pub fn paladin_count(&self) -> usize {
self.paladins.len()
}
pub fn paladins(&self) -> &[Paladin] {
&self.paladins
}
pub fn config(&self) -> &BattalionConfig {
&self.config
}
pub fn aggregation_strategy(&self) -> &AggregationStrategy {
&self.aggregation_strategy
}
pub fn max_concurrency(&self) -> Option<usize> {
self.max_concurrency
}
fn validate(&self) -> Result<(), BattalionError> {
if self.paladins.len() < 2 {
return Err(BattalionError::ValidationError(
"Phalanx requires at least 2 Paladins".to_string(),
));
}
if matches!(self.aggregation_strategy, AggregationStrategy::Majority)
&& self.paladins.len() < 3
{
return Err(BattalionError::ValidationError(
"Majority aggregation requires at least 3 Paladins".to_string(),
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::base::entity::node::Node;
use crate::platform::container::paladin::{MaxLoops, PaladinData, PaladinStatus};
fn create_test_paladin(name: &str) -> Paladin {
let data = PaladinData {
system_prompt: format!("{} system prompt", name),
name: name.to_string(),
user_name: "TestUser".to_string(),
model: "gpt-4".to_string(),
temperature: 0.7,
max_loops: MaxLoops::Fixed(3),
stop_words: vec![],
status: PaladinStatus::Idle,
vision_enabled: false,
..Default::default()
};
Node::new(data, Some(name.to_string()))
}
#[test]
fn test_phalanx_creation_valid() {
let p1 = create_test_paladin("Agent1");
let p2 = create_test_paladin("Agent2");
let config = BattalionConfig::new("test_phalanx");
let result = Phalanx::new(vec![p1, p2], config);
assert!(result.is_ok());
let phalanx = result.unwrap();
assert_eq!(phalanx.paladin_count(), 2);
}
#[test]
fn test_phalanx_requires_minimum_two_paladins() {
let p1 = create_test_paladin("Agent1");
let config = BattalionConfig::new("test_phalanx");
let result = Phalanx::new(vec![p1], config);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("at least 2 Paladins")
);
}
#[test]
fn test_phalanx_with_aggregation_strategy() {
let p1 = create_test_paladin("Agent1");
let p2 = create_test_paladin("Agent2");
let config = BattalionConfig::new("test_phalanx");
let phalanx = Phalanx::new(vec![p1, p2], config)
.unwrap()
.with_aggregation(AggregationStrategy::FirstSuccess);
assert!(matches!(
phalanx.aggregation_strategy(),
&AggregationStrategy::FirstSuccess
));
}
#[test]
fn test_phalanx_with_max_concurrency() {
let p1 = create_test_paladin("Agent1");
let p2 = create_test_paladin("Agent2");
let config = BattalionConfig::new("test_phalanx");
let phalanx = Phalanx::new(vec![p1, p2], config)
.unwrap()
.with_max_concurrency(5);
assert_eq!(phalanx.max_concurrency(), Some(5));
}
#[test]
fn test_majority_strategy_validation() {
let p1 = create_test_paladin("Agent1");
let p2 = create_test_paladin("Agent2");
let config = BattalionConfig::new("test_phalanx");
let result = Phalanx::new(vec![p1, p2], config);
assert!(result.is_ok());
}
#[test]
fn test_phalanx_accessors() {
let p1 = create_test_paladin("Agent1");
let p2 = create_test_paladin("Agent2");
let config = BattalionConfig::new("test_phalanx");
let phalanx = Phalanx::new(vec![p1, p2], config).unwrap();
assert_eq!(phalanx.paladin_count(), 2);
assert_eq!(phalanx.paladins().len(), 2);
assert_eq!(phalanx.config().name, "test_phalanx");
assert!(matches!(
phalanx.aggregation_strategy(),
&AggregationStrategy::CollectAll
));
assert_eq!(phalanx.max_concurrency(), None);
}
}