Skip to main content

cc_lb_runtime_protocol/
dispatch.rs

1use std::panic::{self, AssertUnwindSafe};
2
3use cc_lb_plugin_wire::augmented_metadata::AugmentedMetadata;
4use cc_lb_plugin_wire::wire_function::{FallbackPolicy, WireFunction};
5use extism::Plugin;
6use serde_json::Value;
7use thiserror::Error;
8
9#[non_exhaustive]
10#[derive(Clone, Debug, PartialEq, Eq)]
11pub enum DispatchOutcome<R> {
12    Ok(R),
13    Fallback(FallbackPolicy),
14}
15
16pub fn dispatch_wire_call<F: WireFunction>(
17    plugin: &mut Plugin,
18    metadata: &AugmentedMetadata,
19    request: F::Request,
20) -> DispatchOutcome<F::Response> {
21    match panic::catch_unwind(AssertUnwindSafe(|| {
22        dispatch_wire_call_inner::<F>(plugin, metadata, request)
23    })) {
24        Ok(Ok(response)) => DispatchOutcome::Ok(response),
25        Ok(Err(error)) => fallback::<F, F::Response>(error),
26        Err(_) => fallback::<F, F::Response>(DispatchError::Panic),
27    }
28}
29
30fn dispatch_wire_call_inner<F: WireFunction>(
31    plugin: &mut Plugin,
32    metadata: &AugmentedMetadata,
33    request: F::Request,
34) -> Result<F::Response, DispatchError> {
35    let negotiated_version = *metadata
36        .negotiated_functions
37        .get(F::NAME)
38        .ok_or(DispatchError::MissingNegotiatedVersion)?;
39
40    let request_value =
41        serde_json::to_value(request).map_err(|source| DispatchError::RequestValue {
42            reason: source.to_string(),
43        })?;
44    let Value::Object(mut request_map) = request_value else {
45        return Err(DispatchError::RequestNotObject);
46    };
47    request_map.insert("_v".to_owned(), Value::from(negotiated_version));
48
49    let input = serde_json::to_string(&Value::Object(request_map)).map_err(|source| {
50        DispatchError::SerializeRequestEnvelope {
51            reason: source.to_string(),
52        }
53    })?;
54    let output = plugin
55        .call::<String, String>(F::NAME, input)
56        .map_err(|source| DispatchError::PluginCall {
57            reason: source.to_string(),
58        })?;
59
60    let response_value: Value = serde_json::from_str(&output).map_err(|source| {
61        DispatchError::DeserializeResponseEnvelope {
62            reason: source.to_string(),
63        }
64    })?;
65    let Value::Object(mut response_map) = response_value else {
66        return Err(DispatchError::ResponseNotObject);
67    };
68    let response_version = response_map
69        .remove("_v")
70        .ok_or(DispatchError::ResponseVersionMissing)?;
71    let actual_version = response_version
72        .as_u64()
73        .ok_or(DispatchError::ResponseVersionInvalid)?;
74    if actual_version != u64::from(negotiated_version) {
75        return Err(DispatchError::ResponseVersionMismatch {
76            expected: negotiated_version,
77            actual: actual_version,
78        });
79    }
80
81    serde_json::from_value(Value::Object(response_map)).map_err(|source| {
82        DispatchError::ResponseDecode {
83            reason: source.to_string(),
84        }
85    })
86}
87
88fn fallback<F: WireFunction, R>(error: DispatchError) -> DispatchOutcome<R> {
89    let stage_label = error.stage_label();
90    metrics::counter!(
91        "cc_lb_plugin_dispatch_errors_total",
92        "function" => F::NAME,
93        "stage" => stage_label
94    )
95    .increment(1);
96    tracing::warn!(
97        target: "cc_lb_plugin.dispatch",
98        function = F::NAME,
99        error = ?error,
100        "dispatch error, applying fallback {:?}",
101        F::FALLBACK
102    );
103    DispatchOutcome::Fallback(F::FALLBACK)
104}
105
106#[derive(Debug, Error)]
107enum DispatchError {
108    #[error("negotiated version missing")]
109    MissingNegotiatedVersion,
110    #[error("request value serialization failed: {reason}")]
111    RequestValue { reason: String },
112    #[error("request serialized to non-object JSON")]
113    RequestNotObject,
114    #[error("request envelope serialization failed: {reason}")]
115    SerializeRequestEnvelope { reason: String },
116    #[error("plugin call failed: {reason}")]
117    PluginCall { reason: String },
118    #[error("response envelope deserialization failed: {reason}")]
119    DeserializeResponseEnvelope { reason: String },
120    #[error("response envelope was not a JSON object")]
121    ResponseNotObject,
122    #[error("response envelope missing _v")]
123    ResponseVersionMissing,
124    #[error("response envelope _v was not an unsigned integer")]
125    ResponseVersionInvalid,
126    #[error("response envelope version mismatch: expected {expected}, actual {actual}")]
127    ResponseVersionMismatch { expected: u32, actual: u64 },
128    #[error("response decode failed: {reason}")]
129    ResponseDecode { reason: String },
130    #[error("dispatch panicked")]
131    Panic,
132}
133
134impl DispatchError {
135    fn stage_label(&self) -> &'static str {
136        match self {
137            Self::MissingNegotiatedVersion => "version_lookup",
138            Self::RequestValue { .. } | Self::RequestNotObject => "request_envelope",
139            Self::SerializeRequestEnvelope { .. } => "serialize_request",
140            Self::PluginCall { .. } => "plugin_call",
141            Self::DeserializeResponseEnvelope { .. } => "deserialize_response",
142            Self::ResponseNotObject
143            | Self::ResponseVersionMissing
144            | Self::ResponseVersionInvalid
145            | Self::ResponseVersionMismatch { .. } => "response_envelope",
146            Self::ResponseDecode { .. } => "response_decode",
147            Self::Panic => "panic",
148        }
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use std::collections::{BTreeMap, BTreeSet};
155    use std::fmt;
156    use std::sync::atomic::{AtomicU64, Ordering};
157    use std::sync::{Arc, Mutex};
158
159    use cc_lb_plugin_wire::augmented_metadata::AugmentedMetadata;
160    use cc_lb_plugin_wire::identity::{CC_LB_PLUGIN_MAGIC, PluginIdentity};
161    use cc_lb_plugin_wire::v1::sign::{SignFn, SignResponse};
162    use cc_lb_plugin_wire::wire_function::{FallbackPolicy, WireFunction};
163    use extism::{Manifest, PluginBuilder, Wasm};
164    use metrics::{
165        Counter, CounterFn, Gauge, Histogram, Key, KeyName, Metadata, Recorder, SharedString, Unit,
166    };
167    use tracing::field::{Field, Visit};
168    use tracing::span::{Attributes, Id, Record};
169    use tracing::{Event, Level, Metadata as TracingMetadata, Subscriber};
170
171    use super::*;
172
173    #[test]
174    fn successful_call_decodes_versioned_response() {
175        let mut plugin = plugin_with_output(
176            <SignFn as WireFunction>::NAME,
177            r#"{"_v":1,"url":null,"method":null,"headers":null,"body_base64":null}"#,
178        );
179        let metadata = metadata_with_function(<SignFn as WireFunction>::NAME);
180
181        let outcome = dispatch_wire_call::<SignFn>(
182            &mut plugin,
183            &metadata,
184            <SignFn as WireFunction>::dry_run_request(),
185        );
186
187        assert_eq!(outcome, DispatchOutcome::Ok(SignResponse::dry_run_sample()));
188    }
189
190    #[test]
191    fn sign_fallback_is_fail_request_const() {
192        assert_eq!(
193            <SignFn as WireFunction>::FALLBACK,
194            FallbackPolicy::FailRequest
195        );
196        let mut plugin = plugin_with_output(
197            <SignFn as WireFunction>::NAME,
198            r#"{"_v":1,"url":null,"method":null,"headers":null,"body_base64":null}"#,
199        );
200        let metadata = metadata_with_function("shape");
201
202        let outcome = dispatch_wire_call::<SignFn>(
203            &mut plugin,
204            &metadata,
205            <SignFn as WireFunction>::dry_run_request(),
206        );
207
208        assert_eq!(
209            outcome,
210            DispatchOutcome::Fallback(FallbackPolicy::FailRequest)
211        );
212    }
213
214    #[test]
215    fn response_version_mismatch_falls_back_to_function_const() {
216        let mut plugin = plugin_with_output(
217            <SignFn as WireFunction>::NAME,
218            r#"{"_v":2,"url":null,"method":null,"headers":null,"body_base64":null}"#,
219        );
220        let metadata = metadata_with_function(<SignFn as WireFunction>::NAME);
221
222        let outcome = dispatch_wire_call::<SignFn>(
223            &mut plugin,
224            &metadata,
225            <SignFn as WireFunction>::dry_run_request(),
226        );
227
228        assert_eq!(
229            outcome,
230            DispatchOutcome::Fallback(FallbackPolicy::FailRequest)
231        );
232    }
233
234    #[test]
235    fn error_logs_emitted() {
236        let logs = CapturingSubscriber::default();
237        let captured_logs = logs.events.clone();
238        let metrics = CapturingMetrics::default();
239        let captured_metric_total = metrics.total.clone();
240        let captured_metric_keys = metrics.keys.clone();
241        let mut plugin = plugin_with_output(
242            <SignFn as WireFunction>::NAME,
243            r#"{"_v":1,"url":null,"method":null,"headers":null,"body_base64":null}"#,
244        );
245        let metadata = metadata_with_function("shape");
246
247        let outcome = metrics::with_local_recorder(&metrics, || {
248            tracing::subscriber::with_default(logs, || {
249                dispatch_wire_call::<SignFn>(
250                    &mut plugin,
251                    &metadata,
252                    <SignFn as WireFunction>::dry_run_request(),
253                )
254            })
255        });
256
257        assert_eq!(
258            outcome,
259            DispatchOutcome::Fallback(FallbackPolicy::FailRequest)
260        );
261        let rendered_logs = captured_logs.lock().expect("logs lock").join("\n");
262        assert!(
263            rendered_logs.contains("target=cc_lb_plugin.dispatch"),
264            "log target missing from {rendered_logs}"
265        );
266        assert!(
267            rendered_logs.contains("dispatch error, applying fallback FailRequest"),
268            "message missing from {rendered_logs}"
269        );
270        assert!(
271            rendered_logs.contains("function=\"sign\""),
272            "function field missing from {rendered_logs}"
273        );
274        assert!(
275            rendered_logs.contains("MissingNegotiatedVersion"),
276            "error field missing from {rendered_logs}"
277        );
278        assert_eq!(captured_metric_total.load(Ordering::SeqCst), 1);
279        let metric_keys = captured_metric_keys
280            .lock()
281            .expect("metric keys lock")
282            .join("\n");
283        assert!(
284            metric_keys.contains("cc_lb_plugin_dispatch_errors_total"),
285            "metric name missing from {metric_keys}"
286        );
287        assert!(
288            metric_keys.contains("version_lookup"),
289            "metric stage missing from {metric_keys}"
290        );
291    }
292
293    fn metadata_with_function(function_name: &str) -> AugmentedMetadata {
294        AugmentedMetadata {
295            identity: PluginIdentity {
296                magic: CC_LB_PLUGIN_MAGIC,
297                abi_envelope: 1,
298                plugin_name: "dispatch-test".to_owned(),
299                plugin_version: "1.0.0".to_owned(),
300            },
301            negotiated_functions: BTreeMap::from([(function_name.to_owned(), 1)]),
302            negotiated_capabilities: BTreeSet::new(),
303            handshake_completed_at: 1,
304            self_check_passed: true,
305            self_check_completed_at: 1,
306            expires_at: 2,
307        }
308    }
309
310    fn plugin_with_output(function_name: &str, output: &str) -> extism::Plugin {
311        let output_bytes = output.as_bytes();
312        let wat = format!(
313            r#"
314(module
315  (import "extism:host/env" "alloc" (func $alloc (param i64) (result i64)))
316  (import "extism:host/env" "store_u8" (func $store_u8 (param i64 i32)))
317  (import "extism:host/env" "output_set" (func $output_set (param i64 i64)))
318  {output_helper}
319  (func (export "{function_name}") (result i32)
320    (call $output_set (call $dispatch_out) (i64.const {output_len}))
321    (i32.const 0))
322)
323"#,
324            output_helper = bytes_helper("dispatch_out", output_bytes),
325            output_len = output_bytes.len(),
326        );
327        let wasm = wat::parse_str(wat).expect("dispatch fixture wat parses");
328        let manifest = Manifest::new([Wasm::data(wasm)]).disallow_all_hosts();
329        PluginBuilder::new(&manifest)
330            .with_wasi(false)
331            .with_cache_disabled()
332            .build()
333            .expect("dispatch fixture plugin builds")
334    }
335
336    fn bytes_helper(name: &str, bytes: &[u8]) -> String {
337        let mut stores = String::new();
338        for (index, byte) in bytes.iter().enumerate() {
339            stores.push_str(&format!(
340                "  (call $store_u8 (i64.add (local.get $ptr) (i64.const {index})) (i32.const {byte}))\n"
341            ));
342        }
343        format!(
344            r#"
345(func ${name} (result i64)
346  (local $ptr i64)
347  (local.set $ptr (call $alloc (i64.const {len})))
348{stores}  (local.get $ptr))
349"#,
350            len = bytes.len()
351        )
352    }
353
354    #[derive(Clone, Default)]
355    struct CapturingSubscriber {
356        events: Arc<Mutex<Vec<String>>>,
357    }
358
359    impl Subscriber for CapturingSubscriber {
360        fn enabled(&self, metadata: &TracingMetadata<'_>) -> bool {
361            metadata.target() == "cc_lb_plugin.dispatch" && metadata.level() <= &Level::WARN
362        }
363
364        fn new_span(&self, _span: &Attributes<'_>) -> Id {
365            Id::from_u64(1)
366        }
367
368        fn record(&self, _span: &Id, _values: &Record<'_>) {}
369
370        fn record_follows_from(&self, _span: &Id, _follows: &Id) {}
371
372        fn event(&self, event: &Event<'_>) {
373            let mut visitor = EventVisitor::default();
374            event.record(&mut visitor);
375            self.events.lock().expect("events lock").push(format!(
376                "target={} {}",
377                event.metadata().target(),
378                visitor.fields
379            ));
380        }
381
382        fn enter(&self, _span: &Id) {}
383
384        fn exit(&self, _span: &Id) {}
385    }
386
387    #[derive(Default)]
388    struct EventVisitor {
389        fields: String,
390    }
391
392    impl Visit for EventVisitor {
393        fn record_debug(&mut self, field: &Field, value: &dyn fmt::Debug) {
394            if !self.fields.is_empty() {
395                self.fields.push(' ');
396            }
397            self.fields.push_str(&format!("{}={value:?}", field.name()));
398        }
399    }
400
401    #[derive(Clone, Default)]
402    struct CapturingMetrics {
403        total: Arc<AtomicU64>,
404        keys: Arc<Mutex<Vec<String>>>,
405    }
406
407    impl Recorder for CapturingMetrics {
408        fn describe_counter(&self, _key: KeyName, _unit: Option<Unit>, _description: SharedString) {
409        }
410
411        fn describe_gauge(&self, _key: KeyName, _unit: Option<Unit>, _description: SharedString) {}
412
413        fn describe_histogram(
414            &self,
415            _key: KeyName,
416            _unit: Option<Unit>,
417            _description: SharedString,
418        ) {
419        }
420
421        fn register_counter(&self, key: &Key, _metadata: &Metadata<'_>) -> Counter {
422            self.keys
423                .lock()
424                .expect("metric keys lock")
425                .push(format!("{key:?}"));
426            Counter::from_arc(Arc::new(TestCounter {
427                total: self.total.clone(),
428            }))
429        }
430
431        fn register_gauge(&self, _key: &Key, _metadata: &Metadata<'_>) -> Gauge {
432            Gauge::noop()
433        }
434
435        fn register_histogram(&self, _key: &Key, _metadata: &Metadata<'_>) -> Histogram {
436            Histogram::noop()
437        }
438    }
439
440    struct TestCounter {
441        total: Arc<AtomicU64>,
442    }
443
444    impl CounterFn for TestCounter {
445        fn increment(&self, value: u64) {
446            self.total.fetch_add(value, Ordering::SeqCst);
447        }
448
449        fn absolute(&self, value: u64) {
450            self.total.store(value, Ordering::SeqCst);
451        }
452    }
453}