Skip to main content

cc_lb_plugin_wire/handshake/
canonical.rs

1extern crate alloc;
2
3use alloc::string::{String, ToString};
4use alloc::{collections::BTreeMap, collections::BTreeSet, vec::Vec};
5use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7use thiserror::Error;
8
9use super::HandshakeOfferRaw;
10
11#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(deny_unknown_fields)]
13pub struct CanonicalOffer {
14    pub handshake_schema_version: u32,
15    pub envelope_version: u32,
16    pub function_versions: BTreeMap<String, Vec<u32>>,
17    pub host_capabilities: BTreeSet<String>,
18}
19
20#[derive(Clone, Debug, Error)]
21pub enum CanonicalError {
22    #[error("failed to serialize canonical handshake offer: {0}")]
23    Serialize(String),
24}
25
26pub fn canonicalize(offer: &HandshakeOfferRaw) -> CanonicalOffer {
27    let mut function_versions = BTreeMap::new();
28
29    for offered_function in &offer.function_versions {
30        let versions = function_versions
31            .entry(offered_function.function.clone())
32            .or_insert_with(Vec::new);
33        versions.extend(offered_function.versions.iter().copied());
34    }
35
36    for versions in function_versions.values_mut() {
37        versions.sort_unstable();
38        versions.dedup();
39    }
40
41    let host_capabilities = offer
42        .host_capabilities
43        .iter()
44        .map(|capability| capability.to_ascii_lowercase())
45        .collect();
46
47    CanonicalOffer {
48        handshake_schema_version: offer.handshake_schema_version,
49        envelope_version: offer.envelope_version,
50        function_versions,
51        host_capabilities,
52    }
53}
54
55pub fn host_offer_hash(offer: &HandshakeOfferRaw) -> Result<[u8; 32], CanonicalError> {
56    let canonical = canonicalize(offer);
57    let bytes =
58        serde_json::to_vec(&canonical).map_err(|e| CanonicalError::Serialize(e.to_string()))?;
59    Ok(Sha256::digest(bytes).into())
60}
61
62#[cfg(test)]
63mod tests {
64    extern crate std;
65
66    use super::*;
67    use alloc::{string::ToString, vec};
68    use proptest::prelude::*;
69
70    use crate::handshake::FunctionVersionOfferRaw;
71
72    fn function_name_strategy() -> impl Strategy<Value = String> {
73        prop_oneof![
74            Just("shape".to_string()),
75            Just("normalize_error".to_string()),
76            Just("build_signer".to_string()),
77            Just("sign".to_string()),
78            Just("on_unauthorized".to_string()),
79            Just("observe".to_string()),
80        ]
81    }
82
83    fn capability_strategy() -> impl Strategy<Value = String> {
84        prop_oneof![
85            Just("Streaming".to_string()),
86            Just("streaming".to_string()),
87            Just("BATCHING".to_string()),
88            Just("batching".to_string()),
89            Just("Trace-Context".to_string()),
90            Just("trace-context".to_string()),
91            Just("CANARY".to_string()),
92        ]
93    }
94
95    fn function_offer_strategy() -> impl Strategy<Value = FunctionVersionOfferRaw> {
96        (
97            function_name_strategy(),
98            prop::collection::vec(0u32..16, 0..8),
99        )
100            .prop_map(|(function, versions)| FunctionVersionOfferRaw { function, versions })
101    }
102
103    fn offer_strategy() -> impl Strategy<Value = HandshakeOfferRaw> {
104        (
105            0u32..4,
106            0u32..4,
107            prop::collection::vec(function_offer_strategy(), 0..24),
108            prop::collection::vec(capability_strategy(), 0..24),
109        )
110            .prop_map(
111                |(
112                    handshake_schema_version,
113                    envelope_version,
114                    function_versions,
115                    host_capabilities,
116                )| {
117                    HandshakeOfferRaw {
118                        handshake_schema_version,
119                        envelope_version,
120                        function_versions,
121                        host_capabilities,
122                    }
123                },
124            )
125    }
126
127    fn keyed_permutation<Item: Clone>(values: &[Item], keys: &[u64]) -> Vec<Item> {
128        let mut keyed_values: Vec<_> = values
129            .iter()
130            .cloned()
131            .enumerate()
132            .map(|(value_index, value)| {
133                (
134                    keys.get(value_index).copied().unwrap_or(value_index as u64),
135                    value_index,
136                    value,
137                )
138            })
139            .collect();
140        keyed_values.sort_by_key(|(key, value_index, _)| (*key, *value_index));
141        keyed_values
142            .into_iter()
143            .map(|(_, _, value)| value)
144            .collect()
145    }
146
147    fn shuffled(
148        offer: &HandshakeOfferRaw,
149        function_keys: &[u64],
150        capability_keys: &[u64],
151        version_keys: &[Vec<u64>],
152    ) -> HandshakeOfferRaw {
153        let mut function_versions = keyed_permutation(&offer.function_versions, function_keys);
154        for (function_index, function) in function_versions.iter_mut().enumerate() {
155            let keys = version_keys
156                .get(function_index)
157                .map(Vec::as_slice)
158                .unwrap_or_default();
159            function.versions = keyed_permutation(&function.versions, keys);
160        }
161
162        HandshakeOfferRaw {
163            handshake_schema_version: offer.handshake_schema_version,
164            envelope_version: offer.envelope_version,
165            function_versions,
166            host_capabilities: keyed_permutation(&offer.host_capabilities, capability_keys),
167        }
168    }
169
170    proptest! {
171        #[test]
172        fn canonicalize_is_shuffle_invariant(
173            offer in offer_strategy(),
174            function_keys in prop::collection::vec(any::<u64>(), 0..32),
175            capability_keys in prop::collection::vec(any::<u64>(), 0..32),
176            version_keys in prop::collection::vec(prop::collection::vec(any::<u64>(), 0..16), 0..32),
177        ) {
178            let shuffled = shuffled(&offer, &function_keys, &capability_keys, &version_keys);
179
180            prop_assert_eq!(canonicalize(&offer), canonicalize(&shuffled));
181        }
182
183        #[test]
184        fn host_offer_hash_is_shuffle_invariant(
185            offer in offer_strategy(),
186            function_keys in prop::collection::vec(any::<u64>(), 0..32),
187            capability_keys in prop::collection::vec(any::<u64>(), 0..32),
188            version_keys in prop::collection::vec(prop::collection::vec(any::<u64>(), 0..16), 0..32),
189        ) {
190            let shuffled = shuffled(&offer, &function_keys, &capability_keys, &version_keys);
191
192            prop_assert_eq!(
193                host_offer_hash(&offer).unwrap(),
194                host_offer_hash(&shuffled).unwrap(),
195            );
196        }
197    }
198
199    #[test]
200    fn canonicalize_sorts_dedups_and_lowercases() {
201        let offer = HandshakeOfferRaw {
202            handshake_schema_version: 1,
203            envelope_version: 2,
204            function_versions: vec![
205                FunctionVersionOfferRaw {
206                    function: "route".to_string(),
207                    versions: vec![3, 1, 3],
208                },
209                FunctionVersionOfferRaw {
210                    function: "route".to_string(),
211                    versions: vec![2, 1],
212                },
213            ],
214            host_capabilities: vec!["Streaming".to_string(), "streaming".to_string()],
215        };
216
217        let canonical = canonicalize(&offer);
218
219        assert_eq!(canonical.function_versions["route"], vec![1, 2, 3]);
220        assert_eq!(canonical.host_capabilities.len(), 1);
221        assert!(canonical.host_capabilities.contains("streaming"));
222    }
223}