1use std::panic::{self, AssertUnwindSafe};
2
3use cc_lb_plugin_wire::augmented_metadata::AugmentedMetadata;
4use cc_lb_plugin_wire::wire_function::{FallbackPolicy, WireFunction};
5use extism::Plugin;
6use serde_json::Value;
7use thiserror::Error;
8
9#[non_exhaustive]
10#[derive(Clone, Debug, PartialEq, Eq)]
11pub enum DispatchOutcome<R> {
12 Ok(R),
13 Fallback(FallbackPolicy),
14}
15
16pub fn dispatch_wire_call<F: WireFunction>(
17 plugin: &mut Plugin,
18 metadata: &AugmentedMetadata,
19 request: F::Request,
20) -> DispatchOutcome<F::Response> {
21 match panic::catch_unwind(AssertUnwindSafe(|| {
22 dispatch_wire_call_inner::<F>(plugin, metadata, request)
23 })) {
24 Ok(Ok(response)) => DispatchOutcome::Ok(response),
25 Ok(Err(error)) => fallback::<F, F::Response>(error),
26 Err(_) => fallback::<F, F::Response>(DispatchError::Panic),
27 }
28}
29
30fn dispatch_wire_call_inner<F: WireFunction>(
31 plugin: &mut Plugin,
32 metadata: &AugmentedMetadata,
33 request: F::Request,
34) -> Result<F::Response, DispatchError> {
35 let negotiated_version = *metadata
36 .negotiated_functions
37 .get(F::NAME)
38 .ok_or(DispatchError::MissingNegotiatedVersion)?;
39
40 let request_value =
41 serde_json::to_value(request).map_err(|source| DispatchError::RequestValue {
42 reason: source.to_string(),
43 })?;
44 let Value::Object(mut request_map) = request_value else {
45 return Err(DispatchError::RequestNotObject);
46 };
47 request_map.insert("_v".to_owned(), Value::from(negotiated_version));
48
49 let input = serde_json::to_string(&Value::Object(request_map)).map_err(|source| {
50 DispatchError::SerializeRequestEnvelope {
51 reason: source.to_string(),
52 }
53 })?;
54 let output = plugin
55 .call::<String, String>(F::NAME, input)
56 .map_err(|source| DispatchError::PluginCall {
57 reason: source.to_string(),
58 })?;
59
60 let response_value: Value = serde_json::from_str(&output).map_err(|source| {
61 DispatchError::DeserializeResponseEnvelope {
62 reason: source.to_string(),
63 }
64 })?;
65 let Value::Object(mut response_map) = response_value else {
66 return Err(DispatchError::ResponseNotObject);
67 };
68 let response_version = response_map
69 .remove("_v")
70 .ok_or(DispatchError::ResponseVersionMissing)?;
71 let actual_version = response_version
72 .as_u64()
73 .ok_or(DispatchError::ResponseVersionInvalid)?;
74 if actual_version != u64::from(negotiated_version) {
75 return Err(DispatchError::ResponseVersionMismatch {
76 expected: negotiated_version,
77 actual: actual_version,
78 });
79 }
80
81 serde_json::from_value(Value::Object(response_map)).map_err(|source| {
82 DispatchError::ResponseDecode {
83 reason: source.to_string(),
84 }
85 })
86}
87
88fn fallback<F: WireFunction, R>(error: DispatchError) -> DispatchOutcome<R> {
89 let stage_label = error.stage_label();
90 metrics::counter!(
91 "cc_lb_plugin_dispatch_errors_total",
92 "function" => F::NAME,
93 "stage" => stage_label
94 )
95 .increment(1);
96 tracing::warn!(
97 target: "cc_lb_plugin.dispatch",
98 function = F::NAME,
99 error = ?error,
100 "dispatch error, applying fallback {:?}",
101 F::FALLBACK
102 );
103 DispatchOutcome::Fallback(F::FALLBACK)
104}
105
106#[derive(Debug, Error)]
107enum DispatchError {
108 #[error("negotiated version missing")]
109 MissingNegotiatedVersion,
110 #[error("request value serialization failed: {reason}")]
111 RequestValue { reason: String },
112 #[error("request serialized to non-object JSON")]
113 RequestNotObject,
114 #[error("request envelope serialization failed: {reason}")]
115 SerializeRequestEnvelope { reason: String },
116 #[error("plugin call failed: {reason}")]
117 PluginCall { reason: String },
118 #[error("response envelope deserialization failed: {reason}")]
119 DeserializeResponseEnvelope { reason: String },
120 #[error("response envelope was not a JSON object")]
121 ResponseNotObject,
122 #[error("response envelope missing _v")]
123 ResponseVersionMissing,
124 #[error("response envelope _v was not an unsigned integer")]
125 ResponseVersionInvalid,
126 #[error("response envelope version mismatch: expected {expected}, actual {actual}")]
127 ResponseVersionMismatch { expected: u32, actual: u64 },
128 #[error("response decode failed: {reason}")]
129 ResponseDecode { reason: String },
130 #[error("dispatch panicked")]
131 Panic,
132}
133
134impl DispatchError {
135 fn stage_label(&self) -> &'static str {
136 match self {
137 Self::MissingNegotiatedVersion => "version_lookup",
138 Self::RequestValue { .. } | Self::RequestNotObject => "request_envelope",
139 Self::SerializeRequestEnvelope { .. } => "serialize_request",
140 Self::PluginCall { .. } => "plugin_call",
141 Self::DeserializeResponseEnvelope { .. } => "deserialize_response",
142 Self::ResponseNotObject
143 | Self::ResponseVersionMissing
144 | Self::ResponseVersionInvalid
145 | Self::ResponseVersionMismatch { .. } => "response_envelope",
146 Self::ResponseDecode { .. } => "response_decode",
147 Self::Panic => "panic",
148 }
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use std::collections::{BTreeMap, BTreeSet};
155 use std::fmt;
156 use std::sync::atomic::{AtomicU64, Ordering};
157 use std::sync::{Arc, Mutex};
158
159 use cc_lb_plugin_wire::augmented_metadata::AugmentedMetadata;
160 use cc_lb_plugin_wire::identity::{CC_LB_PLUGIN_MAGIC, PluginIdentity};
161 use cc_lb_plugin_wire::v1::sign::{SignFn, SignResponse};
162 use cc_lb_plugin_wire::wire_function::{FallbackPolicy, WireFunction};
163 use extism::{Manifest, PluginBuilder, Wasm};
164 use metrics::{
165 Counter, CounterFn, Gauge, Histogram, Key, KeyName, Metadata, Recorder, SharedString, Unit,
166 };
167 use tracing::field::{Field, Visit};
168 use tracing::span::{Attributes, Id, Record};
169 use tracing::{Event, Level, Metadata as TracingMetadata, Subscriber};
170
171 use super::*;
172
173 #[test]
174 fn successful_call_decodes_versioned_response() {
175 let mut plugin = plugin_with_output(
176 <SignFn as WireFunction>::NAME,
177 r#"{"_v":1,"url":null,"method":null,"headers":null,"body_base64":null}"#,
178 );
179 let metadata = metadata_with_function(<SignFn as WireFunction>::NAME);
180
181 let outcome = dispatch_wire_call::<SignFn>(
182 &mut plugin,
183 &metadata,
184 <SignFn as WireFunction>::dry_run_request(),
185 );
186
187 assert_eq!(outcome, DispatchOutcome::Ok(SignResponse::dry_run_sample()));
188 }
189
190 #[test]
191 fn sign_fallback_is_fail_request_const() {
192 assert_eq!(
193 <SignFn as WireFunction>::FALLBACK,
194 FallbackPolicy::FailRequest
195 );
196 let mut plugin = plugin_with_output(
197 <SignFn as WireFunction>::NAME,
198 r#"{"_v":1,"url":null,"method":null,"headers":null,"body_base64":null}"#,
199 );
200 let metadata = metadata_with_function("shape");
201
202 let outcome = dispatch_wire_call::<SignFn>(
203 &mut plugin,
204 &metadata,
205 <SignFn as WireFunction>::dry_run_request(),
206 );
207
208 assert_eq!(
209 outcome,
210 DispatchOutcome::Fallback(FallbackPolicy::FailRequest)
211 );
212 }
213
214 #[test]
215 fn response_version_mismatch_falls_back_to_function_const() {
216 let mut plugin = plugin_with_output(
217 <SignFn as WireFunction>::NAME,
218 r#"{"_v":2,"url":null,"method":null,"headers":null,"body_base64":null}"#,
219 );
220 let metadata = metadata_with_function(<SignFn as WireFunction>::NAME);
221
222 let outcome = dispatch_wire_call::<SignFn>(
223 &mut plugin,
224 &metadata,
225 <SignFn as WireFunction>::dry_run_request(),
226 );
227
228 assert_eq!(
229 outcome,
230 DispatchOutcome::Fallback(FallbackPolicy::FailRequest)
231 );
232 }
233
234 #[test]
235 fn error_logs_emitted() {
236 let logs = CapturingSubscriber::default();
237 let captured_logs = logs.events.clone();
238 let metrics = CapturingMetrics::default();
239 let captured_metric_total = metrics.total.clone();
240 let captured_metric_keys = metrics.keys.clone();
241 let mut plugin = plugin_with_output(
242 <SignFn as WireFunction>::NAME,
243 r#"{"_v":1,"url":null,"method":null,"headers":null,"body_base64":null}"#,
244 );
245 let metadata = metadata_with_function("shape");
246
247 let outcome = metrics::with_local_recorder(&metrics, || {
248 tracing::subscriber::with_default(logs, || {
249 dispatch_wire_call::<SignFn>(
250 &mut plugin,
251 &metadata,
252 <SignFn as WireFunction>::dry_run_request(),
253 )
254 })
255 });
256
257 assert_eq!(
258 outcome,
259 DispatchOutcome::Fallback(FallbackPolicy::FailRequest)
260 );
261 let rendered_logs = captured_logs.lock().expect("logs lock").join("\n");
262 assert!(
263 rendered_logs.contains("target=cc_lb_plugin.dispatch"),
264 "log target missing from {rendered_logs}"
265 );
266 assert!(
267 rendered_logs.contains("dispatch error, applying fallback FailRequest"),
268 "message missing from {rendered_logs}"
269 );
270 assert!(
271 rendered_logs.contains("function=\"sign\""),
272 "function field missing from {rendered_logs}"
273 );
274 assert!(
275 rendered_logs.contains("MissingNegotiatedVersion"),
276 "error field missing from {rendered_logs}"
277 );
278 assert_eq!(captured_metric_total.load(Ordering::SeqCst), 1);
279 let metric_keys = captured_metric_keys
280 .lock()
281 .expect("metric keys lock")
282 .join("\n");
283 assert!(
284 metric_keys.contains("cc_lb_plugin_dispatch_errors_total"),
285 "metric name missing from {metric_keys}"
286 );
287 assert!(
288 metric_keys.contains("version_lookup"),
289 "metric stage missing from {metric_keys}"
290 );
291 }
292
293 fn metadata_with_function(function_name: &str) -> AugmentedMetadata {
294 AugmentedMetadata {
295 identity: PluginIdentity {
296 magic: CC_LB_PLUGIN_MAGIC,
297 abi_envelope: 1,
298 plugin_name: "dispatch-test".to_owned(),
299 plugin_version: "1.0.0".to_owned(),
300 },
301 negotiated_functions: BTreeMap::from([(function_name.to_owned(), 1)]),
302 negotiated_capabilities: BTreeSet::new(),
303 handshake_completed_at: 1,
304 self_check_passed: true,
305 self_check_completed_at: 1,
306 expires_at: 2,
307 }
308 }
309
310 fn plugin_with_output(function_name: &str, output: &str) -> extism::Plugin {
311 let output_bytes = output.as_bytes();
312 let wat = format!(
313 r#"
314(module
315 (import "extism:host/env" "alloc" (func $alloc (param i64) (result i64)))
316 (import "extism:host/env" "store_u8" (func $store_u8 (param i64 i32)))
317 (import "extism:host/env" "output_set" (func $output_set (param i64 i64)))
318 {output_helper}
319 (func (export "{function_name}") (result i32)
320 (call $output_set (call $dispatch_out) (i64.const {output_len}))
321 (i32.const 0))
322)
323"#,
324 output_helper = bytes_helper("dispatch_out", output_bytes),
325 output_len = output_bytes.len(),
326 );
327 let wasm = wat::parse_str(wat).expect("dispatch fixture wat parses");
328 let manifest = Manifest::new([Wasm::data(wasm)]).disallow_all_hosts();
329 PluginBuilder::new(&manifest)
330 .with_wasi(false)
331 .with_cache_disabled()
332 .build()
333 .expect("dispatch fixture plugin builds")
334 }
335
336 fn bytes_helper(name: &str, bytes: &[u8]) -> String {
337 let mut stores = String::new();
338 for (index, byte) in bytes.iter().enumerate() {
339 stores.push_str(&format!(
340 " (call $store_u8 (i64.add (local.get $ptr) (i64.const {index})) (i32.const {byte}))\n"
341 ));
342 }
343 format!(
344 r#"
345(func ${name} (result i64)
346 (local $ptr i64)
347 (local.set $ptr (call $alloc (i64.const {len})))
348{stores} (local.get $ptr))
349"#,
350 len = bytes.len()
351 )
352 }
353
354 #[derive(Clone, Default)]
355 struct CapturingSubscriber {
356 events: Arc<Mutex<Vec<String>>>,
357 }
358
359 impl Subscriber for CapturingSubscriber {
360 fn enabled(&self, metadata: &TracingMetadata<'_>) -> bool {
361 metadata.target() == "cc_lb_plugin.dispatch" && metadata.level() <= &Level::WARN
362 }
363
364 fn new_span(&self, _span: &Attributes<'_>) -> Id {
365 Id::from_u64(1)
366 }
367
368 fn record(&self, _span: &Id, _values: &Record<'_>) {}
369
370 fn record_follows_from(&self, _span: &Id, _follows: &Id) {}
371
372 fn event(&self, event: &Event<'_>) {
373 let mut visitor = EventVisitor::default();
374 event.record(&mut visitor);
375 self.events.lock().expect("events lock").push(format!(
376 "target={} {}",
377 event.metadata().target(),
378 visitor.fields
379 ));
380 }
381
382 fn enter(&self, _span: &Id) {}
383
384 fn exit(&self, _span: &Id) {}
385 }
386
387 #[derive(Default)]
388 struct EventVisitor {
389 fields: String,
390 }
391
392 impl Visit for EventVisitor {
393 fn record_debug(&mut self, field: &Field, value: &dyn fmt::Debug) {
394 if !self.fields.is_empty() {
395 self.fields.push(' ');
396 }
397 self.fields.push_str(&format!("{}={value:?}", field.name()));
398 }
399 }
400
401 #[derive(Clone, Default)]
402 struct CapturingMetrics {
403 total: Arc<AtomicU64>,
404 keys: Arc<Mutex<Vec<String>>>,
405 }
406
407 impl Recorder for CapturingMetrics {
408 fn describe_counter(&self, _key: KeyName, _unit: Option<Unit>, _description: SharedString) {
409 }
410
411 fn describe_gauge(&self, _key: KeyName, _unit: Option<Unit>, _description: SharedString) {}
412
413 fn describe_histogram(
414 &self,
415 _key: KeyName,
416 _unit: Option<Unit>,
417 _description: SharedString,
418 ) {
419 }
420
421 fn register_counter(&self, key: &Key, _metadata: &Metadata<'_>) -> Counter {
422 self.keys
423 .lock()
424 .expect("metric keys lock")
425 .push(format!("{key:?}"));
426 Counter::from_arc(Arc::new(TestCounter {
427 total: self.total.clone(),
428 }))
429 }
430
431 fn register_gauge(&self, _key: &Key, _metadata: &Metadata<'_>) -> Gauge {
432 Gauge::noop()
433 }
434
435 fn register_histogram(&self, _key: &Key, _metadata: &Metadata<'_>) -> Histogram {
436 Histogram::noop()
437 }
438 }
439
440 struct TestCounter {
441 total: Arc<AtomicU64>,
442 }
443
444 impl CounterFn for TestCounter {
445 fn increment(&self, value: u64) {
446 self.total.fetch_add(value, Ordering::SeqCst);
447 }
448
449 fn absolute(&self, value: u64) {
450 self.total.store(value, Ordering::SeqCst);
451 }
452 }
453}