Skip to main content

cc_lb_runtime_protocol/
handshake.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::time::Duration;
3
4use cc_lb_plugin_api::PluginSlot;
5use cc_lb_plugin_wire::handshake::{
6    HANDSHAKE_SCHEMA_VERSION_V1, HandshakeAccept, HandshakeError, HandshakeOffer,
7};
8use cc_lb_plugin_wire::limits::{
9    HANDSHAKE_FUEL, HANDSHAKE_OUTPUT_MAX_BYTES, HANDSHAKE_WALL_MS, IMPLEMENTED_FUNCTIONS_MAX,
10    VERSION_MAX, VERSION_MIN,
11};
12use cc_lb_plugin_wire::v1::{
13    build_signer::BuildSignerFn, normalize_error::NormalizeErrorFn, observe::ObserveFn,
14    on_unauthorized::OnUnauthorizedFn, shape::ShapeFn, sign::SignFn,
15};
16use cc_lb_plugin_wire::v3::filter::FilterFn;
17use cc_lb_plugin_wire::wire_function::{WireFunction, all_wire_functions};
18use extism::{Manifest, PluginBuilder, Wasm};
19use thiserror::Error;
20
21const HANDSHAKE_EXPORT: &str = "cc_lb_handshake";
22const ENVELOPE_VERSION_V1: u32 = 1;
23
24/// Returns the slots whose wire function the plugin claims to implement.
25/// The result is deterministically sorted, deduplicated, and ignores unknown
26/// function names so newer plugins remain forward-compatible with this host.
27pub fn slot_set_from_handshake(implemented_functions: &BTreeSet<String>) -> Vec<PluginSlot> {
28    let mut slots = BTreeSet::new();
29    for name in implemented_functions {
30        if let Some(slot) = wire_function_to_slot(name) {
31            slots.insert(slot);
32        }
33    }
34    slots.into_iter().collect()
35}
36
37fn wire_function_to_slot(name: &str) -> Option<PluginSlot> {
38    match name {
39        "filter" => Some(PluginSlot::Router),
40        "shape" => Some(PluginSlot::Shape),
41        "observe" => Some(PluginSlot::ObservabilityHook),
42        _ => None,
43    }
44}
45
46/// Narrowly recovers slot metadata for legacy wire-v1/v2 router plugins that
47/// export `route`. Restricted to the `route` -> `Router` migration so we do
48/// not bypass the handshake's `implemented_functions` source of truth for
49/// other slots; new uploads and shape / observe plugins must come through
50/// the handshake gate.
51pub fn slot_set_from_extism_exports(plugin_bytes: &[u8]) -> Vec<PluginSlot> {
52    use extism::{Manifest, PluginBuilder, Wasm};
53
54    let manifest = Manifest::new([Wasm::data(plugin_bytes.to_vec())])
55        .with_timeout(std::time::Duration::from_millis(HANDSHAKE_WALL_MS))
56        .disallow_all_hosts();
57    let Ok(plugin) = PluginBuilder::new(&manifest)
58        .with_wasi(false)
59        .with_cache_disabled()
60        .with_fuel_limit(HANDSHAKE_FUEL)
61        .build()
62    else {
63        return Vec::new();
64    };
65    if plugin.function_exists("route") {
66        vec![PluginSlot::Router]
67    } else {
68        Vec::new()
69    }
70}
71
72pub fn build_offer(host_caps: &BTreeSet<String>) -> HandshakeOffer {
73    let mut function_versions = BTreeMap::new();
74    for (name, versions) in wire_function_versions() {
75        function_versions.insert(name.to_owned(), versions.to_vec());
76    }
77
78    HandshakeOffer {
79        handshake_schema_version: HANDSHAKE_SCHEMA_VERSION_V1,
80        envelope_version: ENVELOPE_VERSION_V1,
81        function_versions,
82        host_capabilities: host_caps.clone(),
83    }
84}
85
86pub fn build_plugin(
87    wasm: &[u8],
88    wall_ms: u64,
89    fuel: u64,
90) -> Result<extism::Plugin, BuildPluginError> {
91    let manifest = Manifest::new([Wasm::data(wasm.to_vec())])
92        .with_timeout(Duration::from_millis(wall_ms))
93        .disallow_all_hosts();
94    PluginBuilder::new(&manifest)
95        .with_wasi(false)
96        .with_cache_disabled()
97        .with_fuel_limit(fuel)
98        .build()
99        .map_err(|source| BuildPluginError::Instantiate {
100            reason: source.to_string(),
101        })
102}
103
104#[non_exhaustive]
105#[derive(Debug, Error)]
106pub enum BuildPluginError {
107    #[error("failed to instantiate plugin: {reason}")]
108    Instantiate { reason: String },
109}
110
111pub fn execute_handshake(
112    plugin_bytes: &[u8],
113    offer: &HandshakeOffer,
114) -> Result<HandshakeAccept, HandshakeExecutionError> {
115    offer.validate()?;
116    metrics::counter!("cc_lb_plugin_handshake_total").increment(1);
117
118    let mut plugin = build_plugin(plugin_bytes, HANDSHAKE_WALL_MS, HANDSHAKE_FUEL).map_err(
119        |source| match source {
120            BuildPluginError::Instantiate { reason } => {
121                HandshakeExecutionError::Instantiate { reason }
122            }
123        },
124    )?;
125
126    if !plugin.function_exists(HANDSHAKE_EXPORT) {
127        return Err(HandshakeExecutionError::MissingHandshakeExport);
128    }
129
130    let request =
131        serde_json::to_string(offer).map_err(|source| HandshakeExecutionError::SerializeOffer {
132            reason: source.to_string(),
133        })?;
134    let response = plugin
135        .call::<&str, String>(HANDSHAKE_EXPORT, request.as_str())
136        .map_err(|source| classify_call_error(source.to_string()))?;
137
138    if response.len() > HANDSHAKE_OUTPUT_MAX_BYTES {
139        return Err(HandshakeExecutionError::OutputTooLarge {
140            bytes: response.len(),
141            max: HANDSHAKE_OUTPUT_MAX_BYTES,
142        });
143    }
144
145    let accept: HandshakeAccept = serde_json::from_str(&response).map_err(|source| {
146        HandshakeExecutionError::DecodeAccept {
147            reason: source.to_string(),
148        }
149    })?;
150    accept.validate_against_offer(offer)?;
151    validate_accept_shape(&accept, offer)?;
152    cross_check_implemented_exports(&plugin, &accept)?;
153
154    Ok(accept)
155}
156
157fn classify_call_error(reason: String) -> HandshakeExecutionError {
158    let lower = reason.to_ascii_lowercase();
159    if lower.contains("timeout")
160        || lower.contains("timed out")
161        || lower.contains("deadline")
162        || lower.contains("fuel")
163    {
164        HandshakeExecutionError::Timeout
165    } else {
166        HandshakeExecutionError::Call { reason }
167    }
168}
169
170fn wire_function_versions() -> [(&'static str, &'static [u32]); 7] {
171    [
172        (
173            <ShapeFn as WireFunction>::NAME,
174            <ShapeFn as WireFunction>::SUPPORTED_VERSIONS,
175        ),
176        (
177            <NormalizeErrorFn as WireFunction>::NAME,
178            <NormalizeErrorFn as WireFunction>::SUPPORTED_VERSIONS,
179        ),
180        (
181            <BuildSignerFn as WireFunction>::NAME,
182            <BuildSignerFn as WireFunction>::SUPPORTED_VERSIONS,
183        ),
184        (
185            <SignFn as WireFunction>::NAME,
186            <SignFn as WireFunction>::SUPPORTED_VERSIONS,
187        ),
188        (
189            <OnUnauthorizedFn as WireFunction>::NAME,
190            <OnUnauthorizedFn as WireFunction>::SUPPORTED_VERSIONS,
191        ),
192        (
193            <ObserveFn as WireFunction>::NAME,
194            <ObserveFn as WireFunction>::SUPPORTED_VERSIONS,
195        ),
196        (
197            <FilterFn as WireFunction>::NAME,
198            <FilterFn as WireFunction>::SUPPORTED_VERSIONS,
199        ),
200    ]
201}
202
203fn validate_accept_shape(
204    accept: &HandshakeAccept,
205    offer: &HandshakeOffer,
206) -> Result<(), HandshakeExecutionError> {
207    if accept.implemented_functions.len() > IMPLEMENTED_FUNCTIONS_MAX {
208        return Err(HandshakeExecutionError::ImplementedFunctionCountExceeded {
209            count: accept.implemented_functions.len(),
210            max: IMPLEMENTED_FUNCTIONS_MAX,
211        });
212    }
213
214    for function in &accept.implemented_functions {
215        if !offer.function_versions.contains_key(function) {
216            return Err(HandshakeExecutionError::ImplementedUnknownFunction {
217                function: function.clone(),
218            });
219        }
220    }
221
222    for (function, version) in &accept.plugin_supported {
223        if !offer.function_versions.contains_key(function) {
224            return Err(HandshakeExecutionError::SupportedUnknownFunction {
225                function: function.clone(),
226            });
227        }
228        for &supported in version {
229            if !(VERSION_MIN..=VERSION_MAX).contains(&supported) {
230                return Err(HandshakeExecutionError::SupportedVersionOutOfRange {
231                    function: function.clone(),
232                    version: supported,
233                    min: VERSION_MIN,
234                    max: VERSION_MAX,
235                });
236            }
237        }
238    }
239
240    for (function, chosen) in &accept.chosen_versions {
241        if !accept.implemented_functions.contains(function) {
242            return Err(HandshakeExecutionError::ChosenFunctionNotImplemented {
243                function: function.clone(),
244            });
245        }
246        let Some(supported) = accept.plugin_supported.get(function) else {
247            return Err(HandshakeExecutionError::ChosenVersionNotSupported {
248                function: function.clone(),
249                version: *chosen,
250            });
251        };
252        if !supported.contains(chosen) {
253            return Err(HandshakeExecutionError::ChosenVersionNotSupported {
254                function: function.clone(),
255                version: *chosen,
256            });
257        }
258    }
259
260    Ok(())
261}
262
263fn cross_check_implemented_exports(
264    plugin: &extism::Plugin,
265    accept: &HandshakeAccept,
266) -> Result<(), HandshakeExecutionError> {
267    for function in &accept.implemented_functions {
268        if !plugin.function_exists(function) {
269            return Err(HandshakeExecutionError::DeclaredFunctionMissing {
270                function: function.clone(),
271            });
272        }
273    }
274    for function in all_wire_functions() {
275        if plugin.function_exists(function) && !accept.implemented_functions.contains(*function) {
276            return Err(HandshakeExecutionError::UndeclaredExport {
277                function: (*function).to_owned(),
278            });
279        }
280    }
281    Ok(())
282}
283
284#[non_exhaustive]
285#[derive(Debug, Error)]
286pub enum HandshakeExecutionError {
287    #[error("handshake validation failed: {0}")]
288    Validation(#[from] HandshakeError),
289    #[error("handshake plugin instantiation failed: {reason}")]
290    Instantiate { reason: String },
291    #[error("plugin does not export cc_lb_handshake")]
292    MissingHandshakeExport,
293    #[error("handshake offer serialization failed: {reason}")]
294    SerializeOffer { reason: String },
295    #[error("handshake call failed: {reason}")]
296    Call { reason: String },
297    #[error("handshake call exceeded timeout/fuel budget")]
298    Timeout,
299    #[error("handshake output size {bytes} exceeds maximum {max}")]
300    OutputTooLarge { bytes: usize, max: usize },
301    #[error("handshake accept decode failed: {reason}")]
302    DecodeAccept { reason: String },
303    #[error("implemented function count {count} exceeds maximum {max}")]
304    ImplementedFunctionCountExceeded { count: usize, max: usize },
305    #[error("implemented unknown function: {function}")]
306    ImplementedUnknownFunction { function: String },
307    #[error("supported unknown function: {function}")]
308    SupportedUnknownFunction { function: String },
309    #[error(
310        "supported version {version} for function {function} outside valid range [{min}, {max}]"
311    )]
312    SupportedVersionOutOfRange {
313        function: String,
314        version: u32,
315        min: u32,
316        max: u32,
317    },
318    #[error("chosen function not listed as implemented: {function}")]
319    ChosenFunctionNotImplemented { function: String },
320    #[error("chosen version {version} for function {function} not listed as plugin-supported")]
321    ChosenVersionNotSupported { function: String, version: u32 },
322    #[error("declared function missing wasm export: {function}")]
323    DeclaredFunctionMissing { function: String },
324    #[error("undeclared wire function export present: {function}")]
325    UndeclaredExport { function: String },
326}
327
328#[cfg(test)]
329mod tests {
330    use std::collections::{BTreeMap, BTreeSet};
331
332    use cc_lb_plugin_wire::handshake::HandshakeError;
333    use serde_json::json;
334
335    use super::*;
336
337    #[test]
338    fn slot_set_from_handshake_maps_filter_shape_observe() {
339        let fns: BTreeSet<String> = ["filter", "shape", "observe"]
340            .into_iter()
341            .map(String::from)
342            .collect();
343        assert_eq!(
344            slot_set_from_handshake(&fns),
345            vec![
346                PluginSlot::Router,
347                PluginSlot::ObservabilityHook,
348                PluginSlot::Shape,
349            ],
350        );
351    }
352
353    #[test]
354    fn slot_set_from_handshake_handles_partial_exports() {
355        let only_shape: BTreeSet<String> = ["shape".to_owned()].into_iter().collect();
356        assert_eq!(
357            slot_set_from_handshake(&only_shape),
358            vec![PluginSlot::Shape]
359        );
360
361        let only_filter: BTreeSet<String> = ["filter".to_owned()].into_iter().collect();
362        assert_eq!(
363            slot_set_from_handshake(&only_filter),
364            vec![PluginSlot::Router],
365        );
366    }
367
368    #[test]
369    fn slot_set_from_handshake_is_empty_when_no_slot_functions_exported() {
370        let empty: BTreeSet<String> = BTreeSet::new();
371        assert!(slot_set_from_handshake(&empty).is_empty());
372
373        let unrelated: BTreeSet<String> = ["sign".to_owned(), "build_signer".to_owned()]
374            .into_iter()
375            .collect();
376        assert!(slot_set_from_handshake(&unrelated).is_empty());
377    }
378
379    #[test]
380    fn slot_set_from_extism_exports_maps_only_route_to_router() {
381        let wasm = wat::parse_str(r#"(module (func (export "route") (result i32) (i32.const 0)))"#)
382            .expect("route-only wat parses");
383        assert_eq!(
384            slot_set_from_extism_exports(&wasm),
385            vec![PluginSlot::Router],
386        );
387    }
388
389    #[test]
390    fn slot_set_from_extism_exports_ignores_filter_shape_observe_exports() {
391        for export in ["filter", "shape", "observe"] {
392            let wat = format!(r#"(module (func (export "{export}") (result i32) (i32.const 0)))"#,);
393            let wasm = wat::parse_str(&wat).expect("single-export wat parses");
394            assert!(
395                slot_set_from_extism_exports(&wasm).is_empty(),
396                "fallback must not trust {export} export without handshake validation",
397            );
398        }
399    }
400
401    #[test]
402    fn slot_set_from_extism_exports_returns_router_for_route_plus_shape_legacy() {
403        let wasm = wat::parse_str(
404            r#"(module
405                (func (export "route") (result i32) (i32.const 0))
406                (func (export "shape") (result i32) (i32.const 0)))"#,
407        )
408        .expect("legacy route+shape wat parses");
409        assert_eq!(
410            slot_set_from_extism_exports(&wasm),
411            vec![PluginSlot::Router],
412            "shape must not promote without handshake; only route -> Router is trusted",
413        );
414    }
415
416    #[test]
417    fn slot_set_from_extism_exports_is_empty_for_module_without_known_exports() {
418        let wasm = wat::parse_str(r#"(module (func (export "noop") (result i32) (i32.const 0)))"#)
419            .expect("noop wat parses");
420        assert!(slot_set_from_extism_exports(&wasm).is_empty());
421    }
422
423    #[test]
424    fn slot_set_from_extism_exports_is_empty_for_corrupt_bytes() {
425        let bytes = vec![0u8; 16];
426        assert!(slot_set_from_extism_exports(&bytes).is_empty());
427    }
428
429    #[test]
430    fn slot_set_from_handshake_ignores_unknown_function_names_for_forward_compat() {
431        let mixed: BTreeSet<String> = ["filter", "future_slot_v9000"]
432            .into_iter()
433            .map(String::from)
434            .collect();
435        assert_eq!(slot_set_from_handshake(&mixed), vec![PluginSlot::Router]);
436    }
437
438    #[test]
439    fn build_offer_lists_v1_wire_functions_and_host_capabilities() {
440        let host_caps = BTreeSet::from(["streaming".to_owned(), "storage".to_owned()]);
441
442        let offer = build_offer(&host_caps);
443
444        assert_eq!(offer.handshake_schema_version, HANDSHAKE_SCHEMA_VERSION_V1);
445        assert_eq!(offer.envelope_version, ENVELOPE_VERSION_V1);
446        assert_eq!(offer.host_capabilities, host_caps);
447        for (name, versions) in wire_function_versions() {
448            assert_eq!(offer.function_versions.get(name), Some(&versions.to_vec()));
449        }
450        offer.validate().expect("host offer is valid");
451    }
452
453    #[test]
454    fn execute_handshake_accepts_valid_plugin_and_checks_export() {
455        let offer = build_offer(&BTreeSet::from(["streaming".to_owned()]));
456        let accept = accept_json(
457            &["shape"],
458            &[("shape", &[1])],
459            &[("shape", 1)],
460            &["streaming"],
461        );
462        let wasm = handshake_module(&accept, &["shape"], false);
463
464        let actual = execute_handshake(&wasm, &offer).expect("handshake succeeds");
465
466        assert!(actual.implemented_functions.contains("shape"));
467        assert_eq!(actual.chosen_versions.get("shape"), Some(&1));
468    }
469
470    #[test]
471    fn execute_handshake_rejects_downgrade() {
472        let mut offer = build_offer(&BTreeSet::new());
473        offer
474            .function_versions
475            .insert("shape".to_owned(), vec![1, 2, 3]);
476        let accept = accept_json(&["shape"], &[("shape", &[1, 2, 3])], &[("shape", 1)], &[]);
477        let wasm = handshake_module(&accept, &["shape"], false);
478
479        let err = execute_handshake(&wasm, &offer).expect_err("downgrade rejected");
480
481        match err {
482            HandshakeExecutionError::Validation(HandshakeError::DowngradeAttempt { .. }) => {}
483            other => panic!("expected downgrade error, got {other:?}"),
484        }
485    }
486
487    #[test]
488    fn execute_handshake_rejects_implemented_function_without_export() {
489        let offer = build_offer(&BTreeSet::new());
490        let accept = accept_json(&["shape"], &[("shape", &[1])], &[("shape", 1)], &[]);
491        let wasm = handshake_module(&accept, &[], false);
492
493        let err = execute_handshake(&wasm, &offer).expect_err("missing export rejected");
494
495        match err {
496            HandshakeExecutionError::DeclaredFunctionMissing { function } => {
497                assert_eq!(function, "shape");
498            }
499            other => panic!("expected missing export, got {other:?}"),
500        }
501    }
502
503    #[test]
504    fn execute_handshake_rejects_user_host_imports() {
505        let offer = build_offer(&BTreeSet::new());
506        let accept = accept_json(&[], &[], &[], &[]);
507        let wasm = handshake_module(&accept, &[], true);
508
509        let err = execute_handshake(&wasm, &offer).expect_err("host import rejected");
510
511        match err {
512            HandshakeExecutionError::Instantiate { .. } | HandshakeExecutionError::Call { .. } => {}
513            other => panic!("expected purity failure, got {other:?}"),
514        }
515    }
516
517    fn accept_json(
518        implemented: &[&str],
519        supported: &[(&str, &[u32])],
520        chosen: &[(&str, u32)],
521        required_caps: &[&str],
522    ) -> String {
523        let implemented_functions: BTreeSet<_> =
524            implemented.iter().map(|name| name.to_string()).collect();
525        let plugin_supported: BTreeMap<_, _> = supported
526            .iter()
527            .map(|(name, versions)| (name.to_string(), versions.to_vec()))
528            .collect();
529        let chosen_versions: BTreeMap<_, _> = chosen
530            .iter()
531            .map(|(name, version)| (name.to_string(), *version))
532            .collect();
533        let required_capabilities: BTreeSet<_> = required_caps
534            .iter()
535            .map(|capability| capability.to_string())
536            .collect();
537
538        json!({
539            "handshake_schema_version": HANDSHAKE_SCHEMA_VERSION_V1,
540            "envelope_version": ENVELOPE_VERSION_V1,
541            "chosen_versions": chosen_versions,
542            "plugin_supported": plugin_supported,
543            "implemented_functions": implemented_functions,
544            "required_capabilities": required_capabilities,
545        })
546        .to_string()
547    }
548
549    fn handshake_module(output: &str, extra_exports: &[&str], import_user_host: bool) -> Vec<u8> {
550        let output_helper = bytes_helper("handshake_out", output.as_bytes());
551        let user_import = if import_user_host {
552            r#"(import "extism:host/user" "cc_lb_log" (func $cc_lb_log (param i64 i64)))"#
553        } else {
554            ""
555        };
556        let user_call = if import_user_host {
557            "  (call $cc_lb_log (call $handshake_out) (call $handshake_out))"
558        } else {
559            ""
560        };
561        let mut exports = String::new();
562        for export in extra_exports {
563            exports.push_str(&format!(
564                r#"
565(func (export "{export}") (result i32)
566  (i32.const 0))
567"#
568            ));
569        }
570
571        let wat = format!(
572            r#"
573(module
574  (import "extism:host/env" "alloc" (func $alloc (param i64) (result i64)))
575  (import "extism:host/env" "store_u8" (func $store_u8 (param i64 i32)))
576  (import "extism:host/env" "output_set" (func $output_set (param i64 i64)))
577  {user_import}
578  {output_helper}
579  (func (export "cc_lb_handshake") (result i32)
580{user_call}
581    (call $output_set (call $handshake_out) (i64.const {len}))
582    (i32.const 0))
583  {exports}
584)
585"#,
586            len = output.len()
587        );
588        wat::parse_str(&wat).expect("handshake wat parses")
589    }
590
591    fn bytes_helper(name: &str, bytes: &[u8]) -> String {
592        let mut stores = String::new();
593        for (index, byte) in bytes.iter().enumerate() {
594            stores.push_str(&format!(
595                "  (call $store_u8 (i64.add (local.get $ptr) (i64.const {index})) (i32.const {byte}))\n"
596            ));
597        }
598        format!(
599            r#"
600(func ${name} (result i64)
601  (local $ptr i64)
602  (local.set $ptr (call $alloc (i64.const {len})))
603{stores}  (local.get $ptr))
604"#,
605            len = bytes.len()
606        )
607    }
608}