1use serde::{Deserialize, Serialize};
8
9use crate::codec::Algorithm;
10use crate::models::Encoding;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct CompressionCaps {
15 pub algorithms: Vec<Algorithm>,
17 pub max_payload: usize,
19 pub streaming: bool,
21 pub ml_routing: bool,
23 #[serde(default)]
25 pub encodings: Vec<Encoding>,
26 #[serde(default)]
28 pub preferred_encoding: Encoding,
29}
30
31impl Default for CompressionCaps {
32 fn default() -> Self {
33 Self {
34 algorithms: vec![
38 Algorithm::M2M,
39 Algorithm::TokenNative,
40 Algorithm::Brotli,
41 Algorithm::None,
42 ],
43 max_payload: 0, streaming: true,
45 ml_routing: false,
46 encodings: vec![Encoding::Cl100kBase, Encoding::O200kBase],
47 preferred_encoding: Encoding::Cl100kBase,
48 }
49 }
50}
51
52impl CompressionCaps {
53 pub fn with_ml_routing(mut self) -> Self {
55 self.ml_routing = true;
56 self
57 }
58
59 pub fn with_algorithms(mut self, algorithms: Vec<Algorithm>) -> Self {
61 self.algorithms = algorithms;
62 self
63 }
64
65 pub fn with_encodings(mut self, encodings: Vec<Encoding>) -> Self {
67 self.encodings = encodings;
68 self
69 }
70
71 pub fn with_preferred_encoding(mut self, encoding: Encoding) -> Self {
73 self.preferred_encoding = encoding;
74 self
75 }
76
77 pub fn supports(&self, algorithm: Algorithm) -> bool {
79 self.algorithms.contains(&algorithm)
80 }
81
82 pub fn supports_encoding(&self, encoding: Encoding) -> bool {
84 self.encodings.contains(&encoding)
85 }
86
87 pub fn negotiate(&self, other: &CompressionCaps) -> Option<Algorithm> {
89 for algo in &self.algorithms {
91 if other.supports(*algo) {
92 return Some(*algo);
93 }
94 }
95 None
96 }
97
98 pub fn negotiate_encoding(&self, other: &CompressionCaps) -> Encoding {
100 if other.supports_encoding(self.preferred_encoding) {
102 return self.preferred_encoding;
103 }
104 for enc in &self.encodings {
106 if other.supports_encoding(*enc) {
107 return *enc;
108 }
109 }
110 Encoding::Cl100kBase
112 }
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct SecurityCaps {
118 pub threat_detection: bool,
120 pub model_version: Option<String>,
122 pub blocking_mode: bool,
124 pub block_threshold: f32,
126}
127
128impl Default for SecurityCaps {
129 fn default() -> Self {
130 Self {
131 threat_detection: false,
132 model_version: None,
133 blocking_mode: false,
134 block_threshold: 0.8,
135 }
136 }
137}
138
139impl SecurityCaps {
140 pub fn with_threat_detection(mut self, model_version: &str) -> Self {
142 self.threat_detection = true;
143 self.model_version = Some(model_version.to_string());
144 self
145 }
146
147 pub fn with_blocking(mut self, threshold: f32) -> Self {
149 self.blocking_mode = true;
150 self.block_threshold = threshold.clamp(0.0, 1.0);
151 self
152 }
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct Capabilities {
158 pub version: String,
160 pub agent_id: String,
162 pub agent_type: String,
164 pub compression: CompressionCaps,
166 pub security: SecurityCaps,
168 #[serde(default)]
170 pub extensions: std::collections::HashMap<String, String>,
171}
172
173impl Default for Capabilities {
174 fn default() -> Self {
175 Self {
176 version: super::PROTOCOL_VERSION.to_string(),
177 agent_id: uuid::Uuid::new_v4().to_string(),
178 agent_type: "m2m-rust".to_string(),
179 compression: CompressionCaps::default(),
180 security: SecurityCaps::default(),
181 extensions: std::collections::HashMap::new(),
182 }
183 }
184}
185
186impl Capabilities {
187 pub fn new(agent_type: &str) -> Self {
189 Self {
190 agent_type: agent_type.to_string(),
191 ..Default::default()
192 }
193 }
194
195 pub fn with_compression(mut self, caps: CompressionCaps) -> Self {
197 self.compression = caps;
198 self
199 }
200
201 pub fn with_security(mut self, caps: SecurityCaps) -> Self {
203 self.security = caps;
204 self
205 }
206
207 pub fn with_extension(mut self, key: &str, value: &str) -> Self {
209 self.extensions.insert(key.to_string(), value.to_string());
210 self
211 }
212
213 pub fn is_compatible(&self, other: &Capabilities) -> bool {
215 let self_major = self.version.split('.').next().unwrap_or("0");
217 let other_major = other.version.split('.').next().unwrap_or("0");
218 self_major == other_major
219 }
220
221 pub fn negotiate(&self, peer: &Capabilities) -> Option<NegotiatedCaps> {
223 if !self.is_compatible(peer) {
224 return None;
225 }
226
227 let algorithm = self.compression.negotiate(&peer.compression)?;
228 let encoding = self.compression.negotiate_encoding(&peer.compression);
229
230 Some(NegotiatedCaps {
231 algorithm,
232 encoding,
233 streaming: self.compression.streaming && peer.compression.streaming,
234 ml_routing: self.compression.ml_routing && peer.compression.ml_routing,
235 threat_detection: self.security.threat_detection || peer.security.threat_detection,
236 blocking_mode: self.security.blocking_mode || peer.security.blocking_mode,
237 })
238 }
239}
240
241#[derive(Debug, Clone)]
243pub struct NegotiatedCaps {
244 pub algorithm: Algorithm,
246 pub encoding: Encoding,
248 pub streaming: bool,
250 pub ml_routing: bool,
252 pub threat_detection: bool,
254 pub blocking_mode: bool,
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_algorithm_negotiation() {
264 let caps1 = CompressionCaps::default();
265 let caps2 = CompressionCaps {
266 algorithms: vec![Algorithm::Brotli, Algorithm::None],
267 ..Default::default()
268 };
269
270 assert_eq!(caps1.negotiate(&caps2), Some(Algorithm::Brotli));
272 }
273
274 #[test]
275 fn test_no_common_algorithm() {
276 let caps1 = CompressionCaps {
277 algorithms: vec![Algorithm::TokenNative],
278 ..Default::default()
279 };
280 let caps2 = CompressionCaps {
281 algorithms: vec![Algorithm::Brotli],
282 ..Default::default()
283 };
284
285 assert_eq!(caps1.negotiate(&caps2), None);
286 }
287
288 #[test]
289 fn test_version_compatibility() {
290 let caps1 = Capabilities::default();
291 let mut caps2 = Capabilities::default();
292
293 assert!(caps1.is_compatible(&caps2));
294
295 caps2.version = "3.1".to_string();
296 assert!(caps1.is_compatible(&caps2)); caps2.version = "4.0".to_string();
299 assert!(!caps1.is_compatible(&caps2)); }
301
302 #[test]
303 fn test_full_negotiation() {
304 let caps1 = Capabilities::default()
305 .with_security(SecurityCaps::default().with_threat_detection("1.0"));
306
307 let caps2 = Capabilities::default();
308
309 let negotiated = caps1.negotiate(&caps2).unwrap();
310 assert_eq!(negotiated.algorithm, Algorithm::M2M); assert_eq!(negotiated.encoding, Encoding::Cl100kBase);
312 assert!(negotiated.threat_detection); }
314}