1use serde::{Deserialize, Serialize};
8
9use crate::MPL_VERSION;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ClientHello {
14 pub mpl_version: String,
16
17 #[serde(default)]
19 pub protocols: Vec<String>,
20
21 #[serde(default)]
23 pub stypes: Vec<String>,
24
25 #[serde(default)]
27 pub tools: Vec<ToolRequest>,
28
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub profile: Option<String>,
32
33 #[serde(default)]
35 pub policies: Vec<String>,
36
37 #[serde(default)]
39 pub features: Vec<String>,
40
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub model: Option<String>,
44
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub client_id: Option<String>,
48}
49
50impl ClientHello {
51 pub fn new() -> Self {
53 Self {
54 mpl_version: MPL_VERSION.to_string(),
55 protocols: Vec::new(),
56 stypes: Vec::new(),
57 tools: Vec::new(),
58 profile: None,
59 policies: Vec::new(),
60 features: Vec::new(),
61 model: None,
62 client_id: None,
63 }
64 }
65
66 pub fn with_protocols(mut self, protocols: Vec<String>) -> Self {
68 self.protocols = protocols;
69 self
70 }
71
72 pub fn with_stypes(mut self, stypes: Vec<String>) -> Self {
74 self.stypes = stypes;
75 self
76 }
77
78 pub fn with_tools(mut self, tools: Vec<ToolRequest>) -> Self {
80 self.tools = tools;
81 self
82 }
83
84 pub fn with_profile(mut self, profile: impl Into<String>) -> Self {
86 self.profile = Some(profile.into());
87 self
88 }
89
90 pub fn with_policies(mut self, policies: Vec<String>) -> Self {
92 self.policies = policies;
93 self
94 }
95
96 pub fn with_features(mut self, features: Vec<String>) -> Self {
98 self.features = features;
99 self
100 }
101}
102
103impl Default for ClientHello {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ToolRequest {
112 pub id: String,
114
115 #[serde(default)]
117 pub features: Vec<String>,
118}
119
120impl ToolRequest {
121 pub fn new(id: impl Into<String>) -> Self {
122 Self {
123 id: id.into(),
124 features: Vec::new(),
125 }
126 }
127
128 pub fn with_features(mut self, features: Vec<String>) -> Self {
129 self.features = features;
130 self
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct ServerSelect {
137 pub mpl_version: String,
139
140 #[serde(skip_serializing_if = "Option::is_none")]
142 pub protocol: Option<String>,
143
144 #[serde(default)]
146 pub stypes: Vec<String>,
147
148 #[serde(default)]
150 pub tools: Vec<ToolResponse>,
151
152 #[serde(skip_serializing_if = "Option::is_none")]
154 pub profile: Option<String>,
155
156 #[serde(default)]
158 pub policies: Vec<String>,
159
160 #[serde(default)]
162 pub features: Vec<String>,
163
164 #[serde(default)]
166 pub downgrades: Vec<Downgrade>,
167
168 pub success: bool,
170
171 #[serde(skip_serializing_if = "Option::is_none")]
173 pub error: Option<String>,
174
175 #[serde(skip_serializing_if = "Option::is_none")]
177 pub server_id: Option<String>,
178}
179
180impl ServerSelect {
181 pub fn success() -> Self {
183 Self {
184 mpl_version: MPL_VERSION.to_string(),
185 protocol: None,
186 stypes: Vec::new(),
187 tools: Vec::new(),
188 profile: None,
189 policies: Vec::new(),
190 features: Vec::new(),
191 downgrades: Vec::new(),
192 success: true,
193 error: None,
194 server_id: None,
195 }
196 }
197
198 pub fn failed(error: impl Into<String>) -> Self {
200 Self {
201 mpl_version: MPL_VERSION.to_string(),
202 protocol: None,
203 stypes: Vec::new(),
204 tools: Vec::new(),
205 profile: None,
206 policies: Vec::new(),
207 features: Vec::new(),
208 downgrades: Vec::new(),
209 success: false,
210 error: Some(error.into()),
211 server_id: None,
212 }
213 }
214
215 pub fn with_protocol(mut self, protocol: impl Into<String>) -> Self {
217 self.protocol = Some(protocol.into());
218 self
219 }
220
221 pub fn with_stypes(mut self, stypes: Vec<String>) -> Self {
223 self.stypes = stypes;
224 self
225 }
226
227 pub fn with_tools(mut self, tools: Vec<ToolResponse>) -> Self {
229 self.tools = tools;
230 self
231 }
232
233 pub fn with_profile(mut self, profile: impl Into<String>) -> Self {
235 self.profile = Some(profile.into());
236 self
237 }
238
239 pub fn with_downgrade(mut self, downgrade: Downgrade) -> Self {
241 self.downgrades.push(downgrade);
242 self
243 }
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct ToolResponse {
249 pub id: String,
251
252 pub available: bool,
254
255 #[serde(default)]
257 pub features: Vec<String>,
258
259 #[serde(skip_serializing_if = "Option::is_none")]
261 pub reason: Option<String>,
262}
263
264impl ToolResponse {
265 pub fn available(id: impl Into<String>) -> Self {
266 Self {
267 id: id.into(),
268 available: true,
269 features: Vec::new(),
270 reason: None,
271 }
272 }
273
274 pub fn unavailable(id: impl Into<String>, reason: impl Into<String>) -> Self {
275 Self {
276 id: id.into(),
277 available: false,
278 features: Vec::new(),
279 reason: Some(reason.into()),
280 }
281 }
282
283 pub fn with_features(mut self, features: Vec<String>) -> Self {
284 self.features = features;
285 self
286 }
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize)]
291pub struct Downgrade {
292 pub category: DowngradeCategory,
294
295 pub requested: String,
297
298 #[serde(skip_serializing_if = "Option::is_none")]
300 pub selected: Option<String>,
301
302 pub reason: String,
304}
305
306impl Downgrade {
307 pub fn new(
308 category: DowngradeCategory,
309 requested: impl Into<String>,
310 reason: impl Into<String>,
311 ) -> Self {
312 Self {
313 category,
314 requested: requested.into(),
315 selected: None,
316 reason: reason.into(),
317 }
318 }
319
320 pub fn with_selected(mut self, selected: impl Into<String>) -> Self {
321 self.selected = Some(selected.into());
322 self
323 }
324}
325
326#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
328#[serde(rename_all = "snake_case")]
329pub enum DowngradeCategory {
330 Protocol,
331 Stype,
332 Tool,
333 Profile,
334 Policy,
335 Feature,
336 Model,
337}
338
339pub fn negotiate(client: &ClientHello, server_capabilities: &ServerCapabilities) -> ServerSelect {
341 let mut response = ServerSelect::success();
342 let mut downgrades = Vec::new();
343
344 if let Some(protocol) = client
346 .protocols
347 .iter()
348 .find(|p| server_capabilities.protocols.contains(p))
349 {
350 response.protocol = Some(protocol.clone());
351 } else if !client.protocols.is_empty() {
352 downgrades.push(Downgrade::new(
353 DowngradeCategory::Protocol,
354 client.protocols.join(", "),
355 "No compatible protocol found",
356 ));
357 }
358
359 for stype in &client.stypes {
361 if server_capabilities.stypes.contains(stype) {
362 response.stypes.push(stype.clone());
363 } else {
364 downgrades.push(Downgrade::new(
365 DowngradeCategory::Stype,
366 stype,
367 "SType not supported",
368 ));
369 }
370 }
371
372 for tool_req in &client.tools {
374 if let Some(server_tool) = server_capabilities.tools.iter().find(|t| t.id == tool_req.id) {
375 let supported_features: Vec<_> = tool_req
376 .features
377 .iter()
378 .filter(|f| server_tool.features.contains(f))
379 .cloned()
380 .collect();
381
382 let unsupported: Vec<_> = tool_req
383 .features
384 .iter()
385 .filter(|f| !server_tool.features.contains(f))
386 .cloned()
387 .collect();
388
389 response
390 .tools
391 .push(ToolResponse::available(&tool_req.id).with_features(supported_features));
392
393 for feature in unsupported {
394 downgrades.push(Downgrade::new(
395 DowngradeCategory::Feature,
396 format!("{}:{}", tool_req.id, feature),
397 "Feature not supported for tool",
398 ));
399 }
400 } else {
401 response
402 .tools
403 .push(ToolResponse::unavailable(&tool_req.id, "Tool not available"));
404 downgrades.push(Downgrade::new(
405 DowngradeCategory::Tool,
406 &tool_req.id,
407 "Tool not available",
408 ));
409 }
410 }
411
412 if let Some(requested_profile) = &client.profile {
414 if server_capabilities.profiles.contains(requested_profile) {
415 response.profile = Some(requested_profile.clone());
416 } else if let Some(fallback) = server_capabilities.profiles.first() {
417 response.profile = Some(fallback.clone());
418 downgrades.push(
419 Downgrade::new(
420 DowngradeCategory::Profile,
421 requested_profile,
422 "Requested profile not available",
423 )
424 .with_selected(fallback),
425 );
426 }
427 }
428
429 for policy in &client.policies {
431 if server_capabilities.policies.contains(policy) {
432 response.policies.push(policy.clone());
433 } else {
434 downgrades.push(Downgrade::new(
435 DowngradeCategory::Policy,
436 policy,
437 "Policy not supported",
438 ));
439 }
440 }
441
442 response.downgrades = downgrades;
443 response
444}
445
446#[derive(Debug, Clone, Default)]
448pub struct ServerCapabilities {
449 pub protocols: Vec<String>,
450 pub stypes: Vec<String>,
451 pub tools: Vec<ToolCapability>,
452 pub profiles: Vec<String>,
453 pub policies: Vec<String>,
454 pub features: Vec<String>,
455}
456
457#[derive(Debug, Clone)]
459pub struct ToolCapability {
460 pub id: String,
461 pub features: Vec<String>,
462}
463
464impl ToolCapability {
465 pub fn new(id: impl Into<String>) -> Self {
466 Self {
467 id: id.into(),
468 features: Vec::new(),
469 }
470 }
471
472 pub fn with_features(mut self, features: Vec<String>) -> Self {
473 self.features = features;
474 self
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481
482 #[test]
483 fn test_client_hello_builder() {
484 let hello = ClientHello::new()
485 .with_stypes(vec!["org.calendar.Event.v1".to_string()])
486 .with_tools(vec![ToolRequest::new("calendar.create.v1")])
487 .with_profile("qom-basic");
488
489 assert_eq!(hello.stypes.len(), 1);
490 assert_eq!(hello.tools.len(), 1);
491 assert_eq!(hello.profile, Some("qom-basic".to_string()));
492 }
493
494 #[test]
495 fn test_negotiation_success() {
496 let client = ClientHello::new()
497 .with_stypes(vec!["org.calendar.Event.v1".to_string()])
498 .with_profile("qom-basic");
499
500 let server = ServerCapabilities {
501 stypes: vec!["org.calendar.Event.v1".to_string()],
502 profiles: vec!["qom-basic".to_string()],
503 ..Default::default()
504 };
505
506 let response = negotiate(&client, &server);
507 assert!(response.success);
508 assert!(response.stypes.contains(&"org.calendar.Event.v1".to_string()));
509 assert_eq!(response.profile, Some("qom-basic".to_string()));
510 }
511
512 #[test]
513 fn test_negotiation_downgrade() {
514 let client = ClientHello::new().with_profile("qom-strict-argcheck");
515
516 let server = ServerCapabilities {
517 profiles: vec!["qom-basic".to_string()],
518 ..Default::default()
519 };
520
521 let response = negotiate(&client, &server);
522 assert!(response.success);
523 assert_eq!(response.profile, Some("qom-basic".to_string()));
524 assert!(!response.downgrades.is_empty());
525 assert_eq!(response.downgrades[0].category, DowngradeCategory::Profile);
526 }
527}