use super::core::Agent;
use super::session::AgentSession;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
type RegistryKey = (String, u32);
#[derive(Clone)]
pub struct AgentRegistry {
sessions: Arc<Mutex<HashMap<RegistryKey, Arc<AgentSession>>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentInfo {
pub agent: u32,
pub task: String,
pub completed: bool,
pub has_error: bool,
pub step_count: usize,
}
impl AgentRegistry {
pub fn new() -> Self {
Self {
sessions: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn find_or_create(
&self,
connection_id: &str,
agent_id: u32,
agent: Agent,
task: String,
max_steps: usize,
) -> Result<Arc<AgentSession>> {
let key = (connection_id.to_string(), agent_id);
let mut sessions = self.sessions.lock().await;
if let Some(session) = sessions.get(&key) {
return Ok(session.clone());
}
let session = Arc::new(AgentSession::new(agent, task, max_steps));
sessions.insert(key, session.clone());
Ok(session)
}
pub async fn get(
&self,
connection_id: &str,
agent_id: u32,
) -> Option<Arc<AgentSession>> {
let key = (connection_id.to_string(), agent_id);
let sessions = self.sessions.lock().await;
sessions.get(&key).cloned()
}
pub async fn remove(&self, connection_id: &str, agent_id: u32) -> Option<Arc<AgentSession>> {
let key = (connection_id.to_string(), agent_id);
let mut sessions = self.sessions.lock().await;
sessions.remove(&key)
}
pub async fn list(&self, connection_id: &str) -> Result<Vec<AgentInfo>> {
let sessions_map = self.sessions.lock().await;
let mut agent_infos = Vec::new();
for ((conn_id, agent_num), session) in sessions_map.iter() {
if conn_id == connection_id {
let completed = session.is_complete().await;
let has_error = session.has_error().await;
let step_count = session.step_count().await;
let output = session.read(*agent_num).await;
agent_infos.push(AgentInfo {
agent: *agent_num,
task: output.task,
completed,
has_error,
step_count,
});
}
}
agent_infos.sort_by_key(|a| a.agent);
Ok(agent_infos)
}
pub async fn cleanup_completed(&self, connection_id: &str) -> usize {
let mut sessions = self.sessions.lock().await;
let mut to_remove = Vec::new();
for ((conn_id, agent_num), session) in sessions.iter() {
if conn_id == connection_id && session.is_complete().await {
to_remove.push((conn_id.clone(), *agent_num));
}
}
let count = to_remove.len();
for key in to_remove {
sessions.remove(&key);
}
count
}
pub async fn cleanup_connection(&self, connection_id: &str) -> usize {
let mut sessions = self.sessions.lock().await;
let to_remove: Vec<RegistryKey> = sessions
.keys()
.filter(|(conn_id, _)| conn_id == connection_id)
.cloned()
.collect();
let count = to_remove.len();
for key in to_remove {
if let Some(session) = sessions.remove(&key) {
log::debug!(
"Cleaning up agent session {} for connection {}",
key.1,
connection_id
);
drop(session);
}
}
count
}
}
impl Default for AgentRegistry {
fn default() -> Self {
Self::new()
}
}