use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use crate::config::ResourceConstraints;
use super::decision::{LLMProvider, MonitoringLevel, SecurityLevel};
use super::error::TaskType;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingConfig {
pub enabled: bool,
pub policy: RoutingPolicyConfig,
pub classification: TaskClassificationConfig,
pub llm_providers: HashMap<String, LLMProviderConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingPolicyConfig {
pub global_settings: GlobalRoutingSettings,
pub rules: Vec<RoutingRule>,
pub default_action: RouteAction,
pub fallback_config: FallbackConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GlobalRoutingSettings {
pub slm_routing_enabled: bool,
pub always_audit: bool,
pub global_confidence_threshold: f64,
pub max_slm_retries: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingRule {
pub name: String,
pub priority: u32,
pub conditions: RoutingConditions,
pub action: RouteAction,
#[serde(default)]
pub action_extension: Option<ActionExtension>,
pub override_allowed: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingConditions {
pub task_types: Option<Vec<TaskType>>,
pub agent_ids: Option<Vec<String>>,
pub resource_constraints: Option<ResourceConstraints>,
pub security_level: Option<SecurityLevel>,
pub custom_conditions: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RouteAction {
UseSLM {
model_preference: ModelPreference,
monitoring_level: MonitoringLevel,
fallback_on_low_confidence: bool,
confidence_threshold: Option<f64>,
},
UseLLM {
provider: LLMProvider,
model: Option<String>,
},
Deny {
reason: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelPreference {
Specialist,
Generalist,
Specific { model_id: String },
BestAvailable,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActionExtension {
pub sandbox: Option<String>,
}
impl Default for ActionExtension {
fn default() -> Self {
Self {
sandbox: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FallbackConfig {
pub enabled: bool,
pub max_attempts: u32,
#[serde(with = "humantime_serde")]
pub timeout: Duration,
pub providers: HashMap<String, LLMProviderConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMProviderConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
#[serde(with = "humantime_serde")]
pub timeout: Duration,
pub max_retries: u32,
pub rate_limit: Option<RateLimitConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub requests_per_minute: u32,
pub tokens_per_minute: Option<u32>,
pub burst_allowance: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskClassificationConfig {
pub enabled: bool,
pub patterns: HashMap<TaskType, ClassificationPattern>,
pub confidence_threshold: f64,
pub default_task_type: TaskType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassificationPattern {
pub keywords: Vec<String>,
pub patterns: Vec<String>,
pub weight: f64,
}
impl Default for RoutingConfig {
fn default() -> Self {
let mut llm_providers = HashMap::new();
llm_providers.insert("openai".to_string(), LLMProviderConfig {
api_key_env: "OPENAI_API_KEY".to_string(),
base_url: "https://api.openai.com/v1".to_string(),
default_model: "gpt-3.5-turbo".to_string(),
timeout: Duration::from_secs(60),
max_retries: 3,
rate_limit: Some(RateLimitConfig {
requests_per_minute: 60,
tokens_per_minute: Some(10000),
burst_allowance: Some(10),
}),
});
llm_providers.insert("anthropic".to_string(), LLMProviderConfig {
api_key_env: "ANTHROPIC_API_KEY".to_string(),
base_url: "https://api.anthropic.com".to_string(),
default_model: "claude-3-sonnet-20240229".to_string(),
timeout: Duration::from_secs(60),
max_retries: 3,
rate_limit: Some(RateLimitConfig {
requests_per_minute: 60,
tokens_per_minute: Some(10000),
burst_allowance: Some(10),
}),
});
Self {
enabled: true,
policy: RoutingPolicyConfig::default(),
classification: TaskClassificationConfig::default(),
llm_providers,
}
}
}
impl Default for RoutingPolicyConfig {
fn default() -> Self {
Self {
global_settings: GlobalRoutingSettings::default(),
rules: Vec::new(),
default_action: RouteAction::UseLLM {
provider: LLMProvider::OpenAI { model: None },
model: Some("gpt-3.5-turbo".to_string()),
},
fallback_config: FallbackConfig::default(),
}
}
}
impl Default for GlobalRoutingSettings {
fn default() -> Self {
Self {
slm_routing_enabled: true,
always_audit: true,
global_confidence_threshold: 0.85,
max_slm_retries: 2,
}
}
}
impl Default for FallbackConfig {
fn default() -> Self {
let mut providers = HashMap::new();
providers.insert("primary".to_string(), LLMProviderConfig {
api_key_env: "OPENAI_API_KEY".to_string(),
base_url: "https://api.openai.com/v1".to_string(),
default_model: "gpt-3.5-turbo".to_string(),
timeout: Duration::from_secs(60),
max_retries: 3,
rate_limit: None,
});
Self {
enabled: true,
max_attempts: 3,
timeout: Duration::from_secs(30),
providers,
}
}
}
impl Default for TaskClassificationConfig {
fn default() -> Self {
let mut patterns = HashMap::new();
patterns.insert(TaskType::Intent, ClassificationPattern {
keywords: vec!["intent".to_string(), "intention".to_string(), "purpose".to_string()],
patterns: vec![r"what.*intent".to_string(), r"user.*wants".to_string()],
weight: 1.0,
});
patterns.insert(TaskType::CodeGeneration, ClassificationPattern {
keywords: vec!["code".to_string(), "function".to_string(), "implement".to_string(), "generate".to_string()],
patterns: vec![r"write.*code".to_string(), r"implement.*function".to_string()],
weight: 1.0,
});
patterns.insert(TaskType::Analysis, ClassificationPattern {
keywords: vec!["analyze".to_string(), "analysis".to_string(), "examine".to_string(), "review".to_string()],
patterns: vec![r"analyze.*data".to_string(), r"perform.*analysis".to_string()],
weight: 1.0,
});
Self {
enabled: true,
patterns,
confidence_threshold: 0.7,
default_task_type: TaskType::Custom("unknown".to_string()),
}
}
}
impl RoutingRule {
pub fn matches(&self, context: &super::decision::RoutingContext) -> bool {
if let Some(ref task_types) = self.conditions.task_types {
if !task_types.contains(&context.task_type) {
return false;
}
}
if let Some(ref agent_ids) = self.conditions.agent_ids {
if !agent_ids.contains(&context.agent_id.to_string()) {
return false;
}
}
if let Some(ref required_level) = self.conditions.security_level {
if context.agent_security_level < *required_level {
return false;
}
}
if let Some(ref rule_constraints) = self.conditions.resource_constraints {
if let Some(ref context_limits) = context.resource_limits {
if context_limits.max_memory_mb > rule_constraints.max_memory_mb {
return false;
}
}
}
if let Some(ref custom_conditions) = self.conditions.custom_conditions {
for condition_expr in custom_conditions {
if !self.evaluate_custom_condition(condition_expr, context) {
return false;
}
}
}
true
}
fn evaluate_custom_condition(&self, condition_expr: &str, context: &super::decision::RoutingContext) -> bool {
if condition_expr.contains("agent_id") {
if let Some(expected_id) = condition_expr.strip_prefix("agent_id == ") {
let expected_id = expected_id.trim_matches('"');
return context.agent_id.to_string() == expected_id;
}
}
if condition_expr.contains("task_type") {
if let Some(expected_type) = condition_expr.strip_prefix("task_type == ") {
let expected_type = expected_type.trim_matches('"');
return format!("{:?}", context.task_type).to_lowercase().contains(&expected_type.to_lowercase());
}
}
if condition_expr.contains("security_level") {
if condition_expr.contains(">=") {
if let Some(level_str) = condition_expr.strip_prefix("security_level >= ") {
if let Ok(required_level) = level_str.trim().parse::<u8>() {
let current_level = match context.agent_security_level {
SecurityLevel::Low => 1,
SecurityLevel::Medium => 2,
SecurityLevel::High => 3,
SecurityLevel::Critical => 4,
};
return current_level >= required_level;
}
}
}
}
if condition_expr.contains("memory_limit") {
if let Some(ref resource_limits) = context.resource_limits {
if condition_expr.contains("<=") {
if let Some(limit_str) = condition_expr.strip_prefix("memory_limit <= ") {
if let Ok(max_memory) = limit_str.trim().parse::<u64>() {
return resource_limits.max_memory_mb <= max_memory;
}
}
}
}
}
if condition_expr == "true" {
return true;
}
if condition_expr == "false" {
return false;
}
tracing::warn!("Unrecognized custom condition: {}", condition_expr);
true
}
}
impl RoutingPolicyConfig {
pub fn validate(&self) -> Result<(), super::error::RoutingError> {
let mut prev_priority = u32::MAX;
for rule in &self.rules {
if rule.priority > prev_priority {
return Err(super::error::RoutingError::ConfigurationError {
key: "policy.rules".to_string(),
reason: "Rules must be ordered by priority (highest first)".to_string(),
});
}
prev_priority = rule.priority;
}
if self.global_settings.global_confidence_threshold < 0.0 ||
self.global_settings.global_confidence_threshold > 1.0 {
return Err(super::error::RoutingError::ConfigurationError {
key: "policy.global_settings.global_confidence_threshold".to_string(),
reason: "Confidence threshold must be between 0.0 and 1.0".to_string(),
});
}
Ok(())
}
}