1use std::collections::{BTreeMap, BTreeSet};
2use std::time::Duration;
3
4use cc_lb_plugin_api::PluginSlot;
5use cc_lb_plugin_wire::handshake::{
6 HANDSHAKE_SCHEMA_VERSION_V1, HandshakeAccept, HandshakeError, HandshakeOffer,
7};
8use cc_lb_plugin_wire::limits::{
9 HANDSHAKE_FUEL, HANDSHAKE_OUTPUT_MAX_BYTES, HANDSHAKE_WALL_MS, IMPLEMENTED_FUNCTIONS_MAX,
10 VERSION_MAX, VERSION_MIN,
11};
12use cc_lb_plugin_wire::v1::{
13 build_signer::BuildSignerFn, normalize_error::NormalizeErrorFn, observe::ObserveFn,
14 on_unauthorized::OnUnauthorizedFn, shape::ShapeFn, sign::SignFn,
15};
16use cc_lb_plugin_wire::v3::filter::FilterFn;
17use cc_lb_plugin_wire::wire_function::{WireFunction, all_wire_functions};
18use extism::{Manifest, PluginBuilder, Wasm};
19use thiserror::Error;
20
21const HANDSHAKE_EXPORT: &str = "cc_lb_handshake";
22const ENVELOPE_VERSION_V1: u32 = 1;
23
24pub fn slot_set_from_handshake(implemented_functions: &BTreeSet<String>) -> Vec<PluginSlot> {
28 let mut slots = BTreeSet::new();
29 for name in implemented_functions {
30 if let Some(slot) = wire_function_to_slot(name) {
31 slots.insert(slot);
32 }
33 }
34 slots.into_iter().collect()
35}
36
37fn wire_function_to_slot(name: &str) -> Option<PluginSlot> {
38 match name {
39 "filter" => Some(PluginSlot::Router),
40 "shape" => Some(PluginSlot::Shape),
41 "observe" => Some(PluginSlot::ObservabilityHook),
42 _ => None,
43 }
44}
45
46pub fn slot_set_from_extism_exports(plugin_bytes: &[u8]) -> Vec<PluginSlot> {
52 use extism::{Manifest, PluginBuilder, Wasm};
53
54 let manifest = Manifest::new([Wasm::data(plugin_bytes.to_vec())])
55 .with_timeout(std::time::Duration::from_millis(HANDSHAKE_WALL_MS))
56 .disallow_all_hosts();
57 let Ok(plugin) = PluginBuilder::new(&manifest)
58 .with_wasi(false)
59 .with_cache_disabled()
60 .with_fuel_limit(HANDSHAKE_FUEL)
61 .build()
62 else {
63 return Vec::new();
64 };
65 if plugin.function_exists("route") {
66 vec![PluginSlot::Router]
67 } else {
68 Vec::new()
69 }
70}
71
72pub fn build_offer(host_caps: &BTreeSet<String>) -> HandshakeOffer {
73 let mut function_versions = BTreeMap::new();
74 for (name, versions) in wire_function_versions() {
75 function_versions.insert(name.to_owned(), versions.to_vec());
76 }
77
78 HandshakeOffer {
79 handshake_schema_version: HANDSHAKE_SCHEMA_VERSION_V1,
80 envelope_version: ENVELOPE_VERSION_V1,
81 function_versions,
82 host_capabilities: host_caps.clone(),
83 }
84}
85
86pub fn build_plugin(
87 wasm: &[u8],
88 wall_ms: u64,
89 fuel: u64,
90) -> Result<extism::Plugin, BuildPluginError> {
91 let manifest = Manifest::new([Wasm::data(wasm.to_vec())])
92 .with_timeout(Duration::from_millis(wall_ms))
93 .disallow_all_hosts();
94 PluginBuilder::new(&manifest)
95 .with_wasi(false)
96 .with_cache_disabled()
97 .with_fuel_limit(fuel)
98 .build()
99 .map_err(|source| BuildPluginError::Instantiate {
100 reason: source.to_string(),
101 })
102}
103
104#[non_exhaustive]
105#[derive(Debug, Error)]
106pub enum BuildPluginError {
107 #[error("failed to instantiate plugin: {reason}")]
108 Instantiate { reason: String },
109}
110
111pub fn execute_handshake(
112 plugin_bytes: &[u8],
113 offer: &HandshakeOffer,
114) -> Result<HandshakeAccept, HandshakeExecutionError> {
115 offer.validate()?;
116 metrics::counter!("cc_lb_plugin_handshake_total").increment(1);
117
118 let mut plugin = build_plugin(plugin_bytes, HANDSHAKE_WALL_MS, HANDSHAKE_FUEL).map_err(
119 |source| match source {
120 BuildPluginError::Instantiate { reason } => {
121 HandshakeExecutionError::Instantiate { reason }
122 }
123 },
124 )?;
125
126 if !plugin.function_exists(HANDSHAKE_EXPORT) {
127 return Err(HandshakeExecutionError::MissingHandshakeExport);
128 }
129
130 let request =
131 serde_json::to_string(offer).map_err(|source| HandshakeExecutionError::SerializeOffer {
132 reason: source.to_string(),
133 })?;
134 let response = plugin
135 .call::<&str, String>(HANDSHAKE_EXPORT, request.as_str())
136 .map_err(|source| classify_call_error(source.to_string()))?;
137
138 if response.len() > HANDSHAKE_OUTPUT_MAX_BYTES {
139 return Err(HandshakeExecutionError::OutputTooLarge {
140 bytes: response.len(),
141 max: HANDSHAKE_OUTPUT_MAX_BYTES,
142 });
143 }
144
145 let accept: HandshakeAccept = serde_json::from_str(&response).map_err(|source| {
146 HandshakeExecutionError::DecodeAccept {
147 reason: source.to_string(),
148 }
149 })?;
150 accept.validate_against_offer(offer)?;
151 validate_accept_shape(&accept, offer)?;
152 cross_check_implemented_exports(&plugin, &accept)?;
153
154 Ok(accept)
155}
156
157fn classify_call_error(reason: String) -> HandshakeExecutionError {
158 let lower = reason.to_ascii_lowercase();
159 if lower.contains("timeout")
160 || lower.contains("timed out")
161 || lower.contains("deadline")
162 || lower.contains("fuel")
163 {
164 HandshakeExecutionError::Timeout
165 } else {
166 HandshakeExecutionError::Call { reason }
167 }
168}
169
170fn wire_function_versions() -> [(&'static str, &'static [u32]); 7] {
171 [
172 (
173 <ShapeFn as WireFunction>::NAME,
174 <ShapeFn as WireFunction>::SUPPORTED_VERSIONS,
175 ),
176 (
177 <NormalizeErrorFn as WireFunction>::NAME,
178 <NormalizeErrorFn as WireFunction>::SUPPORTED_VERSIONS,
179 ),
180 (
181 <BuildSignerFn as WireFunction>::NAME,
182 <BuildSignerFn as WireFunction>::SUPPORTED_VERSIONS,
183 ),
184 (
185 <SignFn as WireFunction>::NAME,
186 <SignFn as WireFunction>::SUPPORTED_VERSIONS,
187 ),
188 (
189 <OnUnauthorizedFn as WireFunction>::NAME,
190 <OnUnauthorizedFn as WireFunction>::SUPPORTED_VERSIONS,
191 ),
192 (
193 <ObserveFn as WireFunction>::NAME,
194 <ObserveFn as WireFunction>::SUPPORTED_VERSIONS,
195 ),
196 (
197 <FilterFn as WireFunction>::NAME,
198 <FilterFn as WireFunction>::SUPPORTED_VERSIONS,
199 ),
200 ]
201}
202
203fn validate_accept_shape(
204 accept: &HandshakeAccept,
205 offer: &HandshakeOffer,
206) -> Result<(), HandshakeExecutionError> {
207 if accept.implemented_functions.len() > IMPLEMENTED_FUNCTIONS_MAX {
208 return Err(HandshakeExecutionError::ImplementedFunctionCountExceeded {
209 count: accept.implemented_functions.len(),
210 max: IMPLEMENTED_FUNCTIONS_MAX,
211 });
212 }
213
214 for function in &accept.implemented_functions {
215 if !offer.function_versions.contains_key(function) {
216 return Err(HandshakeExecutionError::ImplementedUnknownFunction {
217 function: function.clone(),
218 });
219 }
220 }
221
222 for (function, version) in &accept.plugin_supported {
223 if !offer.function_versions.contains_key(function) {
224 return Err(HandshakeExecutionError::SupportedUnknownFunction {
225 function: function.clone(),
226 });
227 }
228 for &supported in version {
229 if !(VERSION_MIN..=VERSION_MAX).contains(&supported) {
230 return Err(HandshakeExecutionError::SupportedVersionOutOfRange {
231 function: function.clone(),
232 version: supported,
233 min: VERSION_MIN,
234 max: VERSION_MAX,
235 });
236 }
237 }
238 }
239
240 for (function, chosen) in &accept.chosen_versions {
241 if !accept.implemented_functions.contains(function) {
242 return Err(HandshakeExecutionError::ChosenFunctionNotImplemented {
243 function: function.clone(),
244 });
245 }
246 let Some(supported) = accept.plugin_supported.get(function) else {
247 return Err(HandshakeExecutionError::ChosenVersionNotSupported {
248 function: function.clone(),
249 version: *chosen,
250 });
251 };
252 if !supported.contains(chosen) {
253 return Err(HandshakeExecutionError::ChosenVersionNotSupported {
254 function: function.clone(),
255 version: *chosen,
256 });
257 }
258 }
259
260 Ok(())
261}
262
263fn cross_check_implemented_exports(
264 plugin: &extism::Plugin,
265 accept: &HandshakeAccept,
266) -> Result<(), HandshakeExecutionError> {
267 for function in &accept.implemented_functions {
268 if !plugin.function_exists(function) {
269 return Err(HandshakeExecutionError::DeclaredFunctionMissing {
270 function: function.clone(),
271 });
272 }
273 }
274 for function in all_wire_functions() {
275 if plugin.function_exists(function) && !accept.implemented_functions.contains(*function) {
276 return Err(HandshakeExecutionError::UndeclaredExport {
277 function: (*function).to_owned(),
278 });
279 }
280 }
281 Ok(())
282}
283
284#[non_exhaustive]
285#[derive(Debug, Error)]
286pub enum HandshakeExecutionError {
287 #[error("handshake validation failed: {0}")]
288 Validation(#[from] HandshakeError),
289 #[error("handshake plugin instantiation failed: {reason}")]
290 Instantiate { reason: String },
291 #[error("plugin does not export cc_lb_handshake")]
292 MissingHandshakeExport,
293 #[error("handshake offer serialization failed: {reason}")]
294 SerializeOffer { reason: String },
295 #[error("handshake call failed: {reason}")]
296 Call { reason: String },
297 #[error("handshake call exceeded timeout/fuel budget")]
298 Timeout,
299 #[error("handshake output size {bytes} exceeds maximum {max}")]
300 OutputTooLarge { bytes: usize, max: usize },
301 #[error("handshake accept decode failed: {reason}")]
302 DecodeAccept { reason: String },
303 #[error("implemented function count {count} exceeds maximum {max}")]
304 ImplementedFunctionCountExceeded { count: usize, max: usize },
305 #[error("implemented unknown function: {function}")]
306 ImplementedUnknownFunction { function: String },
307 #[error("supported unknown function: {function}")]
308 SupportedUnknownFunction { function: String },
309 #[error(
310 "supported version {version} for function {function} outside valid range [{min}, {max}]"
311 )]
312 SupportedVersionOutOfRange {
313 function: String,
314 version: u32,
315 min: u32,
316 max: u32,
317 },
318 #[error("chosen function not listed as implemented: {function}")]
319 ChosenFunctionNotImplemented { function: String },
320 #[error("chosen version {version} for function {function} not listed as plugin-supported")]
321 ChosenVersionNotSupported { function: String, version: u32 },
322 #[error("declared function missing wasm export: {function}")]
323 DeclaredFunctionMissing { function: String },
324 #[error("undeclared wire function export present: {function}")]
325 UndeclaredExport { function: String },
326}
327
328#[cfg(test)]
329mod tests {
330 use std::collections::{BTreeMap, BTreeSet};
331
332 use cc_lb_plugin_wire::handshake::HandshakeError;
333 use serde_json::json;
334
335 use super::*;
336
337 #[test]
338 fn slot_set_from_handshake_maps_filter_shape_observe() {
339 let fns: BTreeSet<String> = ["filter", "shape", "observe"]
340 .into_iter()
341 .map(String::from)
342 .collect();
343 assert_eq!(
344 slot_set_from_handshake(&fns),
345 vec![
346 PluginSlot::Router,
347 PluginSlot::ObservabilityHook,
348 PluginSlot::Shape,
349 ],
350 );
351 }
352
353 #[test]
354 fn slot_set_from_handshake_handles_partial_exports() {
355 let only_shape: BTreeSet<String> = ["shape".to_owned()].into_iter().collect();
356 assert_eq!(
357 slot_set_from_handshake(&only_shape),
358 vec![PluginSlot::Shape]
359 );
360
361 let only_filter: BTreeSet<String> = ["filter".to_owned()].into_iter().collect();
362 assert_eq!(
363 slot_set_from_handshake(&only_filter),
364 vec![PluginSlot::Router],
365 );
366 }
367
368 #[test]
369 fn slot_set_from_handshake_is_empty_when_no_slot_functions_exported() {
370 let empty: BTreeSet<String> = BTreeSet::new();
371 assert!(slot_set_from_handshake(&empty).is_empty());
372
373 let unrelated: BTreeSet<String> = ["sign".to_owned(), "build_signer".to_owned()]
374 .into_iter()
375 .collect();
376 assert!(slot_set_from_handshake(&unrelated).is_empty());
377 }
378
379 #[test]
380 fn slot_set_from_extism_exports_maps_only_route_to_router() {
381 let wasm = wat::parse_str(r#"(module (func (export "route") (result i32) (i32.const 0)))"#)
382 .expect("route-only wat parses");
383 assert_eq!(
384 slot_set_from_extism_exports(&wasm),
385 vec![PluginSlot::Router],
386 );
387 }
388
389 #[test]
390 fn slot_set_from_extism_exports_ignores_filter_shape_observe_exports() {
391 for export in ["filter", "shape", "observe"] {
392 let wat = format!(r#"(module (func (export "{export}") (result i32) (i32.const 0)))"#,);
393 let wasm = wat::parse_str(&wat).expect("single-export wat parses");
394 assert!(
395 slot_set_from_extism_exports(&wasm).is_empty(),
396 "fallback must not trust {export} export without handshake validation",
397 );
398 }
399 }
400
401 #[test]
402 fn slot_set_from_extism_exports_returns_router_for_route_plus_shape_legacy() {
403 let wasm = wat::parse_str(
404 r#"(module
405 (func (export "route") (result i32) (i32.const 0))
406 (func (export "shape") (result i32) (i32.const 0)))"#,
407 )
408 .expect("legacy route+shape wat parses");
409 assert_eq!(
410 slot_set_from_extism_exports(&wasm),
411 vec![PluginSlot::Router],
412 "shape must not promote without handshake; only route -> Router is trusted",
413 );
414 }
415
416 #[test]
417 fn slot_set_from_extism_exports_is_empty_for_module_without_known_exports() {
418 let wasm = wat::parse_str(r#"(module (func (export "noop") (result i32) (i32.const 0)))"#)
419 .expect("noop wat parses");
420 assert!(slot_set_from_extism_exports(&wasm).is_empty());
421 }
422
423 #[test]
424 fn slot_set_from_extism_exports_is_empty_for_corrupt_bytes() {
425 let bytes = vec![0u8; 16];
426 assert!(slot_set_from_extism_exports(&bytes).is_empty());
427 }
428
429 #[test]
430 fn slot_set_from_handshake_ignores_unknown_function_names_for_forward_compat() {
431 let mixed: BTreeSet<String> = ["filter", "future_slot_v9000"]
432 .into_iter()
433 .map(String::from)
434 .collect();
435 assert_eq!(slot_set_from_handshake(&mixed), vec![PluginSlot::Router]);
436 }
437
438 #[test]
439 fn build_offer_lists_v1_wire_functions_and_host_capabilities() {
440 let host_caps = BTreeSet::from(["streaming".to_owned(), "storage".to_owned()]);
441
442 let offer = build_offer(&host_caps);
443
444 assert_eq!(offer.handshake_schema_version, HANDSHAKE_SCHEMA_VERSION_V1);
445 assert_eq!(offer.envelope_version, ENVELOPE_VERSION_V1);
446 assert_eq!(offer.host_capabilities, host_caps);
447 for (name, versions) in wire_function_versions() {
448 assert_eq!(offer.function_versions.get(name), Some(&versions.to_vec()));
449 }
450 offer.validate().expect("host offer is valid");
451 }
452
453 #[test]
454 fn execute_handshake_accepts_valid_plugin_and_checks_export() {
455 let offer = build_offer(&BTreeSet::from(["streaming".to_owned()]));
456 let accept = accept_json(
457 &["shape"],
458 &[("shape", &[1])],
459 &[("shape", 1)],
460 &["streaming"],
461 );
462 let wasm = handshake_module(&accept, &["shape"], false);
463
464 let actual = execute_handshake(&wasm, &offer).expect("handshake succeeds");
465
466 assert!(actual.implemented_functions.contains("shape"));
467 assert_eq!(actual.chosen_versions.get("shape"), Some(&1));
468 }
469
470 #[test]
471 fn execute_handshake_rejects_downgrade() {
472 let mut offer = build_offer(&BTreeSet::new());
473 offer
474 .function_versions
475 .insert("shape".to_owned(), vec![1, 2, 3]);
476 let accept = accept_json(&["shape"], &[("shape", &[1, 2, 3])], &[("shape", 1)], &[]);
477 let wasm = handshake_module(&accept, &["shape"], false);
478
479 let err = execute_handshake(&wasm, &offer).expect_err("downgrade rejected");
480
481 match err {
482 HandshakeExecutionError::Validation(HandshakeError::DowngradeAttempt { .. }) => {}
483 other => panic!("expected downgrade error, got {other:?}"),
484 }
485 }
486
487 #[test]
488 fn execute_handshake_rejects_implemented_function_without_export() {
489 let offer = build_offer(&BTreeSet::new());
490 let accept = accept_json(&["shape"], &[("shape", &[1])], &[("shape", 1)], &[]);
491 let wasm = handshake_module(&accept, &[], false);
492
493 let err = execute_handshake(&wasm, &offer).expect_err("missing export rejected");
494
495 match err {
496 HandshakeExecutionError::DeclaredFunctionMissing { function } => {
497 assert_eq!(function, "shape");
498 }
499 other => panic!("expected missing export, got {other:?}"),
500 }
501 }
502
503 #[test]
504 fn execute_handshake_rejects_user_host_imports() {
505 let offer = build_offer(&BTreeSet::new());
506 let accept = accept_json(&[], &[], &[], &[]);
507 let wasm = handshake_module(&accept, &[], true);
508
509 let err = execute_handshake(&wasm, &offer).expect_err("host import rejected");
510
511 match err {
512 HandshakeExecutionError::Instantiate { .. } | HandshakeExecutionError::Call { .. } => {}
513 other => panic!("expected purity failure, got {other:?}"),
514 }
515 }
516
517 fn accept_json(
518 implemented: &[&str],
519 supported: &[(&str, &[u32])],
520 chosen: &[(&str, u32)],
521 required_caps: &[&str],
522 ) -> String {
523 let implemented_functions: BTreeSet<_> =
524 implemented.iter().map(|name| name.to_string()).collect();
525 let plugin_supported: BTreeMap<_, _> = supported
526 .iter()
527 .map(|(name, versions)| (name.to_string(), versions.to_vec()))
528 .collect();
529 let chosen_versions: BTreeMap<_, _> = chosen
530 .iter()
531 .map(|(name, version)| (name.to_string(), *version))
532 .collect();
533 let required_capabilities: BTreeSet<_> = required_caps
534 .iter()
535 .map(|capability| capability.to_string())
536 .collect();
537
538 json!({
539 "handshake_schema_version": HANDSHAKE_SCHEMA_VERSION_V1,
540 "envelope_version": ENVELOPE_VERSION_V1,
541 "chosen_versions": chosen_versions,
542 "plugin_supported": plugin_supported,
543 "implemented_functions": implemented_functions,
544 "required_capabilities": required_capabilities,
545 })
546 .to_string()
547 }
548
549 fn handshake_module(output: &str, extra_exports: &[&str], import_user_host: bool) -> Vec<u8> {
550 let output_helper = bytes_helper("handshake_out", output.as_bytes());
551 let user_import = if import_user_host {
552 r#"(import "extism:host/user" "cc_lb_log" (func $cc_lb_log (param i64 i64)))"#
553 } else {
554 ""
555 };
556 let user_call = if import_user_host {
557 " (call $cc_lb_log (call $handshake_out) (call $handshake_out))"
558 } else {
559 ""
560 };
561 let mut exports = String::new();
562 for export in extra_exports {
563 exports.push_str(&format!(
564 r#"
565(func (export "{export}") (result i32)
566 (i32.const 0))
567"#
568 ));
569 }
570
571 let wat = format!(
572 r#"
573(module
574 (import "extism:host/env" "alloc" (func $alloc (param i64) (result i64)))
575 (import "extism:host/env" "store_u8" (func $store_u8 (param i64 i32)))
576 (import "extism:host/env" "output_set" (func $output_set (param i64 i64)))
577 {user_import}
578 {output_helper}
579 (func (export "cc_lb_handshake") (result i32)
580{user_call}
581 (call $output_set (call $handshake_out) (i64.const {len}))
582 (i32.const 0))
583 {exports}
584)
585"#,
586 len = output.len()
587 );
588 wat::parse_str(&wat).expect("handshake wat parses")
589 }
590
591 fn bytes_helper(name: &str, bytes: &[u8]) -> String {
592 let mut stores = String::new();
593 for (index, byte) in bytes.iter().enumerate() {
594 stores.push_str(&format!(
595 " (call $store_u8 (i64.add (local.get $ptr) (i64.const {index})) (i32.const {byte}))\n"
596 ));
597 }
598 format!(
599 r#"
600(func ${name} (result i64)
601 (local $ptr i64)
602 (local.set $ptr (call $alloc (i64.const {len})))
603{stores} (local.get $ptr))
604"#,
605 len = bytes.len()
606 )
607 }
608}