use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use terraphim_rolegraph::RoleGraph;
use crate::{AgentMetadata, RegistryResult};
pub struct KnowledgeGraphIntegration {
role_graph: Arc<RoleGraph>,
automata_config: AutomataConfig,
query_cache: Arc<tokio::sync::RwLock<HashMap<String, QueryResult>>>,
similarity_thresholds: SimilarityThresholds,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AutomataConfig {
pub min_confidence: f64,
pub max_paragraphs: usize,
pub context_window: usize,
pub language_models: Vec<String>,
}
impl Default for AutomataConfig {
fn default() -> Self {
Self {
min_confidence: 0.7,
max_paragraphs: 10,
context_window: 512,
language_models: vec!["default".to_string()],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimilarityThresholds {
pub role_similarity: f64,
pub capability_similarity: f64,
pub domain_similarity: f64,
pub concept_similarity: f64,
}
impl Default for SimilarityThresholds {
fn default() -> Self {
Self {
role_similarity: 0.8,
capability_similarity: 0.75,
domain_similarity: 0.7,
concept_similarity: 0.65,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryResult {
pub query_hash: String,
pub concepts: Vec<String>,
pub connectivity: ConnectivityResult,
pub cached_at: chrono::DateTime<chrono::Utc>,
pub expires_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConnectivityResult {
pub all_connected: bool,
pub paths: Vec<Vec<String>>,
pub disconnected: Vec<String>,
pub strength_score: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentDiscoveryQuery {
pub required_roles: Vec<String>,
pub required_capabilities: Vec<String>,
pub required_domains: Vec<String>,
pub task_description: Option<String>,
pub min_success_rate: Option<f64>,
pub max_resource_usage: Option<crate::ResourceUsage>,
pub preferred_tags: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentDiscoveryResult {
pub matches: Vec<AgentMatch>,
pub query_analysis: QueryAnalysis,
pub suggestions: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentMatch {
pub agent: AgentMetadata,
pub match_score: f64,
pub score_breakdown: ScoreBreakdown,
pub explanation: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoreBreakdown {
pub role_score: f64,
pub capability_score: f64,
pub domain_score: f64,
pub performance_score: f64,
pub availability_score: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryAnalysis {
pub extracted_concepts: Vec<String>,
pub identified_domains: Vec<String>,
pub suggested_roles: Vec<String>,
pub connectivity_analysis: ConnectivityResult,
}
impl KnowledgeGraphIntegration {
pub fn new(
role_graph: Arc<RoleGraph>,
automata_config: AutomataConfig,
similarity_thresholds: SimilarityThresholds,
) -> Self {
Self {
role_graph,
automata_config,
query_cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
similarity_thresholds,
}
}
pub async fn discover_agents(
&self,
query: AgentDiscoveryQuery,
available_agents: &[AgentMetadata],
) -> RegistryResult<AgentDiscoveryResult> {
let query_analysis = self.analyze_query(&query).await?;
let mut eligible_agents = Vec::new();
for agent in available_agents {
if self
.check_basic_requirements(agent, &query)
.await
.unwrap_or(false)
{
eligible_agents.push(agent);
}
}
let mut matches = Vec::new();
for agent in eligible_agents {
if let Ok(agent_match) = self.score_agent_match(agent, &query, &query_analysis).await {
matches.push(agent_match);
}
}
#[allow(clippy::unnecessary_sort_by)]
matches.sort_by(|a, b| {
b.match_score
.partial_cmp(&a.match_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let suggestions = self
.generate_query_suggestions(&query, &query_analysis, &matches)
.await;
Ok(AgentDiscoveryResult {
matches,
query_analysis,
suggestions,
})
}
async fn check_basic_requirements(
&self,
agent: &AgentMetadata,
query: &AgentDiscoveryQuery,
) -> RegistryResult<bool> {
if !query.required_roles.is_empty() {
let has_required_role = query.required_roles.iter().any(|required_role| {
agent.primary_role.role_id == *required_role
|| agent
.secondary_roles
.iter()
.any(|role| role.role_id == *required_role)
});
if !has_required_role {
return Ok(false);
}
}
if !query.required_capabilities.is_empty() {
let has_required_capability = query.required_capabilities.iter().any(|required_cap| {
agent
.capabilities
.iter()
.any(|cap| cap.capability_id == *required_cap)
});
if !has_required_capability {
return Ok(false);
}
}
if query.max_resource_usage.is_some() {
}
Ok(true)
}
async fn analyze_query(&self, query: &AgentDiscoveryQuery) -> RegistryResult<QueryAnalysis> {
let mut extracted_concepts = Vec::new();
let mut identified_domains = Vec::new();
let mut suggested_roles = Vec::new();
if let Some(task_description) = &query.task_description {
extracted_concepts = self.extract_concepts_from_text(task_description).await?;
identified_domains = self
.identify_domains_from_concepts(&extracted_concepts)
.await?;
}
for role_id in &query.required_roles {
if let Some(related_roles) = self.find_related_roles(role_id).await? {
suggested_roles.extend(related_roles);
}
}
let all_terms: Vec<String> = query
.required_roles
.iter()
.chain(query.required_capabilities.iter())
.chain(query.required_domains.iter())
.chain(extracted_concepts.iter())
.cloned()
.collect();
let connectivity_analysis = self.analyze_term_connectivity(&all_terms).await?;
Ok(QueryAnalysis {
extracted_concepts,
identified_domains,
suggested_roles,
connectivity_analysis,
})
}
async fn score_agent_match(
&self,
agent: &AgentMetadata,
query: &AgentDiscoveryQuery,
query_analysis: &QueryAnalysis,
) -> RegistryResult<AgentMatch> {
let role_score = self
.calculate_role_score(agent, &query.required_roles)
.await?;
let capability_score = self
.calculate_capability_score(agent, &query.required_capabilities)
.await?;
let domain_score = self
.calculate_domain_score(
agent,
&query.required_domains,
&query_analysis.identified_domains,
)
.await?;
let performance_score = self.calculate_performance_score(agent, query).await?;
let availability_score = self.calculate_availability_score(agent).await?;
let match_score = (role_score * 0.25
+ capability_score * 0.25
+ domain_score * 0.20
+ performance_score * 0.20
+ availability_score * 0.10)
.clamp(0.0, 1.0);
let score_breakdown = ScoreBreakdown {
role_score,
capability_score,
domain_score,
performance_score,
availability_score,
};
let explanation = self.generate_match_explanation(agent, &score_breakdown, query_analysis);
Ok(AgentMatch {
agent: agent.clone(),
match_score,
score_breakdown,
explanation,
})
}
async fn extract_concepts_from_text(&self, text: &str) -> RegistryResult<Vec<String>> {
let scan_end = text
.char_indices()
.nth(self.automata_config.context_window)
.map(|(idx, _)| idx)
.unwrap_or(text.len());
let text = &text[..scan_end];
let mut concepts = HashSet::new();
let words: Vec<&str> = text
.split(|c: char| c.is_whitespace() || c == ',' || c == '.' || c == ';' || c == ':')
.collect();
for word in words {
let clean_word = word.trim().to_lowercase();
if clean_word.len() > 2
&& ![
"the", "and", "for", "with", "using", "from", "that", "this", "are", "was",
"were", "been", "have", "has", "had", "not", "but", "they", "you", "their",
"can", "will",
]
.contains(&clean_word.as_str())
&& !clean_word.chars().all(|c| c.is_ascii_punctuation())
{
concepts.insert(clean_word);
}
}
Ok(concepts.into_iter().collect())
}
async fn identify_domains_from_concepts(
&self,
concepts: &[String],
) -> RegistryResult<Vec<String>> {
let mut domains = HashSet::new();
for concept in concepts {
let concept_lower = concept.to_lowercase();
if concept_lower.contains("plan") || concept_lower.contains("strategy") {
domains.insert("planning".to_string());
}
if concept_lower.contains("data") || concept_lower.contains("analysis") {
domains.insert("data_analysis".to_string());
}
if concept_lower.contains("execute") || concept_lower.contains("implement") {
domains.insert("execution".to_string());
}
if concept_lower.contains("coordinate") || concept_lower.contains("manage") {
domains.insert("coordination".to_string());
}
if concept_lower.contains("neural")
|| concept_lower.contains("network")
|| concept_lower.contains("deep")
|| concept_lower.contains("learning")
|| concept_lower.contains("machine")
|| concept_lower.contains("model")
|| concept_lower.contains("classification")
|| concept_lower.contains("image")
{
domains.insert("machine_learning".to_string());
}
if concept_lower.contains("tensor") || concept_lower.contains("flow") {
domains.insert("tensorflow".to_string());
}
}
Ok(domains.into_iter().collect())
}
async fn find_related_roles(&self, _role_id: &str) -> RegistryResult<Option<Vec<String>>> {
Ok(Some(Vec::new()))
}
async fn analyze_term_connectivity(
&self,
terms: &[String],
) -> RegistryResult<ConnectivityResult> {
if terms.is_empty() {
return Ok(ConnectivityResult {
all_connected: true,
paths: Vec::new(),
disconnected: Vec::new(),
strength_score: 1.0,
});
}
let cache_key = format!("connectivity_{}", terms.join("_"));
{
let cache = self.query_cache.read().await;
if let Some(cached_result) = cache.get(&cache_key)
&& cached_result.expires_at > chrono::Utc::now()
{
return Ok(cached_result.connectivity.clone());
}
}
let text = terms.join(" ");
let all_connected = self.role_graph.is_all_terms_connected_by_path(&text);
let connectivity_result = ConnectivityResult {
all_connected,
paths: if all_connected {
vec![terms.to_vec()]
} else {
Vec::new()
},
disconnected: if all_connected {
Vec::new()
} else {
terms.to_vec()
},
strength_score: if all_connected { 1.0 } else { 0.0 },
};
{
let mut cache = self.query_cache.write().await;
cache.insert(
cache_key.clone(),
QueryResult {
query_hash: cache_key,
concepts: terms.to_vec(),
connectivity: connectivity_result.clone(),
cached_at: chrono::Utc::now(),
expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
},
);
}
Ok(connectivity_result)
}
async fn calculate_role_score(
&self,
agent: &AgentMetadata,
required_roles: &[String],
) -> RegistryResult<f64> {
if required_roles.is_empty() {
return Ok(1.0);
}
let mut total_score: f64 = 0.0;
let mut role_count = 0;
for required_role in required_roles {
let mut best_score: f64 = 0.0;
if agent.primary_role.role_id == *required_role {
best_score = 1.0;
} else {
for secondary_role in &agent.secondary_roles {
if secondary_role.role_id == *required_role {
best_score = best_score.max(0.9);
}
}
if let Some(related_roles) = self.find_related_roles(required_role).await? {
if related_roles.contains(&agent.primary_role.role_id) {
best_score = best_score.max(0.7);
}
for secondary_role in &agent.secondary_roles {
if related_roles.contains(&secondary_role.role_id) {
best_score = best_score.max(0.6);
}
}
}
}
total_score += best_score;
role_count += 1;
}
Ok(if role_count > 0 {
total_score / role_count as f64
} else {
1.0
})
}
async fn calculate_capability_score(
&self,
agent: &AgentMetadata,
required_capabilities: &[String],
) -> RegistryResult<f64> {
if required_capabilities.is_empty() {
return Ok(1.0);
}
let mut total_score: f64 = 0.0;
let mut capability_count = 0;
for required_capability in required_capabilities {
let mut best_score: f64 = 0.0;
for agent_capability in &agent.capabilities {
if agent_capability.capability_id == *required_capability {
best_score = best_score.max(agent_capability.performance_metrics.success_rate);
} else if agent_capability
.name
.to_lowercase()
.contains(&required_capability.to_lowercase())
|| required_capability
.to_lowercase()
.contains(&agent_capability.name.to_lowercase())
{
best_score =
best_score.max(agent_capability.performance_metrics.success_rate * 0.7);
} else if agent_capability
.category
.to_lowercase()
.contains(&required_capability.to_lowercase())
{
best_score =
best_score.max(agent_capability.performance_metrics.success_rate * 0.5);
}
}
total_score += best_score;
capability_count += 1;
}
Ok(if capability_count > 0 {
total_score / capability_count as f64
} else {
1.0
})
}
async fn calculate_domain_score(
&self,
agent: &AgentMetadata,
required_domains: &[String],
identified_domains: &[String],
) -> RegistryResult<f64> {
let all_domains: HashSet<String> = required_domains
.iter()
.chain(identified_domains.iter())
.cloned()
.collect();
if all_domains.is_empty() {
return Ok(1.0);
}
let mut total_score: f64 = 0.0;
let mut domain_count = 0;
for domain in &all_domains {
let mut best_score: f64 = 0.0;
if agent.can_handle_domain(domain) {
best_score = 1.0;
} else {
for agent_domain in &agent.knowledge_context.domains {
if agent_domain.to_lowercase().contains(&domain.to_lowercase())
|| domain.to_lowercase().contains(&agent_domain.to_lowercase())
{
best_score = best_score.max(0.7);
}
}
}
total_score += best_score;
domain_count += 1;
}
Ok(if domain_count > 0 {
total_score / domain_count as f64
} else {
1.0
})
}
async fn calculate_performance_score(
&self,
agent: &AgentMetadata,
query: &AgentDiscoveryQuery,
) -> RegistryResult<f64> {
let mut score = agent.get_success_rate();
if let Some(min_success_rate) = query.min_success_rate
&& score < min_success_rate
{
score *= 0.5; }
if let Some(max_resource_usage) = &query.max_resource_usage
&& let Some((_, latest_usage)) = agent.statistics.resource_history.last()
&& (latest_usage.memory_mb > max_resource_usage.memory_mb
|| latest_usage.cpu_percent > max_resource_usage.cpu_percent)
{
score *= 0.7; }
Ok(score)
}
async fn calculate_availability_score(&self, agent: &AgentMetadata) -> RegistryResult<f64> {
match agent.status {
crate::AgentStatus::Active => Ok(1.0),
crate::AgentStatus::Idle => Ok(1.0),
crate::AgentStatus::Busy => Ok(0.5),
crate::AgentStatus::Hibernating => Ok(0.8),
crate::AgentStatus::Initializing => Ok(0.3),
crate::AgentStatus::Terminating => Ok(0.0),
crate::AgentStatus::Terminated => Ok(0.0),
crate::AgentStatus::Failed(_) => Ok(0.0),
}
}
fn generate_match_explanation(
&self,
agent: &AgentMetadata,
score_breakdown: &ScoreBreakdown,
_query_analysis: &QueryAnalysis,
) -> String {
let mut explanation = format!("Agent {} ({})", agent.agent_id, agent.primary_role.name);
if score_breakdown.role_score > 0.8 {
explanation.push_str(" has excellent role compatibility");
} else if score_breakdown.role_score > 0.6 {
explanation.push_str(" has good role compatibility");
} else {
explanation.push_str(" has limited role compatibility");
}
if score_breakdown.capability_score > 0.8 {
explanation.push_str(" and strong capability match");
} else if score_breakdown.capability_score > 0.6 {
explanation.push_str(" and moderate capability match");
} else {
explanation.push_str(" but limited capability match");
}
if score_breakdown.performance_score > 0.8 {
explanation.push_str(". Performance history is excellent");
} else if score_breakdown.performance_score > 0.6 {
explanation.push_str(". Performance history is good");
} else {
explanation.push_str(". Performance history needs improvement");
}
explanation.push('.');
explanation
}
async fn generate_query_suggestions(
&self,
query: &AgentDiscoveryQuery,
query_analysis: &QueryAnalysis,
matches: &[AgentMatch],
) -> Vec<String> {
let mut suggestions = Vec::new();
if !query_analysis.connectivity_analysis.all_connected {
suggestions
.push("Consider adding related roles to improve agent connectivity".to_string());
}
if matches.is_empty() || matches.iter().all(|m| m.match_score < 0.5) {
suggestions.push(
"Consider relaxing some requirements to find more suitable agents".to_string(),
);
}
if !query_analysis.identified_domains.is_empty() && query.required_capabilities.is_empty() {
suggestions.push(format!(
"Consider specifying capabilities for domains: {}",
query_analysis.identified_domains.join(", ")
));
}
if let Some(best) = matches.first()
&& (best.score_breakdown.role_score < self.similarity_thresholds.role_similarity
|| best.score_breakdown.capability_score
< self.similarity_thresholds.capability_similarity)
{
suggestions.push(
"Best match falls below configured similarity thresholds; consider broadening the query"
.to_string(),
);
}
suggestions
}
pub async fn cleanup_cache(&self) {
let mut cache = self.query_cache.write().await;
let now = chrono::Utc::now();
cache.retain(|_, result| result.expires_at > now);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{AgentMetadata, AgentPid, AgentRole, SupervisorId};
use terraphim_types::{RoleName, Thesaurus};
#[tokio::test]
async fn test_knowledge_graph_integration_creation() {
let role_name = RoleName::new("test_role");
let thesaurus = Thesaurus::new("test_thesaurus".to_string());
let role_graph = Arc::new(RoleGraph::new(role_name, thesaurus).await.unwrap());
let automata_config = AutomataConfig::default();
let similarity_thresholds = SimilarityThresholds::default();
let kg_integration =
KnowledgeGraphIntegration::new(role_graph, automata_config, similarity_thresholds);
assert_eq!(kg_integration.similarity_thresholds.role_similarity, 0.8);
}
#[tokio::test]
async fn test_agent_discovery_query() {
let query = AgentDiscoveryQuery {
required_roles: vec!["planner".to_string()],
required_capabilities: vec!["task_planning".to_string()],
required_domains: vec!["project_management".to_string()],
task_description: Some(
"Plan and coordinate a software development project".to_string(),
),
min_success_rate: Some(0.8),
max_resource_usage: None,
preferred_tags: vec!["experienced".to_string()],
};
assert_eq!(query.required_roles.len(), 1);
assert_eq!(query.required_capabilities.len(), 1);
assert!(query.task_description.is_some());
}
#[tokio::test]
async fn test_score_calculation() {
let role_name = RoleName::new("test_role");
let thesaurus = Thesaurus::new("test_thesaurus".to_string());
let role_graph = Arc::new(RoleGraph::new(role_name, thesaurus).await.unwrap());
let automata_config = AutomataConfig::default();
let similarity_thresholds = SimilarityThresholds::default();
let kg_integration =
KnowledgeGraphIntegration::new(role_graph, automata_config, similarity_thresholds);
let agent_id = AgentPid::new();
let supervisor_id = SupervisorId::new();
let role = AgentRole::new(
"planner".to_string(),
"Planning Agent".to_string(),
"Responsible for task planning".to_string(),
);
let agent = AgentMetadata::new(agent_id, supervisor_id, role);
let availability_score = kg_integration
.calculate_availability_score(&agent)
.await
.unwrap();
assert!((0.0..=1.0).contains(&availability_score));
}
}