use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use log::{debug, info, warn};
use serde::{Deserialize, Serialize};
use terraphim_automata::Automata;
use terraphim_gen_agent::{GenAgent, GenAgentResult};
use terraphim_rolegraph::RoleGraph;
use terraphim_task_decomposition::Task;
use crate::{KgAgentError, KgAgentResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WorkerMessage {
ExecuteTask { task: Task },
CheckCompatibility { task: Task },
UpdateSpecialization {
domain: String,
expertise_level: f64,
},
GetStatus,
Pause,
Resume,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerState {
pub status: WorkerStatus,
pub specializations: HashMap<String, DomainSpecialization>,
pub execution_history: Vec<TaskExecution>,
pub metrics: WorkerMetrics,
pub config: WorkerConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum WorkerStatus {
Idle,
Executing,
Paused,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DomainSpecialization {
pub domain: String,
pub expertise_level: f64,
pub tasks_completed: u64,
pub success_rate: f64,
pub average_execution_time: std::time::Duration,
pub knowledge_concepts: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskExecution {
pub task_id: String,
pub domain: String,
pub start_time: std::time::SystemTime,
pub duration: std::time::Duration,
pub success: bool,
pub error_message: Option<String>,
pub concepts_used: Vec<String>,
pub confidence: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerMetrics {
pub total_tasks: u64,
pub successful_tasks: u64,
pub average_execution_time: std::time::Duration,
pub success_rate: f64,
pub domain_distribution: HashMap<String, u64>,
}
impl Default for WorkerMetrics {
fn default() -> Self {
Self {
total_tasks: 0,
successful_tasks: 0,
average_execution_time: std::time::Duration::ZERO,
success_rate: 0.0,
domain_distribution: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerConfig {
pub max_concurrent_tasks: usize,
pub execution_timeout: std::time::Duration,
pub min_confidence_threshold: f64,
pub enable_domain_learning: bool,
pub kg_query_timeout: std::time::Duration,
}
impl Default for WorkerConfig {
fn default() -> Self {
Self {
max_concurrent_tasks: 5,
execution_timeout: std::time::Duration::from_secs(300),
min_confidence_threshold: 0.5,
enable_domain_learning: true,
kg_query_timeout: std::time::Duration::from_secs(10),
}
}
}
impl Default for WorkerState {
fn default() -> Self {
Self {
status: WorkerStatus::Idle,
specializations: HashMap::new(),
execution_history: Vec::new(),
metrics: WorkerMetrics::default(),
config: WorkerConfig::default(),
}
}
}
pub struct KnowledgeGraphWorkerAgent {
agent_id: String,
automata: Arc<Automata>,
role_graph: Arc<RoleGraph>,
state: WorkerState,
}
impl KnowledgeGraphWorkerAgent {
pub fn new(
agent_id: String,
automata: Arc<Automata>,
role_graph: Arc<RoleGraph>,
config: WorkerConfig,
) -> Self {
let state = WorkerState {
status: WorkerStatus::Idle,
specializations: HashMap::new(),
execution_history: Vec::new(),
metrics: WorkerMetrics::default(),
config,
};
Self {
agent_id,
automata,
role_graph,
state,
}
}
async fn execute_task(&mut self, task: Task) -> KgAgentResult<TaskExecution> {
info!("Executing task: {}", task.task_id);
if self.state.status != WorkerStatus::Idle {
return Err(KgAgentError::WorkerError(format!(
"Worker {} is not idle (status: {:?})",
self.agent_id, self.state.status
)));
}
self.state.status = WorkerStatus::Executing;
let start_time = std::time::SystemTime::now();
let compatibility = self.check_task_compatibility(&task).await?;
if compatibility < self.state.config.min_confidence_threshold {
self.state.status = WorkerStatus::Idle;
return Err(KgAgentError::CompatibilityError(format!(
"Task {} compatibility {} below threshold {}",
task.task_id, compatibility, self.state.config.min_confidence_threshold
)));
}
let knowledge_context = self.extract_knowledge_context(&task).await?;
let execution_result = self.perform_task_execution(&task, &knowledge_context).await;
let duration = start_time.elapsed().unwrap_or(std::time::Duration::ZERO);
let success = execution_result.is_ok();
let error_message = if let Err(ref e) = execution_result {
Some(e.to_string())
} else {
None
};
let execution = TaskExecution {
task_id: task.task_id.clone(),
domain: task
.required_domains
.first()
.unwrap_or(&"general".to_string())
.clone(),
start_time,
duration,
success,
error_message,
concepts_used: knowledge_context,
confidence: compatibility,
};
self.update_metrics(&execution);
self.update_specializations(&execution);
self.state.execution_history.push(execution.clone());
if self.state.execution_history.len() > 1000 {
self.state.execution_history.remove(0);
}
self.state.status = WorkerStatus::Idle;
if success {
info!(
"Task {} executed successfully in {:.2}s",
task.task_id,
duration.as_secs_f64()
);
} else {
warn!(
"Task {} execution failed after {:.2}s: {:?}",
task.task_id,
duration.as_secs_f64(),
error_message
);
}
Ok(execution)
}
async fn check_task_compatibility(&self, task: &Task) -> KgAgentResult<f64> {
debug!("Checking compatibility for task: {}", task.task_id);
let mut compatibility_score = 0.0;
let mut factors = 0;
for required_domain in &task.required_domains {
if let Some(specialization) = self.state.specializations.get(required_domain) {
compatibility_score += specialization.expertise_level * specialization.success_rate;
factors += 1;
}
}
let task_concepts = &task.concepts;
if !task_concepts.is_empty() {
let connectivity_score = self.analyze_concept_connectivity(task_concepts).await?;
compatibility_score += connectivity_score;
factors += 1;
}
for required_capability in &task.required_capabilities {
let capability_score = self.assess_capability_compatibility(required_capability);
compatibility_score += capability_score;
factors += 1;
}
let final_score = if factors > 0 {
compatibility_score / factors as f64
} else {
0.5 };
debug!(
"Task {} compatibility: {:.2} (based on {} factors)",
task.task_id, final_score, factors
);
Ok(final_score)
}
async fn extract_knowledge_context(&self, task: &Task) -> KgAgentResult<Vec<String>> {
let context_text = format!(
"{} {} {}",
task.description,
task.context_keywords.join(" "),
task.concepts.join(" ")
);
let concepts = context_text
.split_whitespace()
.take(10)
.map(|s| s.to_lowercase())
.collect();
debug!(
"Extracted {} knowledge concepts for task {}",
concepts.len(),
task.task_id
);
Ok(concepts)
}
async fn perform_task_execution(
&self,
task: &Task,
knowledge_context: &[String],
) -> KgAgentResult<String> {
debug!(
"Performing execution for task {} with {} context concepts",
task.task_id,
knowledge_context.len()
);
let execution_time = match task.complexity {
terraphim_task_decomposition::TaskComplexity::Simple => {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
terraphim_task_decomposition::TaskComplexity::Moderate => {
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
}
terraphim_task_decomposition::TaskComplexity::Complex => {
tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
}
terraphim_task_decomposition::TaskComplexity::VeryComplex => {
tokio::time::sleep(std::time::Duration::from_millis(2000)).await;
}
};
let success_probability = if knowledge_context.len() > 5 {
0.9
} else {
0.7
};
let random_value: f64 = fastrand::f64();
if random_value < success_probability {
Ok(format!("Task {} completed successfully", task.task_id))
} else {
Err(KgAgentError::ExecutionFailed(format!(
"Task {} execution failed due to insufficient context",
task.task_id
)))
}
}
async fn analyze_concept_connectivity(&self, concepts: &[String]) -> KgAgentResult<f64> {
if concepts.len() < 2 {
return Ok(1.0);
}
let mut connectivity_score = 0.0;
let mut pairs = 0;
for i in 0..concepts.len() {
for j in (i + 1)..concepts.len() {
pairs += 1;
let concept1 = &concepts[i];
let concept2 = &concepts[j];
if concept1.chars().any(|c| concept2.contains(c)) {
connectivity_score += 1.0;
}
}
}
let final_score = if pairs > 0 {
connectivity_score / pairs as f64
} else {
0.0
};
Ok(final_score)
}
fn assess_capability_compatibility(&self, capability: &str) -> f64 {
for execution in &self.state.execution_history {
if execution.success
&& execution
.concepts_used
.iter()
.any(|c| c.contains(capability))
{
return 0.8;
}
}
for specialization in self.state.specializations.values() {
if specialization
.knowledge_concepts
.iter()
.any(|c| c.contains(capability))
{
return specialization.expertise_level;
}
}
0.3 }
fn update_metrics(&mut self, execution: &TaskExecution) {
self.state.metrics.total_tasks += 1;
if execution.success {
self.state.metrics.successful_tasks += 1;
}
self.state.metrics.success_rate =
self.state.metrics.successful_tasks as f64 / self.state.metrics.total_tasks as f64;
let total_time = self.state.metrics.average_execution_time.as_secs_f64()
* (self.state.metrics.total_tasks - 1) as f64
+ execution.duration.as_secs_f64();
self.state.metrics.average_execution_time =
std::time::Duration::from_secs_f64(total_time / self.state.metrics.total_tasks as f64);
*self
.state
.metrics
.domain_distribution
.entry(execution.domain.clone())
.or_insert(0) += 1;
}
fn update_specializations(&mut self, execution: &TaskExecution) {
if !self.state.config.enable_domain_learning {
return;
}
let specialization = self
.state
.specializations
.entry(execution.domain.clone())
.or_insert_with(|| DomainSpecialization {
domain: execution.domain.clone(),
expertise_level: 0.1,
tasks_completed: 0,
success_rate: 0.0,
average_execution_time: std::time::Duration::ZERO,
knowledge_concepts: Vec::new(),
});
specialization.tasks_completed += 1;
let previous_successes =
(specialization.success_rate * (specialization.tasks_completed - 1) as f64) as u64;
let new_successes = if execution.success {
previous_successes + 1
} else {
previous_successes
};
specialization.success_rate = new_successes as f64 / specialization.tasks_completed as f64;
let experience_factor = (specialization.tasks_completed as f64).ln().max(1.0) / 10.0;
specialization.expertise_level =
(specialization.success_rate * 0.7 + experience_factor * 0.3).min(1.0);
let total_time = specialization.average_execution_time.as_secs_f64()
* (specialization.tasks_completed - 1) as f64
+ execution.duration.as_secs_f64();
specialization.average_execution_time =
std::time::Duration::from_secs_f64(total_time / specialization.tasks_completed as f64);
for concept in &execution.concepts_used {
if !specialization.knowledge_concepts.contains(concept) {
specialization.knowledge_concepts.push(concept.clone());
}
}
if specialization.knowledge_concepts.len() > 100 {
specialization.knowledge_concepts.truncate(100);
}
}
}
#[async_trait]
impl GenAgent<WorkerState> for KnowledgeGraphWorkerAgent {
type Message = WorkerMessage;
async fn init(&mut self, _init_args: serde_json::Value) -> GenAgentResult<()> {
info!("Initializing worker agent: {}", self.agent_id);
self.state.status = WorkerStatus::Idle;
Ok(())
}
async fn handle_call(&mut self, message: Self::Message) -> GenAgentResult<serde_json::Value> {
match message {
WorkerMessage::ExecuteTask { task } => {
let execution = self.execute_task(task).await.map_err(|e| {
terraphim_gen_agent::GenAgentError::ExecutionError(
self.agent_id.clone(),
e.to_string(),
)
})?;
Ok(serde_json::to_value(execution).unwrap())
}
WorkerMessage::CheckCompatibility { task } => {
let compatibility = self.check_task_compatibility(&task).await.map_err(|e| {
terraphim_gen_agent::GenAgentError::ExecutionError(
self.agent_id.clone(),
e.to_string(),
)
})?;
Ok(serde_json::to_value(compatibility).unwrap())
}
WorkerMessage::GetStatus => Ok(serde_json::to_value(&self.state.status).unwrap()),
_ => {
Ok(serde_json::Value::Null)
}
}
}
async fn handle_cast(&mut self, message: Self::Message) -> GenAgentResult<()> {
match message {
WorkerMessage::ExecuteTask { task } => {
let _ = self.execute_task(task).await;
}
WorkerMessage::UpdateSpecialization {
domain,
expertise_level,
} => {
let specialization = self
.state
.specializations
.entry(domain.clone())
.or_insert_with(|| DomainSpecialization {
domain: domain.clone(),
expertise_level: 0.1,
tasks_completed: 0,
success_rate: 0.0,
average_execution_time: std::time::Duration::ZERO,
knowledge_concepts: Vec::new(),
});
specialization.expertise_level = expertise_level.clamp(0.0, 1.0);
}
WorkerMessage::Pause => {
if self.state.status == WorkerStatus::Executing {
self.state.status = WorkerStatus::Paused;
}
}
WorkerMessage::Resume => {
if self.state.status == WorkerStatus::Paused {
self.state.status = WorkerStatus::Executing;
}
}
_ => {
}
}
Ok(())
}
async fn handle_info(&mut self, _message: serde_json::Value) -> GenAgentResult<()> {
Ok(())
}
async fn terminate(&mut self, _reason: String) -> GenAgentResult<()> {
info!("Terminating worker agent: {}", self.agent_id);
self.state.status = WorkerStatus::Idle;
Ok(())
}
fn get_state(&self) -> &WorkerState {
&self.state
}
fn get_state_mut(&mut self) -> &mut WorkerState {
&mut self.state
}
}
#[cfg(test)]
mod tests {
use super::*;
use terraphim_task_decomposition::TaskComplexity;
fn create_test_task() -> Task {
let mut task = Task::new(
"test_task".to_string(),
"Test task for worker".to_string(),
TaskComplexity::Simple,
1,
);
task.required_domains = vec!["testing".to_string()];
task.required_capabilities = vec!["test_execution".to_string()];
task.concepts = vec!["test".to_string(), "execution".to_string()];
task
}
async fn create_test_agent() -> KnowledgeGraphWorkerAgent {
use terraphim_automata::{load_thesaurus, AutomataPath};
use terraphim_types::RoleName;
let automata = Arc::new(terraphim_automata::Automata::default());
let role_name = RoleName::new("worker");
let thesaurus = load_thesaurus(&AutomataPath::local_example())
.await
.unwrap();
let role_graph = Arc::new(RoleGraph::new(role_name, thesaurus).await.unwrap());
KnowledgeGraphWorkerAgent::new(
"test_worker".to_string(),
automata,
role_graph,
WorkerConfig::default(),
)
}
#[tokio::test]
async fn test_worker_agent_creation() {
let agent = create_test_agent().await;
assert_eq!(agent.agent_id, "test_worker");
assert_eq!(agent.state.status, WorkerStatus::Idle);
}
#[tokio::test]
async fn test_task_compatibility_check() {
let agent = create_test_agent().await;
let task = create_test_task();
let compatibility = agent.check_task_compatibility(&task).await.unwrap();
assert!(compatibility >= 0.0 && compatibility <= 1.0);
}
#[tokio::test]
async fn test_knowledge_context_extraction() {
let agent = create_test_agent().await;
let task = create_test_task();
let context = agent.extract_knowledge_context(&task).await.unwrap();
assert!(!context.is_empty());
}
#[tokio::test]
async fn test_concept_connectivity_analysis() {
let agent = create_test_agent().await;
let concepts = vec!["test".to_string(), "execution".to_string()];
let connectivity = agent.analyze_concept_connectivity(&concepts).await.unwrap();
assert!(connectivity >= 0.0 && connectivity <= 1.0);
}
#[tokio::test]
async fn test_capability_compatibility() {
let agent = create_test_agent().await;
let compatibility = agent.assess_capability_compatibility("test_execution");
assert!(compatibility >= 0.0 && compatibility <= 1.0);
}
#[tokio::test]
async fn test_gen_agent_interface() {
let mut agent = create_test_agent().await;
let init_result = agent.init(serde_json::json!({})).await;
assert!(init_result.is_ok());
let task = create_test_task();
let message = WorkerMessage::CheckCompatibility { task };
let call_result = agent.handle_call(message).await;
assert!(call_result.is_ok());
let message = WorkerMessage::Pause;
let cast_result = agent.handle_cast(message).await;
assert!(cast_result.is_ok());
let terminate_result = agent.terminate("test".to_string()).await;
assert!(terminate_result.is_ok());
}
}