Skip to main content

cc_lb_plugin_wire/handshake/
mod.rs

1extern crate alloc;
2
3use alloc::collections::{BTreeMap, BTreeSet};
4use alloc::string::String;
5use alloc::vec::Vec;
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9pub mod canonical;
10
11use crate::limits::{
12    CAPABILITIES_MAX_COUNT, FUNCTION_VERSIONS_KEYS_MAX, FUNCTION_VERSIONS_PER_FN_MAX, VERSION_MAX,
13    VERSION_MIN,
14};
15pub use canonical::{CanonicalError, CanonicalOffer, canonicalize, host_offer_hash};
16
17pub const HANDSHAKE_SCHEMA_VERSION_V1: u32 = 1;
18
19#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(deny_unknown_fields)]
21pub struct HandshakeOfferRaw {
22    pub handshake_schema_version: u32,
23    pub envelope_version: u32,
24    pub function_versions: Vec<FunctionVersionOfferRaw>,
25    pub host_capabilities: Vec<String>,
26}
27
28#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
29#[serde(deny_unknown_fields)]
30pub struct FunctionVersionOfferRaw {
31    pub function: String,
32    pub versions: Vec<u32>,
33}
34
35#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(deny_unknown_fields)]
37pub struct HandshakeOffer {
38    pub handshake_schema_version: u32,
39    pub envelope_version: u32,
40    pub function_versions: BTreeMap<String, Vec<u32>>,
41    pub host_capabilities: BTreeSet<String>,
42}
43
44#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(deny_unknown_fields)]
46pub struct HandshakeAccept {
47    pub handshake_schema_version: u32,
48    pub envelope_version: u32,
49    pub chosen_versions: BTreeMap<String, u32>,
50    pub plugin_supported: BTreeMap<String, Vec<u32>>,
51    pub implemented_functions: BTreeSet<String>,
52    pub required_capabilities: BTreeSet<String>,
53}
54
55#[derive(Debug, Clone, Error)]
56pub enum HandshakeError {
57    #[error("function count {count} exceeds maximum {max}")]
58    FunctionCountExceeded { count: usize, max: usize },
59
60    #[error("function '{function}' has {count} versions, exceeds maximum {max}")]
61    FunctionVersionCountExceeded {
62        function: String,
63        count: usize,
64        max: usize,
65    },
66
67    #[error("version {version} outside valid range [{min}, {max}]")]
68    VersionOutOfRange { version: u32, min: u32, max: u32 },
69
70    #[error("capability count {count} exceeds maximum {max}")]
71    CapabilityCountExceeded { count: usize, max: usize },
72
73    #[error("chosen version not offered: function '{function}' version {version} not in offer")]
74    ChosenVersionNotOffered { function: String, version: u32 },
75
76    #[error("chosen for unknown function: '{function}' not in host offer")]
77    ChosenForUnknownFunction { function: String },
78
79    #[error(
80        "downgrade attempt detected: function '{function}' chosen version {chosen} but max intersection is {max_intersection}"
81    )]
82    DowngradeAttempt {
83        function: String,
84        chosen: u32,
85        max_intersection: u32,
86    },
87
88    #[error("required capability '{capability}' not available in host capabilities")]
89    RequiredCapabilityUnavailable { capability: String },
90
91    #[error("handshake schema version mismatch: got {got}, expected {expected}")]
92    HandshakeSchemaVersionMismatch { got: u32, expected: u32 },
93
94    #[error("canonical offer hash error: {0}")]
95    Canonical(#[from] CanonicalError),
96}
97
98impl HandshakeOffer {
99    pub fn validate(&self) -> Result<(), HandshakeError> {
100        if self.function_versions.len() > FUNCTION_VERSIONS_KEYS_MAX {
101            return Err(HandshakeError::FunctionCountExceeded {
102                count: self.function_versions.len(),
103                max: FUNCTION_VERSIONS_KEYS_MAX,
104            });
105        }
106
107        for (function, versions) in &self.function_versions {
108            if versions.len() > FUNCTION_VERSIONS_PER_FN_MAX {
109                return Err(HandshakeError::FunctionVersionCountExceeded {
110                    function: function.clone(),
111                    count: versions.len(),
112                    max: FUNCTION_VERSIONS_PER_FN_MAX,
113                });
114            }
115
116            for &version in versions {
117                if !(VERSION_MIN..=VERSION_MAX).contains(&version) {
118                    return Err(HandshakeError::VersionOutOfRange {
119                        version,
120                        min: VERSION_MIN,
121                        max: VERSION_MAX,
122                    });
123                }
124            }
125        }
126
127        if self.host_capabilities.len() > CAPABILITIES_MAX_COUNT {
128            return Err(HandshakeError::CapabilityCountExceeded {
129                count: self.host_capabilities.len(),
130                max: CAPABILITIES_MAX_COUNT,
131            });
132        }
133
134        Ok(())
135    }
136
137    pub fn canonical_hash(&self) -> Result<[u8; 32], CanonicalError> {
138        let raw = HandshakeOfferRaw {
139            handshake_schema_version: self.handshake_schema_version,
140            envelope_version: self.envelope_version,
141            function_versions: self
142                .function_versions
143                .iter()
144                .map(|(fn_name, versions)| FunctionVersionOfferRaw {
145                    function: fn_name.clone(),
146                    versions: versions.clone(),
147                })
148                .collect(),
149            host_capabilities: self.host_capabilities.iter().cloned().collect(),
150        };
151        host_offer_hash(&raw)
152    }
153}
154
155impl HandshakeAccept {
156    #[allow(clippy::collapsible_if)]
157    pub fn validate_against_offer(&self, offer: &HandshakeOffer) -> Result<(), HandshakeError> {
158        if self.handshake_schema_version != offer.handshake_schema_version {
159            return Err(HandshakeError::HandshakeSchemaVersionMismatch {
160                got: self.handshake_schema_version,
161                expected: offer.handshake_schema_version,
162            });
163        }
164
165        for (function, &chosen) in &self.chosen_versions {
166            if !self.plugin_supported.contains_key(function) {
167                return Err(HandshakeError::ChosenForUnknownFunction {
168                    function: function.clone(),
169                });
170            }
171
172            if !offer.function_versions.contains_key(function) {
173                return Err(HandshakeError::ChosenForUnknownFunction {
174                    function: function.clone(),
175                });
176            }
177
178            let plugin_versions = &self.plugin_supported[function];
179            let offer_versions = &offer.function_versions[function];
180
181            if !offer_versions.contains(&chosen) {
182                return Err(HandshakeError::ChosenVersionNotOffered {
183                    function: function.clone(),
184                    version: chosen,
185                });
186            }
187
188            let intersection_max = offer_versions
189                .iter()
190                .filter(|v| plugin_versions.contains(v))
191                .max()
192                .copied();
193
194            if let Some(max_intersection) = intersection_max {
195                if chosen < max_intersection {
196                    return Err(HandshakeError::DowngradeAttempt {
197                        function: function.clone(),
198                        chosen,
199                        max_intersection,
200                    });
201                }
202            }
203        }
204
205        for capability in &self.required_capabilities {
206            if !offer.host_capabilities.contains(capability) {
207                return Err(HandshakeError::RequiredCapabilityUnavailable {
208                    capability: capability.clone(),
209                });
210            }
211        }
212
213        Ok(())
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    extern crate std;
220
221    use super::*;
222    use alloc::string::ToString;
223    use alloc::vec;
224
225    #[test]
226    fn handshake_schema_version_locked() {
227        assert_eq!(HANDSHAKE_SCHEMA_VERSION_V1, 1);
228    }
229
230    #[test]
231    fn empty_offer_valid() {
232        let offer = HandshakeOffer {
233            handshake_schema_version: 1,
234            envelope_version: 1,
235            function_versions: BTreeMap::new(),
236            host_capabilities: BTreeSet::new(),
237        };
238        assert!(offer.validate().is_ok());
239    }
240
241    #[test]
242    fn function_count_exceeds_limit() {
243        let mut functions = BTreeMap::new();
244        for i in 0..=FUNCTION_VERSIONS_KEYS_MAX {
245            functions.insert(alloc::format!("fn{}", i), vec![1]);
246        }
247
248        let offer = HandshakeOffer {
249            handshake_schema_version: 1,
250            envelope_version: 1,
251            function_versions: functions,
252            host_capabilities: BTreeSet::new(),
253        };
254
255        match offer.validate() {
256            Err(HandshakeError::FunctionCountExceeded { .. }) => {}
257            other => panic!("expected FunctionCountExceeded, got {:?}", other),
258        }
259    }
260
261    #[test]
262    fn version_count_per_function_exceeds_limit() {
263        let mut functions = BTreeMap::new();
264        functions.insert(
265            "route".to_string(),
266            (1..=FUNCTION_VERSIONS_PER_FN_MAX + 1)
267                .map(|v| v as u32)
268                .collect::<Vec<_>>(),
269        );
270
271        let offer = HandshakeOffer {
272            handshake_schema_version: 1,
273            envelope_version: 1,
274            function_versions: functions,
275            host_capabilities: BTreeSet::new(),
276        };
277
278        match offer.validate() {
279            Err(HandshakeError::FunctionVersionCountExceeded { .. }) => {}
280            other => panic!("expected FunctionVersionCountExceeded, got {:?}", other),
281        }
282    }
283
284    #[test]
285    fn invalid_version_too_low() {
286        let mut functions = BTreeMap::new();
287        functions.insert("route".to_string(), vec![0]);
288
289        let offer = HandshakeOffer {
290            handshake_schema_version: 1,
291            envelope_version: 1,
292            function_versions: functions,
293            host_capabilities: BTreeSet::new(),
294        };
295
296        match offer.validate() {
297            Err(HandshakeError::VersionOutOfRange { .. }) => {}
298            other => panic!("expected VersionOutOfRange, got {:?}", other),
299        }
300    }
301
302    #[test]
303    fn capability_count_exceeds_limit() {
304        let mut capabilities = BTreeSet::new();
305        for i in 0..=CAPABILITIES_MAX_COUNT {
306            capabilities.insert(alloc::format!("cap{}", i));
307        }
308
309        let offer = HandshakeOffer {
310            handshake_schema_version: 1,
311            envelope_version: 1,
312            function_versions: BTreeMap::new(),
313            host_capabilities: capabilities,
314        };
315
316        match offer.validate() {
317            Err(HandshakeError::CapabilityCountExceeded { .. }) => {}
318            other => panic!("expected CapabilityCountExceeded, got {:?}", other),
319        }
320    }
321
322    #[test]
323    fn downgrade_attack_rejected() {
324        let mut offer_fns = BTreeMap::new();
325        offer_fns.insert("route".to_string(), vec![1, 2, 3]);
326
327        let offer = HandshakeOffer {
328            handshake_schema_version: 1,
329            envelope_version: 1,
330            function_versions: offer_fns,
331            host_capabilities: BTreeSet::new(),
332        };
333
334        let mut plugin_supported = BTreeMap::new();
335        plugin_supported.insert("route".to_string(), vec![1, 2, 3]);
336
337        let mut chosen = BTreeMap::new();
338        chosen.insert("route".to_string(), 1);
339
340        let accept = HandshakeAccept {
341            handshake_schema_version: 1,
342            envelope_version: 1,
343            chosen_versions: chosen,
344            plugin_supported,
345            implemented_functions: BTreeSet::new(),
346            required_capabilities: BTreeSet::new(),
347        };
348
349        match accept.validate_against_offer(&offer) {
350            Err(HandshakeError::DowngradeAttempt {
351                function,
352                chosen,
353                max_intersection,
354            }) => {
355                assert_eq!(function, "route");
356                assert_eq!(chosen, 1);
357                assert_eq!(max_intersection, 3);
358            }
359            other => panic!("expected DowngradeAttempt, got {:?}", other),
360        }
361    }
362
363    #[test]
364    fn chosen_not_in_offer_rejected() {
365        let mut offer_fns = BTreeMap::new();
366        offer_fns.insert("route".to_string(), vec![1, 2]);
367
368        let offer = HandshakeOffer {
369            handshake_schema_version: 1,
370            envelope_version: 1,
371            function_versions: offer_fns,
372            host_capabilities: BTreeSet::new(),
373        };
374
375        let mut plugin_supported = BTreeMap::new();
376        plugin_supported.insert("route".to_string(), vec![1, 2, 3]);
377
378        let mut chosen = BTreeMap::new();
379        chosen.insert("route".to_string(), 99);
380
381        let accept = HandshakeAccept {
382            handshake_schema_version: 1,
383            envelope_version: 1,
384            chosen_versions: chosen,
385            plugin_supported,
386            implemented_functions: BTreeSet::new(),
387            required_capabilities: BTreeSet::new(),
388        };
389
390        match accept.validate_against_offer(&offer) {
391            Err(HandshakeError::ChosenVersionNotOffered { .. }) => {}
392            other => panic!("expected ChosenVersionNotOffered, got {:?}", other),
393        }
394    }
395
396    #[test]
397    fn missing_required_capability() {
398        let offer = HandshakeOffer {
399            handshake_schema_version: 1,
400            envelope_version: 1,
401            function_versions: BTreeMap::new(),
402            host_capabilities: {
403                let mut caps = BTreeSet::new();
404                caps.insert("log".to_string());
405                caps
406            },
407        };
408
409        let accept = HandshakeAccept {
410            handshake_schema_version: 1,
411            envelope_version: 1,
412            chosen_versions: BTreeMap::new(),
413            plugin_supported: BTreeMap::new(),
414            implemented_functions: BTreeSet::new(),
415            required_capabilities: {
416                let mut caps = BTreeSet::new();
417                caps.insert("trace".to_string());
418                caps
419            },
420        };
421
422        match accept.validate_against_offer(&offer) {
423            Err(HandshakeError::RequiredCapabilityUnavailable { capability }) => {
424                assert_eq!(capability, "trace");
425            }
426            other => panic!("expected RequiredCapabilityUnavailable, got {:?}", other),
427        }
428    }
429
430    #[test]
431    fn schema_version_mismatch() {
432        let offer = HandshakeOffer {
433            handshake_schema_version: 1,
434            envelope_version: 1,
435            function_versions: BTreeMap::new(),
436            host_capabilities: BTreeSet::new(),
437        };
438
439        let accept = HandshakeAccept {
440            handshake_schema_version: 2,
441            envelope_version: 1,
442            chosen_versions: BTreeMap::new(),
443            plugin_supported: BTreeMap::new(),
444            implemented_functions: BTreeSet::new(),
445            required_capabilities: BTreeSet::new(),
446        };
447
448        match accept.validate_against_offer(&offer) {
449            Err(HandshakeError::HandshakeSchemaVersionMismatch { .. }) => {}
450            other => panic!("expected HandshakeSchemaVersionMismatch, got {:?}", other),
451        }
452    }
453
454    #[test]
455    fn valid_handshake_roundtrip() {
456        let mut offer_fns = BTreeMap::new();
457        offer_fns.insert("route".to_string(), vec![1, 2]);
458        offer_fns.insert("shape".to_string(), vec![1]);
459
460        let mut offer_caps = BTreeSet::new();
461        offer_caps.insert("streaming".to_string());
462
463        let offer = HandshakeOffer {
464            handshake_schema_version: 1,
465            envelope_version: 1,
466            function_versions: offer_fns,
467            host_capabilities: offer_caps,
468        };
469
470        assert!(offer.validate().is_ok());
471
472        let mut plugin_supported = BTreeMap::new();
473        plugin_supported.insert("route".to_string(), vec![1, 2]);
474        plugin_supported.insert("shape".to_string(), vec![1]);
475
476        let mut plugin_caps = BTreeSet::new();
477        plugin_caps.insert("streaming".to_string());
478
479        let mut chosen = BTreeMap::new();
480        chosen.insert("route".to_string(), 2);
481        chosen.insert("shape".to_string(), 1);
482
483        let accept = HandshakeAccept {
484            handshake_schema_version: 1,
485            envelope_version: 1,
486            chosen_versions: chosen,
487            plugin_supported,
488            implemented_functions: {
489                let mut fns = BTreeSet::new();
490                fns.insert("route".to_string());
491                fns.insert("shape".to_string());
492                fns
493            },
494            required_capabilities: plugin_caps,
495        };
496
497        assert!(accept.validate_against_offer(&offer).is_ok());
498    }
499
500    #[test]
501    fn serde_roundtrip_offer() {
502        let mut offer_fns = BTreeMap::new();
503        offer_fns.insert("route".to_string(), vec![1, 2]);
504
505        let original = HandshakeOffer {
506            handshake_schema_version: 1,
507            envelope_version: 1,
508            function_versions: offer_fns,
509            host_capabilities: {
510                let mut caps = BTreeSet::new();
511                caps.insert("streaming".to_string());
512                caps
513            },
514        };
515
516        let json = serde_json::to_vec(&original).expect("serialize");
517        let deserialized: HandshakeOffer = serde_json::from_slice(&json).expect("deserialize");
518
519        assert_eq!(original, deserialized);
520    }
521
522    #[test]
523    fn serde_roundtrip_accept() {
524        let mut plugin_supported = BTreeMap::new();
525        plugin_supported.insert("route".to_string(), vec![1, 2]);
526
527        let mut chosen = BTreeMap::new();
528        chosen.insert("route".to_string(), 2);
529
530        let original = HandshakeAccept {
531            handshake_schema_version: 1,
532            envelope_version: 1,
533            chosen_versions: chosen,
534            plugin_supported,
535            implemented_functions: {
536                let mut fns = BTreeSet::new();
537                fns.insert("route".to_string());
538                fns
539            },
540            required_capabilities: {
541                let mut caps = BTreeSet::new();
542                caps.insert("streaming".to_string());
543                caps
544            },
545        };
546
547        let json = serde_json::to_vec(&original).expect("serialize");
548        let deserialized: HandshakeAccept = serde_json::from_slice(&json).expect("deserialize");
549
550        assert_eq!(original, deserialized);
551    }
552}