use crate::spec_ai_collective::types::{Domain, InstanceId};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Capability {
pub domain: Domain,
pub proficiency: f32,
pub experience_count: u64,
pub success_rate: f32,
pub avg_duration_ms: Option<u64>,
pub last_updated: DateTime<Utc>,
}
impl Capability {
pub fn new(domain: Domain) -> Self {
Self {
domain,
proficiency: 0.0,
experience_count: 0,
success_rate: 0.0,
avg_duration_ms: None,
last_updated: Utc::now(),
}
}
pub fn update(&mut self, outcome: &TaskOutcome) {
self.experience_count += 1;
let success_value = match outcome {
TaskOutcome::Success { .. } => 1.0,
TaskOutcome::Partial { completion_ratio } => *completion_ratio,
TaskOutcome::Failure { .. } => 0.0,
};
let alpha = 0.1; self.success_rate = (1.0 - alpha) * self.success_rate + alpha * success_value;
if let TaskOutcome::Success { duration_ms, .. } = outcome {
self.avg_duration_ms = Some(match self.avg_duration_ms {
Some(avg) => ((avg as f64 * 0.9) + (*duration_ms as f64 * 0.1)) as u64,
None => *duration_ms,
});
}
self.proficiency = self.calculate_proficiency();
self.last_updated = Utc::now();
}
fn calculate_proficiency(&self) -> f32 {
let experience_factor = (1.0 - (-0.01 * self.experience_count as f32).exp()).min(1.0);
(experience_factor * self.success_rate).min(1.0)
}
pub fn is_specialist(&self) -> bool {
self.proficiency > 0.8
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TaskOutcome {
Success {
confidence: f32,
duration_ms: u64,
},
Failure {
error_category: String,
recoverable: bool,
},
Partial {
completion_ratio: f32,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearningEvent {
pub task_type: String,
pub outcome: TaskOutcome,
pub strategy_used: String,
pub context_embedding: Option<Vec<f32>>,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpertiseProfile {
pub instance_id: InstanceId,
pub capabilities: HashMap<Domain, Capability>,
pub specializations: Vec<Domain>,
pub learning_history: Vec<LearningEvent>,
#[serde(default = "default_max_history")]
pub max_history: usize,
}
fn default_max_history() -> usize {
100
}
impl ExpertiseProfile {
pub fn new(instance_id: InstanceId) -> Self {
Self {
instance_id,
capabilities: HashMap::new(),
specializations: Vec::new(),
learning_history: Vec::new(),
max_history: default_max_history(),
}
}
pub fn get_or_create_capability(&mut self, domain: &str) -> &mut Capability {
self.capabilities
.entry(domain.to_string())
.or_insert_with(|| Capability::new(domain.to_string()))
}
pub fn record_outcome(&mut self, domain: &str, outcome: TaskOutcome, strategy: String) {
let capability = self.get_or_create_capability(domain);
capability.update(&outcome);
let event = LearningEvent {
task_type: domain.to_string(),
outcome,
strategy_used: strategy,
context_embedding: None,
timestamp: Utc::now(),
};
self.learning_history.push(event);
while self.learning_history.len() > self.max_history {
self.learning_history.remove(0);
}
self.update_specializations();
}
fn update_specializations(&mut self) {
self.specializations = self
.capabilities
.iter()
.filter(|(_, cap)| cap.is_specialist())
.map(|(domain, _)| domain.clone())
.collect();
}
pub fn match_score(&self, required: &[String]) -> f32 {
if required.is_empty() {
return 1.0;
}
let total: f32 = required
.iter()
.map(|domain| {
self.capabilities
.get(domain)
.map(|c| c.proficiency)
.unwrap_or(0.0)
})
.sum();
total / required.len() as f32
}
}
#[derive(Debug)]
pub struct CapabilityTracker {
instance_id: InstanceId,
profile: ExpertiseProfile,
peers: HashMap<InstanceId, ExpertiseProfile>,
}
impl CapabilityTracker {
pub fn new(instance_id: InstanceId) -> Self {
Self {
instance_id: instance_id.clone(),
profile: ExpertiseProfile::new(instance_id),
peers: HashMap::new(),
}
}
pub fn instance_id(&self) -> &str {
&self.instance_id
}
pub fn profile(&self) -> &ExpertiseProfile {
&self.profile
}
pub fn profile_mut(&mut self) -> &mut ExpertiseProfile {
&mut self.profile
}
pub fn record_task_outcome(&mut self, domain: &str, outcome: TaskOutcome, strategy: String) {
self.profile.record_outcome(domain, outcome, strategy);
}
pub fn update_peer_profile(&mut self, profile: ExpertiseProfile) {
self.peers.insert(profile.instance_id.clone(), profile);
}
pub fn get_best_agent(
&self,
required_capabilities: &[String],
) -> Option<RoutingRecommendation> {
let mut best: Option<(String, f32)> = None;
let self_score = self.profile.match_score(required_capabilities);
if self_score > 0.0 {
best = Some((self.instance_id.clone(), self_score));
}
for (instance_id, profile) in &self.peers {
let score = profile.match_score(required_capabilities);
if let Some((_, best_score)) = &best {
if score > *best_score {
best = Some((instance_id.clone(), score));
}
} else if score > 0.0 {
best = Some((instance_id.clone(), score));
}
}
best.map(|(instance_id, score)| {
let is_self = instance_id == self.instance_id;
RoutingRecommendation {
instance_id,
score,
is_self,
}
})
}
pub fn get_capable_agents(
&self,
required_capabilities: &[String],
min_score: f32,
) -> Vec<RoutingRecommendation> {
let mut agents = Vec::new();
let self_score = self.profile.match_score(required_capabilities);
if self_score >= min_score {
agents.push(RoutingRecommendation {
instance_id: self.instance_id.clone(),
score: self_score,
is_self: true,
});
}
for (instance_id, profile) in &self.peers {
let score = profile.match_score(required_capabilities);
if score >= min_score {
agents.push(RoutingRecommendation {
instance_id: instance_id.clone(),
score,
is_self: false,
});
}
}
agents.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
agents
}
pub fn peers(&self) -> &HashMap<InstanceId, ExpertiseProfile> {
&self.peers
}
}
#[derive(Debug, Clone)]
pub struct RoutingRecommendation {
pub instance_id: InstanceId,
pub score: f32,
pub is_self: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_capability_update() {
let mut cap = Capability::new("code_review".to_string());
assert_eq!(cap.proficiency, 0.0);
for _ in 0..10 {
cap.update(&TaskOutcome::Success {
confidence: 0.9,
duration_ms: 1000,
});
}
assert!(cap.proficiency > 0.0);
assert!(cap.success_rate > 0.5);
assert_eq!(cap.experience_count, 10);
}
#[test]
fn test_expertise_profile_matching() {
let mut profile = ExpertiseProfile::new("agent-1".to_string());
for _ in 0..20 {
profile.record_outcome(
"code_review",
TaskOutcome::Success {
confidence: 0.9,
duration_ms: 1000,
},
"standard_review".to_string(),
);
}
let score = profile.match_score(&["code_review".to_string()]);
assert!(score > 0.0);
let score2 = profile.match_score(&["unknown_domain".to_string()]);
assert_eq!(score2, 0.0);
}
#[test]
fn test_capability_tracker_routing() {
let mut tracker = CapabilityTracker::new("agent-1".to_string());
for _ in 0..10 {
tracker.record_task_outcome(
"data_analysis",
TaskOutcome::Success {
confidence: 0.9,
duration_ms: 500,
},
"standard".to_string(),
);
}
let rec = tracker.get_best_agent(&["data_analysis".to_string()]);
assert!(rec.is_some());
assert!(rec.unwrap().is_self);
}
}