use log::{debug, info, warn};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use paladin_core::platform::container::battalion::BattalionError;
use paladin_core::platform::container::battalion::grove::{
Grove, RoutingStrategy, Tree, TreeAgent,
};
use paladin_core::platform::container::prompt::{PromptItem, PromptType, UserPrompt};
use paladin_ports::output::embedding_port::EmbeddingPort;
use paladin_ports::output::llm_port::{LlmPort, LlmRequest};
use paladin_ports::output::paladin_port::PaladinPort;
use paladin_ports::output::paladin_registry::PaladinRegistry;
#[derive(Debug, Clone)]
pub struct RoutingDecision {
pub selected_tree: String,
pub selected_agent: String,
pub confidence: f32,
pub reasoning: String,
}
#[derive(Debug, Clone)]
pub struct GroveResult {
pub routing_decision: RoutingDecision,
pub execution_result: String,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct RoutingResponse {
tree_name: String,
agent_id: String,
confidence: f32,
reasoning: String,
}
pub struct GroveExecutionService {
paladin_port: Arc<dyn PaladinPort>,
embedding_port: Option<Arc<dyn EmbeddingPort>>,
#[allow(dead_code)]
llm_port: Option<Arc<dyn LlmPort>>,
registry: Arc<dyn PaladinRegistry>,
}
impl GroveExecutionService {
pub fn new(
paladin_port: Arc<dyn PaladinPort>,
embedding_port: Option<Arc<dyn EmbeddingPort>>,
llm_port: Option<Arc<dyn LlmPort>>,
registry: Arc<dyn PaladinRegistry>,
) -> Self {
Self {
paladin_port,
embedding_port,
llm_port,
registry,
}
}
pub async fn execute(&self, grove: &Grove, task: &str) -> Result<GroveResult, BattalionError> {
info!("Grove '{}' routing task: {}", grove.node.name, task);
let routing_decision = self.route_task(grove, task).await?;
debug!(
"Routed to agent '{}' in tree '{}' (confidence: {:.2})",
routing_decision.selected_agent,
routing_decision.selected_tree,
routing_decision.confidence
);
let paladin = self
.registry
.get(&routing_decision.selected_agent)
.ok_or_else(|| {
BattalionError::PaladinNotFound(format!(
"Agent '{}' not found in registry",
routing_decision.selected_agent
))
})?;
let execution_result = self.execute_agent(&paladin, task).await.map_err(|e| {
BattalionError::ExecutionError(format!(
"Failed to execute agent '{}': {}",
routing_decision.selected_agent, e
))
})?;
let mut metadata = HashMap::new();
metadata.insert("grove_name".to_string(), grove.node.name.clone());
metadata.insert(
"routing_strategy".to_string(),
format!("{:?}", grove.node.config.routing_strategy),
);
metadata.insert(
"confidence".to_string(),
routing_decision.confidence.to_string(),
);
Ok(GroveResult {
routing_decision,
execution_result,
metadata,
})
}
async fn route_task(
&self,
grove: &Grove,
task: &str,
) -> Result<RoutingDecision, BattalionError> {
let strategy = &grove.node.config.routing_strategy;
let result = match strategy {
RoutingStrategy::KeywordMatch => self.route_by_keywords(grove, task),
RoutingStrategy::SemanticSimilarity => {
self.route_by_semantic_similarity(grove, task).await
}
RoutingStrategy::LlmRouting => self.route_by_llm(grove, task).await,
};
match result {
Ok(decision) => Ok(decision),
Err(e) => {
warn!(
"Primary routing strategy failed: {}. Attempting fallback.",
e
);
if let Some(ref fallback_tree_name) = grove.node.config.fallback_tree
&& let Some(fallback_tree) = grove
.node
.trees
.iter()
.find(|t| &t.name == fallback_tree_name)
{
info!("Using fallback tree: {}", fallback_tree_name);
return self.select_from_tree(fallback_tree, task, "fallback");
}
if let Some(first_tree) = grove.node.trees.first() {
warn!(
"Using default fallback: first agent in tree '{}'",
first_tree.name
);
return self.select_from_tree(first_tree, task, "default_fallback");
}
Err(BattalionError::RoutingError(
"No agents available for routing".to_string(),
))
}
}
}
fn route_by_keywords(
&self,
grove: &Grove,
task: &str,
) -> Result<RoutingDecision, BattalionError> {
debug!("Routing by keyword matching");
let task_lower = task.to_lowercase();
let task_tokens: Vec<String> = task_lower
.split_whitespace()
.map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
.collect();
let mut best_score = 0;
let mut best_tree: Option<&Tree> = None;
let mut best_agent: Option<&TreeAgent> = None;
for tree in &grove.node.trees {
for agent in &tree.agents {
let score = self.calculate_keyword_score(&task_tokens, agent);
if best_agent.is_none() {
best_tree = Some(tree);
best_agent = Some(agent);
}
if score > best_score {
best_score = score;
best_tree = Some(tree);
best_agent = Some(agent);
}
}
}
if let (Some(tree), Some(agent)) = (best_tree, best_agent) {
let confidence = if task_tokens.is_empty() {
0.5 } else {
(best_score as f32 / task_tokens.len() as f32).min(1.0)
};
Ok(RoutingDecision {
selected_tree: tree.name.clone(),
selected_agent: agent.paladin_id.clone(),
confidence,
reasoning: format!(
"Matched {} keywords from agent expertise: {:?}",
best_score, agent.expertise_keywords
),
})
} else {
Err(BattalionError::RoutingError(
"No agents found for keyword matching".to_string(),
))
}
}
fn calculate_keyword_score(&self, task_tokens: &[String], agent: &TreeAgent) -> usize {
let agent_keywords: Vec<String> = agent
.expertise_keywords
.iter()
.map(|k| k.to_lowercase())
.collect();
task_tokens
.iter()
.filter(|token| agent_keywords.contains(token))
.count()
}
async fn route_by_semantic_similarity(
&self,
grove: &Grove,
task: &str,
) -> Result<RoutingDecision, BattalionError> {
debug!("Routing by semantic similarity");
let embedding_port = self.embedding_port.as_ref().ok_or_else(|| {
BattalionError::RoutingError(
"EmbeddingPort required for SemanticSimilarity routing".to_string(),
)
})?;
let task_embedding = embedding_port
.embed_text(task)
.await
.map_err(|e| BattalionError::RoutingError(format!("Failed to embed task: {}", e)))?;
let threshold = grove.node.config.similarity_threshold;
let mut best_similarity = threshold; let mut best_tree: Option<&Tree> = None;
let mut best_agent: Option<&TreeAgent> = None;
for tree in &grove.node.trees {
for agent in &tree.agents {
if let Some(ref agent_embedding) = agent.expertise_embedding {
let similarity =
self.cosine_similarity(&task_embedding.vector, agent_embedding);
if similarity > best_similarity {
best_similarity = similarity;
best_tree = Some(tree);
best_agent = Some(agent);
}
}
}
}
if let (Some(tree), Some(agent)) = (best_tree, best_agent) {
Ok(RoutingDecision {
selected_tree: tree.name.clone(),
selected_agent: agent.paladin_id.clone(),
confidence: best_similarity,
reasoning: format!(
"Semantic similarity score: {:.3} (threshold: {:.3})",
best_similarity, threshold
),
})
} else {
Err(BattalionError::RoutingError(format!(
"No agents found with similarity above threshold {:.3}",
threshold
)))
}
}
fn cosine_similarity(&self, vec_a: &[f32], vec_b: &[f32]) -> f32 {
if vec_a.len() != vec_b.len() {
warn!(
"Dimension mismatch in cosine similarity: {} vs {}",
vec_a.len(),
vec_b.len()
);
return 0.0;
}
let dot_product: f32 = vec_a.iter().zip(vec_b.iter()).map(|(a, b)| a * b).sum();
let magnitude_a: f32 = vec_a.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_b: f32 = vec_b.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude_a == 0.0 || magnitude_b == 0.0 {
return 0.0;
}
(dot_product / (magnitude_a * magnitude_b)).clamp(0.0, 1.0)
}
async fn route_by_llm(
&self,
grove: &Grove,
task: &str,
) -> Result<RoutingDecision, BattalionError> {
debug!("Routing by LLM analysis");
let llm_port = self.llm_port.as_ref().ok_or_else(|| {
BattalionError::RoutingError(
"LLM port not configured for LLM-based routing".to_string(),
)
})?;
let mut prompt = format!(
"You are a task routing system. Given the following task and available specialized agents, \
select the most appropriate agent and explain your reasoning.\n\n\
Task: {}\n\n\
Available Agents:\n",
task
);
for tree in &grove.node.trees {
prompt.push_str(&format!("\nTree: {}\n", tree.name));
for agent in &tree.agents {
prompt.push_str(&format!(
" - Agent ID: {}\n Expertise: {}\n",
agent.paladin_id,
agent.expertise_keywords.join(", ")
));
}
}
prompt.push_str(
"\n\nRespond in JSON format with the following structure:\n\
{\n \
\"tree_name\": \"selected tree name\",\n \
\"agent_id\": \"selected agent ID\",\n \
\"confidence\": 0.85,\n \
\"reasoning\": \"explanation for selection\"\n\
}\n",
);
debug!("LLM routing prompt prepared: {} characters", prompt.len());
let user_prompt = UserPrompt {
query: prompt,
context: None,
};
let prompt_item = PromptItem::new(PromptType::User(user_prompt))
.map_err(|e| BattalionError::RoutingError(format!("Failed to create prompt: {}", e)))?;
let llm_request = LlmRequest {
id: uuid::Uuid::new_v4(),
model: "gpt-4".to_string(), prompt: prompt_item,
attachments: vec![],
stream: false,
metadata: HashMap::new(),
};
let llm_response = llm_port.generate(llm_request).await.map_err(|e| {
let msg = format!("LLM call failed: {}", e);
warn!("{}", msg);
e });
let llm_response = match llm_response {
Ok(resp) => resp,
Err(_) => {
return self.handle_routing_failure(grove, task, "LLM call failed");
}
};
debug!(
"LLM response received: {} characters",
llm_response.content.len()
);
let routing_response: RoutingResponse = match serde_json::from_str(&llm_response.content) {
Ok(resp) => resp,
Err(e) => {
warn!("Failed to parse LLM JSON response: {}", e);
return self.handle_routing_failure(
grove,
task,
&format!("Failed to parse JSON: {}", e),
);
}
};
debug!(
"Parsed routing response: agent={}, confidence={}",
routing_response.agent_id, routing_response.confidence
);
if routing_response.confidence < grove.node.config.min_confidence {
warn!(
"Confidence {} below threshold {}",
routing_response.confidence, grove.node.config.min_confidence
);
return self.handle_routing_failure(
grove,
task,
&format!(
"Confidence {} below threshold {}",
routing_response.confidence, grove.node.config.min_confidence
),
);
}
let agent_exists = grove.node.trees.iter().any(|tree| {
tree.agents
.iter()
.any(|agent| agent.paladin_id == routing_response.agent_id)
});
if !agent_exists {
warn!(
"Unknown agent_id in LLM response: {}",
routing_response.agent_id
);
return self.handle_routing_failure(
grove,
task,
&format!("Unknown agent_id: {}", routing_response.agent_id),
);
}
info!(
"LLM routing successful: {} (confidence: {})",
routing_response.agent_id, routing_response.confidence
);
Ok(RoutingDecision {
selected_tree: routing_response.tree_name,
selected_agent: routing_response.agent_id,
confidence: routing_response.confidence,
reasoning: routing_response.reasoning,
})
}
fn handle_routing_failure(
&self,
grove: &Grove,
task: &str,
reason: &str,
) -> Result<RoutingDecision, BattalionError> {
match grove.node.config.routing_fallback.as_str() {
"keyword" => {
warn!("Falling back to keyword matching: {}", reason);
self.route_by_keywords(grove, task).map(|mut decision| {
decision.reasoning = format!(
"LLM routing failed ({}), fell back to keyword matching: {}",
reason, decision.reasoning
);
decision
})
}
"error" => Err(BattalionError::RoutingError(format!(
"LLM routing failed: {}",
reason
))),
other => {
warn!(
"Unknown routing_fallback value '{}', treating as 'error'",
other
);
Err(BattalionError::RoutingError(format!(
"LLM routing failed: {}",
reason
)))
}
}
}
fn select_from_tree(
&self,
tree: &Tree,
task: &str,
reason: &str,
) -> Result<RoutingDecision, BattalionError> {
tree.agents.first().map_or_else(
|| {
Err(BattalionError::RoutingError(format!(
"Tree '{}' has no agents",
tree.name
)))
},
|agent| {
Ok(RoutingDecision {
selected_tree: tree.name.clone(),
selected_agent: agent.paladin_id.clone(),
confidence: 0.5,
reasoning: format!("Using {} strategy for task: {}", reason, task),
})
},
)
}
async fn execute_agent(
&self,
paladin: &paladin_core::platform::container::paladin::Paladin,
task: &str,
) -> Result<String, BattalionError> {
debug!(
"Executing agent '{}' with task: {}",
paladin.node.name, task
);
let result = self
.paladin_port
.execute(paladin, task)
.await
.map_err(|e| {
BattalionError::ExecutionError(format!(
"Paladin '{}' execution failed: {}",
paladin.node.name, e
))
})?;
Ok(result.output)
}
}
#[cfg(test)]
mod tests {
use super::*;
use paladin_core::platform::container::battalion::grove::{GroveBuilder, Tree, TreeAgent};
use paladin_ports::output::embedding_port::Embedding;
struct MockEmbeddingPort;
#[async_trait::async_trait]
impl EmbeddingPort for MockEmbeddingPort {
async fn embed_text(
&self,
text: &str,
) -> Result<Embedding, paladin_ports::output::embedding_port::EmbeddingError> {
let vector = vec![text.len() as f32 / 100.0; 128];
Ok(Embedding {
vector,
model: "mock-model".to_string(),
dimension: 128,
token_count: Some(text.split_whitespace().count() as u32),
})
}
async fn embed_batch(
&self,
texts: &[&str],
) -> Result<Vec<Embedding>, paladin_ports::output::embedding_port::EmbeddingError> {
let mut embeddings = Vec::new();
for text in texts {
embeddings.push(self.embed_text(text).await?);
}
Ok(embeddings)
}
fn dimension(&self) -> usize {
128
}
fn model_name(&self) -> &str {
"mock-model"
}
}
#[test]
fn test_calculate_keyword_score() {
let service = create_test_service();
let agent =
TreeAgent::new("test_agent").with_keywords(vec!["rust", "backend", "api", "database"]);
let task_tokens = vec!["rust".to_string(), "api".to_string(), "testing".to_string()];
let score = service.calculate_keyword_score(&task_tokens, &agent);
assert_eq!(score, 2); }
#[test]
fn test_calculate_keyword_score_case_insensitive() {
let service = create_test_service();
let agent = TreeAgent::new("test_agent").with_keywords(vec!["Rust", "Backend", "API"]);
let task_tokens = vec!["rust".to_string(), "api".to_string()];
let score = service.calculate_keyword_score(&task_tokens, &agent);
assert_eq!(score, 2);
}
#[test]
fn test_calculate_keyword_score_no_match() {
let service = create_test_service();
let agent = TreeAgent::new("test_agent").with_keywords(vec!["python", "django"]);
let task_tokens = vec!["rust".to_string(), "api".to_string()];
let score = service.calculate_keyword_score(&task_tokens, &agent);
assert_eq!(score, 0);
}
#[test]
fn test_cosine_similarity_identical() {
let service = create_test_service();
let vec_a = vec![1.0, 2.0, 3.0];
let vec_b = vec![1.0, 2.0, 3.0];
let similarity = service.cosine_similarity(&vec_a, &vec_b);
assert!((similarity - 1.0).abs() < 0.001);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let service = create_test_service();
let vec_a = vec![1.0, 0.0, 0.0];
let vec_b = vec![0.0, 1.0, 0.0];
let similarity = service.cosine_similarity(&vec_a, &vec_b);
assert!((similarity - 0.0).abs() < 0.001);
}
#[test]
fn test_cosine_similarity_opposite() {
let service = create_test_service();
let vec_a = vec![1.0, 2.0, 3.0];
let vec_b = vec![-1.0, -2.0, -3.0];
let similarity = service.cosine_similarity(&vec_a, &vec_b);
assert!((similarity - 0.0).abs() < 0.001);
}
#[test]
fn test_cosine_similarity_dimension_mismatch() {
let service = create_test_service();
let vec_a = vec![1.0, 2.0];
let vec_b = vec![1.0, 2.0, 3.0];
let similarity = service.cosine_similarity(&vec_a, &vec_b);
assert_eq!(similarity, 0.0);
}
#[test]
fn test_route_by_keywords_basic() {
let service = create_test_service();
let grove = create_test_grove();
let result = service.route_by_keywords(&grove, "rust backend api development");
assert!(result.is_ok());
let decision = result.unwrap();
assert_eq!(decision.selected_agent, "backend_expert");
assert!(decision.confidence > 0.0);
}
#[test]
fn test_route_by_keywords_no_match() {
let service = create_test_service();
let grove = create_test_grove();
let result = service.route_by_keywords(&grove, "quantum physics simulation");
assert!(result.is_ok());
let decision = result.unwrap();
assert!(decision.confidence <= 0.5);
}
#[tokio::test]
async fn test_grove_resolves_routed_agent() {
use crate::in_memory_registry::HashMapPaladinRegistry;
use paladin_core::base::entity::node::Node;
use paladin_core::platform::container::paladin::{MaxLoops, PaladinData, PaladinStatus};
use paladin_ports::output::paladin_port::PaladinResult;
use paladin_ports::output::paladin_registry::PaladinRegistry;
let backend_paladin = Node::new(
PaladinData {
system_prompt: "Backend expert".to_string(),
name: "backend_expert".to_string(),
user_name: "User".to_string(),
model: "gpt-4".to_string(),
temperature: 0.7,
max_loops: MaxLoops::Fixed(3),
stop_words: vec![],
status: PaladinStatus::Idle,
vision_enabled: false,
..Default::default()
},
Some("backend_expert".to_string()),
);
let registry = HashMapPaladinRegistry::new();
registry
.register("backend_expert".to_string(), Arc::new(backend_paladin))
.expect("Should register backend_expert");
struct ExecutingMockPort;
#[async_trait::async_trait]
impl PaladinPort for ExecutingMockPort {
async fn execute(
&self,
paladin: &paladin_core::platform::container::paladin::Paladin,
input: &str,
) -> Result<PaladinResult, paladin_core::platform::container::paladin_error::PaladinError>
{
Ok(PaladinResult {
output: format!("[{}] Analyzed: {}", paladin.node.name, input),
token_count: 100,
execution_time_ms: 10,
loop_count: 1,
..Default::default()
})
}
async fn execute_stream(
&self,
_paladin: &paladin_core::platform::container::paladin::Paladin,
_input: &str,
) -> Result<
paladin_ports::output::paladin_port::PaladinStream,
paladin_core::platform::container::paladin_error::PaladinError,
> {
let (_tx, rx) = tokio::sync::mpsc::channel(1);
Ok(rx)
}
fn validate(
&self,
_paladin: &paladin_core::platform::container::paladin::Paladin,
) -> Result<(), paladin_core::platform::container::paladin_error::PaladinError>
{
Ok(())
}
}
let service =
GroveExecutionService::new(Arc::new(ExecutingMockPort), None, None, Arc::new(registry));
let grove = create_test_grove();
let result = service.execute(&grove, "rust backend task").await;
assert!(result.is_ok(), "Grove should resolve and execute agent");
let grove_result = result.unwrap();
assert_eq!(
grove_result.routing_decision.selected_agent,
"backend_expert"
);
assert!(grove_result.execution_result.contains("backend_expert"));
}
#[tokio::test]
async fn test_grove_paladin_not_found_error() {
use crate::in_memory_registry::HashMapPaladinRegistry;
let registry = HashMapPaladinRegistry::new();
let service =
GroveExecutionService::new(Arc::new(MockPaladinPort), None, None, Arc::new(registry));
let grove = create_test_grove();
let result = service.execute(&grove, "rust backend task").await;
assert!(result.is_err(), "Should return error when agent not found");
match result {
Err(BattalionError::PaladinNotFound(msg)) => {
assert!(msg.contains("backend_expert"));
}
_ => panic!("Expected PaladinNotFound error"),
}
}
fn create_test_service() -> GroveExecutionService {
use crate::in_memory_registry::HashMapPaladinRegistry;
let registry = HashMapPaladinRegistry::new();
GroveExecutionService::new(
Arc::new(MockPaladinPort),
Some(Arc::new(MockEmbeddingPort)),
Some(Arc::new(MockLlmPort)),
Arc::new(registry),
)
}
fn create_test_grove() -> Grove {
GroveBuilder::new()
.name("Test Grove")
.add_tree(
Tree::new("engineering")
.add_agent(
TreeAgent::new("backend_expert")
.with_keywords(vec!["rust", "backend", "api", "database"]),
)
.add_agent(TreeAgent::new("frontend_expert").with_keywords(vec![
"react",
"ui",
"css",
"javascript",
])),
)
.build()
.unwrap()
}
struct MockPaladinPort;
struct MockLlmPort;
#[async_trait::async_trait]
impl PaladinPort for MockPaladinPort {
async fn execute(
&self,
_paladin: &paladin_core::platform::container::paladin::Paladin,
_input: &str,
) -> Result<
paladin_ports::output::paladin_port::PaladinResult,
paladin_core::platform::container::paladin_error::PaladinError,
> {
unimplemented!("Mock not needed for these tests")
}
async fn execute_stream(
&self,
_paladin: &paladin_core::platform::container::paladin::Paladin,
_input: &str,
) -> Result<
paladin_ports::output::paladin_port::PaladinStream,
paladin_core::platform::container::paladin_error::PaladinError,
> {
unimplemented!("Mock not needed for these tests")
}
fn validate(
&self,
_paladin: &paladin_core::platform::container::paladin::Paladin,
) -> Result<(), paladin_core::platform::container::paladin_error::PaladinError> {
Ok(())
}
}
#[async_trait::async_trait]
impl LlmPort for MockLlmPort {
async fn generate(
&self,
_request: paladin_ports::output::llm_port::LlmRequest,
) -> Result<
paladin_ports::output::llm_port::LlmResponse,
paladin_ports::output::llm_port::LlmError,
> {
unimplemented!("Mock not needed for these tests")
}
async fn generate_stream(
&self,
_request: paladin_ports::output::llm_port::LlmRequest,
) -> Result<
Box<
dyn futures::Stream<
Item = Result<
paladin_ports::output::llm_port::StreamingResponse,
paladin_ports::output::llm_port::LlmError,
>,
> + Send,
>,
paladin_ports::output::llm_port::LlmError,
> {
unimplemented!("Mock not needed for these tests")
}
async fn validate_model(
&self,
_model: &str,
) -> Result<bool, paladin_ports::output::llm_port::LlmError> {
Ok(true)
}
async fn get_available_models(
&self,
) -> Result<Vec<String>, paladin_ports::output::llm_port::LlmError> {
Ok(vec!["mock-model".to_string()])
}
fn get_provider_name(&self) -> &'static str {
"mock"
}
fn get_capabilities(&self) -> paladin_ports::output::llm_port::ProviderCapabilities {
use paladin_ports::output::llm_port::ProviderCapabilities;
ProviderCapabilities {
supports_streaming: false,
supports_tool_calling: false,
supports_function_calling: false,
supports_vision: false,
supports_embeddings: false,
max_context_tokens: Some(4096),
supports_system_messages: true,
}
}
}
#[tokio::test]
async fn test_route_with_llm_successful() {
struct SuccessfulLlmMock;
#[async_trait::async_trait]
impl LlmPort for SuccessfulLlmMock {
async fn generate(
&self,
_request: paladin_ports::output::llm_port::LlmRequest,
) -> Result<
paladin_ports::output::llm_port::LlmResponse,
paladin_ports::output::llm_port::LlmError,
> {
let response_json = r#"{
"tree_name": "engineering",
"agent_id": "backend_expert",
"confidence": 0.85,
"reasoning": "Task mentions rust and backend, which are backend expert's core skills"
}"#;
Ok(paladin_ports::output::llm_port::LlmResponse {
id: uuid::Uuid::new_v4(),
request_id: uuid::Uuid::new_v4(),
model: "mock-model".to_string(),
content: response_json.to_string(),
finish_reason: paladin_ports::output::llm_port::FinishReason::Stop,
usage: paladin_ports::output::llm_port::TokenUsage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
},
created_at: chrono::Utc::now(),
metadata: std::collections::HashMap::new(),
function_call: None,
})
}
async fn generate_stream(
&self,
_request: paladin_ports::output::llm_port::LlmRequest,
) -> Result<
Box<
dyn futures::Stream<
Item = Result<
paladin_ports::output::llm_port::StreamingResponse,
paladin_ports::output::llm_port::LlmError,
>,
> + Send,
>,
paladin_ports::output::llm_port::LlmError,
> {
unimplemented!()
}
async fn validate_model(
&self,
_model: &str,
) -> Result<bool, paladin_ports::output::llm_port::LlmError> {
Ok(true)
}
async fn get_available_models(
&self,
) -> Result<Vec<String>, paladin_ports::output::llm_port::LlmError> {
Ok(vec!["mock-model".to_string()])
}
fn get_provider_name(&self) -> &'static str {
"mock"
}
fn get_capabilities(&self) -> paladin_ports::output::llm_port::ProviderCapabilities {
paladin_ports::output::llm_port::ProviderCapabilities::default()
}
}
use crate::in_memory_registry::HashMapPaladinRegistry;
let registry = HashMapPaladinRegistry::new();
let service = GroveExecutionService::new(
Arc::new(MockPaladinPort),
None,
Some(Arc::new(SuccessfulLlmMock)),
Arc::new(registry),
);
let grove = create_test_grove();
let result = service
.route_by_llm(&grove, "rust backend development task")
.await;
assert!(
result.is_ok(),
"Should route successfully with high confidence"
);
let decision = result.unwrap();
assert_eq!(decision.selected_tree, "engineering");
assert_eq!(decision.selected_agent, "backend_expert");
assert!(decision.confidence >= 0.85);
assert!(decision.reasoning.contains("rust"));
}
#[tokio::test]
async fn test_route_with_llm_low_confidence() {
struct LowConfidenceLlmMock;
#[async_trait::async_trait]
impl LlmPort for LowConfidenceLlmMock {
async fn generate(
&self,
_request: paladin_ports::output::llm_port::LlmRequest,
) -> Result<
paladin_ports::output::llm_port::LlmResponse,
paladin_ports::output::llm_port::LlmError,
> {
let response_json = r#"{
"tree_name": "engineering",
"agent_id": "backend_expert",
"confidence": 0.3,
"reasoning": "Unclear task, best guess is backend"
}"#;
Ok(paladin_ports::output::llm_port::LlmResponse {
id: uuid::Uuid::new_v4(),
request_id: uuid::Uuid::new_v4(),
model: "mock-model".to_string(),
content: response_json.to_string(),
finish_reason: paladin_ports::output::llm_port::FinishReason::Stop,
usage: paladin_ports::output::llm_port::TokenUsage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
},
created_at: chrono::Utc::now(),
metadata: std::collections::HashMap::new(),
function_call: None,
})
}
async fn generate_stream(
&self,
_request: paladin_ports::output::llm_port::LlmRequest,
) -> Result<
Box<
dyn futures::Stream<
Item = Result<
paladin_ports::output::llm_port::StreamingResponse,
paladin_ports::output::llm_port::LlmError,
>,
> + Send,
>,
paladin_ports::output::llm_port::LlmError,
> {
unimplemented!()
}
async fn validate_model(
&self,
_model: &str,
) -> Result<bool, paladin_ports::output::llm_port::LlmError> {
Ok(true)
}
async fn get_available_models(
&self,
) -> Result<Vec<String>, paladin_ports::output::llm_port::LlmError> {
Ok(vec!["mock-model".to_string()])
}
fn get_provider_name(&self) -> &'static str {
"mock"
}
fn get_capabilities(&self) -> paladin_ports::output::llm_port::ProviderCapabilities {
paladin_ports::output::llm_port::ProviderCapabilities::default()
}
}
use crate::in_memory_registry::HashMapPaladinRegistry;
let registry = HashMapPaladinRegistry::new();
let mut grove = create_test_grove();
grove.node.config.routing_fallback = "error".to_string();
grove.node.config.min_confidence = 0.5;
let service = GroveExecutionService::new(
Arc::new(MockPaladinPort),
None,
Some(Arc::new(LowConfidenceLlmMock)),
Arc::new(registry),
);
let result = service.route_by_llm(&grove, "ambiguous task").await;
assert!(
result.is_err(),
"Should return error when confidence below threshold and fallback is 'error'"
);
match result {
Err(BattalionError::RoutingError(msg)) => {
assert!(msg.contains("confidence") || msg.contains("threshold"));
}
_ => panic!("Expected RoutingError for low confidence with error fallback"),
}
}
#[tokio::test]
async fn test_route_with_llm_invalid_json() {
struct InvalidJsonLlmMock;
#[async_trait::async_trait]
impl LlmPort for InvalidJsonLlmMock {
async fn generate(
&self,
_request: paladin_ports::output::llm_port::LlmRequest,
) -> Result<
paladin_ports::output::llm_port::LlmResponse,
paladin_ports::output::llm_port::LlmError,
> {
let response_json = "This is not JSON at all!";
Ok(paladin_ports::output::llm_port::LlmResponse {
id: uuid::Uuid::new_v4(),
request_id: uuid::Uuid::new_v4(),
model: "mock-model".to_string(),
content: response_json.to_string(),
finish_reason: paladin_ports::output::llm_port::FinishReason::Stop,
usage: paladin_ports::output::llm_port::TokenUsage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
},
created_at: chrono::Utc::now(),
metadata: std::collections::HashMap::new(),
function_call: None,
})
}
async fn generate_stream(
&self,
_request: paladin_ports::output::llm_port::LlmRequest,
) -> Result<
Box<
dyn futures::Stream<
Item = Result<
paladin_ports::output::llm_port::StreamingResponse,
paladin_ports::output::llm_port::LlmError,
>,
> + Send,
>,
paladin_ports::output::llm_port::LlmError,
> {
unimplemented!()
}
async fn validate_model(
&self,
_model: &str,
) -> Result<bool, paladin_ports::output::llm_port::LlmError> {
Ok(true)
}
async fn get_available_models(
&self,
) -> Result<Vec<String>, paladin_ports::output::llm_port::LlmError> {
Ok(vec!["mock-model".to_string()])
}
fn get_provider_name(&self) -> &'static str {
"mock"
}
fn get_capabilities(&self) -> paladin_ports::output::llm_port::ProviderCapabilities {
paladin_ports::output::llm_port::ProviderCapabilities::default()
}
}
use crate::in_memory_registry::HashMapPaladinRegistry;
let registry = HashMapPaladinRegistry::new();
let mut grove = create_test_grove();
grove.node.config.routing_fallback = "error".to_string();
let service = GroveExecutionService::new(
Arc::new(MockPaladinPort),
None,
Some(Arc::new(InvalidJsonLlmMock)),
Arc::new(registry),
);
let result = service.route_by_llm(&grove, "test task").await;
assert!(result.is_err(), "Should return error for invalid JSON");
match result {
Err(BattalionError::RoutingError(_)) => {} _ => panic!("Expected RoutingError for invalid JSON"),
}
}
#[tokio::test]
async fn test_route_with_llm_fallback_to_keyword() {
struct LowConfidenceFallbackMock;
#[async_trait::async_trait]
impl LlmPort for LowConfidenceFallbackMock {
async fn generate(
&self,
_request: paladin_ports::output::llm_port::LlmRequest,
) -> Result<
paladin_ports::output::llm_port::LlmResponse,
paladin_ports::output::llm_port::LlmError,
> {
let response_json = r#"{
"tree_name": "engineering",
"agent_id": "backend_expert",
"confidence": 0.2,
"reasoning": "Very uncertain"
}"#;
Ok(paladin_ports::output::llm_port::LlmResponse {
id: uuid::Uuid::new_v4(),
request_id: uuid::Uuid::new_v4(),
model: "mock-model".to_string(),
content: response_json.to_string(),
finish_reason: paladin_ports::output::llm_port::FinishReason::Stop,
usage: paladin_ports::output::llm_port::TokenUsage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
},
created_at: chrono::Utc::now(),
metadata: std::collections::HashMap::new(),
function_call: None,
})
}
async fn generate_stream(
&self,
_request: paladin_ports::output::llm_port::LlmRequest,
) -> Result<
Box<
dyn futures::Stream<
Item = Result<
paladin_ports::output::llm_port::StreamingResponse,
paladin_ports::output::llm_port::LlmError,
>,
> + Send,
>,
paladin_ports::output::llm_port::LlmError,
> {
unimplemented!()
}
async fn validate_model(
&self,
_model: &str,
) -> Result<bool, paladin_ports::output::llm_port::LlmError> {
Ok(true)
}
async fn get_available_models(
&self,
) -> Result<Vec<String>, paladin_ports::output::llm_port::LlmError> {
Ok(vec!["mock-model".to_string()])
}
fn get_provider_name(&self) -> &'static str {
"mock"
}
fn get_capabilities(&self) -> paladin_ports::output::llm_port::ProviderCapabilities {
paladin_ports::output::llm_port::ProviderCapabilities::default()
}
}
use crate::in_memory_registry::HashMapPaladinRegistry;
let registry = HashMapPaladinRegistry::new();
let mut grove = create_test_grove();
grove.node.config.routing_fallback = "keyword".to_string();
grove.node.config.min_confidence = 0.5;
let service = GroveExecutionService::new(
Arc::new(MockPaladinPort),
None,
Some(Arc::new(LowConfidenceFallbackMock)),
Arc::new(registry),
);
let result = service
.route_by_llm(&grove, "rust backend development")
.await;
assert!(
result.is_ok(),
"Should fallback to keyword matching when confidence below threshold"
);
let decision = result.unwrap();
assert_eq!(decision.selected_agent, "backend_expert");
assert!(decision.reasoning.contains("keyword") || decision.reasoning.contains("fallback"));
}
}