use crate::spec_ai_collective::capability::{CapabilityTracker, ExpertiseProfile};
use crate::spec_ai_collective::types::{Domain, InstanceId};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SpecializationStatus {
Learning,
Proficient,
Specialist,
Expert,
}
impl SpecializationStatus {
pub fn min_proficiency(&self) -> f32 {
match self {
SpecializationStatus::Learning => 0.0,
SpecializationStatus::Proficient => 0.5,
SpecializationStatus::Specialist => 0.8,
SpecializationStatus::Expert => 0.95,
}
}
pub fn routing_priority(&self) -> u32 {
match self {
SpecializationStatus::Learning => 1,
SpecializationStatus::Proficient => 2,
SpecializationStatus::Specialist => 3,
SpecializationStatus::Expert => 4,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Specialist {
pub instance_id: InstanceId,
pub domain: Domain,
pub status: SpecializationStatus,
pub proficiency: f32,
pub task_count: u64,
pub success_rate: f32,
pub detected_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub available: bool,
pub last_active: DateTime<Utc>,
}
impl Specialist {
pub fn new(instance_id: InstanceId, domain: Domain, proficiency: f32) -> Self {
let status = Self::status_for_proficiency(proficiency);
let now = Utc::now();
Self {
instance_id,
domain,
status,
proficiency,
task_count: 0,
success_rate: 0.0,
detected_at: now,
updated_at: now,
available: true,
last_active: now,
}
}
fn status_for_proficiency(proficiency: f32) -> SpecializationStatus {
if proficiency >= 0.95 {
SpecializationStatus::Expert
} else if proficiency >= 0.8 {
SpecializationStatus::Specialist
} else if proficiency >= 0.5 {
SpecializationStatus::Proficient
} else {
SpecializationStatus::Learning
}
}
pub fn update(&mut self, proficiency: f32, task_count: u64, success_rate: f32) {
self.proficiency = proficiency;
self.task_count = task_count;
self.success_rate = success_rate;
self.status = Self::status_for_proficiency(proficiency);
self.updated_at = Utc::now();
}
pub fn mark_active(&mut self) {
self.last_active = Utc::now();
self.available = true;
}
pub fn mark_unavailable(&mut self) {
self.available = false;
}
pub fn is_stale(&self, timeout: Duration) -> bool {
Utc::now() - self.last_active > timeout
}
}
#[derive(Debug)]
pub struct SpecializationEngine {
instance_id: InstanceId,
specialists: HashMap<Domain, Vec<Specialist>>,
my_specializations: HashMap<Domain, Specialist>,
specialist_threshold: f32,
stale_timeout: Duration,
max_per_domain: usize,
}
impl SpecializationEngine {
pub fn new(instance_id: InstanceId) -> Self {
Self {
instance_id,
specialists: HashMap::new(),
my_specializations: HashMap::new(),
specialist_threshold: 0.8,
stale_timeout: Duration::minutes(30),
max_per_domain: 10,
}
}
pub fn instance_id(&self) -> &str {
&self.instance_id
}
pub fn set_specialist_threshold(&mut self, threshold: f32) {
self.specialist_threshold = threshold.clamp(0.0, 1.0);
}
pub fn update_from_tracker(&mut self, tracker: &CapabilityTracker) {
let profile = tracker.profile();
self.update_from_profile(profile);
}
pub fn update_from_profile(&mut self, profile: &ExpertiseProfile) {
for (domain, capability) in &profile.capabilities {
if capability.proficiency >= self.specialist_threshold {
let entry = self
.my_specializations
.entry(domain.clone())
.or_insert_with(|| {
Specialist::new(
self.instance_id.clone(),
domain.clone(),
capability.proficiency,
)
});
entry.update(
capability.proficiency,
capability.experience_count,
capability.success_rate,
);
}
}
}
pub fn register_specialist(&mut self, specialist: Specialist) {
if specialist.instance_id == self.instance_id {
return;
}
let domain_specialists = self
.specialists
.entry(specialist.domain.clone())
.or_default();
if let Some(existing) = domain_specialists
.iter_mut()
.find(|s| s.instance_id == specialist.instance_id)
{
existing.update(
specialist.proficiency,
specialist.task_count,
specialist.success_rate,
);
existing.mark_active();
} else {
domain_specialists.push(specialist);
if domain_specialists.len() > self.max_per_domain {
domain_specialists.sort_by(|a, b| {
b.proficiency
.partial_cmp(&a.proficiency)
.unwrap_or(std::cmp::Ordering::Equal)
});
domain_specialists.truncate(self.max_per_domain);
}
}
}
pub fn mark_unavailable(&mut self, instance_id: &str, domain: &str) {
if let Some(specialists) = self.specialists.get_mut(domain) {
if let Some(specialist) = specialists
.iter_mut()
.find(|s| s.instance_id == instance_id)
{
specialist.mark_unavailable();
}
}
}
pub fn get_specialists(&self, domain: &str) -> Vec<&Specialist> {
let mut specialists: Vec<_> = self
.specialists
.get(domain)
.map(|s| {
s.iter()
.filter(|s| s.available && !s.is_stale(self.stale_timeout))
.collect()
})
.unwrap_or_default();
if let Some(my_spec) = self.my_specializations.get(domain) {
specialists.push(my_spec);
}
specialists.sort_by(|a, b| {
b.proficiency
.partial_cmp(&a.proficiency)
.unwrap_or(std::cmp::Ordering::Equal)
});
specialists
}
pub fn get_best_specialist(&self, domain: &str) -> Option<&Specialist> {
self.get_specialists(domain).first().copied()
}
pub fn covered_domains(&self) -> Vec<&str> {
let mut domains: Vec<_> = self
.specialists
.keys()
.filter(|d| !self.get_specialists(d).is_empty())
.map(|s| s.as_str())
.collect();
for domain in self.my_specializations.keys() {
if !domains.contains(&domain.as_str()) {
domains.push(domain.as_str());
}
}
domains
}
pub fn identify_gaps(&self, required_domains: &[String]) -> Vec<String> {
let covered: std::collections::HashSet<_> = self.covered_domains().into_iter().collect();
required_domains
.iter()
.filter(|d| !covered.contains(d.as_str()))
.cloned()
.collect()
}
pub fn my_specializations(&self) -> Vec<&Specialist> {
self.my_specializations.values().collect()
}
pub fn cleanup_stale(&mut self) -> usize {
let mut removed = 0;
for specialists in self.specialists.values_mut() {
let before = specialists.len();
specialists.retain(|s| !s.is_stale(self.stale_timeout));
removed += before - specialists.len();
}
removed
}
pub fn domain_stats(&self, domain: &str) -> Option<DomainStats> {
let specialists = self.get_specialists(domain);
if specialists.is_empty() {
return None;
}
let avg_proficiency =
specialists.iter().map(|s| s.proficiency).sum::<f32>() / specialists.len() as f32;
let total_tasks: u64 = specialists.iter().map(|s| s.task_count).sum();
let expert_count = specialists
.iter()
.filter(|s| s.status == SpecializationStatus::Expert)
.count();
Some(DomainStats {
domain: domain.to_string(),
specialist_count: specialists.len(),
expert_count,
avg_proficiency,
total_tasks,
})
}
}
#[derive(Debug, Clone)]
pub struct DomainStats {
pub domain: String,
pub specialist_count: usize,
pub expert_count: usize,
pub avg_proficiency: f32,
pub total_tasks: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_specialist_creation() {
let specialist = Specialist::new("agent-1".to_string(), "code_review".to_string(), 0.85);
assert_eq!(specialist.status, SpecializationStatus::Specialist);
assert!(specialist.available);
}
#[test]
fn test_specialization_engine() {
let mut engine = SpecializationEngine::new("agent-1".to_string());
engine.register_specialist(Specialist::new(
"agent-2".to_string(),
"code_review".to_string(),
0.9,
));
engine.register_specialist(Specialist::new(
"agent-3".to_string(),
"code_review".to_string(),
0.85,
));
let specialists = engine.get_specialists("code_review");
assert_eq!(specialists.len(), 2);
assert_eq!(specialists[0].instance_id, "agent-2"); }
#[test]
fn test_domain_gaps() {
let mut engine = SpecializationEngine::new("agent-1".to_string());
engine.register_specialist(Specialist::new(
"agent-2".to_string(),
"code_review".to_string(),
0.9,
));
let gaps = engine.identify_gaps(&[
"code_review".to_string(),
"data_analysis".to_string(),
"testing".to_string(),
]);
assert_eq!(gaps.len(), 2);
assert!(gaps.contains(&"data_analysis".to_string()));
assert!(gaps.contains(&"testing".to_string()));
}
}