cc_lb_plugin_wire/handshake/
canonical.rs1extern crate alloc;
2
3use alloc::string::{String, ToString};
4use alloc::{collections::BTreeMap, collections::BTreeSet, vec::Vec};
5use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7use thiserror::Error;
8
9use super::HandshakeOfferRaw;
10
11#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(deny_unknown_fields)]
13pub struct CanonicalOffer {
14 pub handshake_schema_version: u32,
15 pub envelope_version: u32,
16 pub function_versions: BTreeMap<String, Vec<u32>>,
17 pub host_capabilities: BTreeSet<String>,
18}
19
20#[derive(Clone, Debug, Error)]
21pub enum CanonicalError {
22 #[error("failed to serialize canonical handshake offer: {0}")]
23 Serialize(String),
24}
25
26pub fn canonicalize(offer: &HandshakeOfferRaw) -> CanonicalOffer {
27 let mut function_versions = BTreeMap::new();
28
29 for offered_function in &offer.function_versions {
30 let versions = function_versions
31 .entry(offered_function.function.clone())
32 .or_insert_with(Vec::new);
33 versions.extend(offered_function.versions.iter().copied());
34 }
35
36 for versions in function_versions.values_mut() {
37 versions.sort_unstable();
38 versions.dedup();
39 }
40
41 let host_capabilities = offer
42 .host_capabilities
43 .iter()
44 .map(|capability| capability.to_ascii_lowercase())
45 .collect();
46
47 CanonicalOffer {
48 handshake_schema_version: offer.handshake_schema_version,
49 envelope_version: offer.envelope_version,
50 function_versions,
51 host_capabilities,
52 }
53}
54
55pub fn host_offer_hash(offer: &HandshakeOfferRaw) -> Result<[u8; 32], CanonicalError> {
56 let canonical = canonicalize(offer);
57 let bytes =
58 serde_json::to_vec(&canonical).map_err(|e| CanonicalError::Serialize(e.to_string()))?;
59 Ok(Sha256::digest(bytes).into())
60}
61
62#[cfg(test)]
63mod tests {
64 extern crate std;
65
66 use super::*;
67 use alloc::{string::ToString, vec};
68 use proptest::prelude::*;
69
70 use crate::handshake::FunctionVersionOfferRaw;
71
72 fn function_name_strategy() -> impl Strategy<Value = String> {
73 prop_oneof![
74 Just("shape".to_string()),
75 Just("normalize_error".to_string()),
76 Just("build_signer".to_string()),
77 Just("sign".to_string()),
78 Just("on_unauthorized".to_string()),
79 Just("observe".to_string()),
80 ]
81 }
82
83 fn capability_strategy() -> impl Strategy<Value = String> {
84 prop_oneof![
85 Just("Streaming".to_string()),
86 Just("streaming".to_string()),
87 Just("BATCHING".to_string()),
88 Just("batching".to_string()),
89 Just("Trace-Context".to_string()),
90 Just("trace-context".to_string()),
91 Just("CANARY".to_string()),
92 ]
93 }
94
95 fn function_offer_strategy() -> impl Strategy<Value = FunctionVersionOfferRaw> {
96 (
97 function_name_strategy(),
98 prop::collection::vec(0u32..16, 0..8),
99 )
100 .prop_map(|(function, versions)| FunctionVersionOfferRaw { function, versions })
101 }
102
103 fn offer_strategy() -> impl Strategy<Value = HandshakeOfferRaw> {
104 (
105 0u32..4,
106 0u32..4,
107 prop::collection::vec(function_offer_strategy(), 0..24),
108 prop::collection::vec(capability_strategy(), 0..24),
109 )
110 .prop_map(
111 |(
112 handshake_schema_version,
113 envelope_version,
114 function_versions,
115 host_capabilities,
116 )| {
117 HandshakeOfferRaw {
118 handshake_schema_version,
119 envelope_version,
120 function_versions,
121 host_capabilities,
122 }
123 },
124 )
125 }
126
127 fn keyed_permutation<Item: Clone>(values: &[Item], keys: &[u64]) -> Vec<Item> {
128 let mut keyed_values: Vec<_> = values
129 .iter()
130 .cloned()
131 .enumerate()
132 .map(|(value_index, value)| {
133 (
134 keys.get(value_index).copied().unwrap_or(value_index as u64),
135 value_index,
136 value,
137 )
138 })
139 .collect();
140 keyed_values.sort_by_key(|(key, value_index, _)| (*key, *value_index));
141 keyed_values
142 .into_iter()
143 .map(|(_, _, value)| value)
144 .collect()
145 }
146
147 fn shuffled(
148 offer: &HandshakeOfferRaw,
149 function_keys: &[u64],
150 capability_keys: &[u64],
151 version_keys: &[Vec<u64>],
152 ) -> HandshakeOfferRaw {
153 let mut function_versions = keyed_permutation(&offer.function_versions, function_keys);
154 for (function_index, function) in function_versions.iter_mut().enumerate() {
155 let keys = version_keys
156 .get(function_index)
157 .map(Vec::as_slice)
158 .unwrap_or_default();
159 function.versions = keyed_permutation(&function.versions, keys);
160 }
161
162 HandshakeOfferRaw {
163 handshake_schema_version: offer.handshake_schema_version,
164 envelope_version: offer.envelope_version,
165 function_versions,
166 host_capabilities: keyed_permutation(&offer.host_capabilities, capability_keys),
167 }
168 }
169
170 proptest! {
171 #[test]
172 fn canonicalize_is_shuffle_invariant(
173 offer in offer_strategy(),
174 function_keys in prop::collection::vec(any::<u64>(), 0..32),
175 capability_keys in prop::collection::vec(any::<u64>(), 0..32),
176 version_keys in prop::collection::vec(prop::collection::vec(any::<u64>(), 0..16), 0..32),
177 ) {
178 let shuffled = shuffled(&offer, &function_keys, &capability_keys, &version_keys);
179
180 prop_assert_eq!(canonicalize(&offer), canonicalize(&shuffled));
181 }
182
183 #[test]
184 fn host_offer_hash_is_shuffle_invariant(
185 offer in offer_strategy(),
186 function_keys in prop::collection::vec(any::<u64>(), 0..32),
187 capability_keys in prop::collection::vec(any::<u64>(), 0..32),
188 version_keys in prop::collection::vec(prop::collection::vec(any::<u64>(), 0..16), 0..32),
189 ) {
190 let shuffled = shuffled(&offer, &function_keys, &capability_keys, &version_keys);
191
192 prop_assert_eq!(
193 host_offer_hash(&offer).unwrap(),
194 host_offer_hash(&shuffled).unwrap(),
195 );
196 }
197 }
198
199 #[test]
200 fn canonicalize_sorts_dedups_and_lowercases() {
201 let offer = HandshakeOfferRaw {
202 handshake_schema_version: 1,
203 envelope_version: 2,
204 function_versions: vec![
205 FunctionVersionOfferRaw {
206 function: "route".to_string(),
207 versions: vec![3, 1, 3],
208 },
209 FunctionVersionOfferRaw {
210 function: "route".to_string(),
211 versions: vec![2, 1],
212 },
213 ],
214 host_capabilities: vec!["Streaming".to_string(), "streaming".to_string()],
215 };
216
217 let canonical = canonicalize(&offer);
218
219 assert_eq!(canonical.function_versions["route"], vec![1, 2, 3]);
220 assert_eq!(canonical.host_capabilities.len(), 1);
221 assert!(canonical.host_capabilities.contains("streaming"));
222 }
223}