1use std::collections::{BTreeMap, BTreeSet};
2use std::time::Duration;
3
4use cc_lb_plugin_api::PluginSlot;
5use cc_lb_plugin_wire::handshake::{
6 HANDSHAKE_SCHEMA_VERSION_V1, HandshakeAccept, HandshakeError, HandshakeOffer,
7};
8use cc_lb_plugin_wire::limits::{
9 HANDSHAKE_FUEL, HANDSHAKE_OUTPUT_MAX_BYTES, HANDSHAKE_WALL_MS, IMPLEMENTED_FUNCTIONS_MAX,
10 VERSION_MAX, VERSION_MIN,
11};
12use cc_lb_plugin_wire::v1::{
13 build_signer::BuildSignerFn, normalize_error::NormalizeErrorFn, observe::ObserveFn,
14 on_unauthorized::OnUnauthorizedFn, shape::ShapeFn, sign::SignFn,
15};
16use cc_lb_plugin_wire::v3::filter::FilterFn;
17use cc_lb_plugin_wire::wire_function::{WireFunction, all_wire_functions};
18use extism::{Manifest, PluginBuilder, Wasm};
19use thiserror::Error;
20
21const HANDSHAKE_EXPORT: &str = "cc_lb_handshake";
22const ENVELOPE_VERSION_V1: u32 = 1;
23
24pub fn slot_set_from_handshake(implemented_functions: &BTreeSet<String>) -> Vec<PluginSlot> {
28 let mut slots = BTreeSet::new();
29 for name in implemented_functions {
30 if let Some(slot) = wire_function_to_slot(name) {
31 slots.insert(slot);
32 }
33 }
34 slots.into_iter().collect()
35}
36
37fn wire_function_to_slot(name: &str) -> Option<PluginSlot> {
38 match name {
39 name if name == slot_to_wire_function(PluginSlot::Router) => Some(PluginSlot::Router),
40 name if name == slot_to_wire_function(PluginSlot::Shape) => Some(PluginSlot::Shape),
41 name if name == slot_to_wire_function(PluginSlot::ObservabilityHook) => {
42 Some(PluginSlot::ObservabilityHook)
43 }
44 _ => None,
45 }
46}
47
48pub(crate) fn slot_to_wire_function(slot: PluginSlot) -> &'static str {
49 match slot {
50 PluginSlot::Router => "filter",
51 PluginSlot::Shape => "shape",
52 PluginSlot::ObservabilityHook => "observe",
53 }
54}
55
56pub fn build_offer(host_caps: &BTreeSet<String>) -> HandshakeOffer {
57 let mut function_versions = BTreeMap::new();
58 for (name, versions) in wire_function_versions() {
59 function_versions.insert(name.to_owned(), versions.to_vec());
60 }
61
62 HandshakeOffer {
63 handshake_schema_version: HANDSHAKE_SCHEMA_VERSION_V1,
64 envelope_version: ENVELOPE_VERSION_V1,
65 function_versions,
66 host_capabilities: host_caps.clone(),
67 }
68}
69
70pub fn build_plugin(
71 wasm: &[u8],
72 wall_ms: u64,
73 fuel: u64,
74) -> Result<extism::Plugin, BuildPluginError> {
75 let manifest = Manifest::new([Wasm::data(wasm.to_vec())])
76 .with_timeout(Duration::from_millis(wall_ms))
77 .disallow_all_hosts();
78 PluginBuilder::new(&manifest)
79 .with_wasi(false)
80 .with_cache_disabled()
81 .with_fuel_limit(fuel)
82 .build()
83 .map_err(|source| BuildPluginError::Instantiate {
84 reason: source.to_string(),
85 })
86}
87
88#[non_exhaustive]
89#[derive(Debug, Error)]
90pub enum BuildPluginError {
91 #[error("failed to instantiate plugin: {reason}")]
92 Instantiate { reason: String },
93}
94
95pub fn execute_handshake(
96 plugin_bytes: &[u8],
97 offer: &HandshakeOffer,
98) -> Result<HandshakeAccept, HandshakeExecutionError> {
99 offer.validate()?;
100 metrics::counter!("cc_lb_plugin_handshake_total").increment(1);
101
102 let mut plugin = build_plugin(plugin_bytes, HANDSHAKE_WALL_MS, HANDSHAKE_FUEL).map_err(
103 |source| match source {
104 BuildPluginError::Instantiate { reason } => {
105 HandshakeExecutionError::Instantiate { reason }
106 }
107 },
108 )?;
109
110 if !plugin.function_exists(HANDSHAKE_EXPORT) {
111 return Err(HandshakeExecutionError::MissingHandshakeExport);
112 }
113
114 let request =
115 serde_json::to_string(offer).map_err(|source| HandshakeExecutionError::SerializeOffer {
116 reason: source.to_string(),
117 })?;
118 let response = plugin
119 .call::<&str, String>(HANDSHAKE_EXPORT, request.as_str())
120 .map_err(|source| classify_call_error(source.to_string()))?;
121
122 if response.len() > HANDSHAKE_OUTPUT_MAX_BYTES {
123 return Err(HandshakeExecutionError::OutputTooLarge {
124 bytes: response.len(),
125 max: HANDSHAKE_OUTPUT_MAX_BYTES,
126 });
127 }
128
129 let accept: HandshakeAccept = serde_json::from_str(&response).map_err(|source| {
130 HandshakeExecutionError::DecodeAccept {
131 reason: source.to_string(),
132 }
133 })?;
134 accept.validate_against_offer(offer)?;
135 validate_accept_shape(&accept, offer)?;
136 cross_check_implemented_exports(&plugin, &accept)?;
137
138 Ok(accept)
139}
140
141fn classify_call_error(reason: String) -> HandshakeExecutionError {
142 let lower = reason.to_ascii_lowercase();
143 if lower.contains("timeout")
144 || lower.contains("timed out")
145 || lower.contains("deadline")
146 || lower.contains("fuel")
147 {
148 HandshakeExecutionError::Timeout
149 } else {
150 HandshakeExecutionError::Call { reason }
151 }
152}
153
154fn wire_function_versions() -> [(&'static str, &'static [u32]); 7] {
155 [
156 (
157 <ShapeFn as WireFunction>::NAME,
158 <ShapeFn as WireFunction>::SUPPORTED_VERSIONS,
159 ),
160 (
161 <NormalizeErrorFn as WireFunction>::NAME,
162 <NormalizeErrorFn as WireFunction>::SUPPORTED_VERSIONS,
163 ),
164 (
165 <BuildSignerFn as WireFunction>::NAME,
166 <BuildSignerFn as WireFunction>::SUPPORTED_VERSIONS,
167 ),
168 (
169 <SignFn as WireFunction>::NAME,
170 <SignFn as WireFunction>::SUPPORTED_VERSIONS,
171 ),
172 (
173 <OnUnauthorizedFn as WireFunction>::NAME,
174 <OnUnauthorizedFn as WireFunction>::SUPPORTED_VERSIONS,
175 ),
176 (
177 <ObserveFn as WireFunction>::NAME,
178 <ObserveFn as WireFunction>::SUPPORTED_VERSIONS,
179 ),
180 (
181 <FilterFn as WireFunction>::NAME,
182 <FilterFn as WireFunction>::SUPPORTED_VERSIONS,
183 ),
184 ]
185}
186
187fn validate_accept_shape(
188 accept: &HandshakeAccept,
189 offer: &HandshakeOffer,
190) -> Result<(), HandshakeExecutionError> {
191 if accept.implemented_functions.len() > IMPLEMENTED_FUNCTIONS_MAX {
192 return Err(HandshakeExecutionError::ImplementedFunctionCountExceeded {
193 count: accept.implemented_functions.len(),
194 max: IMPLEMENTED_FUNCTIONS_MAX,
195 });
196 }
197
198 for function in &accept.implemented_functions {
199 if !offer.function_versions.contains_key(function) {
200 return Err(HandshakeExecutionError::ImplementedUnknownFunction {
201 function: function.clone(),
202 });
203 }
204 }
205
206 for (function, version) in &accept.plugin_supported {
207 if !offer.function_versions.contains_key(function) {
208 return Err(HandshakeExecutionError::SupportedUnknownFunction {
209 function: function.clone(),
210 });
211 }
212 for &supported in version {
213 if !(VERSION_MIN..=VERSION_MAX).contains(&supported) {
214 return Err(HandshakeExecutionError::SupportedVersionOutOfRange {
215 function: function.clone(),
216 version: supported,
217 min: VERSION_MIN,
218 max: VERSION_MAX,
219 });
220 }
221 }
222 }
223
224 for (function, chosen) in &accept.chosen_versions {
225 if !accept.implemented_functions.contains(function) {
226 return Err(HandshakeExecutionError::ChosenFunctionNotImplemented {
227 function: function.clone(),
228 });
229 }
230 let Some(supported) = accept.plugin_supported.get(function) else {
231 return Err(HandshakeExecutionError::ChosenVersionNotSupported {
232 function: function.clone(),
233 version: *chosen,
234 });
235 };
236 if !supported.contains(chosen) {
237 return Err(HandshakeExecutionError::ChosenVersionNotSupported {
238 function: function.clone(),
239 version: *chosen,
240 });
241 }
242 }
243
244 Ok(())
245}
246
247fn cross_check_implemented_exports(
248 plugin: &extism::Plugin,
249 accept: &HandshakeAccept,
250) -> Result<(), HandshakeExecutionError> {
251 for function in &accept.implemented_functions {
252 if !plugin.function_exists(function) {
253 return Err(HandshakeExecutionError::DeclaredFunctionMissing {
254 function: function.clone(),
255 });
256 }
257 }
258 for function in all_wire_functions() {
259 if plugin.function_exists(function) && !accept.implemented_functions.contains(*function) {
260 return Err(HandshakeExecutionError::UndeclaredExport {
261 function: (*function).to_owned(),
262 });
263 }
264 }
265 Ok(())
266}
267
268#[non_exhaustive]
269#[derive(Debug, Error)]
270pub enum HandshakeExecutionError {
271 #[error("handshake validation failed: {0}")]
272 Validation(#[from] HandshakeError),
273 #[error("handshake plugin instantiation failed: {reason}")]
274 Instantiate { reason: String },
275 #[error("plugin does not export cc_lb_handshake")]
276 MissingHandshakeExport,
277 #[error("handshake offer serialization failed: {reason}")]
278 SerializeOffer { reason: String },
279 #[error("handshake call failed: {reason}")]
280 Call { reason: String },
281 #[error("handshake call exceeded timeout/fuel budget")]
282 Timeout,
283 #[error("handshake output size {bytes} exceeds maximum {max}")]
284 OutputTooLarge { bytes: usize, max: usize },
285 #[error("handshake accept decode failed: {reason}")]
286 DecodeAccept { reason: String },
287 #[error("implemented function count {count} exceeds maximum {max}")]
288 ImplementedFunctionCountExceeded { count: usize, max: usize },
289 #[error("implemented unknown function: {function}")]
290 ImplementedUnknownFunction { function: String },
291 #[error("supported unknown function: {function}")]
292 SupportedUnknownFunction { function: String },
293 #[error(
294 "supported version {version} for function {function} outside valid range [{min}, {max}]"
295 )]
296 SupportedVersionOutOfRange {
297 function: String,
298 version: u32,
299 min: u32,
300 max: u32,
301 },
302 #[error("chosen function not listed as implemented: {function}")]
303 ChosenFunctionNotImplemented { function: String },
304 #[error("chosen version {version} for function {function} not listed as plugin-supported")]
305 ChosenVersionNotSupported { function: String, version: u32 },
306 #[error("declared function missing wasm export: {function}")]
307 DeclaredFunctionMissing { function: String },
308 #[error("undeclared wire function export present: {function}")]
309 UndeclaredExport { function: String },
310}
311
312#[cfg(test)]
313mod tests {
314 use std::collections::{BTreeMap, BTreeSet};
315
316 use cc_lb_plugin_wire::handshake::HandshakeError;
317 use serde_json::json;
318
319 use super::*;
320
321 #[test]
322 fn slot_set_from_handshake_maps_filter_shape_observe() {
323 let fns: BTreeSet<String> = ["filter", "shape", "observe"]
324 .into_iter()
325 .map(String::from)
326 .collect();
327 assert_eq!(
328 slot_set_from_handshake(&fns),
329 vec![
330 PluginSlot::Router,
331 PluginSlot::ObservabilityHook,
332 PluginSlot::Shape,
333 ],
334 );
335 }
336
337 #[test]
338 fn slot_set_from_handshake_handles_partial_exports() {
339 let only_shape: BTreeSet<String> = ["shape".to_owned()].into_iter().collect();
340 assert_eq!(
341 slot_set_from_handshake(&only_shape),
342 vec![PluginSlot::Shape]
343 );
344
345 let only_filter: BTreeSet<String> = ["filter".to_owned()].into_iter().collect();
346 assert_eq!(
347 slot_set_from_handshake(&only_filter),
348 vec![PluginSlot::Router],
349 );
350 }
351
352 #[test]
353 fn slot_set_from_handshake_is_empty_when_no_slot_functions_exported() {
354 let empty: BTreeSet<String> = BTreeSet::new();
355 assert!(slot_set_from_handshake(&empty).is_empty());
356
357 let unrelated: BTreeSet<String> = ["sign".to_owned(), "build_signer".to_owned()]
358 .into_iter()
359 .collect();
360 assert!(slot_set_from_handshake(&unrelated).is_empty());
361 }
362
363 #[test]
364 fn slot_set_from_handshake_ignores_unknown_function_names_for_forward_compat() {
365 let mixed: BTreeSet<String> = ["filter", "future_slot_v9000"]
366 .into_iter()
367 .map(String::from)
368 .collect();
369 assert_eq!(slot_set_from_handshake(&mixed), vec![PluginSlot::Router]);
370 }
371
372 #[test]
373 fn build_offer_lists_v1_wire_functions_and_host_capabilities() {
374 let host_caps = BTreeSet::from(["streaming".to_owned(), "storage".to_owned()]);
375
376 let offer = build_offer(&host_caps);
377
378 assert_eq!(offer.handshake_schema_version, HANDSHAKE_SCHEMA_VERSION_V1);
379 assert_eq!(offer.envelope_version, ENVELOPE_VERSION_V1);
380 assert_eq!(offer.host_capabilities, host_caps);
381 for (name, versions) in wire_function_versions() {
382 assert_eq!(offer.function_versions.get(name), Some(&versions.to_vec()));
383 }
384 offer.validate().expect("host offer is valid");
385 }
386
387 #[test]
388 fn execute_handshake_accepts_valid_plugin_and_checks_export() {
389 let offer = build_offer(&BTreeSet::from(["streaming".to_owned()]));
390 let accept = accept_json(
391 &["shape"],
392 &[("shape", &[1])],
393 &[("shape", 1)],
394 &["streaming"],
395 );
396 let wasm = handshake_module(&accept, &["shape"], false);
397
398 let actual = execute_handshake(&wasm, &offer).expect("handshake succeeds");
399
400 assert!(actual.implemented_functions.contains("shape"));
401 assert_eq!(actual.chosen_versions.get("shape"), Some(&1));
402 }
403
404 #[test]
405 fn execute_handshake_rejects_downgrade() {
406 let mut offer = build_offer(&BTreeSet::new());
407 offer
408 .function_versions
409 .insert("shape".to_owned(), vec![1, 2, 3]);
410 let accept = accept_json(&["shape"], &[("shape", &[1, 2, 3])], &[("shape", 1)], &[]);
411 let wasm = handshake_module(&accept, &["shape"], false);
412
413 let err = execute_handshake(&wasm, &offer).expect_err("downgrade rejected");
414
415 match err {
416 HandshakeExecutionError::Validation(HandshakeError::DowngradeAttempt { .. }) => {}
417 other => panic!("expected downgrade error, got {other:?}"),
418 }
419 }
420
421 #[test]
422 fn execute_handshake_rejects_implemented_function_without_export() {
423 let offer = build_offer(&BTreeSet::new());
424 let accept = accept_json(&["shape"], &[("shape", &[1])], &[("shape", 1)], &[]);
425 let wasm = handshake_module(&accept, &[], false);
426
427 let err = execute_handshake(&wasm, &offer).expect_err("missing export rejected");
428
429 match err {
430 HandshakeExecutionError::DeclaredFunctionMissing { function } => {
431 assert_eq!(function, "shape");
432 }
433 other => panic!("expected missing export, got {other:?}"),
434 }
435 }
436
437 #[test]
438 fn execute_handshake_rejects_user_host_imports() {
439 let offer = build_offer(&BTreeSet::new());
440 let accept = accept_json(&[], &[], &[], &[]);
441 let wasm = handshake_module(&accept, &[], true);
442
443 let err = execute_handshake(&wasm, &offer).expect_err("host import rejected");
444
445 match err {
446 HandshakeExecutionError::Instantiate { .. } | HandshakeExecutionError::Call { .. } => {}
447 other => panic!("expected purity failure, got {other:?}"),
448 }
449 }
450
451 fn accept_json(
452 implemented: &[&str],
453 supported: &[(&str, &[u32])],
454 chosen: &[(&str, u32)],
455 required_caps: &[&str],
456 ) -> String {
457 let implemented_functions: BTreeSet<_> =
458 implemented.iter().map(|name| name.to_string()).collect();
459 let plugin_supported: BTreeMap<_, _> = supported
460 .iter()
461 .map(|(name, versions)| (name.to_string(), versions.to_vec()))
462 .collect();
463 let chosen_versions: BTreeMap<_, _> = chosen
464 .iter()
465 .map(|(name, version)| (name.to_string(), *version))
466 .collect();
467 let required_capabilities: BTreeSet<_> = required_caps
468 .iter()
469 .map(|capability| capability.to_string())
470 .collect();
471
472 json!({
473 "handshake_schema_version": HANDSHAKE_SCHEMA_VERSION_V1,
474 "envelope_version": ENVELOPE_VERSION_V1,
475 "chosen_versions": chosen_versions,
476 "plugin_supported": plugin_supported,
477 "implemented_functions": implemented_functions,
478 "required_capabilities": required_capabilities,
479 })
480 .to_string()
481 }
482
483 fn handshake_module(output: &str, extra_exports: &[&str], import_user_host: bool) -> Vec<u8> {
484 let output_helper = bytes_helper("handshake_out", output.as_bytes());
485 let user_import = if import_user_host {
486 r#"(import "extism:host/user" "cc_lb_log" (func $cc_lb_log (param i64 i64)))"#
487 } else {
488 ""
489 };
490 let user_call = if import_user_host {
491 " (call $cc_lb_log (call $handshake_out) (call $handshake_out))"
492 } else {
493 ""
494 };
495 let mut exports = String::new();
496 for export in extra_exports {
497 exports.push_str(&format!(
498 r#"
499(func (export "{export}") (result i32)
500 (i32.const 0))
501"#
502 ));
503 }
504
505 let wat = format!(
506 r#"
507(module
508 (import "extism:host/env" "alloc" (func $alloc (param i64) (result i64)))
509 (import "extism:host/env" "store_u8" (func $store_u8 (param i64 i32)))
510 (import "extism:host/env" "output_set" (func $output_set (param i64 i64)))
511 {user_import}
512 {output_helper}
513 (func (export "cc_lb_handshake") (result i32)
514{user_call}
515 (call $output_set (call $handshake_out) (i64.const {len}))
516 (i32.const 0))
517 {exports}
518)
519"#,
520 len = output.len()
521 );
522 wat::parse_str(&wat).expect("handshake wat parses")
523 }
524
525 fn bytes_helper(name: &str, bytes: &[u8]) -> String {
526 let mut stores = String::new();
527 for (index, byte) in bytes.iter().enumerate() {
528 stores.push_str(&format!(
529 " (call $store_u8 (i64.add (local.get $ptr) (i64.const {index})) (i32.const {byte}))\n"
530 ));
531 }
532 format!(
533 r#"
534(func ${name} (result i64)
535 (local $ptr i64)
536 (local.set $ptr (call $alloc (i64.const {len})))
537{stores} (local.get $ptr))
538"#,
539 len = bytes.len()
540 )
541 }
542}