hehe_server/
state.rs

1use hehe_agent::{Agent, Session};
2use hehe_core::Id;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6
7#[derive(Clone)]
8pub struct AppState {
9    pub agent: Arc<Agent>,
10    sessions: Arc<RwLock<HashMap<Id, Session>>>,
11}
12
13impl AppState {
14    pub fn new(agent: Agent) -> Self {
15        Self {
16            agent: Arc::new(agent),
17            sessions: Arc::new(RwLock::new(HashMap::new())),
18        }
19    }
20
21    pub async fn get_or_create_session(&self, session_id: Option<Id>) -> Session {
22        match session_id {
23            Some(id) => {
24                let sessions = self.sessions.read().await;
25                if let Some(session) = sessions.get(&id) {
26                    return session.clone();
27                }
28                drop(sessions);
29
30                let session = Session::with_id(id.clone());
31                self.sessions.write().await.insert(id, session.clone());
32                session
33            }
34            None => {
35                let session = self.agent.create_session();
36                self.sessions
37                    .write()
38                    .await
39                    .insert(session.id().clone(), session.clone());
40                session
41            }
42        }
43    }
44
45    pub async fn get_session(&self, session_id: &Id) -> Option<Session> {
46        self.sessions.read().await.get(session_id).cloned()
47    }
48
49    pub async fn remove_session(&self, session_id: &Id) -> Option<Session> {
50        self.sessions.write().await.remove(session_id)
51    }
52
53    pub async fn session_count(&self) -> usize {
54        self.sessions.read().await.len()
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61    use async_trait::async_trait;
62    use hehe_agent::AgentConfig;
63    use hehe_core::capability::Capabilities;
64    use hehe_core::stream::StreamChunk;
65    use hehe_core::Message;
66    use hehe_llm::{BoxStream, CompletionRequest, CompletionResponse, LlmError, LlmProvider, ModelInfo};
67
68    struct MockLlm;
69
70    #[async_trait]
71    impl LlmProvider for MockLlm {
72        fn name(&self) -> &str { "mock" }
73        fn capabilities(&self) -> &Capabilities {
74            static CAPS: std::sync::OnceLock<Capabilities> = std::sync::OnceLock::new();
75            CAPS.get_or_init(Capabilities::text_basic)
76        }
77        async fn complete(&self, _: CompletionRequest) -> std::result::Result<CompletionResponse, LlmError> {
78            Ok(CompletionResponse::new("id", "mock", Message::assistant("Hi")))
79        }
80        async fn complete_stream(&self, _: CompletionRequest) -> std::result::Result<BoxStream<StreamChunk>, LlmError> {
81            use futures::stream;
82            Ok(Box::pin(stream::empty()))
83        }
84        async fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, LlmError> { Ok(vec![]) }
85        fn default_model(&self) -> &str { "mock" }
86    }
87
88    fn create_test_agent() -> Agent {
89        Agent::builder()
90            .system_prompt("Test")
91            .llm(Arc::new(MockLlm))
92            .build()
93            .unwrap()
94    }
95
96    #[tokio::test]
97    async fn test_app_state_create_session() {
98        let state = AppState::new(create_test_agent());
99        
100        let session = state.get_or_create_session(None).await;
101        assert_eq!(state.session_count().await, 1);
102
103        let session2 = state.get_or_create_session(Some(session.id().clone())).await;
104        assert_eq!(session.id(), session2.id());
105        assert_eq!(state.session_count().await, 1);
106    }
107
108    #[tokio::test]
109    async fn test_app_state_remove_session() {
110        let state = AppState::new(create_test_agent());
111        let session = state.get_or_create_session(None).await;
112        
113        assert!(state.remove_session(session.id()).await.is_some());
114        assert_eq!(state.session_count().await, 0);
115    }
116}