1use hehe_agent::{Agent, Session};
2use hehe_core::Id;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6
7#[derive(Clone)]
8pub struct AppState {
9 pub agent: Arc<Agent>,
10 sessions: Arc<RwLock<HashMap<Id, Session>>>,
11}
12
13impl AppState {
14 pub fn new(agent: Agent) -> Self {
15 Self {
16 agent: Arc::new(agent),
17 sessions: Arc::new(RwLock::new(HashMap::new())),
18 }
19 }
20
21 pub async fn get_or_create_session(&self, session_id: Option<Id>) -> Session {
22 match session_id {
23 Some(id) => {
24 let sessions = self.sessions.read().await;
25 if let Some(session) = sessions.get(&id) {
26 return session.clone();
27 }
28 drop(sessions);
29
30 let session = Session::with_id(id.clone());
31 self.sessions.write().await.insert(id, session.clone());
32 session
33 }
34 None => {
35 let session = self.agent.create_session();
36 self.sessions
37 .write()
38 .await
39 .insert(session.id().clone(), session.clone());
40 session
41 }
42 }
43 }
44
45 pub async fn get_session(&self, session_id: &Id) -> Option<Session> {
46 self.sessions.read().await.get(session_id).cloned()
47 }
48
49 pub async fn remove_session(&self, session_id: &Id) -> Option<Session> {
50 self.sessions.write().await.remove(session_id)
51 }
52
53 pub async fn session_count(&self) -> usize {
54 self.sessions.read().await.len()
55 }
56}
57
58#[cfg(test)]
59mod tests {
60 use super::*;
61 use async_trait::async_trait;
62 use hehe_agent::AgentConfig;
63 use hehe_core::capability::Capabilities;
64 use hehe_core::stream::StreamChunk;
65 use hehe_core::Message;
66 use hehe_llm::{BoxStream, CompletionRequest, CompletionResponse, LlmError, LlmProvider, ModelInfo};
67
68 struct MockLlm;
69
70 #[async_trait]
71 impl LlmProvider for MockLlm {
72 fn name(&self) -> &str { "mock" }
73 fn capabilities(&self) -> &Capabilities {
74 static CAPS: std::sync::OnceLock<Capabilities> = std::sync::OnceLock::new();
75 CAPS.get_or_init(Capabilities::text_basic)
76 }
77 async fn complete(&self, _: CompletionRequest) -> std::result::Result<CompletionResponse, LlmError> {
78 Ok(CompletionResponse::new("id", "mock", Message::assistant("Hi")))
79 }
80 async fn complete_stream(&self, _: CompletionRequest) -> std::result::Result<BoxStream<StreamChunk>, LlmError> {
81 use futures::stream;
82 Ok(Box::pin(stream::empty()))
83 }
84 async fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, LlmError> { Ok(vec![]) }
85 fn default_model(&self) -> &str { "mock" }
86 }
87
88 fn create_test_agent() -> Agent {
89 Agent::builder()
90 .system_prompt("Test")
91 .llm(Arc::new(MockLlm))
92 .build()
93 .unwrap()
94 }
95
96 #[tokio::test]
97 async fn test_app_state_create_session() {
98 let state = AppState::new(create_test_agent());
99
100 let session = state.get_or_create_session(None).await;
101 assert_eq!(state.session_count().await, 1);
102
103 let session2 = state.get_or_create_session(Some(session.id().clone())).await;
104 assert_eq!(session.id(), session2.id());
105 assert_eq!(state.session_count().await, 1);
106 }
107
108 #[tokio::test]
109 async fn test_app_state_remove_session() {
110 let state = AppState::new(create_test_agent());
111 let session = state.get_or_create_session(None).await;
112
113 assert!(state.remove_session(session.id()).await.is_some());
114 assert_eq!(state.session_count().await, 0);
115 }
116}