m2m/server/
state.rs

1//! Server state and session management.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use tokio::sync::RwLock;
8
9use super::config::ServerConfig;
10use crate::codec::CodecEngine;
11use crate::inference::HydraModel;
12use crate::protocol::{Capabilities, Session};
13use crate::security::SecurityScanner;
14
15/// Application state shared across handlers
16pub struct AppState {
17    /// Server configuration
18    pub config: ServerConfig,
19    /// Session manager
20    pub sessions: SessionManager,
21    /// Codec engine
22    pub codec: CodecEngine,
23    /// Security scanner
24    pub scanner: SecurityScanner,
25    /// Hydra model (optional)
26    pub model: Option<HydraModel>,
27    /// Server start time
28    pub start_time: Instant,
29}
30
31impl AppState {
32    /// Create new application state
33    pub fn new(config: ServerConfig) -> Self {
34        let scanner = if config.security_enabled {
35            if config.security_blocking {
36                SecurityScanner::new().with_blocking(config.block_threshold)
37            } else {
38                SecurityScanner::new()
39            }
40        } else {
41            SecurityScanner::new()
42        };
43
44        let model = config
45            .model_path
46            .as_ref()
47            .and_then(|path| HydraModel::load(path).ok());
48
49        Self {
50            config,
51            sessions: SessionManager::new(),
52            codec: CodecEngine::new(),
53            scanner,
54            model,
55            start_time: Instant::now(),
56        }
57    }
58
59    /// Get server uptime
60    pub fn uptime(&self) -> Duration {
61        self.start_time.elapsed()
62    }
63
64    /// Get server capabilities
65    pub fn capabilities(&self) -> Capabilities {
66        let mut caps = Capabilities::new("m2m-server");
67
68        if self.config.security_enabled {
69            caps = caps.with_security(
70                crate::protocol::SecurityCaps::default()
71                    .with_threat_detection(crate::security::SECURITY_VERSION),
72            );
73        }
74
75        if self.model.is_some() {
76            caps.compression = caps.compression.with_ml_routing();
77        }
78
79        caps
80    }
81}
82
83/// Manages active sessions
84pub struct SessionManager {
85    /// Active sessions by ID
86    sessions: Arc<RwLock<HashMap<String, SessionEntry>>>,
87    /// Session timeout
88    timeout: Duration,
89}
90
91/// Session entry with metadata
92struct SessionEntry {
93    /// The session
94    session: Session,
95    /// Last access time
96    last_access: Instant,
97}
98
99impl Default for SessionManager {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl SessionManager {
106    /// Create new session manager
107    pub fn new() -> Self {
108        Self {
109            sessions: Arc::new(RwLock::new(HashMap::new())),
110            timeout: Duration::from_secs(300),
111        }
112    }
113
114    /// Set session timeout
115    pub fn with_timeout(mut self, timeout: Duration) -> Self {
116        self.timeout = timeout;
117        self
118    }
119
120    /// Create a new session
121    pub async fn create(&self, capabilities: Capabilities) -> Session {
122        let session = Session::new(capabilities);
123        let id = session.id().to_string();
124
125        let entry = SessionEntry {
126            session: session.clone(),
127            last_access: Instant::now(),
128        };
129
130        self.sessions.write().await.insert(id, entry);
131        session
132    }
133
134    /// Get session by ID
135    pub async fn get(&self, id: &str) -> Option<Session> {
136        let mut sessions = self.sessions.write().await;
137
138        if let Some(entry) = sessions.get_mut(id) {
139            // Check expiry
140            if entry.last_access.elapsed() > self.timeout {
141                sessions.remove(id);
142                return None;
143            }
144
145            entry.last_access = Instant::now();
146            Some(entry.session.clone())
147        } else {
148            None
149        }
150    }
151
152    /// Update session
153    pub async fn update(&self, session: &Session) {
154        let mut sessions = self.sessions.write().await;
155
156        if let Some(entry) = sessions.get_mut(session.id()) {
157            entry.session = session.clone();
158            entry.last_access = Instant::now();
159        }
160    }
161
162    /// Remove session
163    pub async fn remove(&self, id: &str) {
164        self.sessions.write().await.remove(id);
165    }
166
167    /// Get session count
168    pub async fn count(&self) -> usize {
169        self.sessions.read().await.len()
170    }
171
172    /// Clean up expired sessions
173    pub async fn cleanup(&self) -> usize {
174        let mut sessions = self.sessions.write().await;
175        let before = sessions.len();
176
177        sessions.retain(|_, entry| entry.last_access.elapsed() < self.timeout);
178
179        before - sessions.len()
180    }
181
182    /// Get all session IDs
183    pub async fn list_ids(&self) -> Vec<String> {
184        self.sessions.read().await.keys().cloned().collect()
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[tokio::test]
193    async fn test_session_create_and_get() {
194        let manager = SessionManager::new();
195        let caps = Capabilities::default();
196
197        let session = manager.create(caps).await;
198        let id = session.id().to_string();
199
200        let retrieved = manager.get(&id).await;
201        assert!(retrieved.is_some());
202        assert_eq!(retrieved.unwrap().id(), id);
203    }
204
205    #[tokio::test]
206    async fn test_session_remove() {
207        let manager = SessionManager::new();
208        let caps = Capabilities::default();
209
210        let session = manager.create(caps).await;
211        let id = session.id().to_string();
212
213        manager.remove(&id).await;
214
215        let retrieved = manager.get(&id).await;
216        assert!(retrieved.is_none());
217    }
218
219    #[tokio::test]
220    async fn test_session_count() {
221        let manager = SessionManager::new();
222        let caps = Capabilities::default();
223
224        assert_eq!(manager.count().await, 0);
225
226        manager.create(caps.clone()).await;
227        manager.create(caps.clone()).await;
228        manager.create(caps).await;
229
230        assert_eq!(manager.count().await, 3);
231    }
232
233    #[tokio::test]
234    async fn test_session_expiry() {
235        let manager = SessionManager::new().with_timeout(Duration::from_millis(10));
236        let caps = Capabilities::default();
237
238        let session = manager.create(caps).await;
239        let id = session.id().to_string();
240
241        // Wait for expiry
242        tokio::time::sleep(Duration::from_millis(20)).await;
243
244        let retrieved = manager.get(&id).await;
245        assert!(retrieved.is_none());
246    }
247}