m2m/protocol/
session.rs

1//! Session management for M2M protocol.
2//!
3//! Handles the lifecycle of agent-to-agent sessions including
4//! handshake, data exchange, and termination.
5
6use std::time::{Duration, Instant};
7
8use super::capabilities::{Capabilities, NegotiatedCaps};
9use super::message::{Message, MessageType, RejectionCode};
10use super::SESSION_TIMEOUT_SECS;
11use crate::codec::{Algorithm, CodecEngine};
12use crate::error::{M2MError, Result};
13
14/// Session state machine
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum SessionState {
17    /// Initial state, no handshake yet
18    Initial,
19    /// HELLO sent, waiting for ACCEPT/REJECT
20    HelloSent,
21    /// Session established, ready for data
22    Established,
23    /// Session closing
24    Closing,
25    /// Session closed
26    Closed,
27}
28
29/// M2M protocol session
30pub struct Session {
31    /// Session ID
32    id: String,
33    /// Current state
34    state: SessionState,
35    /// Local capabilities
36    local_caps: Capabilities,
37    /// Remote capabilities (after handshake)
38    remote_caps: Option<Capabilities>,
39    /// Negotiated capabilities
40    negotiated: Option<NegotiatedCaps>,
41    /// Codec engine
42    codec: CodecEngine,
43    /// Session creation timestamp
44    created_at: Instant,
45    /// Last activity timestamp
46    last_activity: Instant,
47    /// Session timeout duration
48    timeout: Duration,
49    /// Messages sent
50    messages_sent: u64,
51    /// Messages received
52    messages_received: u64,
53    /// Bytes compressed
54    bytes_compressed: u64,
55    /// Bytes saved
56    bytes_saved: u64,
57}
58
59impl Session {
60    /// Create new session with capabilities
61    pub fn new(capabilities: Capabilities) -> Self {
62        let now = Instant::now();
63        Self {
64            id: uuid::Uuid::new_v4().to_string(),
65            state: SessionState::Initial,
66            local_caps: capabilities,
67            remote_caps: None,
68            negotiated: None,
69            codec: CodecEngine::new(),
70            created_at: now,
71            last_activity: now,
72            timeout: Duration::from_secs(SESSION_TIMEOUT_SECS),
73            messages_sent: 0,
74            messages_received: 0,
75            bytes_compressed: 0,
76            bytes_saved: 0,
77        }
78    }
79
80    /// Create session with existing ID (for server-side)
81    pub fn with_id(id: &str, capabilities: Capabilities) -> Self {
82        let mut session = Self::new(capabilities);
83        session.id = id.to_string();
84        session
85    }
86
87    /// Get session ID
88    pub fn id(&self) -> &str {
89        &self.id
90    }
91
92    /// Get current state
93    pub fn state(&self) -> SessionState {
94        self.state
95    }
96
97    /// Check if session is established
98    pub fn is_established(&self) -> bool {
99        self.state == SessionState::Established
100    }
101
102    /// Check if session is expired
103    pub fn is_expired(&self) -> bool {
104        self.last_activity.elapsed() > self.timeout
105    }
106
107    /// Get negotiated algorithm
108    pub fn algorithm(&self) -> Option<Algorithm> {
109        self.negotiated.as_ref().map(|n| n.algorithm)
110    }
111
112    /// Get negotiated encoding (for TokenNative compression)
113    pub fn encoding(&self) -> Option<crate::models::Encoding> {
114        self.negotiated.as_ref().map(|n| n.encoding)
115    }
116
117    /// Create HELLO message to initiate handshake
118    pub fn create_hello(&mut self) -> Message {
119        self.state = SessionState::HelloSent;
120        self.messages_sent += 1;
121        self.touch();
122        Message::hello(self.local_caps.clone())
123    }
124
125    /// Process incoming HELLO and create ACCEPT/REJECT response
126    pub fn process_hello(&mut self, hello: &Message) -> Result<Message> {
127        if self.state != SessionState::Initial {
128            return Err(M2MError::Protocol(format!(
129                "Cannot process HELLO in state {:?}",
130                self.state
131            )));
132        }
133
134        let remote_caps = hello
135            .get_capabilities()
136            .ok_or_else(|| M2MError::InvalidMessage("HELLO missing capabilities".to_string()))?;
137
138        self.messages_received += 1;
139        self.touch();
140
141        // Check version compatibility
142        if !self.local_caps.is_compatible(remote_caps) {
143            return Ok(Message::reject(
144                RejectionCode::VersionMismatch,
145                &format!(
146                    "Version {} not compatible with {}",
147                    remote_caps.version, self.local_caps.version
148                ),
149            ));
150        }
151
152        // Negotiate capabilities
153        match self.local_caps.negotiate(remote_caps) {
154            Some(negotiated) => {
155                self.remote_caps = Some(remote_caps.clone());
156                self.negotiated = Some(negotiated);
157                self.state = SessionState::Established;
158
159                // Configure codec based on negotiated caps
160                if let Some(ref neg) = self.negotiated {
161                    self.codec = self
162                        .codec
163                        .clone()
164                        .with_ml_routing(neg.ml_routing)
165                        .with_encoding(neg.encoding);
166                }
167
168                self.messages_sent += 1;
169                Ok(Message::accept(&self.id, self.local_caps.clone()))
170            },
171            None => Ok(Message::reject(
172                RejectionCode::NoCommonAlgorithm,
173                "No common compression algorithm",
174            )),
175        }
176    }
177
178    /// Process incoming ACCEPT message
179    pub fn process_accept(&mut self, accept: &Message) -> Result<()> {
180        if self.state != SessionState::HelloSent {
181            return Err(M2MError::Protocol(format!(
182                "Cannot process ACCEPT in state {:?}",
183                self.state
184            )));
185        }
186
187        let remote_caps = accept
188            .get_capabilities()
189            .ok_or_else(|| M2MError::InvalidMessage("ACCEPT missing capabilities".to_string()))?;
190
191        let session_id = accept
192            .session_id
193            .as_ref()
194            .ok_or_else(|| M2MError::InvalidMessage("ACCEPT missing session ID".to_string()))?;
195
196        self.messages_received += 1;
197        self.touch();
198
199        // Update session ID from server
200        self.id = session_id.clone();
201
202        // Negotiate and store
203        match self.local_caps.negotiate(remote_caps) {
204            Some(negotiated) => {
205                self.remote_caps = Some(remote_caps.clone());
206                self.negotiated = Some(negotiated);
207                self.state = SessionState::Established;
208
209                // Configure codec
210                if let Some(ref neg) = self.negotiated {
211                    self.codec = self
212                        .codec
213                        .clone()
214                        .with_ml_routing(neg.ml_routing)
215                        .with_encoding(neg.encoding);
216                }
217
218                Ok(())
219            },
220            None => Err(M2MError::NegotiationFailed(
221                "Failed to negotiate capabilities".to_string(),
222            )),
223        }
224    }
225
226    /// Process incoming REJECT message
227    pub fn process_reject(&mut self, reject: &Message) -> Result<()> {
228        self.messages_received += 1;
229        self.state = SessionState::Closed;
230
231        let rejection = reject.get_rejection();
232        let reason = rejection
233            .map(|r| format!("{:?}: {}", r.code, r.message))
234            .unwrap_or_else(|| "Unknown rejection".to_string());
235
236        Err(M2MError::NegotiationFailed(reason))
237    }
238
239    /// Compress and create DATA message
240    pub fn compress(&mut self, content: &str) -> Result<Message> {
241        if !self.is_established() {
242            return Err(M2MError::SessionNotEstablished);
243        }
244
245        if self.is_expired() {
246            return Err(M2MError::SessionExpired);
247        }
248
249        let algorithm = self.algorithm().unwrap_or(Algorithm::M2M);
250        let result = self.codec.compress(content, algorithm)?;
251
252        // Update stats
253        self.bytes_compressed += result.compressed_bytes as u64;
254        if result.original_bytes > result.compressed_bytes {
255            self.bytes_saved += (result.original_bytes - result.compressed_bytes) as u64;
256        }
257        self.messages_sent += 1;
258        self.touch();
259
260        Ok(Message::data(&self.id, algorithm, result.data))
261    }
262
263    /// Decompress DATA message content
264    pub fn decompress(&mut self, message: &Message) -> Result<String> {
265        if !self.is_established() {
266            return Err(M2MError::SessionNotEstablished);
267        }
268
269        if self.is_expired() {
270            return Err(M2MError::SessionExpired);
271        }
272
273        let data = message
274            .get_data()
275            .ok_or_else(|| M2MError::InvalidMessage("Not a DATA message".to_string()))?;
276
277        self.messages_received += 1;
278        self.touch();
279
280        self.codec.decompress(&data.content)
281    }
282
283    /// Process any incoming message
284    pub fn process_message(&mut self, message: &Message) -> Result<Option<Message>> {
285        self.touch();
286
287        match message.msg_type {
288            MessageType::Hello => {
289                let response = self.process_hello(message)?;
290                Ok(Some(response))
291            },
292            MessageType::Accept => {
293                self.process_accept(message)?;
294                Ok(None)
295            },
296            MessageType::Reject => {
297                self.process_reject(message)?;
298                Ok(None)
299            },
300            MessageType::Ping => {
301                self.messages_received += 1;
302                self.messages_sent += 1;
303                Ok(Some(Message::pong(&self.id)))
304            },
305            MessageType::Pong => {
306                self.messages_received += 1;
307                Ok(None)
308            },
309            MessageType::Close => {
310                self.messages_received += 1;
311                self.state = SessionState::Closed;
312                Ok(None)
313            },
314            MessageType::Data => {
315                // Data messages are processed via decompress()
316                Ok(None)
317            },
318        }
319    }
320
321    /// Close the session
322    pub fn close(&mut self) -> Message {
323        self.state = SessionState::Closing;
324        self.messages_sent += 1;
325        Message::close(&self.id)
326    }
327
328    /// Get session statistics
329    pub fn stats(&self) -> SessionStats {
330        SessionStats {
331            session_id: self.id.clone(),
332            state: self.state,
333            messages_sent: self.messages_sent,
334            messages_received: self.messages_received,
335            bytes_compressed: self.bytes_compressed,
336            bytes_saved: self.bytes_saved,
337            uptime_secs: self.created_at.elapsed().as_secs(),
338        }
339    }
340
341    /// Update last activity timestamp
342    fn touch(&mut self) {
343        self.last_activity = Instant::now();
344    }
345}
346
347impl Clone for Session {
348    fn clone(&self) -> Self {
349        // Preserve ML routing and encoding configuration from negotiated capabilities
350        let mut codec = CodecEngine::new();
351        if let Some(ref neg) = self.negotiated {
352            codec = codec
353                .with_ml_routing(neg.ml_routing)
354                .with_encoding(neg.encoding);
355        }
356
357        let now = Instant::now();
358        Self {
359            id: self.id.clone(),
360            state: self.state,
361            local_caps: self.local_caps.clone(),
362            remote_caps: self.remote_caps.clone(),
363            negotiated: self.negotiated.clone(),
364            codec,
365            created_at: now,
366            last_activity: now,
367            timeout: self.timeout,
368            // Note: Stats are reset on clone as this is typically used
369            // for creating a new session handler, not duplicating state
370            messages_sent: 0,
371            messages_received: 0,
372            bytes_compressed: 0,
373            bytes_saved: 0,
374        }
375    }
376}
377
378/// Session statistics
379#[derive(Debug, Clone)]
380pub struct SessionStats {
381    /// Session ID
382    pub session_id: String,
383    /// Current state
384    pub state: SessionState,
385    /// Messages sent
386    pub messages_sent: u64,
387    /// Messages received
388    pub messages_received: u64,
389    /// Total bytes compressed
390    pub bytes_compressed: u64,
391    /// Bytes saved by compression
392    pub bytes_saved: u64,
393    /// Session uptime in seconds
394    pub uptime_secs: u64,
395}
396
397impl SessionStats {
398    /// Calculate compression ratio
399    pub fn compression_ratio(&self) -> f64 {
400        if self.bytes_compressed == 0 {
401            1.0
402        } else {
403            (self.bytes_compressed + self.bytes_saved) as f64 / self.bytes_compressed as f64
404        }
405    }
406
407    /// Calculate savings percentage
408    pub fn savings_percent(&self) -> f64 {
409        let total = self.bytes_compressed + self.bytes_saved;
410        if total == 0 {
411            0.0
412        } else {
413            self.bytes_saved as f64 / total as f64 * 100.0
414        }
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421    use crate::models::Encoding;
422    use crate::protocol::capabilities::CompressionCaps;
423
424    #[test]
425    fn test_session_handshake() {
426        // Client side
427        let mut client = Session::new(Capabilities::default());
428        let hello = client.create_hello();
429        assert_eq!(client.state(), SessionState::HelloSent);
430
431        // Server side
432        let mut server = Session::new(Capabilities::default());
433        let accept = server.process_hello(&hello).unwrap();
434        assert_eq!(server.state(), SessionState::Established);
435        assert_eq!(accept.msg_type, MessageType::Accept);
436
437        // Client processes accept
438        client.process_accept(&accept).unwrap();
439        assert_eq!(client.state(), SessionState::Established);
440        assert_eq!(client.id(), server.id()); // IDs should match
441    }
442
443    #[test]
444    fn test_session_reject() {
445        let mut client = Session::new(Capabilities::new("client"));
446        let hello = client.create_hello();
447
448        // Server with incompatible version
449        let server_caps = Capabilities {
450            version: "4.0".to_string(),
451            ..Default::default()
452        };
453        let mut server = Session::new(server_caps);
454
455        let response = server.process_hello(&hello).unwrap();
456        assert_eq!(response.msg_type, MessageType::Reject);
457
458        // Client processes reject
459        let result = client.process_reject(&response);
460        assert!(result.is_err());
461        assert_eq!(client.state(), SessionState::Closed);
462    }
463
464    #[test]
465    fn test_session_data_exchange() {
466        // Establish session
467        let mut client = Session::new(Capabilities::default());
468        let mut server = Session::new(Capabilities::default());
469
470        let hello = client.create_hello();
471        let accept = server.process_hello(&hello).unwrap();
472        client.process_accept(&accept).unwrap();
473
474        // Send data from client
475        let content = r#"{"model":"gpt-4o","messages":[{"role":"user","content":"Hello"}]}"#;
476        let data_msg = client.compress(content).unwrap();
477
478        // Server receives and decompresses
479        let decompressed = server.decompress(&data_msg).unwrap();
480        let original: serde_json::Value = serde_json::from_str(content).unwrap();
481        let recovered: serde_json::Value = serde_json::from_str(&decompressed).unwrap();
482
483        assert_eq!(
484            original["messages"][0]["content"],
485            recovered["messages"][0]["content"]
486        );
487    }
488
489    #[test]
490    fn test_session_stats() {
491        let mut client = Session::new(Capabilities::default());
492        let mut server = Session::new(Capabilities::default());
493
494        let hello = client.create_hello();
495        let accept = server.process_hello(&hello).unwrap();
496        client.process_accept(&accept).unwrap();
497
498        // Send some data
499        for _ in 0..5 {
500            let _ = client.compress(r#"{"test":"data"}"#);
501        }
502
503        let stats = client.stats();
504        assert_eq!(stats.messages_sent, 6); // 1 hello + 5 data
505        assert!(stats.bytes_compressed > 0);
506    }
507
508    #[test]
509    fn test_encoding_negotiation() {
510        // Client prefers o200k, server prefers cl100k
511        // Both support both encodings
512        // Each side negotiates from their own preference
513        let client_caps = Capabilities::default().with_compression(
514            CompressionCaps::default().with_preferred_encoding(Encoding::O200kBase),
515        );
516        let mut client = Session::new(client_caps);
517
518        let server_caps = Capabilities::default().with_compression(
519            CompressionCaps::default()
520                .with_preferred_encoding(Encoding::Cl100kBase)
521                .with_encodings(vec![Encoding::Cl100kBase, Encoding::O200kBase]),
522        );
523        let mut server = Session::new(server_caps);
524
525        // Handshake
526        let hello = client.create_hello();
527        let accept = server.process_hello(&hello).unwrap();
528        client.process_accept(&accept).unwrap();
529
530        // Server negotiates using server's preference (cl100k) - client supports it
531        assert_eq!(server.encoding(), Some(Encoding::Cl100kBase));
532        // Client re-negotiates using client's preference (o200k) - server supports it
533        // NOTE: This is a known limitation - each side uses their own preference
534        // In practice, the wire format includes tokenizer ID so decompression works
535        assert_eq!(client.encoding(), Some(Encoding::O200kBase));
536    }
537
538    #[test]
539    fn test_encoding_negotiation_fallback() {
540        // Client only supports o200k
541        let client_caps = Capabilities::default().with_compression(
542            CompressionCaps::default()
543                .with_encodings(vec![Encoding::O200kBase])
544                .with_preferred_encoding(Encoding::O200kBase),
545        );
546        let mut client = Session::new(client_caps);
547
548        // Server prefers cl100k but supports both
549        let server_caps = Capabilities::default().with_compression(
550            CompressionCaps::default()
551                .with_preferred_encoding(Encoding::Cl100kBase)
552                .with_encodings(vec![Encoding::Cl100kBase, Encoding::O200kBase]),
553        );
554        let mut server = Session::new(server_caps);
555
556        // Handshake
557        let hello = client.create_hello();
558        let accept = server.process_hello(&hello).unwrap();
559        client.process_accept(&accept).unwrap();
560
561        // Server's preferred (cl100k) not supported by client, falls back to o200k
562        assert_eq!(server.encoding(), Some(Encoding::O200kBase));
563        assert_eq!(client.encoding(), Some(Encoding::O200kBase));
564    }
565
566    #[test]
567    fn test_token_native_algorithm_negotiated() {
568        let mut client = Session::new(Capabilities::default());
569        let mut server = Session::new(Capabilities::default());
570
571        let hello = client.create_hello();
572        let accept = server.process_hello(&hello).unwrap();
573        client.process_accept(&accept).unwrap();
574
575        // M2M should be the default negotiated algorithm
576        assert_eq!(client.algorithm(), Some(Algorithm::M2M));
577        assert_eq!(server.algorithm(), Some(Algorithm::M2M));
578
579        // Encoding should be cl100k (default)
580        assert_eq!(client.encoding(), Some(Encoding::Cl100kBase));
581        assert_eq!(server.encoding(), Some(Encoding::Cl100kBase));
582    }
583
584    #[test]
585    fn test_session_clone_preserves_encoding() {
586        let mut client = Session::new(Capabilities::default());
587        let mut server = Session::new(Capabilities::default());
588
589        let hello = client.create_hello();
590        let accept = server.process_hello(&hello).unwrap();
591        client.process_accept(&accept).unwrap();
592
593        // Clone the session
594        let cloned = client.clone();
595
596        // Encoding and algorithm should be preserved
597        assert_eq!(cloned.algorithm(), client.algorithm());
598        assert_eq!(cloned.encoding(), client.encoding());
599    }
600}