1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use tokio::sync::RwLock;
8
9use super::config::ServerConfig;
10use crate::codec::CodecEngine;
11use crate::inference::HydraModel;
12use crate::protocol::{Capabilities, Session};
13use crate::security::SecurityScanner;
14
15pub struct AppState {
17 pub config: ServerConfig,
19 pub sessions: SessionManager,
21 pub codec: CodecEngine,
23 pub scanner: SecurityScanner,
25 pub model: Option<HydraModel>,
27 pub start_time: Instant,
29}
30
31impl AppState {
32 pub fn new(config: ServerConfig) -> Self {
34 let scanner = if config.security_enabled {
35 if config.security_blocking {
36 SecurityScanner::new().with_blocking(config.block_threshold)
37 } else {
38 SecurityScanner::new()
39 }
40 } else {
41 SecurityScanner::new()
42 };
43
44 let model = config
45 .model_path
46 .as_ref()
47 .and_then(|path| HydraModel::load(path).ok());
48
49 Self {
50 config,
51 sessions: SessionManager::new(),
52 codec: CodecEngine::new(),
53 scanner,
54 model,
55 start_time: Instant::now(),
56 }
57 }
58
59 pub fn uptime(&self) -> Duration {
61 self.start_time.elapsed()
62 }
63
64 pub fn capabilities(&self) -> Capabilities {
66 let mut caps = Capabilities::new("m2m-server");
67
68 if self.config.security_enabled {
69 caps = caps.with_security(
70 crate::protocol::SecurityCaps::default()
71 .with_threat_detection(crate::security::SECURITY_VERSION),
72 );
73 }
74
75 if self.model.is_some() {
76 caps.compression = caps.compression.with_ml_routing();
77 }
78
79 caps
80 }
81}
82
83pub struct SessionManager {
85 sessions: Arc<RwLock<HashMap<String, SessionEntry>>>,
87 timeout: Duration,
89}
90
91struct SessionEntry {
93 session: Session,
95 last_access: Instant,
97}
98
99impl Default for SessionManager {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl SessionManager {
106 pub fn new() -> Self {
108 Self {
109 sessions: Arc::new(RwLock::new(HashMap::new())),
110 timeout: Duration::from_secs(300),
111 }
112 }
113
114 pub fn with_timeout(mut self, timeout: Duration) -> Self {
116 self.timeout = timeout;
117 self
118 }
119
120 pub async fn create(&self, capabilities: Capabilities) -> Session {
122 let session = Session::new(capabilities);
123 let id = session.id().to_string();
124
125 let entry = SessionEntry {
126 session: session.clone(),
127 last_access: Instant::now(),
128 };
129
130 self.sessions.write().await.insert(id, entry);
131 session
132 }
133
134 pub async fn get(&self, id: &str) -> Option<Session> {
136 let mut sessions = self.sessions.write().await;
137
138 if let Some(entry) = sessions.get_mut(id) {
139 if entry.last_access.elapsed() > self.timeout {
141 sessions.remove(id);
142 return None;
143 }
144
145 entry.last_access = Instant::now();
146 Some(entry.session.clone())
147 } else {
148 None
149 }
150 }
151
152 pub async fn update(&self, session: &Session) {
154 let mut sessions = self.sessions.write().await;
155
156 if let Some(entry) = sessions.get_mut(session.id()) {
157 entry.session = session.clone();
158 entry.last_access = Instant::now();
159 }
160 }
161
162 pub async fn remove(&self, id: &str) {
164 self.sessions.write().await.remove(id);
165 }
166
167 pub async fn count(&self) -> usize {
169 self.sessions.read().await.len()
170 }
171
172 pub async fn cleanup(&self) -> usize {
174 let mut sessions = self.sessions.write().await;
175 let before = sessions.len();
176
177 sessions.retain(|_, entry| entry.last_access.elapsed() < self.timeout);
178
179 before - sessions.len()
180 }
181
182 pub async fn list_ids(&self) -> Vec<String> {
184 self.sessions.read().await.keys().cloned().collect()
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[tokio::test]
193 async fn test_session_create_and_get() {
194 let manager = SessionManager::new();
195 let caps = Capabilities::default();
196
197 let session = manager.create(caps).await;
198 let id = session.id().to_string();
199
200 let retrieved = manager.get(&id).await;
201 assert!(retrieved.is_some());
202 assert_eq!(retrieved.unwrap().id(), id);
203 }
204
205 #[tokio::test]
206 async fn test_session_remove() {
207 let manager = SessionManager::new();
208 let caps = Capabilities::default();
209
210 let session = manager.create(caps).await;
211 let id = session.id().to_string();
212
213 manager.remove(&id).await;
214
215 let retrieved = manager.get(&id).await;
216 assert!(retrieved.is_none());
217 }
218
219 #[tokio::test]
220 async fn test_session_count() {
221 let manager = SessionManager::new();
222 let caps = Capabilities::default();
223
224 assert_eq!(manager.count().await, 0);
225
226 manager.create(caps.clone()).await;
227 manager.create(caps.clone()).await;
228 manager.create(caps).await;
229
230 assert_eq!(manager.count().await, 3);
231 }
232
233 #[tokio::test]
234 async fn test_session_expiry() {
235 let manager = SessionManager::new().with_timeout(Duration::from_millis(10));
236 let caps = Capabilities::default();
237
238 let session = manager.create(caps).await;
239 let id = session.id().to_string();
240
241 tokio::time::sleep(Duration::from_millis(20)).await;
243
244 let retrieved = manager.get(&id).await;
245 assert!(retrieved.is_none());
246 }
247}