m2m/protocol/
capabilities.rs

1//! Agent capabilities for protocol negotiation.
2//!
3//! Capabilities are advertised during the HELLO/ACCEPT handshake
4//! to establish what compression algorithms and features both
5//! agents support.
6
7use serde::{Deserialize, Serialize};
8
9use crate::codec::Algorithm;
10use crate::models::Encoding;
11
12/// Compression-related capabilities
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct CompressionCaps {
15    /// Supported algorithms in preference order
16    pub algorithms: Vec<Algorithm>,
17    /// Maximum payload size in bytes (0 = unlimited)
18    pub max_payload: usize,
19    /// Supports streaming compression
20    pub streaming: bool,
21    /// Has ML routing capability
22    pub ml_routing: bool,
23    /// Supported tokenizer encodings (for TokenNative)
24    #[serde(default)]
25    pub encodings: Vec<Encoding>,
26    /// Preferred tokenizer encoding
27    #[serde(default)]
28    pub preferred_encoding: Encoding,
29}
30
31impl Default for CompressionCaps {
32    fn default() -> Self {
33        Self {
34            // M2M is first preference (100% JSON fidelity with routing headers)
35            // TokenNative is second (good for small-medium JSON)
36            // Brotli is third (best for large content)
37            algorithms: vec![
38                Algorithm::M2M,
39                Algorithm::TokenNative,
40                Algorithm::Brotli,
41                Algorithm::None,
42            ],
43            max_payload: 0, // unlimited
44            streaming: true,
45            ml_routing: false,
46            encodings: vec![Encoding::Cl100kBase, Encoding::O200kBase],
47            preferred_encoding: Encoding::Cl100kBase,
48        }
49    }
50}
51
52impl CompressionCaps {
53    /// Create with ML routing enabled
54    pub fn with_ml_routing(mut self) -> Self {
55        self.ml_routing = true;
56        self
57    }
58
59    /// Create with specific algorithms
60    pub fn with_algorithms(mut self, algorithms: Vec<Algorithm>) -> Self {
61        self.algorithms = algorithms;
62        self
63    }
64
65    /// Create with specific encodings
66    pub fn with_encodings(mut self, encodings: Vec<Encoding>) -> Self {
67        self.encodings = encodings;
68        self
69    }
70
71    /// Set preferred encoding
72    pub fn with_preferred_encoding(mut self, encoding: Encoding) -> Self {
73        self.preferred_encoding = encoding;
74        self
75    }
76
77    /// Check if algorithm is supported
78    pub fn supports(&self, algorithm: Algorithm) -> bool {
79        self.algorithms.contains(&algorithm)
80    }
81
82    /// Check if encoding is supported
83    pub fn supports_encoding(&self, encoding: Encoding) -> bool {
84        self.encodings.contains(&encoding)
85    }
86
87    /// Get best mutually supported algorithm
88    pub fn negotiate(&self, other: &CompressionCaps) -> Option<Algorithm> {
89        // Find first algorithm supported by both (preference order is ours)
90        for algo in &self.algorithms {
91            if other.supports(*algo) {
92                return Some(*algo);
93            }
94        }
95        None
96    }
97
98    /// Negotiate tokenizer encoding
99    pub fn negotiate_encoding(&self, other: &CompressionCaps) -> Encoding {
100        // Prefer our preferred encoding if other supports it
101        if other.supports_encoding(self.preferred_encoding) {
102            return self.preferred_encoding;
103        }
104        // Otherwise, find first mutually supported
105        for enc in &self.encodings {
106            if other.supports_encoding(*enc) {
107                return *enc;
108            }
109        }
110        // Fallback to canonical cl100k
111        Encoding::Cl100kBase
112    }
113}
114
115/// Security-related capabilities
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct SecurityCaps {
118    /// Has security threat detection
119    pub threat_detection: bool,
120    /// Security model version
121    pub model_version: Option<String>,
122    /// Blocks detected threats (vs just flagging)
123    pub blocking_mode: bool,
124    /// Minimum confidence threshold for blocking (0.0 - 1.0)
125    pub block_threshold: f32,
126}
127
128impl Default for SecurityCaps {
129    fn default() -> Self {
130        Self {
131            threat_detection: false,
132            model_version: None,
133            blocking_mode: false,
134            block_threshold: 0.8,
135        }
136    }
137}
138
139impl SecurityCaps {
140    /// Enable threat detection
141    pub fn with_threat_detection(mut self, model_version: &str) -> Self {
142        self.threat_detection = true;
143        self.model_version = Some(model_version.to_string());
144        self
145    }
146
147    /// Enable blocking mode
148    pub fn with_blocking(mut self, threshold: f32) -> Self {
149        self.blocking_mode = true;
150        self.block_threshold = threshold.clamp(0.0, 1.0);
151        self
152    }
153}
154
155/// Full agent capabilities
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct Capabilities {
158    /// Protocol version
159    pub version: String,
160    /// Agent identifier
161    pub agent_id: String,
162    /// Agent type/name
163    pub agent_type: String,
164    /// Compression capabilities
165    pub compression: CompressionCaps,
166    /// Security capabilities
167    pub security: SecurityCaps,
168    /// Custom extensions (key-value pairs)
169    #[serde(default)]
170    pub extensions: std::collections::HashMap<String, String>,
171}
172
173impl Default for Capabilities {
174    fn default() -> Self {
175        Self {
176            version: super::PROTOCOL_VERSION.to_string(),
177            agent_id: uuid::Uuid::new_v4().to_string(),
178            agent_type: "m2m-rust".to_string(),
179            compression: CompressionCaps::default(),
180            security: SecurityCaps::default(),
181            extensions: std::collections::HashMap::new(),
182        }
183    }
184}
185
186impl Capabilities {
187    /// Create new capabilities with custom agent type
188    pub fn new(agent_type: &str) -> Self {
189        Self {
190            agent_type: agent_type.to_string(),
191            ..Default::default()
192        }
193    }
194
195    /// Add compression capabilities
196    pub fn with_compression(mut self, caps: CompressionCaps) -> Self {
197        self.compression = caps;
198        self
199    }
200
201    /// Add security capabilities
202    pub fn with_security(mut self, caps: SecurityCaps) -> Self {
203        self.security = caps;
204        self
205    }
206
207    /// Add extension
208    pub fn with_extension(mut self, key: &str, value: &str) -> Self {
209        self.extensions.insert(key.to_string(), value.to_string());
210        self
211    }
212
213    /// Check version compatibility
214    pub fn is_compatible(&self, other: &Capabilities) -> bool {
215        // Major version must match
216        let self_major = self.version.split('.').next().unwrap_or("0");
217        let other_major = other.version.split('.').next().unwrap_or("0");
218        self_major == other_major
219    }
220
221    /// Negotiate capabilities with peer
222    pub fn negotiate(&self, peer: &Capabilities) -> Option<NegotiatedCaps> {
223        if !self.is_compatible(peer) {
224            return None;
225        }
226
227        let algorithm = self.compression.negotiate(&peer.compression)?;
228        let encoding = self.compression.negotiate_encoding(&peer.compression);
229
230        Some(NegotiatedCaps {
231            algorithm,
232            encoding,
233            streaming: self.compression.streaming && peer.compression.streaming,
234            ml_routing: self.compression.ml_routing && peer.compression.ml_routing,
235            threat_detection: self.security.threat_detection || peer.security.threat_detection,
236            blocking_mode: self.security.blocking_mode || peer.security.blocking_mode,
237        })
238    }
239}
240
241/// Result of capability negotiation
242#[derive(Debug, Clone)]
243pub struct NegotiatedCaps {
244    /// Agreed compression algorithm
245    pub algorithm: Algorithm,
246    /// Agreed tokenizer encoding (for TokenNative)
247    pub encoding: Encoding,
248    /// Both support streaming
249    pub streaming: bool,
250    /// Both have ML routing
251    pub ml_routing: bool,
252    /// Either has threat detection
253    pub threat_detection: bool,
254    /// Either has blocking mode
255    pub blocking_mode: bool,
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_algorithm_negotiation() {
264        let caps1 = CompressionCaps::default();
265        let caps2 = CompressionCaps {
266            algorithms: vec![Algorithm::Brotli, Algorithm::None],
267            ..Default::default()
268        };
269
270        // Should negotiate to Brotli (first common in caps1's preference order)
271        assert_eq!(caps1.negotiate(&caps2), Some(Algorithm::Brotli));
272    }
273
274    #[test]
275    fn test_no_common_algorithm() {
276        let caps1 = CompressionCaps {
277            algorithms: vec![Algorithm::TokenNative],
278            ..Default::default()
279        };
280        let caps2 = CompressionCaps {
281            algorithms: vec![Algorithm::Brotli],
282            ..Default::default()
283        };
284
285        assert_eq!(caps1.negotiate(&caps2), None);
286    }
287
288    #[test]
289    fn test_version_compatibility() {
290        let caps1 = Capabilities::default();
291        let mut caps2 = Capabilities::default();
292
293        assert!(caps1.is_compatible(&caps2));
294
295        caps2.version = "3.1".to_string();
296        assert!(caps1.is_compatible(&caps2)); // Minor version diff OK
297
298        caps2.version = "4.0".to_string();
299        assert!(!caps1.is_compatible(&caps2)); // Major version diff NOT OK
300    }
301
302    #[test]
303    fn test_full_negotiation() {
304        let caps1 = Capabilities::default()
305            .with_security(SecurityCaps::default().with_threat_detection("1.0"));
306
307        let caps2 = Capabilities::default();
308
309        let negotiated = caps1.negotiate(&caps2).unwrap();
310        assert_eq!(negotiated.algorithm, Algorithm::M2M); // New default
311        assert_eq!(negotiated.encoding, Encoding::Cl100kBase);
312        assert!(negotiated.threat_detection); // One has it
313    }
314}