use serde::{Deserialize, Serialize};
use crate::base::entity::node::Node;
use super::{BattalionError, GroveError};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum RoutingStrategy {
#[default]
KeywordMatch,
SemanticSimilarity,
LlmRouting,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TreeAgent {
pub paladin_id: String,
pub expertise_keywords: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expertise_embedding: Option<Vec<f32>>,
}
impl TreeAgent {
pub fn new(paladin_id: impl Into<String>) -> Self {
Self {
paladin_id: paladin_id.into(),
expertise_keywords: Vec::new(),
expertise_embedding: None,
}
}
pub fn with_keywords(mut self, keywords: Vec<impl Into<String>>) -> Self {
self.expertise_keywords = keywords.into_iter().map(|k| k.into()).collect();
self
}
pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
self.expertise_embedding = Some(embedding);
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Tree {
pub name: String,
pub agents: Vec<TreeAgent>,
}
impl Tree {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
agents: Vec::new(),
}
}
pub fn add_agent(mut self, agent: TreeAgent) -> Self {
self.agents.push(agent);
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GroveConfig {
pub routing_strategy: RoutingStrategy,
#[serde(skip_serializing_if = "Option::is_none")]
pub fallback_tree: Option<String>,
pub similarity_threshold: f32,
pub routing_fallback: String,
pub min_confidence: f32,
}
impl Default for GroveConfig {
fn default() -> Self {
Self {
routing_strategy: RoutingStrategy::default(),
fallback_tree: None,
similarity_threshold: 0.7,
routing_fallback: "keyword".to_string(),
min_confidence: 0.5,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GroveData {
pub name: String,
pub trees: Vec<Tree>,
pub config: GroveConfig,
}
pub type Grove = Node<GroveData>;
pub struct GroveBuilder {
name: String,
trees: Vec<Tree>,
routing_strategy: RoutingStrategy,
fallback_tree: Option<String>,
similarity_threshold: f32,
routing_fallback: String,
min_confidence: f32,
}
impl GroveBuilder {
pub fn new() -> Self {
Self {
name: String::new(),
trees: Vec::new(),
routing_strategy: RoutingStrategy::default(),
fallback_tree: None,
similarity_threshold: 0.7,
routing_fallback: "keyword".to_string(),
min_confidence: 0.5,
}
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
pub fn add_tree(mut self, tree: Tree) -> Self {
self.trees.push(tree);
self
}
pub fn routing_strategy(mut self, strategy: RoutingStrategy) -> Self {
self.routing_strategy = strategy;
self
}
pub fn fallback_tree(mut self, tree_name: impl Into<String>) -> Self {
self.fallback_tree = Some(tree_name.into());
self
}
pub fn similarity_threshold(mut self, threshold: f32) -> Self {
self.similarity_threshold = threshold;
self
}
pub fn routing_fallback(mut self, fallback: impl Into<String>) -> Self {
self.routing_fallback = fallback.into();
self
}
pub fn min_confidence(mut self, threshold: f32) -> Self {
self.min_confidence = threshold;
self
}
pub fn config(mut self, config: GroveConfig) -> Self {
self.routing_strategy = config.routing_strategy;
self.fallback_tree = config.fallback_tree;
self.similarity_threshold = config.similarity_threshold;
self.routing_fallback = config.routing_fallback;
self.min_confidence = config.min_confidence;
self
}
pub fn build(self) -> Result<Grove, BattalionError> {
if self.name.is_empty() {
return Err(BattalionError::ValidationError(
"Grove name cannot be empty".to_string(),
));
}
if self.trees.is_empty() {
return Err(GroveError::NoTrees.into());
}
let total_agents: usize = self.trees.iter().map(|t| t.agents.len()).sum();
if total_agents == 0 {
return Err(GroveError::NoAgents.into());
}
for tree in &self.trees {
if tree.agents.is_empty() {
return Err(BattalionError::ValidationError(format!(
"Tree '{}' must have at least one agent",
tree.name
)));
}
}
if !(0.0..=1.0).contains(&self.similarity_threshold) {
return Err(GroveError::InvalidSimilarityThreshold(self.similarity_threshold).into());
}
if self.routing_fallback != "keyword" && self.routing_fallback != "error" {
return Err(BattalionError::ValidationError(format!(
"routing_fallback must be 'keyword' or 'error', got '{}'",
self.routing_fallback
)));
}
if !(0.0..=1.0).contains(&self.min_confidence) {
return Err(BattalionError::ValidationError(format!(
"min_confidence must be between 0.0 and 1.0, got {}",
self.min_confidence
)));
}
if let Some(ref fallback) = self.fallback_tree
&& !self.trees.iter().any(|t| &t.name == fallback)
{
return Err(BattalionError::ValidationError(format!(
"Fallback tree '{}' not found in Grove trees",
fallback
)));
}
let name = self.name.clone();
let data = GroveData {
name: self.name,
trees: self.trees,
config: GroveConfig {
routing_strategy: self.routing_strategy,
fallback_tree: self.fallback_tree,
similarity_threshold: self.similarity_threshold,
routing_fallback: self.routing_fallback,
min_confidence: self.min_confidence,
},
};
Ok(Node::new(data, Some(name)))
}
}
impl Default for GroveBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_routing_strategy_default() {
let strategy = RoutingStrategy::default();
assert_eq!(strategy, RoutingStrategy::KeywordMatch);
}
#[test]
fn test_tree_agent_creation() {
let agent = TreeAgent::new("test_agent")
.with_keywords(vec!["rust", "backend", "api"])
.with_embedding(vec![0.1, 0.2, 0.3]);
assert_eq!(agent.paladin_id, "test_agent");
assert_eq!(agent.expertise_keywords, vec!["rust", "backend", "api"]);
assert!(agent.expertise_embedding.is_some());
assert_eq!(agent.expertise_embedding.unwrap().len(), 3);
}
#[test]
fn test_tree_creation() {
let tree = Tree::new("technical_support")
.add_agent(TreeAgent::new("backend_expert"))
.add_agent(TreeAgent::new("frontend_expert"));
assert_eq!(tree.name, "technical_support");
assert_eq!(tree.agents.len(), 2);
}
#[test]
fn test_grove_config_default() {
let config = GroveConfig::default();
assert_eq!(config.routing_strategy, RoutingStrategy::KeywordMatch);
assert!(config.fallback_tree.is_none());
assert_eq!(config.similarity_threshold, 0.7);
}
#[test]
fn test_grove_builder_basic() {
let grove = GroveBuilder::new()
.name("Test Grove")
.add_tree(
Tree::new("support")
.add_agent(TreeAgent::new("agent1").with_keywords(vec!["help", "support"])),
)
.build();
assert!(grove.is_ok());
let grove = grove.unwrap();
assert_eq!(grove.node.name, "Test Grove");
assert_eq!(grove.node.trees.len(), 1);
}
#[test]
fn test_grove_builder_validation_empty_name() {
let result = GroveBuilder::new()
.add_tree(Tree::new("support").add_agent(TreeAgent::new("agent1")))
.build();
assert!(result.is_err());
assert!(matches!(result, Err(BattalionError::ValidationError(_))));
}
#[test]
fn test_grove_builder_validation_no_trees() {
let result = GroveBuilder::new().name("Test Grove").build();
assert!(result.is_err());
assert!(matches!(
result,
Err(BattalionError::GroveError(GroveError::NoTrees))
));
}
#[test]
fn test_grove_builder_validation_tree_without_agents() {
let result = GroveBuilder::new()
.name("Test Grove")
.add_tree(Tree::new("empty_tree"))
.build();
assert!(result.is_err());
match result {
Err(BattalionError::GroveError(GroveError::NoAgents)) => {
}
other => panic!("Expected GroveError::NoAgents, got {:?}", other),
}
}
#[test]
fn test_grove_builder_validation_invalid_threshold() {
let result = GroveBuilder::new()
.name("Test Grove")
.add_tree(Tree::new("support").add_agent(TreeAgent::new("agent1")))
.similarity_threshold(1.5)
.build();
assert!(result.is_err());
match result {
Err(BattalionError::GroveError(GroveError::InvalidSimilarityThreshold(threshold))) => {
assert_eq!(threshold, 1.5);
}
other => panic!(
"Expected GroveError::InvalidSimilarityThreshold, got {:?}",
other
),
}
}
#[test]
fn test_grove_builder_validation_invalid_fallback() {
let result = GroveBuilder::new()
.name("Test Grove")
.add_tree(Tree::new("support").add_agent(TreeAgent::new("agent1")))
.fallback_tree("nonexistent")
.build();
assert!(result.is_err());
match result {
Err(BattalionError::ValidationError(msg)) => {
assert!(msg.contains("Fallback tree"));
assert!(msg.contains("not found"));
}
_ => panic!("Expected ValidationError"),
}
}
#[test]
fn test_grove_builder_with_valid_fallback() {
let grove = GroveBuilder::new()
.name("Test Grove")
.add_tree(Tree::new("support").add_agent(TreeAgent::new("agent1")))
.add_tree(Tree::new("general").add_agent(TreeAgent::new("agent2")))
.fallback_tree("general")
.build();
assert!(grove.is_ok());
let grove = grove.unwrap();
assert_eq!(grove.node.config.fallback_tree, Some("general".to_string()));
}
#[test]
fn test_grove_builder_full_config() {
let grove = GroveBuilder::new()
.name("Advanced Grove")
.add_tree(
Tree::new("technical")
.add_agent(
TreeAgent::new("expert1")
.with_keywords(vec!["rust", "systems"])
.with_embedding(vec![0.1, 0.2]),
)
.add_agent(TreeAgent::new("expert2").with_keywords(vec!["web", "api"])),
)
.add_tree(Tree::new("general").add_agent(TreeAgent::new("generalist")))
.routing_strategy(RoutingStrategy::SemanticSimilarity)
.similarity_threshold(0.8)
.fallback_tree("general")
.build();
assert!(grove.is_ok());
let grove = grove.unwrap();
assert_eq!(grove.node.name, "Advanced Grove");
assert_eq!(grove.node.trees.len(), 2);
assert_eq!(
grove.node.config.routing_strategy,
RoutingStrategy::SemanticSimilarity
);
assert_eq!(grove.node.config.similarity_threshold, 0.8);
assert_eq!(grove.node.config.fallback_tree, Some("general".to_string()));
}
#[test]
fn test_grove_serialization() {
let grove = GroveBuilder::new()
.name("Serialization Test")
.add_tree(
Tree::new("test_tree")
.add_agent(TreeAgent::new("agent1").with_keywords(vec!["test"])),
)
.build()
.unwrap();
let json = serde_json::to_string(&grove.node).unwrap();
let deserialized: GroveData = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.name, "Serialization Test");
assert_eq!(deserialized.trees.len(), 1);
}
#[test]
fn test_grove_config_validation_routing_fallback() {
let result = GroveBuilder::new()
.name("Test Grove")
.add_tree(Tree::new("support").add_agent(TreeAgent::new("agent1")))
.routing_fallback("keyword")
.build();
assert!(result.is_ok());
let result = GroveBuilder::new()
.name("Test Grove")
.add_tree(Tree::new("support").add_agent(TreeAgent::new("agent1")))
.routing_fallback("error")
.build();
assert!(result.is_ok());
let result = GroveBuilder::new()
.name("Test Grove")
.add_tree(Tree::new("support").add_agent(TreeAgent::new("agent1")))
.routing_fallback("invalid")
.build();
assert!(result.is_err());
match result {
Err(BattalionError::ValidationError(msg)) => {
assert!(msg.contains("routing_fallback"));
assert!(msg.contains("'keyword' or 'error'"));
}
_ => panic!("Expected ValidationError for invalid routing_fallback"),
}
}
#[test]
fn test_grove_config_validation_min_confidence() {
let result = GroveBuilder::new()
.name("Test Grove")
.add_tree(Tree::new("support").add_agent(TreeAgent::new("agent1")))
.min_confidence(0.0)
.build();
assert!(result.is_ok());
let result = GroveBuilder::new()
.name("Test Grove")
.add_tree(Tree::new("support").add_agent(TreeAgent::new("agent1")))
.min_confidence(1.0)
.build();
assert!(result.is_ok());
let result = GroveBuilder::new()
.name("Test Grove")
.add_tree(Tree::new("support").add_agent(TreeAgent::new("agent1")))
.min_confidence(0.5)
.build();
assert!(result.is_ok());
let result = GroveBuilder::new()
.name("Test Grove")
.add_tree(Tree::new("support").add_agent(TreeAgent::new("agent1")))
.min_confidence(-0.1)
.build();
assert!(result.is_err());
match result {
Err(BattalionError::ValidationError(msg)) => {
assert!(msg.contains("min_confidence"));
assert!(msg.contains("between 0.0 and 1.0"));
}
_ => panic!("Expected ValidationError for negative min_confidence"),
}
let result = GroveBuilder::new()
.name("Test Grove")
.add_tree(Tree::new("support").add_agent(TreeAgent::new("agent1")))
.min_confidence(1.5)
.build();
assert!(result.is_err());
match result {
Err(BattalionError::ValidationError(msg)) => {
assert!(msg.contains("min_confidence"));
assert!(msg.contains("between 0.0 and 1.0"));
}
_ => panic!("Expected ValidationError for min_confidence > 1.0"),
}
}
}