Skip to main content

cc_lb_runtime_protocol/
handshake.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::time::Duration;
3
4use cc_lb_plugin_api::PluginSlot;
5use cc_lb_plugin_wire::handshake::{
6    HANDSHAKE_SCHEMA_VERSION_V1, HandshakeAccept, HandshakeError, HandshakeOffer,
7};
8use cc_lb_plugin_wire::limits::{
9    HANDSHAKE_FUEL, HANDSHAKE_OUTPUT_MAX_BYTES, HANDSHAKE_WALL_MS, IMPLEMENTED_FUNCTIONS_MAX,
10    VERSION_MAX, VERSION_MIN,
11};
12use cc_lb_plugin_wire::v1::{
13    build_signer::BuildSignerFn, normalize_error::NormalizeErrorFn, observe::ObserveFn,
14    on_unauthorized::OnUnauthorizedFn, shape::ShapeFn, sign::SignFn,
15};
16use cc_lb_plugin_wire::v3::filter::FilterFn;
17use cc_lb_plugin_wire::wire_function::{WireFunction, all_wire_functions};
18use extism::{Manifest, PluginBuilder, Wasm};
19use thiserror::Error;
20
21const HANDSHAKE_EXPORT: &str = "cc_lb_handshake";
22const ENVELOPE_VERSION_V1: u32 = 1;
23
24/// Returns the slots whose wire function the plugin claims to implement.
25/// The result is deterministically sorted, deduplicated, and ignores unknown
26/// function names so newer plugins remain forward-compatible with this host.
27pub fn slot_set_from_handshake(implemented_functions: &BTreeSet<String>) -> Vec<PluginSlot> {
28    let mut slots = BTreeSet::new();
29    for name in implemented_functions {
30        if let Some(slot) = wire_function_to_slot(name) {
31            slots.insert(slot);
32        }
33    }
34    slots.into_iter().collect()
35}
36
37fn wire_function_to_slot(name: &str) -> Option<PluginSlot> {
38    match name {
39        name if name == slot_to_wire_function(PluginSlot::Router) => Some(PluginSlot::Router),
40        name if name == slot_to_wire_function(PluginSlot::Shape) => Some(PluginSlot::Shape),
41        name if name == slot_to_wire_function(PluginSlot::ObservabilityHook) => {
42            Some(PluginSlot::ObservabilityHook)
43        }
44        _ => None,
45    }
46}
47
48pub(crate) fn slot_to_wire_function(slot: PluginSlot) -> &'static str {
49    match slot {
50        PluginSlot::Router => "filter",
51        PluginSlot::Shape => "shape",
52        PluginSlot::ObservabilityHook => "observe",
53    }
54}
55
56pub fn build_offer(host_caps: &BTreeSet<String>) -> HandshakeOffer {
57    let mut function_versions = BTreeMap::new();
58    for (name, versions) in wire_function_versions() {
59        function_versions.insert(name.to_owned(), versions.to_vec());
60    }
61
62    HandshakeOffer {
63        handshake_schema_version: HANDSHAKE_SCHEMA_VERSION_V1,
64        envelope_version: ENVELOPE_VERSION_V1,
65        function_versions,
66        host_capabilities: host_caps.clone(),
67    }
68}
69
70pub fn build_plugin(
71    wasm: &[u8],
72    wall_ms: u64,
73    fuel: u64,
74) -> Result<extism::Plugin, BuildPluginError> {
75    let manifest = Manifest::new([Wasm::data(wasm.to_vec())])
76        .with_timeout(Duration::from_millis(wall_ms))
77        .disallow_all_hosts();
78    PluginBuilder::new(&manifest)
79        .with_wasi(false)
80        .with_cache_disabled()
81        .with_fuel_limit(fuel)
82        .build()
83        .map_err(|source| BuildPluginError::Instantiate {
84            reason: source.to_string(),
85        })
86}
87
88#[non_exhaustive]
89#[derive(Debug, Error)]
90pub enum BuildPluginError {
91    #[error("failed to instantiate plugin: {reason}")]
92    Instantiate { reason: String },
93}
94
95pub fn execute_handshake(
96    plugin_bytes: &[u8],
97    offer: &HandshakeOffer,
98) -> Result<HandshakeAccept, HandshakeExecutionError> {
99    offer.validate()?;
100    metrics::counter!("cc_lb_plugin_handshake_total").increment(1);
101
102    let mut plugin = build_plugin(plugin_bytes, HANDSHAKE_WALL_MS, HANDSHAKE_FUEL).map_err(
103        |source| match source {
104            BuildPluginError::Instantiate { reason } => {
105                HandshakeExecutionError::Instantiate { reason }
106            }
107        },
108    )?;
109
110    if !plugin.function_exists(HANDSHAKE_EXPORT) {
111        return Err(HandshakeExecutionError::MissingHandshakeExport);
112    }
113
114    let request =
115        serde_json::to_string(offer).map_err(|source| HandshakeExecutionError::SerializeOffer {
116            reason: source.to_string(),
117        })?;
118    let response = plugin
119        .call::<&str, String>(HANDSHAKE_EXPORT, request.as_str())
120        .map_err(|source| classify_call_error(source.to_string()))?;
121
122    if response.len() > HANDSHAKE_OUTPUT_MAX_BYTES {
123        return Err(HandshakeExecutionError::OutputTooLarge {
124            bytes: response.len(),
125            max: HANDSHAKE_OUTPUT_MAX_BYTES,
126        });
127    }
128
129    let accept: HandshakeAccept = serde_json::from_str(&response).map_err(|source| {
130        HandshakeExecutionError::DecodeAccept {
131            reason: source.to_string(),
132        }
133    })?;
134    accept.validate_against_offer(offer)?;
135    validate_accept_shape(&accept, offer)?;
136    cross_check_implemented_exports(&plugin, &accept)?;
137
138    Ok(accept)
139}
140
141fn classify_call_error(reason: String) -> HandshakeExecutionError {
142    let lower = reason.to_ascii_lowercase();
143    if lower.contains("timeout")
144        || lower.contains("timed out")
145        || lower.contains("deadline")
146        || lower.contains("fuel")
147    {
148        HandshakeExecutionError::Timeout
149    } else {
150        HandshakeExecutionError::Call { reason }
151    }
152}
153
154fn wire_function_versions() -> [(&'static str, &'static [u32]); 7] {
155    [
156        (
157            <ShapeFn as WireFunction>::NAME,
158            <ShapeFn as WireFunction>::SUPPORTED_VERSIONS,
159        ),
160        (
161            <NormalizeErrorFn as WireFunction>::NAME,
162            <NormalizeErrorFn as WireFunction>::SUPPORTED_VERSIONS,
163        ),
164        (
165            <BuildSignerFn as WireFunction>::NAME,
166            <BuildSignerFn as WireFunction>::SUPPORTED_VERSIONS,
167        ),
168        (
169            <SignFn as WireFunction>::NAME,
170            <SignFn as WireFunction>::SUPPORTED_VERSIONS,
171        ),
172        (
173            <OnUnauthorizedFn as WireFunction>::NAME,
174            <OnUnauthorizedFn as WireFunction>::SUPPORTED_VERSIONS,
175        ),
176        (
177            <ObserveFn as WireFunction>::NAME,
178            <ObserveFn as WireFunction>::SUPPORTED_VERSIONS,
179        ),
180        (
181            <FilterFn as WireFunction>::NAME,
182            <FilterFn as WireFunction>::SUPPORTED_VERSIONS,
183        ),
184    ]
185}
186
187fn validate_accept_shape(
188    accept: &HandshakeAccept,
189    offer: &HandshakeOffer,
190) -> Result<(), HandshakeExecutionError> {
191    if accept.implemented_functions.len() > IMPLEMENTED_FUNCTIONS_MAX {
192        return Err(HandshakeExecutionError::ImplementedFunctionCountExceeded {
193            count: accept.implemented_functions.len(),
194            max: IMPLEMENTED_FUNCTIONS_MAX,
195        });
196    }
197
198    for function in &accept.implemented_functions {
199        if !offer.function_versions.contains_key(function) {
200            return Err(HandshakeExecutionError::ImplementedUnknownFunction {
201                function: function.clone(),
202            });
203        }
204    }
205
206    for (function, version) in &accept.plugin_supported {
207        if !offer.function_versions.contains_key(function) {
208            return Err(HandshakeExecutionError::SupportedUnknownFunction {
209                function: function.clone(),
210            });
211        }
212        for &supported in version {
213            if !(VERSION_MIN..=VERSION_MAX).contains(&supported) {
214                return Err(HandshakeExecutionError::SupportedVersionOutOfRange {
215                    function: function.clone(),
216                    version: supported,
217                    min: VERSION_MIN,
218                    max: VERSION_MAX,
219                });
220            }
221        }
222    }
223
224    for (function, chosen) in &accept.chosen_versions {
225        if !accept.implemented_functions.contains(function) {
226            return Err(HandshakeExecutionError::ChosenFunctionNotImplemented {
227                function: function.clone(),
228            });
229        }
230        let Some(supported) = accept.plugin_supported.get(function) else {
231            return Err(HandshakeExecutionError::ChosenVersionNotSupported {
232                function: function.clone(),
233                version: *chosen,
234            });
235        };
236        if !supported.contains(chosen) {
237            return Err(HandshakeExecutionError::ChosenVersionNotSupported {
238                function: function.clone(),
239                version: *chosen,
240            });
241        }
242    }
243
244    Ok(())
245}
246
247fn cross_check_implemented_exports(
248    plugin: &extism::Plugin,
249    accept: &HandshakeAccept,
250) -> Result<(), HandshakeExecutionError> {
251    for function in &accept.implemented_functions {
252        if !plugin.function_exists(function) {
253            return Err(HandshakeExecutionError::DeclaredFunctionMissing {
254                function: function.clone(),
255            });
256        }
257    }
258    for function in all_wire_functions() {
259        if plugin.function_exists(function) && !accept.implemented_functions.contains(*function) {
260            return Err(HandshakeExecutionError::UndeclaredExport {
261                function: (*function).to_owned(),
262            });
263        }
264    }
265    Ok(())
266}
267
268#[non_exhaustive]
269#[derive(Debug, Error)]
270pub enum HandshakeExecutionError {
271    #[error("handshake validation failed: {0}")]
272    Validation(#[from] HandshakeError),
273    #[error("handshake plugin instantiation failed: {reason}")]
274    Instantiate { reason: String },
275    #[error("plugin does not export cc_lb_handshake")]
276    MissingHandshakeExport,
277    #[error("handshake offer serialization failed: {reason}")]
278    SerializeOffer { reason: String },
279    #[error("handshake call failed: {reason}")]
280    Call { reason: String },
281    #[error("handshake call exceeded timeout/fuel budget")]
282    Timeout,
283    #[error("handshake output size {bytes} exceeds maximum {max}")]
284    OutputTooLarge { bytes: usize, max: usize },
285    #[error("handshake accept decode failed: {reason}")]
286    DecodeAccept { reason: String },
287    #[error("implemented function count {count} exceeds maximum {max}")]
288    ImplementedFunctionCountExceeded { count: usize, max: usize },
289    #[error("implemented unknown function: {function}")]
290    ImplementedUnknownFunction { function: String },
291    #[error("supported unknown function: {function}")]
292    SupportedUnknownFunction { function: String },
293    #[error(
294        "supported version {version} for function {function} outside valid range [{min}, {max}]"
295    )]
296    SupportedVersionOutOfRange {
297        function: String,
298        version: u32,
299        min: u32,
300        max: u32,
301    },
302    #[error("chosen function not listed as implemented: {function}")]
303    ChosenFunctionNotImplemented { function: String },
304    #[error("chosen version {version} for function {function} not listed as plugin-supported")]
305    ChosenVersionNotSupported { function: String, version: u32 },
306    #[error("declared function missing wasm export: {function}")]
307    DeclaredFunctionMissing { function: String },
308    #[error("undeclared wire function export present: {function}")]
309    UndeclaredExport { function: String },
310}
311
312#[cfg(test)]
313mod tests {
314    use std::collections::{BTreeMap, BTreeSet};
315
316    use cc_lb_plugin_wire::handshake::HandshakeError;
317    use serde_json::json;
318
319    use super::*;
320
321    #[test]
322    fn slot_set_from_handshake_maps_filter_shape_observe() {
323        let fns: BTreeSet<String> = ["filter", "shape", "observe"]
324            .into_iter()
325            .map(String::from)
326            .collect();
327        assert_eq!(
328            slot_set_from_handshake(&fns),
329            vec![
330                PluginSlot::Router,
331                PluginSlot::ObservabilityHook,
332                PluginSlot::Shape,
333            ],
334        );
335    }
336
337    #[test]
338    fn slot_set_from_handshake_handles_partial_exports() {
339        let only_shape: BTreeSet<String> = ["shape".to_owned()].into_iter().collect();
340        assert_eq!(
341            slot_set_from_handshake(&only_shape),
342            vec![PluginSlot::Shape]
343        );
344
345        let only_filter: BTreeSet<String> = ["filter".to_owned()].into_iter().collect();
346        assert_eq!(
347            slot_set_from_handshake(&only_filter),
348            vec![PluginSlot::Router],
349        );
350    }
351
352    #[test]
353    fn slot_set_from_handshake_is_empty_when_no_slot_functions_exported() {
354        let empty: BTreeSet<String> = BTreeSet::new();
355        assert!(slot_set_from_handshake(&empty).is_empty());
356
357        let unrelated: BTreeSet<String> = ["sign".to_owned(), "build_signer".to_owned()]
358            .into_iter()
359            .collect();
360        assert!(slot_set_from_handshake(&unrelated).is_empty());
361    }
362
363    #[test]
364    fn slot_set_from_handshake_ignores_unknown_function_names_for_forward_compat() {
365        let mixed: BTreeSet<String> = ["filter", "future_slot_v9000"]
366            .into_iter()
367            .map(String::from)
368            .collect();
369        assert_eq!(slot_set_from_handshake(&mixed), vec![PluginSlot::Router]);
370    }
371
372    #[test]
373    fn build_offer_lists_v1_wire_functions_and_host_capabilities() {
374        let host_caps = BTreeSet::from(["streaming".to_owned(), "storage".to_owned()]);
375
376        let offer = build_offer(&host_caps);
377
378        assert_eq!(offer.handshake_schema_version, HANDSHAKE_SCHEMA_VERSION_V1);
379        assert_eq!(offer.envelope_version, ENVELOPE_VERSION_V1);
380        assert_eq!(offer.host_capabilities, host_caps);
381        for (name, versions) in wire_function_versions() {
382            assert_eq!(offer.function_versions.get(name), Some(&versions.to_vec()));
383        }
384        offer.validate().expect("host offer is valid");
385    }
386
387    #[test]
388    fn execute_handshake_accepts_valid_plugin_and_checks_export() {
389        let offer = build_offer(&BTreeSet::from(["streaming".to_owned()]));
390        let accept = accept_json(
391            &["shape"],
392            &[("shape", &[1])],
393            &[("shape", 1)],
394            &["streaming"],
395        );
396        let wasm = handshake_module(&accept, &["shape"], false);
397
398        let actual = execute_handshake(&wasm, &offer).expect("handshake succeeds");
399
400        assert!(actual.implemented_functions.contains("shape"));
401        assert_eq!(actual.chosen_versions.get("shape"), Some(&1));
402    }
403
404    #[test]
405    fn execute_handshake_rejects_downgrade() {
406        let mut offer = build_offer(&BTreeSet::new());
407        offer
408            .function_versions
409            .insert("shape".to_owned(), vec![1, 2, 3]);
410        let accept = accept_json(&["shape"], &[("shape", &[1, 2, 3])], &[("shape", 1)], &[]);
411        let wasm = handshake_module(&accept, &["shape"], false);
412
413        let err = execute_handshake(&wasm, &offer).expect_err("downgrade rejected");
414
415        match err {
416            HandshakeExecutionError::Validation(HandshakeError::DowngradeAttempt { .. }) => {}
417            other => panic!("expected downgrade error, got {other:?}"),
418        }
419    }
420
421    #[test]
422    fn execute_handshake_rejects_implemented_function_without_export() {
423        let offer = build_offer(&BTreeSet::new());
424        let accept = accept_json(&["shape"], &[("shape", &[1])], &[("shape", 1)], &[]);
425        let wasm = handshake_module(&accept, &[], false);
426
427        let err = execute_handshake(&wasm, &offer).expect_err("missing export rejected");
428
429        match err {
430            HandshakeExecutionError::DeclaredFunctionMissing { function } => {
431                assert_eq!(function, "shape");
432            }
433            other => panic!("expected missing export, got {other:?}"),
434        }
435    }
436
437    #[test]
438    fn execute_handshake_rejects_user_host_imports() {
439        let offer = build_offer(&BTreeSet::new());
440        let accept = accept_json(&[], &[], &[], &[]);
441        let wasm = handshake_module(&accept, &[], true);
442
443        let err = execute_handshake(&wasm, &offer).expect_err("host import rejected");
444
445        match err {
446            HandshakeExecutionError::Instantiate { .. } | HandshakeExecutionError::Call { .. } => {}
447            other => panic!("expected purity failure, got {other:?}"),
448        }
449    }
450
451    fn accept_json(
452        implemented: &[&str],
453        supported: &[(&str, &[u32])],
454        chosen: &[(&str, u32)],
455        required_caps: &[&str],
456    ) -> String {
457        let implemented_functions: BTreeSet<_> =
458            implemented.iter().map(|name| name.to_string()).collect();
459        let plugin_supported: BTreeMap<_, _> = supported
460            .iter()
461            .map(|(name, versions)| (name.to_string(), versions.to_vec()))
462            .collect();
463        let chosen_versions: BTreeMap<_, _> = chosen
464            .iter()
465            .map(|(name, version)| (name.to_string(), *version))
466            .collect();
467        let required_capabilities: BTreeSet<_> = required_caps
468            .iter()
469            .map(|capability| capability.to_string())
470            .collect();
471
472        json!({
473            "handshake_schema_version": HANDSHAKE_SCHEMA_VERSION_V1,
474            "envelope_version": ENVELOPE_VERSION_V1,
475            "chosen_versions": chosen_versions,
476            "plugin_supported": plugin_supported,
477            "implemented_functions": implemented_functions,
478            "required_capabilities": required_capabilities,
479        })
480        .to_string()
481    }
482
483    fn handshake_module(output: &str, extra_exports: &[&str], import_user_host: bool) -> Vec<u8> {
484        let output_helper = bytes_helper("handshake_out", output.as_bytes());
485        let user_import = if import_user_host {
486            r#"(import "extism:host/user" "cc_lb_log" (func $cc_lb_log (param i64 i64)))"#
487        } else {
488            ""
489        };
490        let user_call = if import_user_host {
491            "  (call $cc_lb_log (call $handshake_out) (call $handshake_out))"
492        } else {
493            ""
494        };
495        let mut exports = String::new();
496        for export in extra_exports {
497            exports.push_str(&format!(
498                r#"
499(func (export "{export}") (result i32)
500  (i32.const 0))
501"#
502            ));
503        }
504
505        let wat = format!(
506            r#"
507(module
508  (import "extism:host/env" "alloc" (func $alloc (param i64) (result i64)))
509  (import "extism:host/env" "store_u8" (func $store_u8 (param i64 i32)))
510  (import "extism:host/env" "output_set" (func $output_set (param i64 i64)))
511  {user_import}
512  {output_helper}
513  (func (export "cc_lb_handshake") (result i32)
514{user_call}
515    (call $output_set (call $handshake_out) (i64.const {len}))
516    (i32.const 0))
517  {exports}
518)
519"#,
520            len = output.len()
521        );
522        wat::parse_str(&wat).expect("handshake wat parses")
523    }
524
525    fn bytes_helper(name: &str, bytes: &[u8]) -> String {
526        let mut stores = String::new();
527        for (index, byte) in bytes.iter().enumerate() {
528            stores.push_str(&format!(
529                "  (call $store_u8 (i64.add (local.get $ptr) (i64.const {index})) (i32.const {byte}))\n"
530            ));
531        }
532        format!(
533            r#"
534(func ${name} (result i64)
535  (local $ptr i64)
536  (local.set $ptr (call $alloc (i64.const {len})))
537{stores}  (local.get $ptr))
538"#,
539            len = bytes.len()
540        )
541    }
542}