Skip to main content

cc_lb_runtime_protocol/
self_check.rs

1use std::time::{SystemTime, UNIX_EPOCH};
2
3use cc_lb_plugin_api::PluginSlot;
4use cc_lb_plugin_wire::limits::{
5    IMPLEMENTED_FUNCTIONS_MAX, SELF_CHECK_FUEL, SELF_CHECK_OUTPUT_MAX_BYTES, SELF_CHECK_WALL_MS,
6};
7use cc_lb_plugin_wire::self_check::{
8    SelfCheckError, SelfCheckRequest, SelfCheckResponse, SelfCheckStatus,
9};
10use thiserror::Error;
11
12use crate::handshake::{BuildPluginError, build_plugin, slot_to_wire_function};
13
14const SELF_CHECK_EXPORT: &str = "cc_lb_self_check";
15
16pub fn execute_self_check(
17    plugin_bytes: &[u8],
18    supported_slots: &[PluginSlot],
19) -> Result<SelfCheckResponse, SelfCheckExecutionError> {
20    let mut plugin =
21        build_plugin(plugin_bytes, SELF_CHECK_WALL_MS, SELF_CHECK_FUEL).map_err(|source| {
22            match source {
23                BuildPluginError::Instantiate { reason } => {
24                    SelfCheckExecutionError::Instantiate { reason }
25                }
26            }
27        })?;
28
29    if !plugin.function_exists(SELF_CHECK_EXPORT) {
30        return Err(SelfCheckExecutionError::MissingSelfCheckExport);
31    }
32
33    if supported_slots.is_empty() {
34        let response = SelfCheckResponse {
35            status: SelfCheckStatus::Success,
36            failures: Vec::new(),
37            completed_at: unix_timestamp()?,
38        };
39        response.validate()?;
40        return Ok(response);
41    }
42
43    let request = build_request(supported_slots)?;
44    let request = serde_json::to_string(&request).map_err(|source| {
45        SelfCheckExecutionError::SerializeRequest {
46            reason: source.to_string(),
47        }
48    })?;
49    let response = plugin
50        .call::<&str, String>(SELF_CHECK_EXPORT, request.as_str())
51        .map_err(|source| SelfCheckExecutionError::Call {
52            reason: source.to_string(),
53        })?;
54
55    if response.len() > SELF_CHECK_OUTPUT_MAX_BYTES {
56        return Err(SelfCheckExecutionError::OutputTooLarge {
57            bytes: response.len(),
58            max: SELF_CHECK_OUTPUT_MAX_BYTES,
59        });
60    }
61
62    let response: SelfCheckResponse = serde_json::from_str(&response).map_err(|source| {
63        SelfCheckExecutionError::DecodeResponse {
64            reason: source.to_string(),
65        }
66    })?;
67    response.validate()?;
68    validate_status_failures(&response)?;
69
70    Ok(response)
71}
72
73fn build_request(
74    supported_slots: &[PluginSlot],
75) -> Result<SelfCheckRequest, SelfCheckExecutionError> {
76    let functions: Vec<_> = supported_slots
77        .iter()
78        .copied()
79        .map(slot_to_wire_function)
80        .map(str::to_owned)
81        .collect();
82    if functions.len() > IMPLEMENTED_FUNCTIONS_MAX {
83        return Err(SelfCheckExecutionError::Validation(
84            SelfCheckError::TooManyFunctions,
85        ));
86    }
87
88    let initiated_at = unix_timestamp()?;
89    let request = SelfCheckRequest {
90        functions_to_test: functions,
91        initiated_at,
92    };
93    if !request.functions_to_test.is_empty() {
94        request.validate()?;
95    }
96    Ok(request)
97}
98
99fn unix_timestamp() -> Result<i64, SelfCheckExecutionError> {
100    let seconds = SystemTime::now()
101        .duration_since(UNIX_EPOCH)
102        .map_err(|source| SelfCheckExecutionError::Clock {
103            reason: source.to_string(),
104        })?
105        .as_secs();
106    i64::try_from(seconds).map_err(|source| SelfCheckExecutionError::Clock {
107        reason: source.to_string(),
108    })
109}
110
111fn validate_status_failures(response: &SelfCheckResponse) -> Result<(), SelfCheckExecutionError> {
112    match response.status {
113        SelfCheckStatus::Success if !response.failures.is_empty() => {
114            Err(SelfCheckExecutionError::SuccessWithFailures {
115                count: response.failures.len(),
116            })
117        }
118        SelfCheckStatus::Failure if response.failures.is_empty() => {
119            Err(SelfCheckExecutionError::FailureWithoutFailures)
120        }
121        SelfCheckStatus::Failure => Err(SelfCheckExecutionError::FailureStatus {
122            failures: response.failures.len(),
123        }),
124        SelfCheckStatus::Success => Ok(()),
125    }
126}
127
128#[non_exhaustive]
129#[derive(Debug, Error)]
130pub enum SelfCheckExecutionError {
131    #[error("self-check validation failed: {0}")]
132    Validation(#[from] SelfCheckError),
133    #[error("self-check plugin instantiation failed: {reason}")]
134    Instantiate { reason: String },
135    #[error("plugin does not export cc_lb_self_check")]
136    MissingSelfCheckExport,
137    #[error("self-check request serialization failed: {reason}")]
138    SerializeRequest { reason: String },
139    #[error("self-check call failed: {reason}")]
140    Call { reason: String },
141    #[error("self-check output size {bytes} exceeds maximum {max}")]
142    OutputTooLarge { bytes: usize, max: usize },
143    #[error("self-check response decode failed: {reason}")]
144    DecodeResponse { reason: String },
145    #[error("self-check success response included {count} failure(s)")]
146    SuccessWithFailures { count: usize },
147    #[error("self-check failure response did not include failures")]
148    FailureWithoutFailures,
149    #[error("self-check reported failure status with {failures} failure(s)")]
150    FailureStatus { failures: usize },
151    #[error("self-check timestamp generation failed: {reason}")]
152    Clock { reason: String },
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    const TEST_SUPPORTED_SLOTS: &[PluginSlot] = &[
160        PluginSlot::Router,
161        PluginSlot::Shape,
162        PluginSlot::ObservabilityHook,
163    ];
164
165    #[test]
166    fn execute_self_check_accepts_success_response() {
167        let wasm = self_check_module(
168            r#"{"status":"success","failures":[],"completed_at":1}"#,
169            false,
170        );
171
172        let response =
173            execute_self_check(&wasm, TEST_SUPPORTED_SLOTS).expect("self-check succeeds");
174
175        assert_eq!(response.status, SelfCheckStatus::Success);
176        assert!(response.failures.is_empty());
177    }
178
179    #[test]
180    fn execute_self_check_rejects_failure_status() {
181        let wasm = self_check_module(
182            r#"{"status":"failure","failures":[{"stage":"wire_function_test","message":"bad wire shape"}],"completed_at":1}"#,
183            false,
184        );
185
186        let err = execute_self_check(&wasm, TEST_SUPPORTED_SLOTS)
187            .expect_err("failure status rejected at executor level");
188
189        match err {
190            SelfCheckExecutionError::FailureStatus { failures } => assert_eq!(failures, 1),
191            other => panic!("expected failure status rejection, got {other:?}"),
192        }
193    }
194
195    #[test]
196    fn execute_self_check_rejects_missing_export() {
197        let wasm = wat::parse_str(r#"(module (func (export "shape") (result i32) (i32.const 0)))"#)
198            .expect("wat parses");
199
200        let err =
201            execute_self_check(&wasm, TEST_SUPPORTED_SLOTS).expect_err("missing export rejected");
202
203        match err {
204            SelfCheckExecutionError::MissingSelfCheckExport => {}
205            other => panic!("expected missing export, got {other:?}"),
206        }
207    }
208
209    #[test]
210    fn execute_self_check_rejects_user_host_imports() {
211        let wasm = self_check_module(
212            r#"{"status":"success","failures":[],"completed_at":1}"#,
213            true,
214        );
215
216        let err =
217            execute_self_check(&wasm, TEST_SUPPORTED_SLOTS).expect_err("host import rejected");
218
219        match err {
220            SelfCheckExecutionError::Instantiate { .. } | SelfCheckExecutionError::Call { .. } => {}
221            other => panic!("expected purity failure, got {other:?}"),
222        }
223    }
224
225    #[test]
226    fn execute_self_check_rejects_oversized_output() {
227        let output = "x".repeat(SELF_CHECK_OUTPUT_MAX_BYTES + 1);
228        let wasm = self_check_module(&output, false);
229
230        let err = execute_self_check(&wasm, TEST_SUPPORTED_SLOTS)
231            .expect_err("oversized response rejected");
232
233        match err {
234            SelfCheckExecutionError::OutputTooLarge { bytes, max } => {
235                assert_eq!(bytes, SELF_CHECK_OUTPUT_MAX_BYTES + 1);
236                assert_eq!(max, SELF_CHECK_OUTPUT_MAX_BYTES);
237            }
238            other => panic!("expected oversized output, got {other:?}"),
239        }
240    }
241
242    #[test]
243    fn execute_self_check_rejects_success_with_failures() {
244        let wasm = self_check_module(
245            r#"{"status":"success","failures":[{"stage":"wire_function_test","message":"bad"}],"completed_at":1}"#,
246            false,
247        );
248
249        let err = execute_self_check(&wasm, TEST_SUPPORTED_SLOTS)
250            .expect_err("status/failures mismatch rejected");
251
252        match err {
253            SelfCheckExecutionError::SuccessWithFailures { count } => assert_eq!(count, 1),
254            other => panic!("expected status/failures mismatch, got {other:?}"),
255        }
256    }
257
258    #[test]
259    fn execute_self_check_rejects_failure_without_failures() {
260        let wasm = self_check_module(
261            r#"{"status":"failure","failures":[],"completed_at":1}"#,
262            false,
263        );
264
265        let err = execute_self_check(&wasm, TEST_SUPPORTED_SLOTS)
266            .expect_err("status/failures mismatch rejected");
267
268        match err {
269            SelfCheckExecutionError::FailureWithoutFailures => {}
270            other => panic!("expected status/failures mismatch, got {other:?}"),
271        }
272    }
273
274    #[test]
275    fn execute_self_check_rejects_response_validation_errors() {
276        let wasm = self_check_module(
277            r#"{"status":"success","failures":[],"completed_at":0}"#,
278            false,
279        );
280
281        let err =
282            execute_self_check(&wasm, TEST_SUPPORTED_SLOTS).expect_err("invalid response rejected");
283
284        match err {
285            SelfCheckExecutionError::Validation(SelfCheckError::InvalidTimestamp(_)) => {}
286            other => panic!("expected response validation error, got {other:?}"),
287        }
288    }
289
290    fn self_check_module(output: &str, import_user_host: bool) -> Vec<u8> {
291        let output_helper = bytes_helper("self_check_out", output.as_bytes());
292        let user_import = if import_user_host {
293            r#"(import "extism:host/user" "cc_lb_log" (func $cc_lb_log (param i64 i64)))"#
294        } else {
295            ""
296        };
297        let user_call = if import_user_host {
298            "  (call $cc_lb_log (call $self_check_out) (call $self_check_out))"
299        } else {
300            ""
301        };
302
303        let wat = format!(
304            r#"
305(module
306  (import "extism:host/env" "alloc" (func $alloc (param i64) (result i64)))
307  (import "extism:host/env" "store_u8" (func $store_u8 (param i64 i32)))
308  (import "extism:host/env" "output_set" (func $output_set (param i64 i64)))
309  {user_import}
310  {output_helper}
311  (func (export "cc_lb_self_check") (result i32)
312{user_call}
313    (call $output_set (call $self_check_out) (i64.const {len}))
314    (i32.const 0))
315)
316"#,
317            len = output.len()
318        );
319        wat::parse_str(&wat).expect("self-check wat parses")
320    }
321
322    fn bytes_helper(name: &str, bytes: &[u8]) -> String {
323        let mut stores = String::new();
324        for (index, byte) in bytes.iter().enumerate() {
325            stores.push_str(&format!(
326                "  (call $store_u8 (i64.add (local.get $ptr) (i64.const {index})) (i32.const {byte}))\n"
327            ));
328        }
329        format!(
330            r#"
331(func ${name} (result i64)
332  (local $ptr i64)
333  (local.set $ptr (call $alloc (i64.const {len})))
334{stores}  (local.get $ptr))
335"#,
336            len = bytes.len()
337        )
338    }
339}