use crate::agent::capabilities::{AgentCapabilities, AgentRequirements};
use crate::agent::context::AgentContext;
use crate::agent::core::MoFAAgent;
use crate::agent::error::{AgentError, AgentResult};
use crate::agent::traits::AgentMetadata;
use crate::agent::types::AgentState;
use mofa_kernel::agent::config::AgentConfig;
use mofa_kernel::agent::registry::AgentFactory;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
struct AgentEntry {
agent: Arc<RwLock<dyn MoFAAgent>>,
metadata: AgentMetadata,
registered_at: u64,
}
struct CapabilityIndex {
by_tag: HashMap<String, Vec<String>>,
by_strategy: HashMap<String, Vec<String>>,
}
impl CapabilityIndex {
fn new() -> Self {
Self {
by_tag: HashMap::new(),
by_strategy: HashMap::new(),
}
}
fn index(&mut self, agent_id: &str, capabilities: &AgentCapabilities) {
for tag in &capabilities.tags {
self.by_tag
.entry(tag.clone())
.or_default()
.push(agent_id.to_string());
}
for strategy in &capabilities.reasoning_strategies {
let strategy_name = format!("{:?}", strategy);
self.by_strategy
.entry(strategy_name)
.or_default()
.push(agent_id.to_string());
}
}
fn unindex(&mut self, agent_id: &str) {
for ids in self.by_tag.values_mut() {
ids.retain(|id| id != agent_id);
}
for ids in self.by_strategy.values_mut() {
ids.retain(|id| id != agent_id);
}
}
fn find_by_tag(&self, tag: &str) -> Vec<String> {
self.by_tag.get(tag).cloned().unwrap_or_default()
}
fn find_by_tags(&self, tags: &[String]) -> Vec<String> {
if tags.is_empty() {
return vec![];
}
let mut result: Option<Vec<String>> = None;
for tag in tags {
let ids = self.find_by_tag(tag);
result = match result {
None => Some(ids),
Some(existing) => {
let intersection: Vec<String> =
existing.into_iter().filter(|id| ids.contains(id)).collect();
Some(intersection)
}
};
}
result.unwrap_or_default()
}
}
pub struct AgentRegistry {
agents: Arc<RwLock<HashMap<String, AgentEntry>>>,
capability_index: Arc<RwLock<CapabilityIndex>>,
factories: Arc<RwLock<HashMap<String, Arc<dyn AgentFactory>>>>,
}
impl AgentRegistry {
pub fn new() -> Self {
Self {
agents: Arc::new(RwLock::new(HashMap::new())),
capability_index: Arc::new(RwLock::new(CapabilityIndex::new())),
factories: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register(&self, agent: Arc<RwLock<dyn MoFAAgent>>) -> AgentResult<()> {
let agent_guard = agent.read().await;
let id = agent_guard.id().to_string();
let name = agent_guard.name().to_string();
let capabilities = agent_guard.capabilities().clone();
let state = agent_guard.state();
drop(agent_guard);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let metadata = AgentMetadata {
id: id.clone(),
name,
description: None,
version: None,
capabilities: capabilities.clone(),
state,
};
let entry = AgentEntry {
agent,
metadata,
registered_at: now,
};
{
let mut index = self.capability_index.write().await;
index.index(&id, &capabilities);
}
{
let mut agents = self.agents.write().await;
agents.insert(id, entry);
}
Ok(())
}
pub async fn get(&self, id: &str) -> Option<Arc<RwLock<dyn MoFAAgent>>> {
let agents = self.agents.read().await;
agents.get(id).map(|e| e.agent.clone())
}
pub async fn unregister(&self, id: &str) -> AgentResult<bool> {
{
let mut index = self.capability_index.write().await;
index.unindex(id);
}
let mut agents = self.agents.write().await;
Ok(agents.remove(id).is_some())
}
pub async fn get_metadata(&self, id: &str) -> Option<AgentMetadata> {
let agents = self.agents.read().await;
agents.get(id).map(|e| e.metadata.clone())
}
pub async fn list(&self) -> Vec<AgentMetadata> {
let agents = self.agents.read().await;
agents.values().map(|e| e.metadata.clone()).collect()
}
pub async fn count(&self) -> usize {
let agents = self.agents.read().await;
agents.len()
}
pub async fn contains(&self, id: &str) -> bool {
let agents = self.agents.read().await;
agents.contains_key(id)
}
pub async fn find_by_capabilities(
&self,
requirements: &AgentRequirements,
) -> Vec<AgentMetadata> {
let agents = self.agents.read().await;
agents
.values()
.filter(|entry| requirements.matches(&entry.metadata.capabilities))
.map(|entry| entry.metadata.clone())
.collect()
}
pub async fn find_by_tag(&self, tag: &str) -> Vec<AgentMetadata> {
let index = self.capability_index.read().await;
let ids = index.find_by_tag(tag);
drop(index);
let agents = self.agents.read().await;
ids.iter()
.filter_map(|id| agents.get(id).map(|e| e.metadata.clone()))
.collect()
}
pub async fn find_by_tags(&self, tags: &[String]) -> Vec<AgentMetadata> {
let index = self.capability_index.read().await;
let ids = index.find_by_tags(tags);
drop(index);
let agents = self.agents.read().await;
ids.iter()
.filter_map(|id| agents.get(id).map(|e| e.metadata.clone()))
.collect()
}
pub async fn find_by_state(&self, state: AgentState) -> Vec<AgentMetadata> {
let agents = self.agents.read().await;
agents
.values()
.filter(|entry| entry.metadata.state == state)
.map(|entry| entry.metadata.clone())
.collect()
}
pub async fn register_factory(&self, factory: Arc<dyn AgentFactory>) -> AgentResult<()> {
let type_id = factory.type_id().to_string();
let mut factories = self.factories.write().await;
factories.insert(type_id, factory);
Ok(())
}
pub async fn get_factory(&self, type_id: &str) -> Option<Arc<dyn AgentFactory>> {
let factories = self.factories.read().await;
factories.get(type_id).cloned()
}
pub async fn unregister_factory(&self, type_id: &str) -> AgentResult<bool> {
let mut factories = self.factories.write().await;
Ok(factories.remove(type_id).is_some())
}
pub async fn list_factory_types(&self) -> Vec<String> {
let factories = self.factories.read().await;
factories.keys().cloned().collect()
}
pub async fn create(
&self,
type_id: &str,
config: AgentConfig,
) -> AgentResult<Arc<RwLock<dyn MoFAAgent>>> {
let factory = self
.get_factory(type_id)
.await
.ok_or_else(|| AgentError::NotFound(format!("Factory not found: {}", type_id)))?;
factory.validate_config(&config)?;
factory.create(config).await
}
pub async fn create_and_register(
&self,
type_id: &str,
config: AgentConfig,
) -> AgentResult<Arc<RwLock<dyn MoFAAgent>>> {
let agent = self.create(type_id, config).await?;
self.register(agent.clone()).await?;
Ok(agent)
}
pub async fn initialize_all(&self, ctx: &AgentContext) -> AgentResult<Vec<String>> {
let agents = self.agents.read().await;
let mut initialized = Vec::new();
for (id, entry) in agents.iter() {
let mut agent = entry.agent.write().await;
if agent.state() == AgentState::Created {
agent.initialize(ctx).await?;
initialized.push(id.clone());
}
}
Ok(initialized)
}
pub async fn shutdown_all(&self) -> AgentResult<Vec<String>> {
let agents = self.agents.read().await;
let mut shutdown = Vec::new();
for (id, entry) in agents.iter() {
let mut agent = entry.agent.write().await;
let state = agent.state();
if state != AgentState::Shutdown && state != AgentState::Failed {
agent.shutdown().await?;
shutdown.push(id.clone());
}
}
Ok(shutdown)
}
pub async fn clear(&self) -> AgentResult<usize> {
self.shutdown_all().await?;
{
let mut index = self.capability_index.write().await;
*index = CapabilityIndex::new();
}
let mut agents = self.agents.write().await;
let count = agents.len();
agents.clear();
Ok(count)
}
}
impl Default for AgentRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RegistryStats {
pub total_agents: usize,
pub by_state: HashMap<String, usize>,
pub by_tag: HashMap<String, usize>,
pub factory_count: usize,
}
impl AgentRegistry {
pub async fn stats(&self) -> RegistryStats {
let agents = self.agents.read().await;
let factories = self.factories.read().await;
let mut by_state: HashMap<String, usize> = HashMap::new();
let mut by_tag: HashMap<String, usize> = HashMap::new();
for entry in agents.values() {
let state_name = format!("{:?}", entry.metadata.state);
*by_state.entry(state_name).or_insert(0) += 1;
for tag in &entry.metadata.capabilities.tags {
*by_tag.entry(tag.clone()).or_insert(0) += 1;
}
}
RegistryStats {
total_agents: agents.len(),
by_state,
by_tag,
factory_count: factories.len(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::capabilities::AgentCapabilities;
use crate::agent::context::AgentContext;
use crate::agent::core::MoFAAgent;
use crate::agent::error::AgentResult;
use crate::agent::types::{AgentInput, AgentOutput, AgentState};
use async_trait::async_trait;
struct TestAgent {
id: String,
name: String,
capabilities: AgentCapabilities,
state: AgentState,
}
impl TestAgent {
fn new(id: &str, name: &str) -> Self {
Self {
id: id.to_string(),
name: name.to_string(),
capabilities: AgentCapabilities::default(),
state: AgentState::Created,
}
}
}
#[async_trait]
impl MoFAAgent for TestAgent {
fn id(&self) -> &str {
&self.id
}
fn name(&self) -> &str {
&self.name
}
fn capabilities(&self) -> &AgentCapabilities {
&self.capabilities
}
fn state(&self) -> AgentState {
self.state.clone()
}
async fn initialize(&mut self, _ctx: &AgentContext) -> AgentResult<()> {
self.state = AgentState::Ready;
Ok(())
}
async fn execute(
&mut self,
_input: AgentInput,
_ctx: &AgentContext,
) -> AgentResult<AgentOutput> {
Ok(AgentOutput::text("test output"))
}
async fn shutdown(&mut self) -> AgentResult<()> {
self.state = AgentState::Shutdown;
Ok(())
}
}
struct TestAgentFactory;
#[async_trait]
impl AgentFactory for TestAgentFactory {
async fn create(&self, config: AgentConfig) -> AgentResult<Arc<RwLock<dyn MoFAAgent>>> {
let agent = TestAgent::new(&config.id, &config.name);
Ok(Arc::new(RwLock::new(agent)))
}
fn type_id(&self) -> &str {
"test"
}
fn default_capabilities(&self) -> AgentCapabilities {
AgentCapabilities::builder().with_tag("test").build()
}
}
#[tokio::test]
async fn test_register_and_get() {
let registry = AgentRegistry::new();
let agent = Arc::new(RwLock::new(TestAgent::new("agent-1", "Test Agent")));
registry.register(agent).await.unwrap();
let found = registry.get("agent-1").await;
assert!(found.is_some());
let not_found = registry.get("nonexistent").await;
assert!(not_found.is_none());
}
#[tokio::test]
async fn test_factory_create() {
let registry = AgentRegistry::new();
registry
.register_factory(Arc::new(TestAgentFactory))
.await
.unwrap();
let config = AgentConfig::new("agent-2", "Created Agent");
let agent = registry.create("test", config).await.unwrap();
let agent_guard = agent.read().await;
assert_eq!(agent_guard.id(), "agent-2");
assert_eq!(agent_guard.name(), "Created Agent");
}
#[tokio::test]
async fn test_find_by_tag() {
let registry = AgentRegistry::new();
let mut agent1 = TestAgent::new("agent-1", "Agent 1");
agent1.capabilities = AgentCapabilities::builder()
.with_tag("llm")
.with_tag("chat")
.build();
let mut agent2 = TestAgent::new("agent-2", "Agent 2");
agent2.capabilities = AgentCapabilities::builder()
.with_tag("react")
.with_tag("chat")
.build();
registry
.register(Arc::new(RwLock::new(agent1)))
.await
.unwrap();
registry
.register(Arc::new(RwLock::new(agent2)))
.await
.unwrap();
let chat_agents = registry.find_by_tag("chat").await;
assert_eq!(chat_agents.len(), 2);
let llm_agents = registry.find_by_tag("llm").await;
assert_eq!(llm_agents.len(), 1);
}
#[tokio::test]
async fn test_unregister() {
let registry = AgentRegistry::new();
let agent = Arc::new(RwLock::new(TestAgent::new("agent-1", "Test Agent")));
registry.register(agent).await.unwrap();
assert!(registry.contains("agent-1").await);
registry.unregister("agent-1").await.unwrap();
assert!(!registry.contains("agent-1").await);
}
#[tokio::test]
async fn test_stats() {
let registry = AgentRegistry::new();
registry
.register_factory(Arc::new(TestAgentFactory))
.await
.unwrap();
let agent = Arc::new(RwLock::new(TestAgent::new("agent-1", "Test")));
registry.register(agent).await.unwrap();
let stats = registry.stats().await;
assert_eq!(stats.total_agents, 1);
assert_eq!(stats.factory_count, 1);
}
}