use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
use crate::{AgentCapability, CapabilityMetrics, RegistryError, RegistryResult, ResourceUsage};
pub struct CapabilityRegistry {
capabilities: HashMap<String, AgentCapability>,
categories: HashMap<String, Vec<String>>,
dependencies: HashMap<String, Vec<String>>,
compatibility: HashMap<String, HashMap<String, f64>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CapabilityQuery {
pub required_capabilities: Vec<String>,
pub optional_capabilities: Vec<String>,
pub min_performance: Option<CapabilityMetrics>,
pub max_resources: Option<ResourceUsage>,
pub categories: Vec<String>,
pub io_requirements: IORequirements,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct IORequirements {
pub input_types: Vec<String>,
pub output_types: Vec<String>,
pub compatibility_matrix: HashMap<String, Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CapabilityMatch {
pub capability: AgentCapability,
pub match_score: f64,
pub match_details: CapabilityMatchDetails,
pub explanation: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CapabilityMatchDetails {
pub exact_matches: Vec<String>,
pub partial_matches: Vec<(String, f64)>,
pub missing_requirements: Vec<String>,
pub performance_score: f64,
pub resource_score: f64,
pub io_score: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CapabilityTemplate {
pub name: String,
pub description: String,
pub default_category: String,
pub required_fields: Vec<String>,
pub optional_fields: HashMap<String, serde_json::Value>,
pub performance_benchmarks: CapabilityMetrics,
}
impl CapabilityRegistry {
pub fn new() -> Self {
Self {
capabilities: HashMap::new(),
categories: HashMap::new(),
dependencies: HashMap::new(),
compatibility: HashMap::new(),
}
}
pub fn register_capability(&mut self, capability: AgentCapability) -> RegistryResult<()> {
let capability_id = capability.capability_id.clone();
self.validate_capability(&capability)?;
self.categories
.entry(capability.category.clone())
.or_default()
.push(capability_id.clone());
if !capability.dependencies.is_empty() {
self.dependencies
.insert(capability_id.clone(), capability.dependencies.clone());
}
self.capabilities.insert(capability_id, capability);
Ok(())
}
pub fn unregister_capability(&mut self, capability_id: &str) -> RegistryResult<()> {
if let Some(capability) = self.capabilities.remove(capability_id) {
if let Some(category_capabilities) = self.categories.get_mut(&capability.category) {
category_capabilities.retain(|id| id != capability_id);
if category_capabilities.is_empty() {
self.categories.remove(&capability.category);
}
}
self.dependencies.remove(capability_id);
self.compatibility.remove(capability_id);
for compatibility_map in self.compatibility.values_mut() {
compatibility_map.remove(capability_id);
}
Ok(())
} else {
Err(RegistryError::System(format!(
"Capability {} not found",
capability_id
)))
}
}
pub fn get_capability(&self, capability_id: &str) -> Option<&AgentCapability> {
self.capabilities.get(capability_id)
}
pub fn list_capabilities(&self) -> Vec<&AgentCapability> {
self.capabilities.values().collect()
}
pub fn list_capabilities_by_category(&self, category: &str) -> Vec<&AgentCapability> {
if let Some(capability_ids) = self.categories.get(category) {
capability_ids
.iter()
.filter_map(|id| self.capabilities.get(id))
.collect()
} else {
Vec::new()
}
}
pub fn find_capabilities(
&self,
query: &CapabilityQuery,
) -> RegistryResult<Vec<CapabilityMatch>> {
let mut matches = Vec::new();
let candidates = if query.categories.is_empty() {
self.list_capabilities()
} else {
let mut candidates = Vec::new();
for category in &query.categories {
candidates.extend(self.list_capabilities_by_category(category));
}
candidates
};
for capability in candidates {
if let Ok(capability_match) = self.score_capability_match(capability, query)
&& capability_match.match_score > 0.0
{
matches.push(capability_match);
}
}
#[allow(clippy::unnecessary_sort_by)]
matches.sort_by(|a, b| {
b.match_score
.partial_cmp(&a.match_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(matches)
}
fn score_capability_match(
&self,
capability: &AgentCapability,
query: &CapabilityQuery,
) -> RegistryResult<CapabilityMatch> {
let mut exact_matches = Vec::new();
let mut partial_matches = Vec::new();
let mut missing_requirements = Vec::new();
let mut requirement_score = 0.0;
let total_requirements = query.required_capabilities.len();
for required_cap in &query.required_capabilities {
if capability.capability_id == *required_cap {
exact_matches.push(required_cap.clone());
requirement_score += 1.0;
} else {
let similarity =
self.calculate_capability_similarity(&capability.capability_id, required_cap);
if similarity > 0.5 {
partial_matches.push((required_cap.clone(), similarity));
requirement_score += similarity;
} else {
missing_requirements.push(required_cap.clone());
}
}
}
for optional_cap in &query.optional_capabilities {
if capability.capability_id == *optional_cap {
requirement_score += 0.5; } else {
let similarity =
self.calculate_capability_similarity(&capability.capability_id, optional_cap);
if similarity > 0.5 {
requirement_score += similarity * 0.3; }
}
}
if total_requirements > 0 {
requirement_score = (requirement_score / total_requirements as f64).min(1.0);
} else {
requirement_score = 1.0;
}
let performance_score = if let Some(min_performance) = &query.min_performance {
self.calculate_performance_score(&capability.performance_metrics, min_performance)
} else {
1.0
};
let resource_score = if let Some(max_resources) = &query.max_resources {
self.calculate_resource_score(
&capability.performance_metrics.resource_usage,
max_resources,
)
} else {
1.0
};
let io_score = self.calculate_io_score(capability, &query.io_requirements);
let match_score = (requirement_score * 0.4
+ performance_score * 0.25
+ resource_score * 0.2
+ io_score * 0.15)
.clamp(0.0, 1.0);
let match_details = CapabilityMatchDetails {
exact_matches,
partial_matches,
missing_requirements,
performance_score,
resource_score,
io_score,
};
let explanation = self.generate_match_explanation(capability, &match_details, match_score);
Ok(CapabilityMatch {
capability: capability.clone(),
match_score,
match_details,
explanation,
})
}
fn calculate_capability_similarity(&self, cap1: &str, cap2: &str) -> f64 {
if let Some(cap1_compat) = self.compatibility.get(cap1)
&& let Some(similarity) = cap1_compat.get(cap2)
{
return *similarity;
}
self.string_similarity(cap1, cap2)
}
fn string_similarity(&self, s1: &str, s2: &str) -> f64 {
let s1_lower = s1.to_lowercase();
let s2_lower = s2.to_lowercase();
if s1_lower == s2_lower {
return 1.0;
}
if s1_lower.contains(&s2_lower) || s2_lower.contains(&s1_lower) {
return 0.7;
}
let s1_words: HashSet<&str> = s1_lower.split_whitespace().collect();
let s2_words: HashSet<&str> = s2_lower.split_whitespace().collect();
let intersection = s1_words.intersection(&s2_words).count();
let union = s1_words.union(&s2_words).count();
if union > 0 {
intersection as f64 / union as f64
} else {
0.0
}
}
fn calculate_performance_score(
&self,
actual: &CapabilityMetrics,
required: &CapabilityMetrics,
) -> f64 {
let mut score = 1.0;
if actual.success_rate < required.success_rate {
score *= actual.success_rate / required.success_rate;
}
if actual.avg_execution_time > required.avg_execution_time {
let time_ratio =
required.avg_execution_time.as_secs_f64() / actual.avg_execution_time.as_secs_f64();
score *= time_ratio.min(1.0);
}
if actual.quality_score < required.quality_score {
score *= actual.quality_score / required.quality_score;
}
score.clamp(0.0, 1.0)
}
fn calculate_resource_score(&self, actual: &ResourceUsage, max_allowed: &ResourceUsage) -> f64 {
let mut score = 1.0;
if actual.memory_mb > max_allowed.memory_mb {
score *= max_allowed.memory_mb / actual.memory_mb;
}
if actual.cpu_percent > max_allowed.cpu_percent {
score *= max_allowed.cpu_percent / actual.cpu_percent;
}
if actual.network_kbps > max_allowed.network_kbps {
score *= max_allowed.network_kbps / actual.network_kbps;
}
if actual.storage_mb > max_allowed.storage_mb {
score *= max_allowed.storage_mb / actual.storage_mb;
}
score.clamp(0.0, 1.0)
}
fn calculate_io_score(
&self,
capability: &AgentCapability,
requirements: &IORequirements,
) -> f64 {
if requirements.input_types.is_empty() && requirements.output_types.is_empty() {
return 1.0;
}
let mut input_score = 1.0;
let mut output_score = 1.0;
if !requirements.input_types.is_empty() {
let mut matching_inputs = 0;
for required_input in &requirements.input_types {
if capability.input_types.contains(required_input) {
matching_inputs += 1;
} else {
if let Some(compatible_types) =
requirements.compatibility_matrix.get(required_input)
&& capability
.input_types
.iter()
.any(|input| compatible_types.contains(input))
{
matching_inputs += 1;
}
}
}
input_score = matching_inputs as f64 / requirements.input_types.len() as f64;
}
if !requirements.output_types.is_empty() {
let mut matching_outputs = 0;
for required_output in &requirements.output_types {
if capability.output_types.contains(required_output) {
matching_outputs += 1;
} else {
if let Some(compatible_types) =
requirements.compatibility_matrix.get(required_output)
&& capability
.output_types
.iter()
.any(|output| compatible_types.contains(output))
{
matching_outputs += 1;
}
}
}
output_score = matching_outputs as f64 / requirements.output_types.len() as f64;
}
(input_score + output_score) / 2.0
}
fn generate_match_explanation(
&self,
capability: &AgentCapability,
details: &CapabilityMatchDetails,
match_score: f64,
) -> String {
let mut explanation = format!("Capability '{}' ", capability.name);
if !details.exact_matches.is_empty() {
explanation.push_str(&format!(
"exactly matches {} requirements",
details.exact_matches.len()
));
}
if !details.partial_matches.is_empty() {
if !details.exact_matches.is_empty() {
explanation.push_str(" and ");
}
explanation.push_str(&format!(
"partially matches {} requirements",
details.partial_matches.len()
));
}
if !details.missing_requirements.is_empty() {
explanation.push_str(&format!(
", missing {} requirements",
details.missing_requirements.len()
));
}
explanation.push_str(&format!(
". Performance: {:.1}%, Resources: {:.1}%, I/O: {:.1}%. Overall match: {:.1}%",
details.performance_score * 100.0,
details.resource_score * 100.0,
details.io_score * 100.0,
match_score * 100.0
));
explanation
}
fn validate_capability(&self, capability: &AgentCapability) -> RegistryResult<()> {
if capability.capability_id.is_empty() {
return Err(RegistryError::System(
"Capability ID cannot be empty".to_string(),
));
}
if capability.name.is_empty() {
return Err(RegistryError::System(
"Capability name cannot be empty".to_string(),
));
}
if capability.category.is_empty() {
return Err(RegistryError::System(
"Capability category cannot be empty".to_string(),
));
}
if capability.performance_metrics.success_rate < 0.0
|| capability.performance_metrics.success_rate > 1.0
{
return Err(RegistryError::System(
"Success rate must be between 0.0 and 1.0".to_string(),
));
}
if capability.performance_metrics.quality_score < 0.0
|| capability.performance_metrics.quality_score > 1.0
{
return Err(RegistryError::System(
"Quality score must be between 0.0 and 1.0".to_string(),
));
}
Ok(())
}
pub fn set_capability_compatibility(&mut self, cap1: &str, cap2: &str, similarity: f64) {
self.compatibility
.entry(cap1.to_string())
.or_default()
.insert(cap2.to_string(), similarity);
self.compatibility
.entry(cap2.to_string())
.or_default()
.insert(cap1.to_string(), similarity);
}
pub fn get_dependencies(&self, capability_id: &str) -> Vec<String> {
self.dependencies
.get(capability_id)
.cloned()
.unwrap_or_default()
}
pub fn check_dependencies(
&self,
capability_id: &str,
available_capabilities: &[String],
) -> bool {
if let Some(dependencies) = self.dependencies.get(capability_id) {
dependencies
.iter()
.all(|dep| available_capabilities.contains(dep))
} else {
true }
}
pub fn get_statistics(&self) -> CapabilityRegistryStats {
let mut categories_count = HashMap::new();
for (category, capabilities) in &self.categories {
categories_count.insert(category.clone(), capabilities.len());
}
CapabilityRegistryStats {
total_capabilities: self.capabilities.len(),
categories_count,
total_dependencies: self.dependencies.len(),
compatibility_entries: self.compatibility.values().map(|m| m.len()).sum(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CapabilityRegistryStats {
pub total_capabilities: usize,
pub categories_count: HashMap<String, usize>,
pub total_dependencies: usize,
pub compatibility_entries: usize,
}
impl Default for CapabilityRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_capability_registry_creation() {
let registry = CapabilityRegistry::new();
assert_eq!(registry.list_capabilities().len(), 0);
}
#[test]
fn test_capability_registration() {
let mut registry = CapabilityRegistry::new();
let capability = AgentCapability {
capability_id: "test_capability".to_string(),
name: "Test Capability".to_string(),
description: "A test capability".to_string(),
category: "testing".to_string(),
required_domains: vec!["test_domain".to_string()],
input_types: vec!["text".to_string()],
output_types: vec!["result".to_string()],
performance_metrics: CapabilityMetrics::default(),
dependencies: Vec::new(),
};
registry.register_capability(capability.clone()).unwrap();
assert_eq!(registry.list_capabilities().len(), 1);
assert!(registry.get_capability("test_capability").is_some());
let by_category = registry.list_capabilities_by_category("testing");
assert_eq!(by_category.len(), 1);
}
#[test]
fn test_capability_matching() {
let mut registry = CapabilityRegistry::new();
let capability = AgentCapability {
capability_id: "planning".to_string(),
name: "Task Planning".to_string(),
description: "Plan and organize tasks".to_string(),
category: "planning".to_string(),
required_domains: vec!["project_management".to_string()],
input_types: vec!["requirements".to_string()],
output_types: vec!["plan".to_string()],
performance_metrics: CapabilityMetrics {
avg_execution_time: Duration::from_secs(5),
success_rate: 0.9,
resource_usage: ResourceUsage {
memory_mb: 100.0,
cpu_percent: 20.0,
network_kbps: 10.0,
storage_mb: 50.0,
},
quality_score: 0.85,
last_updated: chrono::Utc::now(),
},
dependencies: Vec::new(),
};
registry.register_capability(capability).unwrap();
let query = CapabilityQuery {
required_capabilities: vec!["planning".to_string()],
optional_capabilities: Vec::new(),
min_performance: None,
max_resources: None,
categories: Vec::new(),
io_requirements: IORequirements::default(),
};
let matches = registry.find_capabilities(&query).unwrap();
assert_eq!(matches.len(), 1);
assert!(matches[0].match_score > 0.0);
}
#[test]
fn test_capability_compatibility() {
let mut registry = CapabilityRegistry::new();
registry.set_capability_compatibility("planning", "task_planning", 0.9);
let similarity = registry.calculate_capability_similarity("planning", "task_planning");
assert_eq!(similarity, 0.9);
}
#[test]
fn test_dependency_checking() {
let mut registry = CapabilityRegistry::new();
let capability = AgentCapability {
capability_id: "advanced_planning".to_string(),
name: "Advanced Planning".to_string(),
description: "Advanced task planning".to_string(),
category: "planning".to_string(),
required_domains: Vec::new(),
input_types: Vec::new(),
output_types: Vec::new(),
performance_metrics: CapabilityMetrics::default(),
dependencies: vec!["basic_planning".to_string()],
};
registry.register_capability(capability).unwrap();
let available = vec!["basic_planning".to_string()];
assert!(registry.check_dependencies("advanced_planning", &available));
let unavailable = vec!["other_capability".to_string()];
assert!(!registry.check_dependencies("advanced_planning", &unavailable));
}
}