1extern crate alloc;
2
3use alloc::collections::{BTreeMap, BTreeSet};
4use alloc::string::String;
5use alloc::vec::Vec;
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9pub mod canonical;
10
11use crate::limits::{
12 CAPABILITIES_MAX_COUNT, FUNCTION_VERSIONS_KEYS_MAX, FUNCTION_VERSIONS_PER_FN_MAX, VERSION_MAX,
13 VERSION_MIN,
14};
15pub use canonical::{CanonicalError, CanonicalOffer, canonicalize, host_offer_hash};
16
17pub const HANDSHAKE_SCHEMA_VERSION_V1: u32 = 1;
18
19#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(deny_unknown_fields)]
21pub struct HandshakeOfferRaw {
22 pub handshake_schema_version: u32,
23 pub envelope_version: u32,
24 pub function_versions: Vec<FunctionVersionOfferRaw>,
25 pub host_capabilities: Vec<String>,
26}
27
28#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
29#[serde(deny_unknown_fields)]
30pub struct FunctionVersionOfferRaw {
31 pub function: String,
32 pub versions: Vec<u32>,
33}
34
35#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(deny_unknown_fields)]
37pub struct HandshakeOffer {
38 pub handshake_schema_version: u32,
39 pub envelope_version: u32,
40 pub function_versions: BTreeMap<String, Vec<u32>>,
41 pub host_capabilities: BTreeSet<String>,
42}
43
44#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(deny_unknown_fields)]
46pub struct HandshakeAccept {
47 pub handshake_schema_version: u32,
48 pub envelope_version: u32,
49 pub chosen_versions: BTreeMap<String, u32>,
50 pub plugin_supported: BTreeMap<String, Vec<u32>>,
51 pub implemented_functions: BTreeSet<String>,
52 pub required_capabilities: BTreeSet<String>,
53}
54
55#[derive(Debug, Clone, Error)]
56pub enum HandshakeError {
57 #[error("function count {count} exceeds maximum {max}")]
58 FunctionCountExceeded { count: usize, max: usize },
59
60 #[error("function '{function}' has {count} versions, exceeds maximum {max}")]
61 FunctionVersionCountExceeded {
62 function: String,
63 count: usize,
64 max: usize,
65 },
66
67 #[error("version {version} outside valid range [{min}, {max}]")]
68 VersionOutOfRange { version: u32, min: u32, max: u32 },
69
70 #[error("capability count {count} exceeds maximum {max}")]
71 CapabilityCountExceeded { count: usize, max: usize },
72
73 #[error("chosen version not offered: function '{function}' version {version} not in offer")]
74 ChosenVersionNotOffered { function: String, version: u32 },
75
76 #[error("chosen for unknown function: '{function}' not in host offer")]
77 ChosenForUnknownFunction { function: String },
78
79 #[error(
80 "downgrade attempt detected: function '{function}' chosen version {chosen} but max intersection is {max_intersection}"
81 )]
82 DowngradeAttempt {
83 function: String,
84 chosen: u32,
85 max_intersection: u32,
86 },
87
88 #[error("required capability '{capability}' not available in host capabilities")]
89 RequiredCapabilityUnavailable { capability: String },
90
91 #[error("handshake schema version mismatch: got {got}, expected {expected}")]
92 HandshakeSchemaVersionMismatch { got: u32, expected: u32 },
93
94 #[error("canonical offer hash error: {0}")]
95 Canonical(#[from] CanonicalError),
96}
97
98impl HandshakeOffer {
99 pub fn validate(&self) -> Result<(), HandshakeError> {
100 if self.function_versions.len() > FUNCTION_VERSIONS_KEYS_MAX {
101 return Err(HandshakeError::FunctionCountExceeded {
102 count: self.function_versions.len(),
103 max: FUNCTION_VERSIONS_KEYS_MAX,
104 });
105 }
106
107 for (function, versions) in &self.function_versions {
108 if versions.len() > FUNCTION_VERSIONS_PER_FN_MAX {
109 return Err(HandshakeError::FunctionVersionCountExceeded {
110 function: function.clone(),
111 count: versions.len(),
112 max: FUNCTION_VERSIONS_PER_FN_MAX,
113 });
114 }
115
116 for &version in versions {
117 if !(VERSION_MIN..=VERSION_MAX).contains(&version) {
118 return Err(HandshakeError::VersionOutOfRange {
119 version,
120 min: VERSION_MIN,
121 max: VERSION_MAX,
122 });
123 }
124 }
125 }
126
127 if self.host_capabilities.len() > CAPABILITIES_MAX_COUNT {
128 return Err(HandshakeError::CapabilityCountExceeded {
129 count: self.host_capabilities.len(),
130 max: CAPABILITIES_MAX_COUNT,
131 });
132 }
133
134 Ok(())
135 }
136
137 pub fn canonical_hash(&self) -> Result<[u8; 32], CanonicalError> {
138 let raw = HandshakeOfferRaw {
139 handshake_schema_version: self.handshake_schema_version,
140 envelope_version: self.envelope_version,
141 function_versions: self
142 .function_versions
143 .iter()
144 .map(|(fn_name, versions)| FunctionVersionOfferRaw {
145 function: fn_name.clone(),
146 versions: versions.clone(),
147 })
148 .collect(),
149 host_capabilities: self.host_capabilities.iter().cloned().collect(),
150 };
151 host_offer_hash(&raw)
152 }
153}
154
155impl HandshakeAccept {
156 #[allow(clippy::collapsible_if)]
157 pub fn validate_against_offer(&self, offer: &HandshakeOffer) -> Result<(), HandshakeError> {
158 if self.handshake_schema_version != offer.handshake_schema_version {
159 return Err(HandshakeError::HandshakeSchemaVersionMismatch {
160 got: self.handshake_schema_version,
161 expected: offer.handshake_schema_version,
162 });
163 }
164
165 for (function, &chosen) in &self.chosen_versions {
166 if !self.plugin_supported.contains_key(function) {
167 return Err(HandshakeError::ChosenForUnknownFunction {
168 function: function.clone(),
169 });
170 }
171
172 if !offer.function_versions.contains_key(function) {
173 return Err(HandshakeError::ChosenForUnknownFunction {
174 function: function.clone(),
175 });
176 }
177
178 let plugin_versions = &self.plugin_supported[function];
179 let offer_versions = &offer.function_versions[function];
180
181 if !offer_versions.contains(&chosen) {
182 return Err(HandshakeError::ChosenVersionNotOffered {
183 function: function.clone(),
184 version: chosen,
185 });
186 }
187
188 let intersection_max = offer_versions
189 .iter()
190 .filter(|v| plugin_versions.contains(v))
191 .max()
192 .copied();
193
194 if let Some(max_intersection) = intersection_max {
195 if chosen < max_intersection {
196 return Err(HandshakeError::DowngradeAttempt {
197 function: function.clone(),
198 chosen,
199 max_intersection,
200 });
201 }
202 }
203 }
204
205 for capability in &self.required_capabilities {
206 if !offer.host_capabilities.contains(capability) {
207 return Err(HandshakeError::RequiredCapabilityUnavailable {
208 capability: capability.clone(),
209 });
210 }
211 }
212
213 Ok(())
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 extern crate std;
220
221 use super::*;
222 use alloc::string::ToString;
223 use alloc::vec;
224
225 #[test]
226 fn handshake_schema_version_locked() {
227 assert_eq!(HANDSHAKE_SCHEMA_VERSION_V1, 1);
228 }
229
230 #[test]
231 fn empty_offer_valid() {
232 let offer = HandshakeOffer {
233 handshake_schema_version: 1,
234 envelope_version: 1,
235 function_versions: BTreeMap::new(),
236 host_capabilities: BTreeSet::new(),
237 };
238 assert!(offer.validate().is_ok());
239 }
240
241 #[test]
242 fn function_count_exceeds_limit() {
243 let mut functions = BTreeMap::new();
244 for i in 0..=FUNCTION_VERSIONS_KEYS_MAX {
245 functions.insert(alloc::format!("fn{}", i), vec![1]);
246 }
247
248 let offer = HandshakeOffer {
249 handshake_schema_version: 1,
250 envelope_version: 1,
251 function_versions: functions,
252 host_capabilities: BTreeSet::new(),
253 };
254
255 match offer.validate() {
256 Err(HandshakeError::FunctionCountExceeded { .. }) => {}
257 other => panic!("expected FunctionCountExceeded, got {:?}", other),
258 }
259 }
260
261 #[test]
262 fn version_count_per_function_exceeds_limit() {
263 let mut functions = BTreeMap::new();
264 functions.insert(
265 "route".to_string(),
266 (1..=FUNCTION_VERSIONS_PER_FN_MAX + 1)
267 .map(|v| v as u32)
268 .collect::<Vec<_>>(),
269 );
270
271 let offer = HandshakeOffer {
272 handshake_schema_version: 1,
273 envelope_version: 1,
274 function_versions: functions,
275 host_capabilities: BTreeSet::new(),
276 };
277
278 match offer.validate() {
279 Err(HandshakeError::FunctionVersionCountExceeded { .. }) => {}
280 other => panic!("expected FunctionVersionCountExceeded, got {:?}", other),
281 }
282 }
283
284 #[test]
285 fn invalid_version_too_low() {
286 let mut functions = BTreeMap::new();
287 functions.insert("route".to_string(), vec![0]);
288
289 let offer = HandshakeOffer {
290 handshake_schema_version: 1,
291 envelope_version: 1,
292 function_versions: functions,
293 host_capabilities: BTreeSet::new(),
294 };
295
296 match offer.validate() {
297 Err(HandshakeError::VersionOutOfRange { .. }) => {}
298 other => panic!("expected VersionOutOfRange, got {:?}", other),
299 }
300 }
301
302 #[test]
303 fn capability_count_exceeds_limit() {
304 let mut capabilities = BTreeSet::new();
305 for i in 0..=CAPABILITIES_MAX_COUNT {
306 capabilities.insert(alloc::format!("cap{}", i));
307 }
308
309 let offer = HandshakeOffer {
310 handshake_schema_version: 1,
311 envelope_version: 1,
312 function_versions: BTreeMap::new(),
313 host_capabilities: capabilities,
314 };
315
316 match offer.validate() {
317 Err(HandshakeError::CapabilityCountExceeded { .. }) => {}
318 other => panic!("expected CapabilityCountExceeded, got {:?}", other),
319 }
320 }
321
322 #[test]
323 fn downgrade_attack_rejected() {
324 let mut offer_fns = BTreeMap::new();
325 offer_fns.insert("route".to_string(), vec![1, 2, 3]);
326
327 let offer = HandshakeOffer {
328 handshake_schema_version: 1,
329 envelope_version: 1,
330 function_versions: offer_fns,
331 host_capabilities: BTreeSet::new(),
332 };
333
334 let mut plugin_supported = BTreeMap::new();
335 plugin_supported.insert("route".to_string(), vec![1, 2, 3]);
336
337 let mut chosen = BTreeMap::new();
338 chosen.insert("route".to_string(), 1);
339
340 let accept = HandshakeAccept {
341 handshake_schema_version: 1,
342 envelope_version: 1,
343 chosen_versions: chosen,
344 plugin_supported,
345 implemented_functions: BTreeSet::new(),
346 required_capabilities: BTreeSet::new(),
347 };
348
349 match accept.validate_against_offer(&offer) {
350 Err(HandshakeError::DowngradeAttempt {
351 function,
352 chosen,
353 max_intersection,
354 }) => {
355 assert_eq!(function, "route");
356 assert_eq!(chosen, 1);
357 assert_eq!(max_intersection, 3);
358 }
359 other => panic!("expected DowngradeAttempt, got {:?}", other),
360 }
361 }
362
363 #[test]
364 fn chosen_not_in_offer_rejected() {
365 let mut offer_fns = BTreeMap::new();
366 offer_fns.insert("route".to_string(), vec![1, 2]);
367
368 let offer = HandshakeOffer {
369 handshake_schema_version: 1,
370 envelope_version: 1,
371 function_versions: offer_fns,
372 host_capabilities: BTreeSet::new(),
373 };
374
375 let mut plugin_supported = BTreeMap::new();
376 plugin_supported.insert("route".to_string(), vec![1, 2, 3]);
377
378 let mut chosen = BTreeMap::new();
379 chosen.insert("route".to_string(), 99);
380
381 let accept = HandshakeAccept {
382 handshake_schema_version: 1,
383 envelope_version: 1,
384 chosen_versions: chosen,
385 plugin_supported,
386 implemented_functions: BTreeSet::new(),
387 required_capabilities: BTreeSet::new(),
388 };
389
390 match accept.validate_against_offer(&offer) {
391 Err(HandshakeError::ChosenVersionNotOffered { .. }) => {}
392 other => panic!("expected ChosenVersionNotOffered, got {:?}", other),
393 }
394 }
395
396 #[test]
397 fn missing_required_capability() {
398 let offer = HandshakeOffer {
399 handshake_schema_version: 1,
400 envelope_version: 1,
401 function_versions: BTreeMap::new(),
402 host_capabilities: {
403 let mut caps = BTreeSet::new();
404 caps.insert("log".to_string());
405 caps
406 },
407 };
408
409 let accept = HandshakeAccept {
410 handshake_schema_version: 1,
411 envelope_version: 1,
412 chosen_versions: BTreeMap::new(),
413 plugin_supported: BTreeMap::new(),
414 implemented_functions: BTreeSet::new(),
415 required_capabilities: {
416 let mut caps = BTreeSet::new();
417 caps.insert("trace".to_string());
418 caps
419 },
420 };
421
422 match accept.validate_against_offer(&offer) {
423 Err(HandshakeError::RequiredCapabilityUnavailable { capability }) => {
424 assert_eq!(capability, "trace");
425 }
426 other => panic!("expected RequiredCapabilityUnavailable, got {:?}", other),
427 }
428 }
429
430 #[test]
431 fn schema_version_mismatch() {
432 let offer = HandshakeOffer {
433 handshake_schema_version: 1,
434 envelope_version: 1,
435 function_versions: BTreeMap::new(),
436 host_capabilities: BTreeSet::new(),
437 };
438
439 let accept = HandshakeAccept {
440 handshake_schema_version: 2,
441 envelope_version: 1,
442 chosen_versions: BTreeMap::new(),
443 plugin_supported: BTreeMap::new(),
444 implemented_functions: BTreeSet::new(),
445 required_capabilities: BTreeSet::new(),
446 };
447
448 match accept.validate_against_offer(&offer) {
449 Err(HandshakeError::HandshakeSchemaVersionMismatch { .. }) => {}
450 other => panic!("expected HandshakeSchemaVersionMismatch, got {:?}", other),
451 }
452 }
453
454 #[test]
455 fn valid_handshake_roundtrip() {
456 let mut offer_fns = BTreeMap::new();
457 offer_fns.insert("route".to_string(), vec![1, 2]);
458 offer_fns.insert("shape".to_string(), vec![1]);
459
460 let mut offer_caps = BTreeSet::new();
461 offer_caps.insert("streaming".to_string());
462
463 let offer = HandshakeOffer {
464 handshake_schema_version: 1,
465 envelope_version: 1,
466 function_versions: offer_fns,
467 host_capabilities: offer_caps,
468 };
469
470 assert!(offer.validate().is_ok());
471
472 let mut plugin_supported = BTreeMap::new();
473 plugin_supported.insert("route".to_string(), vec![1, 2]);
474 plugin_supported.insert("shape".to_string(), vec![1]);
475
476 let mut plugin_caps = BTreeSet::new();
477 plugin_caps.insert("streaming".to_string());
478
479 let mut chosen = BTreeMap::new();
480 chosen.insert("route".to_string(), 2);
481 chosen.insert("shape".to_string(), 1);
482
483 let accept = HandshakeAccept {
484 handshake_schema_version: 1,
485 envelope_version: 1,
486 chosen_versions: chosen,
487 plugin_supported,
488 implemented_functions: {
489 let mut fns = BTreeSet::new();
490 fns.insert("route".to_string());
491 fns.insert("shape".to_string());
492 fns
493 },
494 required_capabilities: plugin_caps,
495 };
496
497 assert!(accept.validate_against_offer(&offer).is_ok());
498 }
499
500 #[test]
501 fn serde_roundtrip_offer() {
502 let mut offer_fns = BTreeMap::new();
503 offer_fns.insert("route".to_string(), vec![1, 2]);
504
505 let original = HandshakeOffer {
506 handshake_schema_version: 1,
507 envelope_version: 1,
508 function_versions: offer_fns,
509 host_capabilities: {
510 let mut caps = BTreeSet::new();
511 caps.insert("streaming".to_string());
512 caps
513 },
514 };
515
516 let json = serde_json::to_vec(&original).expect("serialize");
517 let deserialized: HandshakeOffer = serde_json::from_slice(&json).expect("deserialize");
518
519 assert_eq!(original, deserialized);
520 }
521
522 #[test]
523 fn serde_roundtrip_accept() {
524 let mut plugin_supported = BTreeMap::new();
525 plugin_supported.insert("route".to_string(), vec![1, 2]);
526
527 let mut chosen = BTreeMap::new();
528 chosen.insert("route".to_string(), 2);
529
530 let original = HandshakeAccept {
531 handshake_schema_version: 1,
532 envelope_version: 1,
533 chosen_versions: chosen,
534 plugin_supported,
535 implemented_functions: {
536 let mut fns = BTreeSet::new();
537 fns.insert("route".to_string());
538 fns
539 },
540 required_capabilities: {
541 let mut caps = BTreeSet::new();
542 caps.insert("streaming".to_string());
543 caps
544 },
545 };
546
547 let json = serde_json::to_vec(&original).expect("serialize");
548 let deserialized: HandshakeAccept = serde_json::from_slice(&json).expect("deserialize");
549
550 assert_eq!(original, deserialized);
551 }
552}