Skip to main content

cc_lb_plugin_conformance/
identity.rs

1use std::collections::BTreeSet;
2
3use cc_lb_plugin_wire::identity::{CC_LB_PLUGIN_MAGIC, CC_LB_PLUGIN_SECTION_NAME, PluginIdentity};
4use cc_lb_plugin_wire::limits;
5use serde_json::Value;
6use thiserror::Error;
7use wasmparser::{ExternalKind, Parser, Payload};
8
9const ABI_ENVELOPE_VERSION_V1: u64 = 1;
10const REQUIRED_LIFECYCLE_EXPORTS: [&str; 2] = ["cc_lb_handshake", "cc_lb_self_check"];
11const IDENTITY_FIELDS: [&str; 4] = ["abi_envelope", "magic", "plugin_name", "plugin_version"];
12
13pub fn read(wasm: &[u8]) -> Result<IdentityReport, IdentityError> {
14    let static_checks = run_static_checks(wasm);
15    let identity = cc_lb_runtime_protocol::identity::read_identity(wasm)
16        .map_err(|err| IdentityError::Read(IdentityReadError::new(err)))?;
17
18    Ok(IdentityReport {
19        identity,
20        static_checks,
21    })
22}
23
24#[non_exhaustive]
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct IdentityReport {
27    pub identity: PluginIdentity,
28    pub static_checks: Vec<StaticCheck>,
29}
30
31#[non_exhaustive]
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum StaticCheck {
34    CustomSectionExactlyOnce { pass: bool, detail: Option<String> },
35    CustomSectionSizeWithinLimit { pass: bool, detail: Option<String> },
36    JsonFieldsExactlyFour { pass: bool, detail: Option<String> },
37    IdentityFieldsWellFormed { pass: bool, detail: Option<String> },
38    PluginNameRegexCompliant { pass: bool, detail: Option<String> },
39    PluginVersionLengthCompliant { pass: bool, detail: Option<String> },
40    AllRequiredExportsPresent { pass: bool, detail: Option<String> },
41    ExtismCanInstantiate { pass: bool, detail: Option<String> },
42    NoWasiImports { pass: bool, detail: Option<String> },
43}
44
45#[derive(Debug, Error)]
46#[error(transparent)]
47pub struct IdentityReadError(pub(crate) cc_lb_runtime_protocol::identity::IdentityReadError);
48
49impl IdentityReadError {
50    pub(crate) fn new(err: cc_lb_runtime_protocol::identity::IdentityReadError) -> Self {
51        Self(err)
52    }
53}
54
55#[non_exhaustive]
56#[derive(Debug, Error)]
57pub enum IdentityError {
58    #[error(transparent)]
59    Read(#[from] IdentityReadError),
60}
61
62#[derive(Default)]
63struct WasmFacts {
64    identity_sections: Vec<Vec<u8>>,
65    function_exports: BTreeSet<String>,
66    import_modules: BTreeSet<String>,
67    parse_error: Option<String>,
68}
69
70impl WasmFacts {
71    fn parse_error(&self) -> Result<(), String> {
72        match &self.parse_error {
73            Some(error) => Err(format!("invalid wasm: {error}")),
74            None => Ok(()),
75        }
76    }
77
78    fn single_identity_payload(&self) -> Result<&[u8], String> {
79        self.parse_error()?;
80        match self.identity_sections.as_slice() {
81            [payload] => Ok(payload.as_slice()),
82            [] => Err(format!(
83                "missing {CC_LB_PLUGIN_SECTION_NAME} custom section"
84            )),
85            sections => Err(format!(
86                "{CC_LB_PLUGIN_SECTION_NAME} custom section appears {} times",
87                sections.len()
88            )),
89        }
90    }
91}
92
93fn run_static_checks(wasm: &[u8]) -> Vec<StaticCheck> {
94    let facts = collect_wasm_facts(wasm);
95    let checks = [
96        custom_section_exactly_once(&facts),
97        custom_section_size_within_limit(&facts),
98        json_fields_exactly_four(&facts),
99        identity_fields_well_formed(&facts),
100        plugin_name_regex_compliant(&facts),
101        plugin_version_length_compliant(&facts),
102        all_required_exports_present(&facts),
103        extism_can_instantiate(wasm),
104        no_wasi_imports(&facts),
105    ];
106
107    checks.into_iter().collect()
108}
109
110fn collect_wasm_facts(wasm: &[u8]) -> WasmFacts {
111    let mut facts = WasmFacts::default();
112
113    for payload in Parser::new(0).parse_all(wasm) {
114        match payload {
115            Ok(Payload::CustomSection(section)) => {
116                if section.name() == CC_LB_PLUGIN_SECTION_NAME {
117                    facts.identity_sections.push(section.data().to_vec());
118                }
119            }
120            Ok(Payload::ExportSection(section)) => {
121                for export in section {
122                    match export {
123                        Ok(export) if export.kind == ExternalKind::Func => {
124                            facts.function_exports.insert(export.name.to_owned());
125                        }
126                        Ok(_) => {}
127                        Err(error) => {
128                            facts.parse_error = Some(error.to_string());
129                            return facts;
130                        }
131                    }
132                }
133            }
134            Ok(Payload::ImportSection(section)) => {
135                for import in section.into_imports() {
136                    match import {
137                        Ok(import) => {
138                            facts.import_modules.insert(import.module.to_owned());
139                        }
140                        Err(error) => {
141                            facts.parse_error = Some(error.to_string());
142                            return facts;
143                        }
144                    }
145                }
146            }
147            Ok(_) => {}
148            Err(error) => {
149                facts.parse_error = Some(error.to_string());
150                return facts;
151            }
152        }
153    }
154
155    facts
156}
157
158fn custom_section_exactly_once(facts: &WasmFacts) -> StaticCheck {
159    let result = facts.parse_error().and_then(|()| {
160        if facts.identity_sections.len() == 1 {
161            Ok(())
162        } else {
163            Err(format!(
164                "expected exactly one {CC_LB_PLUGIN_SECTION_NAME} custom section, found {}",
165                facts.identity_sections.len()
166            ))
167        }
168    });
169    let (pass, detail) = static_check_result(result);
170    StaticCheck::CustomSectionExactlyOnce { pass, detail }
171}
172
173fn custom_section_size_within_limit(facts: &WasmFacts) -> StaticCheck {
174    let result = facts.parse_error().and_then(|()| {
175        if facts.identity_sections.is_empty() {
176            return Err(format!(
177                "missing {CC_LB_PLUGIN_SECTION_NAME} custom section"
178            ));
179        }
180        let oversized = facts
181            .identity_sections
182            .iter()
183            .map(Vec::len)
184            .filter(|size| *size > limits::CUSTOM_SECTION_MAX_SIZE)
185            .collect::<Vec<_>>();
186        if oversized.is_empty() {
187            Ok(())
188        } else {
189            Err(format!(
190                "custom section payload size(s) {oversized:?} exceed max {} bytes",
191                limits::CUSTOM_SECTION_MAX_SIZE
192            ))
193        }
194    });
195    let (pass, detail) = static_check_result(result);
196    StaticCheck::CustomSectionSizeWithinLimit { pass, detail }
197}
198
199fn json_fields_exactly_four(facts: &WasmFacts) -> StaticCheck {
200    let result = section_json_value(facts).and_then(|value| {
201        let object = value
202            .as_object()
203            .ok_or_else(|| "identity payload top level is not a JSON object".to_owned())?;
204        let mut keys = object.keys().map(String::as_str).collect::<Vec<_>>();
205        keys.sort_unstable();
206        if keys == IDENTITY_FIELDS && object.len() == limits::CUSTOM_SECTION_FIELD_COUNT {
207            Ok(())
208        } else {
209            Err(format!(
210                "identity JSON fields must be exactly {IDENTITY_FIELDS:?}; found {keys:?}"
211            ))
212        }
213    });
214    let (pass, detail) = static_check_result(result);
215    StaticCheck::JsonFieldsExactlyFour { pass, detail }
216}
217
218fn identity_fields_well_formed(facts: &WasmFacts) -> StaticCheck {
219    let result = section_json_value(facts).and_then(|value| {
220        let mut errors = Vec::new();
221
222        match magic_field(&value) {
223            Ok(magic) if magic == CC_LB_PLUGIN_MAGIC => {}
224            Ok(magic) => errors.push(format!(
225                "magic must equal {CC_LB_PLUGIN_MAGIC:?}; found {magic:?}"
226            )),
227            Err(error) => errors.push(error),
228        }
229
230        match value.get("abi_envelope").and_then(Value::as_u64) {
231            Some(ABI_ENVELOPE_VERSION_V1) => {}
232            Some(found) => errors.push(format!(
233                "abi_envelope must be {ABI_ENVELOPE_VERSION_V1}; found {found}"
234            )),
235            None => errors.push("abi_envelope must be an unsigned integer".to_owned()),
236        }
237
238        if let Err(error) = string_field(&value, "plugin_name") {
239            errors.push(error);
240        }
241        if let Err(error) = string_field(&value, "plugin_version") {
242            errors.push(error);
243        }
244
245        if errors.is_empty() {
246            Ok(())
247        } else {
248            Err(errors.join("; "))
249        }
250    });
251    let (pass, detail) = static_check_result(result);
252    StaticCheck::IdentityFieldsWellFormed { pass, detail }
253}
254
255fn plugin_name_regex_compliant(facts: &WasmFacts) -> StaticCheck {
256    let result = section_json_value(facts).and_then(|value| {
257        let name = string_field(&value, "plugin_name")?;
258        if plugin_name_matches_wire_pattern(name) {
259            Ok(())
260        } else {
261            Err(format!(
262                "plugin_name {name:?} must match {} and be at most {} bytes",
263                limits::PLUGIN_NAME_PATTERN,
264                limits::PLUGIN_NAME_MAX_BYTES
265            ))
266        }
267    });
268    let (pass, detail) = static_check_result(result);
269    StaticCheck::PluginNameRegexCompliant { pass, detail }
270}
271
272fn plugin_version_length_compliant(facts: &WasmFacts) -> StaticCheck {
273    let result = section_json_value(facts).and_then(|value| {
274        let version = string_field(&value, "plugin_version")?;
275        let byte_len = version.len();
276        if !version.is_empty() && byte_len <= limits::PLUGIN_VERSION_MAX_BYTES {
277            Ok(())
278        } else {
279            Err(format!(
280                "plugin_version must be non-empty and at most {} bytes; found {byte_len} bytes",
281                limits::PLUGIN_VERSION_MAX_BYTES
282            ))
283        }
284    });
285    let (pass, detail) = static_check_result(result);
286    StaticCheck::PluginVersionLengthCompliant { pass, detail }
287}
288
289fn all_required_exports_present(facts: &WasmFacts) -> StaticCheck {
290    let result = facts.parse_error().and_then(|()| {
291        let missing = REQUIRED_LIFECYCLE_EXPORTS
292            .iter()
293            .filter(|name| !facts.function_exports.contains(**name))
294            .copied()
295            .collect::<Vec<_>>();
296        if missing.is_empty() {
297            Ok(())
298        } else {
299            Err(format!("missing wasm function export(s): {missing:?}"))
300        }
301    });
302    let (pass, detail) = static_check_result(result);
303    StaticCheck::AllRequiredExportsPresent { pass, detail }
304}
305
306fn extism_can_instantiate(wasm: &[u8]) -> StaticCheck {
307    let result = cc_lb_runtime_protocol::build_plugin(
308        wasm,
309        limits::HANDSHAKE_WALL_MS,
310        limits::HANDSHAKE_FUEL,
311    )
312    .map(|_| ())
313    .map_err(|error| error.to_string());
314    let (pass, detail) = static_check_result(result);
315    StaticCheck::ExtismCanInstantiate { pass, detail }
316}
317
318fn no_wasi_imports(facts: &WasmFacts) -> StaticCheck {
319    let result = facts.parse_error().and_then(|()| {
320        let forbidden = facts
321            .import_modules
322            .iter()
323            .filter(|module| is_wasi_module(module))
324            .map(String::as_str)
325            .collect::<Vec<_>>();
326        if forbidden.is_empty() {
327            Ok(())
328        } else {
329            Err(format!("forbidden WASI import module(s): {forbidden:?}"))
330        }
331    });
332    let (pass, detail) = static_check_result(result);
333    StaticCheck::NoWasiImports { pass, detail }
334}
335
336fn static_check_result(result: Result<(), String>) -> (bool, Option<String>) {
337    match result {
338        Ok(()) => (true, None),
339        Err(detail) => (false, Some(detail)),
340    }
341}
342
343fn section_json_value(facts: &WasmFacts) -> Result<Value, String> {
344    serde_json::from_slice(facts.single_identity_payload()?)
345        .map_err(|error| format!("malformed identity JSON: {error}"))
346}
347
348fn magic_field(value: &Value) -> Result<[u8; 8], String> {
349    let array = value
350        .get("magic")
351        .and_then(Value::as_array)
352        .ok_or_else(|| "magic must be an array of 8 bytes".to_owned())?;
353    if array.len() != CC_LB_PLUGIN_MAGIC.len() {
354        return Err(format!(
355            "magic must contain {} bytes; found {}",
356            CC_LB_PLUGIN_MAGIC.len(),
357            array.len()
358        ));
359    }
360
361    let mut magic = [0u8; 8];
362    for (index, byte) in array.iter().enumerate() {
363        let value = byte
364            .as_u64()
365            .ok_or_else(|| format!("magic[{index}] must be an integer byte"))?;
366        if value > u8::MAX as u64 {
367            return Err(format!("magic[{index}] exceeds u8 max: {value}"));
368        }
369        magic[index] = value as u8;
370    }
371    Ok(magic)
372}
373
374fn string_field<'a>(value: &'a Value, field: &'static str) -> Result<&'a str, String> {
375    value
376        .get(field)
377        .and_then(Value::as_str)
378        .ok_or_else(|| format!("{field} must be a string"))
379}
380
381fn plugin_name_matches_wire_pattern(name: &str) -> bool {
382    let bytes = name.as_bytes();
383    !bytes.is_empty()
384        && bytes.len() <= limits::PLUGIN_NAME_MAX_BYTES
385        && bytes[0].is_ascii_lowercase()
386        && bytes.iter().all(|byte| {
387            byte.is_ascii_lowercase() || byte.is_ascii_digit() || *byte == b'_' || *byte == b'-'
388        })
389}
390
391fn is_wasi_module(module: &str) -> bool {
392    module == "wasi_snapshot_preview1"
393        || module == "wasi_snapshot_preview2"
394        || module == "wasi_unstable"
395        || module.starts_with("wasi:")
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use serde_json::json;
402
403    #[test]
404    fn reads_identity_and_all_static_checks_pass_for_synthetic_module() {
405        let wasm = wasm_with_identity_payload(valid_identity_payload("test-plugin").as_bytes());
406
407        let report = read(&wasm).expect("synthetic module identity reads");
408
409        assert_eq!(report.identity.magic, CC_LB_PLUGIN_MAGIC);
410        assert_eq!(report.identity.abi_envelope, 1);
411        assert_eq!(report.identity.plugin_name, "test-plugin");
412        assert_eq!(report.identity.plugin_version, "1.0.0");
413        assert_eq!(report.static_checks.len(), 9);
414        assert_all_variants_present(&report.static_checks);
415        assert!(report.static_checks.iter().all(check_passed));
416    }
417
418    #[test]
419    fn malformed_json_fails_json_check_and_read() {
420        let wasm = wasm_with_identity_payload(b"{");
421
422        let error = read(&wasm).expect_err("malformed JSON is rejected");
423        assert!(format!("{error}").contains("malformed"));
424
425        let checks = run_static_checks(&wasm);
426        let check = find_check(&checks, |check| {
427            matches!(check, StaticCheck::JsonFieldsExactlyFour { .. })
428        });
429        assert!(!check_passed(check));
430    }
431
432    #[test]
433    fn missing_section_fails_custom_section_check_and_read() {
434        let wasm = wasm_with_exports_and_optional_identity(None, &REQUIRED_EXPORTS_FOR_TESTS);
435
436        let error = read(&wasm).expect_err("missing custom section is rejected");
437        assert!(format!("{error}").contains("missing"));
438
439        let checks = run_static_checks(&wasm);
440        let check = find_check(&checks, |check| {
441            matches!(check, StaticCheck::CustomSectionExactlyOnce { .. })
442        });
443        assert!(!check_passed(check));
444    }
445
446    #[test]
447    fn oversized_name_fails_name_check_and_read() {
448        let oversized_name = format!("a{}", "b".repeat(limits::PLUGIN_NAME_MAX_BYTES));
449        let payload = valid_identity_payload(&oversized_name);
450        let wasm = wasm_with_identity_payload(payload.as_bytes());
451
452        let error = read(&wasm).expect_err("oversized name is rejected");
453        assert!(format!("{error}").contains("invalid plugin identity"));
454
455        let checks = run_static_checks(&wasm);
456        let check = find_check(&checks, |check| {
457            matches!(check, StaticCheck::PluginNameRegexCompliant { .. })
458        });
459        assert!(!check_passed(check));
460    }
461
462    const REQUIRED_EXPORTS_FOR_TESTS: [&str; 2] = ["cc_lb_handshake", "cc_lb_self_check"];
463
464    fn valid_identity_payload(name: &str) -> String {
465        json!({
466            "magic": CC_LB_PLUGIN_MAGIC,
467            "abi_envelope": 1,
468            "plugin_name": name,
469            "plugin_version": "1.0.0",
470        })
471        .to_string()
472    }
473
474    fn wasm_with_identity_payload(payload: &[u8]) -> Vec<u8> {
475        wasm_with_exports_and_optional_identity(Some(payload), &REQUIRED_EXPORTS_FOR_TESTS)
476    }
477
478    fn wasm_with_exports_and_optional_identity(
479        identity_payload: Option<&[u8]>,
480        exports: &[&str],
481    ) -> Vec<u8> {
482        let mut wasm = Vec::from([0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00]);
483        if let Some(payload) = identity_payload {
484            push_custom_section(&mut wasm, CC_LB_PLUGIN_SECTION_NAME, payload);
485        }
486        if !exports.is_empty() {
487            push_type_section(&mut wasm);
488            push_function_section(&mut wasm, exports.len());
489            push_export_section(&mut wasm, exports);
490            push_code_section(&mut wasm, exports.len());
491        }
492        wasm
493    }
494
495    fn push_custom_section(wasm: &mut Vec<u8>, name: &str, payload: &[u8]) {
496        let mut section = Vec::new();
497        push_name(&mut section, name);
498        section.extend_from_slice(payload);
499        push_section(wasm, 0, &section);
500    }
501
502    fn push_type_section(wasm: &mut Vec<u8>) {
503        let mut section = Vec::new();
504        encode_u32(1, &mut section);
505        section.push(0x60);
506        encode_u32(0, &mut section);
507        encode_u32(0, &mut section);
508        push_section(wasm, 1, &section);
509    }
510
511    fn push_function_section(wasm: &mut Vec<u8>, function_count: usize) {
512        let mut section = Vec::new();
513        encode_u32(function_count as u32, &mut section);
514        section.extend(std::iter::repeat_n(0, function_count));
515        push_section(wasm, 3, &section);
516    }
517
518    fn push_export_section(wasm: &mut Vec<u8>, exports: &[&str]) {
519        let mut section = Vec::new();
520        encode_u32(exports.len() as u32, &mut section);
521        for (index, export) in exports.iter().enumerate() {
522            push_name(&mut section, export);
523            section.push(0x00);
524            encode_u32(index as u32, &mut section);
525        }
526        push_section(wasm, 7, &section);
527    }
528
529    fn push_code_section(wasm: &mut Vec<u8>, function_count: usize) {
530        let mut section = Vec::new();
531        encode_u32(function_count as u32, &mut section);
532        for _ in 0..function_count {
533            encode_u32(2, &mut section);
534            section.push(0x00);
535            section.push(0x0b);
536        }
537        push_section(wasm, 10, &section);
538    }
539
540    fn push_name(output: &mut Vec<u8>, name: &str) {
541        encode_u32(name.len() as u32, output);
542        output.extend_from_slice(name.as_bytes());
543    }
544
545    fn push_section(wasm: &mut Vec<u8>, id: u8, payload: &[u8]) {
546        wasm.push(id);
547        encode_u32(payload.len() as u32, wasm);
548        wasm.extend_from_slice(payload);
549    }
550
551    fn encode_u32(mut value: u32, output: &mut Vec<u8>) {
552        loop {
553            let mut byte = (value & 0x7f) as u8;
554            value >>= 7;
555            if value != 0 {
556                byte |= 0x80;
557            }
558            output.push(byte);
559            if value == 0 {
560                break;
561            }
562        }
563    }
564
565    fn check_passed(check: &StaticCheck) -> bool {
566        match check {
567            StaticCheck::CustomSectionExactlyOnce { pass, .. }
568            | StaticCheck::CustomSectionSizeWithinLimit { pass, .. }
569            | StaticCheck::JsonFieldsExactlyFour { pass, .. }
570            | StaticCheck::IdentityFieldsWellFormed { pass, .. }
571            | StaticCheck::PluginNameRegexCompliant { pass, .. }
572            | StaticCheck::PluginVersionLengthCompliant { pass, .. }
573            | StaticCheck::AllRequiredExportsPresent { pass, .. }
574            | StaticCheck::ExtismCanInstantiate { pass, .. }
575            | StaticCheck::NoWasiImports { pass, .. } => *pass,
576        }
577    }
578
579    fn find_check<F>(checks: &[StaticCheck], predicate: F) -> &StaticCheck
580    where
581        F: Fn(&StaticCheck) -> bool,
582    {
583        checks
584            .iter()
585            .find(|check| predicate(check))
586            .expect("check exists")
587    }
588
589    fn assert_all_variants_present(checks: &[StaticCheck]) {
590        assert!(
591            checks
592                .iter()
593                .any(|check| matches!(check, StaticCheck::CustomSectionExactlyOnce { .. }))
594        );
595        assert!(
596            checks
597                .iter()
598                .any(|check| matches!(check, StaticCheck::CustomSectionSizeWithinLimit { .. }))
599        );
600        assert!(
601            checks
602                .iter()
603                .any(|check| matches!(check, StaticCheck::JsonFieldsExactlyFour { .. }))
604        );
605        assert!(
606            checks
607                .iter()
608                .any(|check| matches!(check, StaticCheck::IdentityFieldsWellFormed { .. }))
609        );
610        assert!(
611            checks
612                .iter()
613                .any(|check| matches!(check, StaticCheck::PluginNameRegexCompliant { .. }))
614        );
615        assert!(
616            checks
617                .iter()
618                .any(|check| matches!(check, StaticCheck::PluginVersionLengthCompliant { .. }))
619        );
620        assert!(
621            checks
622                .iter()
623                .any(|check| matches!(check, StaticCheck::AllRequiredExportsPresent { .. }))
624        );
625        assert!(
626            checks
627                .iter()
628                .any(|check| matches!(check, StaticCheck::ExtismCanInstantiate { .. }))
629        );
630        assert!(
631            checks
632                .iter()
633                .any(|check| matches!(check, StaticCheck::NoWasiImports { .. }))
634        );
635    }
636}