1use std::time::{Duration, Instant};
7
8use super::capabilities::{Capabilities, NegotiatedCaps};
9use super::message::{Message, MessageType, RejectionCode};
10use super::SESSION_TIMEOUT_SECS;
11use crate::codec::{Algorithm, CodecEngine};
12use crate::error::{M2MError, Result};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum SessionState {
17 Initial,
19 HelloSent,
21 Established,
23 Closing,
25 Closed,
27}
28
29pub struct Session {
31 id: String,
33 state: SessionState,
35 local_caps: Capabilities,
37 remote_caps: Option<Capabilities>,
39 negotiated: Option<NegotiatedCaps>,
41 codec: CodecEngine,
43 created_at: Instant,
45 last_activity: Instant,
47 timeout: Duration,
49 messages_sent: u64,
51 messages_received: u64,
53 bytes_compressed: u64,
55 bytes_saved: u64,
57}
58
59impl Session {
60 pub fn new(capabilities: Capabilities) -> Self {
62 let now = Instant::now();
63 Self {
64 id: uuid::Uuid::new_v4().to_string(),
65 state: SessionState::Initial,
66 local_caps: capabilities,
67 remote_caps: None,
68 negotiated: None,
69 codec: CodecEngine::new(),
70 created_at: now,
71 last_activity: now,
72 timeout: Duration::from_secs(SESSION_TIMEOUT_SECS),
73 messages_sent: 0,
74 messages_received: 0,
75 bytes_compressed: 0,
76 bytes_saved: 0,
77 }
78 }
79
80 pub fn with_id(id: &str, capabilities: Capabilities) -> Self {
82 let mut session = Self::new(capabilities);
83 session.id = id.to_string();
84 session
85 }
86
87 pub fn id(&self) -> &str {
89 &self.id
90 }
91
92 pub fn state(&self) -> SessionState {
94 self.state
95 }
96
97 pub fn is_established(&self) -> bool {
99 self.state == SessionState::Established
100 }
101
102 pub fn is_expired(&self) -> bool {
104 self.last_activity.elapsed() > self.timeout
105 }
106
107 pub fn algorithm(&self) -> Option<Algorithm> {
109 self.negotiated.as_ref().map(|n| n.algorithm)
110 }
111
112 pub fn encoding(&self) -> Option<crate::models::Encoding> {
114 self.negotiated.as_ref().map(|n| n.encoding)
115 }
116
117 pub fn create_hello(&mut self) -> Message {
119 self.state = SessionState::HelloSent;
120 self.messages_sent += 1;
121 self.touch();
122 Message::hello(self.local_caps.clone())
123 }
124
125 pub fn process_hello(&mut self, hello: &Message) -> Result<Message> {
127 if self.state != SessionState::Initial {
128 return Err(M2MError::Protocol(format!(
129 "Cannot process HELLO in state {:?}",
130 self.state
131 )));
132 }
133
134 let remote_caps = hello
135 .get_capabilities()
136 .ok_or_else(|| M2MError::InvalidMessage("HELLO missing capabilities".to_string()))?;
137
138 self.messages_received += 1;
139 self.touch();
140
141 if !self.local_caps.is_compatible(remote_caps) {
143 return Ok(Message::reject(
144 RejectionCode::VersionMismatch,
145 &format!(
146 "Version {} not compatible with {}",
147 remote_caps.version, self.local_caps.version
148 ),
149 ));
150 }
151
152 match self.local_caps.negotiate(remote_caps) {
154 Some(negotiated) => {
155 self.remote_caps = Some(remote_caps.clone());
156 self.negotiated = Some(negotiated);
157 self.state = SessionState::Established;
158
159 if let Some(ref neg) = self.negotiated {
161 self.codec = self
162 .codec
163 .clone()
164 .with_ml_routing(neg.ml_routing)
165 .with_encoding(neg.encoding);
166 }
167
168 self.messages_sent += 1;
169 Ok(Message::accept(&self.id, self.local_caps.clone()))
170 },
171 None => Ok(Message::reject(
172 RejectionCode::NoCommonAlgorithm,
173 "No common compression algorithm",
174 )),
175 }
176 }
177
178 pub fn process_accept(&mut self, accept: &Message) -> Result<()> {
180 if self.state != SessionState::HelloSent {
181 return Err(M2MError::Protocol(format!(
182 "Cannot process ACCEPT in state {:?}",
183 self.state
184 )));
185 }
186
187 let remote_caps = accept
188 .get_capabilities()
189 .ok_or_else(|| M2MError::InvalidMessage("ACCEPT missing capabilities".to_string()))?;
190
191 let session_id = accept
192 .session_id
193 .as_ref()
194 .ok_or_else(|| M2MError::InvalidMessage("ACCEPT missing session ID".to_string()))?;
195
196 self.messages_received += 1;
197 self.touch();
198
199 self.id = session_id.clone();
201
202 match self.local_caps.negotiate(remote_caps) {
204 Some(negotiated) => {
205 self.remote_caps = Some(remote_caps.clone());
206 self.negotiated = Some(negotiated);
207 self.state = SessionState::Established;
208
209 if let Some(ref neg) = self.negotiated {
211 self.codec = self
212 .codec
213 .clone()
214 .with_ml_routing(neg.ml_routing)
215 .with_encoding(neg.encoding);
216 }
217
218 Ok(())
219 },
220 None => Err(M2MError::NegotiationFailed(
221 "Failed to negotiate capabilities".to_string(),
222 )),
223 }
224 }
225
226 pub fn process_reject(&mut self, reject: &Message) -> Result<()> {
228 self.messages_received += 1;
229 self.state = SessionState::Closed;
230
231 let rejection = reject.get_rejection();
232 let reason = rejection
233 .map(|r| format!("{:?}: {}", r.code, r.message))
234 .unwrap_or_else(|| "Unknown rejection".to_string());
235
236 Err(M2MError::NegotiationFailed(reason))
237 }
238
239 pub fn compress(&mut self, content: &str) -> Result<Message> {
241 if !self.is_established() {
242 return Err(M2MError::SessionNotEstablished);
243 }
244
245 if self.is_expired() {
246 return Err(M2MError::SessionExpired);
247 }
248
249 let algorithm = self.algorithm().unwrap_or(Algorithm::M2M);
250 let result = self.codec.compress(content, algorithm)?;
251
252 self.bytes_compressed += result.compressed_bytes as u64;
254 if result.original_bytes > result.compressed_bytes {
255 self.bytes_saved += (result.original_bytes - result.compressed_bytes) as u64;
256 }
257 self.messages_sent += 1;
258 self.touch();
259
260 Ok(Message::data(&self.id, algorithm, result.data))
261 }
262
263 pub fn decompress(&mut self, message: &Message) -> Result<String> {
265 if !self.is_established() {
266 return Err(M2MError::SessionNotEstablished);
267 }
268
269 if self.is_expired() {
270 return Err(M2MError::SessionExpired);
271 }
272
273 let data = message
274 .get_data()
275 .ok_or_else(|| M2MError::InvalidMessage("Not a DATA message".to_string()))?;
276
277 self.messages_received += 1;
278 self.touch();
279
280 self.codec.decompress(&data.content)
281 }
282
283 pub fn process_message(&mut self, message: &Message) -> Result<Option<Message>> {
285 self.touch();
286
287 match message.msg_type {
288 MessageType::Hello => {
289 let response = self.process_hello(message)?;
290 Ok(Some(response))
291 },
292 MessageType::Accept => {
293 self.process_accept(message)?;
294 Ok(None)
295 },
296 MessageType::Reject => {
297 self.process_reject(message)?;
298 Ok(None)
299 },
300 MessageType::Ping => {
301 self.messages_received += 1;
302 self.messages_sent += 1;
303 Ok(Some(Message::pong(&self.id)))
304 },
305 MessageType::Pong => {
306 self.messages_received += 1;
307 Ok(None)
308 },
309 MessageType::Close => {
310 self.messages_received += 1;
311 self.state = SessionState::Closed;
312 Ok(None)
313 },
314 MessageType::Data => {
315 Ok(None)
317 },
318 }
319 }
320
321 pub fn close(&mut self) -> Message {
323 self.state = SessionState::Closing;
324 self.messages_sent += 1;
325 Message::close(&self.id)
326 }
327
328 pub fn stats(&self) -> SessionStats {
330 SessionStats {
331 session_id: self.id.clone(),
332 state: self.state,
333 messages_sent: self.messages_sent,
334 messages_received: self.messages_received,
335 bytes_compressed: self.bytes_compressed,
336 bytes_saved: self.bytes_saved,
337 uptime_secs: self.created_at.elapsed().as_secs(),
338 }
339 }
340
341 fn touch(&mut self) {
343 self.last_activity = Instant::now();
344 }
345}
346
347impl Clone for Session {
348 fn clone(&self) -> Self {
349 let mut codec = CodecEngine::new();
351 if let Some(ref neg) = self.negotiated {
352 codec = codec
353 .with_ml_routing(neg.ml_routing)
354 .with_encoding(neg.encoding);
355 }
356
357 let now = Instant::now();
358 Self {
359 id: self.id.clone(),
360 state: self.state,
361 local_caps: self.local_caps.clone(),
362 remote_caps: self.remote_caps.clone(),
363 negotiated: self.negotiated.clone(),
364 codec,
365 created_at: now,
366 last_activity: now,
367 timeout: self.timeout,
368 messages_sent: 0,
371 messages_received: 0,
372 bytes_compressed: 0,
373 bytes_saved: 0,
374 }
375 }
376}
377
378#[derive(Debug, Clone)]
380pub struct SessionStats {
381 pub session_id: String,
383 pub state: SessionState,
385 pub messages_sent: u64,
387 pub messages_received: u64,
389 pub bytes_compressed: u64,
391 pub bytes_saved: u64,
393 pub uptime_secs: u64,
395}
396
397impl SessionStats {
398 pub fn compression_ratio(&self) -> f64 {
400 if self.bytes_compressed == 0 {
401 1.0
402 } else {
403 (self.bytes_compressed + self.bytes_saved) as f64 / self.bytes_compressed as f64
404 }
405 }
406
407 pub fn savings_percent(&self) -> f64 {
409 let total = self.bytes_compressed + self.bytes_saved;
410 if total == 0 {
411 0.0
412 } else {
413 self.bytes_saved as f64 / total as f64 * 100.0
414 }
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use crate::models::Encoding;
422 use crate::protocol::capabilities::CompressionCaps;
423
424 #[test]
425 fn test_session_handshake() {
426 let mut client = Session::new(Capabilities::default());
428 let hello = client.create_hello();
429 assert_eq!(client.state(), SessionState::HelloSent);
430
431 let mut server = Session::new(Capabilities::default());
433 let accept = server.process_hello(&hello).unwrap();
434 assert_eq!(server.state(), SessionState::Established);
435 assert_eq!(accept.msg_type, MessageType::Accept);
436
437 client.process_accept(&accept).unwrap();
439 assert_eq!(client.state(), SessionState::Established);
440 assert_eq!(client.id(), server.id()); }
442
443 #[test]
444 fn test_session_reject() {
445 let mut client = Session::new(Capabilities::new("client"));
446 let hello = client.create_hello();
447
448 let server_caps = Capabilities {
450 version: "4.0".to_string(),
451 ..Default::default()
452 };
453 let mut server = Session::new(server_caps);
454
455 let response = server.process_hello(&hello).unwrap();
456 assert_eq!(response.msg_type, MessageType::Reject);
457
458 let result = client.process_reject(&response);
460 assert!(result.is_err());
461 assert_eq!(client.state(), SessionState::Closed);
462 }
463
464 #[test]
465 fn test_session_data_exchange() {
466 let mut client = Session::new(Capabilities::default());
468 let mut server = Session::new(Capabilities::default());
469
470 let hello = client.create_hello();
471 let accept = server.process_hello(&hello).unwrap();
472 client.process_accept(&accept).unwrap();
473
474 let content = r#"{"model":"gpt-4o","messages":[{"role":"user","content":"Hello"}]}"#;
476 let data_msg = client.compress(content).unwrap();
477
478 let decompressed = server.decompress(&data_msg).unwrap();
480 let original: serde_json::Value = serde_json::from_str(content).unwrap();
481 let recovered: serde_json::Value = serde_json::from_str(&decompressed).unwrap();
482
483 assert_eq!(
484 original["messages"][0]["content"],
485 recovered["messages"][0]["content"]
486 );
487 }
488
489 #[test]
490 fn test_session_stats() {
491 let mut client = Session::new(Capabilities::default());
492 let mut server = Session::new(Capabilities::default());
493
494 let hello = client.create_hello();
495 let accept = server.process_hello(&hello).unwrap();
496 client.process_accept(&accept).unwrap();
497
498 for _ in 0..5 {
500 let _ = client.compress(r#"{"test":"data"}"#);
501 }
502
503 let stats = client.stats();
504 assert_eq!(stats.messages_sent, 6); assert!(stats.bytes_compressed > 0);
506 }
507
508 #[test]
509 fn test_encoding_negotiation() {
510 let client_caps = Capabilities::default().with_compression(
514 CompressionCaps::default().with_preferred_encoding(Encoding::O200kBase),
515 );
516 let mut client = Session::new(client_caps);
517
518 let server_caps = Capabilities::default().with_compression(
519 CompressionCaps::default()
520 .with_preferred_encoding(Encoding::Cl100kBase)
521 .with_encodings(vec![Encoding::Cl100kBase, Encoding::O200kBase]),
522 );
523 let mut server = Session::new(server_caps);
524
525 let hello = client.create_hello();
527 let accept = server.process_hello(&hello).unwrap();
528 client.process_accept(&accept).unwrap();
529
530 assert_eq!(server.encoding(), Some(Encoding::Cl100kBase));
532 assert_eq!(client.encoding(), Some(Encoding::O200kBase));
536 }
537
538 #[test]
539 fn test_encoding_negotiation_fallback() {
540 let client_caps = Capabilities::default().with_compression(
542 CompressionCaps::default()
543 .with_encodings(vec![Encoding::O200kBase])
544 .with_preferred_encoding(Encoding::O200kBase),
545 );
546 let mut client = Session::new(client_caps);
547
548 let server_caps = Capabilities::default().with_compression(
550 CompressionCaps::default()
551 .with_preferred_encoding(Encoding::Cl100kBase)
552 .with_encodings(vec![Encoding::Cl100kBase, Encoding::O200kBase]),
553 );
554 let mut server = Session::new(server_caps);
555
556 let hello = client.create_hello();
558 let accept = server.process_hello(&hello).unwrap();
559 client.process_accept(&accept).unwrap();
560
561 assert_eq!(server.encoding(), Some(Encoding::O200kBase));
563 assert_eq!(client.encoding(), Some(Encoding::O200kBase));
564 }
565
566 #[test]
567 fn test_token_native_algorithm_negotiated() {
568 let mut client = Session::new(Capabilities::default());
569 let mut server = Session::new(Capabilities::default());
570
571 let hello = client.create_hello();
572 let accept = server.process_hello(&hello).unwrap();
573 client.process_accept(&accept).unwrap();
574
575 assert_eq!(client.algorithm(), Some(Algorithm::M2M));
577 assert_eq!(server.algorithm(), Some(Algorithm::M2M));
578
579 assert_eq!(client.encoding(), Some(Encoding::Cl100kBase));
581 assert_eq!(server.encoding(), Some(Encoding::Cl100kBase));
582 }
583
584 #[test]
585 fn test_session_clone_preserves_encoding() {
586 let mut client = Session::new(Capabilities::default());
587 let mut server = Session::new(Capabilities::default());
588
589 let hello = client.create_hello();
590 let accept = server.process_hello(&hello).unwrap();
591 client.process_accept(&accept).unwrap();
592
593 let cloned = client.clone();
595
596 assert_eq!(cloned.algorithm(), client.algorithm());
598 assert_eq!(cloned.encoding(), client.encoding());
599 }
600}