agent-discord-rs 0.2.7

A high-performance Discord Bot daemon supporting multiple AI agents (pi, opencode).
use crate::agent::{AgentType, AiAgent, CopilotAgent, KiloAgent, OpencodeAgent, PiAgent};
use crate::config::Config;
use crate::migrate;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;

pub struct SessionManager {
    sessions: Arc<RwLock<HashMap<u64, Arc<dyn AiAgent>>>>,
    config: Arc<Config>,
}

impl SessionManager {
    pub fn new(config: Arc<Config>) -> Self {
        Self {
            sessions: Arc::new(RwLock::new(HashMap::new())),
            config,
        }
    }

    pub async fn get_or_create_session(
        &self,
        channel_id: u64,
        agent_type: AgentType,
        backend_manager: &crate::agent::manager::BackendManager,
    ) -> anyhow::Result<(Arc<dyn AiAgent>, bool)> {
        {
            let sessions = self.sessions.read().await;
            if let Some(session) = sessions.get(&channel_id) {
                if session.agent_type() == agent_type.to_string() {
                    return Ok((session.clone(), false));
                }
            }
        }

        let channel_id_str = channel_id.to_string();
        let channel_config = crate::commands::agent::ChannelConfig::load()
            .await
            .unwrap_or_default();
        let entry = channel_config.channels.get(&channel_id_str);

        let model_opt = entry.and_then(|e| {
            if let (Some(p), Some(m)) = (&e.model_provider, &e.model_id) {
                Some((p.clone(), m.clone()))
            } else {
                None
            }
        });

        let existing_sid = entry.and_then(|e| e.session_id.clone());

        let session: Arc<dyn AiAgent> = match agent_type {
            AgentType::Pi => {
                let session_dir = migrate::get_sessions_dir("pi");
                std::fs::create_dir_all(&session_dir)?;
                let (pi_agent, _) = PiAgent::new(channel_id, &session_dir).await?;
                pi_agent
            }
            AgentType::Opencode => {
                let port = backend_manager.ensure_backend(&AgentType::Opencode).await?;
                let api_url = format!("http://127.0.0.1:{}", port);
                let api_key = self.config.opencode.password.clone().unwrap_or_default();

                let agent = OpencodeAgent::new(
                    channel_id,
                    api_url,
                    api_key,
                    existing_sid,
                    model_opt,
                    "opencode",
                )
                .await?;

                self.persist_sid(channel_id, AgentType::Opencode, agent.session_id.clone())
                    .await?;
                agent
            }
            AgentType::Copilot => {
                let agent = CopilotAgent::new(channel_id, existing_sid, model_opt).await?;
                self.persist_sid(channel_id, AgentType::Copilot, agent.session_id())
                    .await?;
                agent
            }
            AgentType::Kilo => {
                let port = backend_manager.ensure_backend(&AgentType::Kilo).await?;
                let api_url = format!("http://127.0.0.1:{}", port);

                let agent = KiloAgent::new(channel_id, api_url, existing_sid, model_opt).await?;

                self.persist_sid(channel_id, AgentType::Kilo, agent.session_id())
                    .await?;
                agent
            }
        };

        {
            let mut sessions = self.sessions.write().await;
            sessions.insert(channel_id, session.clone());
        }

        let is_brand_new = if let Ok(state) = session.get_state().await {
            state.message_count == 0
        } else {
            true
        };

        Ok((session, is_brand_new))
    }

    fn apply_sid(
        channel_config: &mut crate::commands::agent::ChannelConfig,
        channel_id: &str,
        agent_type: AgentType,
        sid: String,
    ) {
        let entry = channel_config
            .channels
            .entry(channel_id.to_string())
            .or_insert_with(|| crate::commands::agent::ChannelEntry {
                agent_type: agent_type.clone(),
                authorized_at: chrono::Utc::now().to_rfc3339(),
                mention_only: true,
                session_id: None,
                model_provider: None,
                model_id: None,
                assistant_name: None,
            });

        entry.session_id = Some(sid);
    }

    async fn persist_sid(
        &self,
        channel_id: u64,
        agent_type: AgentType,
        sid: String,
    ) -> anyhow::Result<()> {
        let channel_id_str = channel_id.to_string();
        let mut channel_config = crate::commands::agent::ChannelConfig::load()
            .await
            .unwrap_or_default();

        Self::apply_sid(&mut channel_config, &channel_id_str, agent_type, sid);
        channel_config.save().await?;
        Ok(())
    }

    pub async fn remove_session(&self, channel_id: u64) {
        let mut sessions = self.sessions.write().await;
        sessions.remove(&channel_id);
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::agent::{AiAgent, MockAgent};

    #[tokio::test]
    async fn test_remove_session_clears_cached_agent() {
        let config = Arc::new(Config::default());
        let manager = SessionManager::new(config);
        let channel_id = 42_u64;
        let mock_agent: Arc<dyn AiAgent> = Arc::new(MockAgent::new());

        {
            let mut sessions = manager.sessions.write().await;
            sessions.insert(channel_id, mock_agent);
            assert!(sessions.contains_key(&channel_id));
        }

        manager.remove_session(channel_id).await;

        let sessions = manager.sessions.read().await;
        assert!(!sessions.contains_key(&channel_id));
    }

    #[test]
    fn test_apply_sid_creates_channel_entry_when_missing() {
        let mut cfg = crate::commands::agent::ChannelConfig::default();
        SessionManager::apply_sid(&mut cfg, "1001", AgentType::Copilot, "sid-1".to_string());
        let entry = cfg.channels.get("1001").expect("entry exists");
        assert_eq!(entry.agent_type, AgentType::Copilot);
        assert_eq!(entry.session_id.as_deref(), Some("sid-1"));
        assert!(entry.mention_only);
        assert!(!entry.authorized_at.is_empty());
    }

    #[test]
    fn test_apply_sid_overwrites_existing_sid_only() {
        let mut cfg = crate::commands::agent::ChannelConfig::default();
        cfg.channels.insert(
            "1002".to_string(),
            crate::commands::agent::ChannelEntry {
                agent_type: AgentType::Pi,
                authorized_at: "2026-01-01T00:00:00Z".to_string(),
                mention_only: false,
                session_id: Some("old".to_string()),
                model_provider: Some("p".to_string()),
                model_id: Some("m".to_string()),
                assistant_name: Some("a".to_string()),
            },
        );
        SessionManager::apply_sid(&mut cfg, "1002", AgentType::Kilo, "new-sid".to_string());
        let entry = cfg.channels.get("1002").expect("entry exists");
        assert_eq!(entry.session_id.as_deref(), Some("new-sid"));
        assert_eq!(entry.agent_type, AgentType::Pi);
        assert_eq!(entry.model_provider.as_deref(), Some("p"));
        assert_eq!(entry.model_id.as_deref(), Some("m"));
        assert_eq!(entry.assistant_name.as_deref(), Some("a"));
    }
}