use super::session::ResearchSession;
use crate::utils::{DeepResearch, ResearchOptions};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
type RegistryKey = (String, u32);
#[derive(Clone)]
pub struct ResearchRegistry {
sessions: Arc<Mutex<HashMap<RegistryKey, Arc<ResearchSession>>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResearchListOutput {
pub connection_id: String,
pub sessions: Vec<SessionInfo>,
pub total: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionInfo {
pub session: u32,
pub query: String,
pub completed: bool,
pub results_count: usize,
}
impl ResearchRegistry {
pub fn new() -> Self {
Self {
sessions: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn create(
&self,
connection_id: &str,
session_id: u32,
research: DeepResearch,
query: String,
options: Option<ResearchOptions>,
) -> Arc<ResearchSession> {
let key = (connection_id.to_string(), session_id);
let mut sessions = self.sessions.lock().await;
if let Some(old) = sessions.remove(&key) {
let _ = old.kill().await;
}
let session = Arc::new(ResearchSession::new(research, query, options));
sessions.insert(key, session.clone());
session
}
pub async fn get(
&self,
connection_id: &str,
session_id: u32,
) -> Option<Arc<ResearchSession>> {
let key = (connection_id.to_string(), session_id);
let sessions = self.sessions.lock().await;
sessions.get(&key).cloned()
}
pub async fn remove(&self, connection_id: &str, session_id: u32) -> Option<Arc<ResearchSession>> {
let key = (connection_id.to_string(), session_id);
let mut sessions = self.sessions.lock().await;
sessions.remove(&key)
}
pub async fn list(&self, connection_id: &str) -> Result<ResearchListOutput> {
let sessions_map = self.sessions.lock().await;
let mut session_infos = Vec::new();
for ((conn_id, session_num), session) in sessions_map.iter() {
if conn_id == connection_id {
let completed = session.is_complete().await;
let results_count = session.results_count().await;
let output = session.read(*session_num).await;
session_infos.push(SessionInfo {
session: *session_num,
query: output.query,
completed,
results_count,
});
}
}
session_infos.sort_by_key(|s| s.session);
let total = session_infos.len();
Ok(ResearchListOutput {
connection_id: connection_id.to_string(),
sessions: session_infos,
total,
})
}
pub async fn cleanup_completed(&self, connection_id: &str) -> usize {
let mut sessions = self.sessions.lock().await;
let mut to_remove = Vec::new();
for ((conn_id, session_num), session) in sessions.iter() {
if conn_id == connection_id && session.is_complete().await {
to_remove.push((conn_id.clone(), *session_num));
}
}
let count = to_remove.len();
for key in to_remove {
sessions.remove(&key);
}
count
}
pub async fn cleanup_connection(&self, connection_id: &str) -> usize {
let mut sessions = self.sessions.lock().await;
let to_remove: Vec<RegistryKey> = sessions
.keys()
.filter(|(conn_id, _)| conn_id == connection_id)
.cloned()
.collect();
let count = to_remove.len();
for key in to_remove {
if let Some(session) = sessions.remove(&key) {
log::debug!(
"Cleaning up research session {} for connection {}",
key.1,
connection_id
);
drop(session);
}
}
count
}
}
impl Default for ResearchRegistry {
fn default() -> Self {
Self::new()
}
}