Skip to main content

cc_lb_plugin_conformance/
handshake.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use cc_lb_plugin_wire::handshake::HandshakeError as WireHandshakeError;
4use cc_lb_plugin_wire::identity::IdentityError as WireIdentityError;
5use cc_lb_runtime_protocol::handshake::{HandshakeExecutionError, build_offer, execute_handshake};
6use cc_lb_runtime_protocol::identity::{IdentityReadError, read_identity};
7use thiserror::Error;
8
9pub fn run(wasm: &[u8]) -> Result<HandshakeReport, HandshakeError> {
10    run_with_caps(wasm, &BTreeSet::new())
11}
12
13pub fn run_with_caps(
14    wasm: &[u8],
15    host_capabilities: &BTreeSet<String>,
16) -> Result<HandshakeReport, HandshakeError> {
17    let offer = build_offer(host_capabilities);
18    let accept = execute_handshake(wasm, &offer).map_err(HandshakeError::from_protocol)?;
19    let identity = read_identity(wasm).map_err(HandshakeError::from_identity_error)?;
20
21    Ok(HandshakeReport {
22        chosen_versions: accept.chosen_versions,
23        envelope_version: accept.envelope_version,
24        identity: PluginIdentity {
25            name: identity.plugin_name,
26            version: identity.plugin_version,
27            abi_envelope: identity.abi_envelope,
28        },
29        negotiated_capabilities: accept.required_capabilities,
30    })
31}
32
33#[non_exhaustive]
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct HandshakeReport {
36    pub chosen_versions: BTreeMap<String, u32>,
37    pub envelope_version: u32,
38    pub identity: PluginIdentity,
39    pub negotiated_capabilities: BTreeSet<String>,
40}
41
42#[non_exhaustive]
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct PluginIdentity {
45    pub name: String,
46    pub version: String,
47    pub abi_envelope: u32,
48}
49
50#[non_exhaustive]
51#[derive(Debug, Clone, Error, PartialEq, Eq)]
52pub enum HandshakeError {
53    #[error("plugin instantiation failed: {reason}")]
54    Instantiate { reason: String },
55    #[error("plugin does not export cc_lb_handshake")]
56    MissingExport,
57    #[error("invalid identity field {field}: {reason}")]
58    InvalidIdentity { field: &'static str, reason: String },
59    #[error("declared function missing wasm export: {name}")]
60    FunctionMissing { name: String },
61    #[error(
62        "downgrade attempt for function {function}: host offered max {host_offered_max}, plugin chose {plugin_chose}"
63    )]
64    DowngradeAttempt {
65        function: String,
66        host_offered_max: u32,
67        plugin_chose: u32,
68    },
69    #[error("missing host capability: {name}")]
70    MissingCapability { name: String },
71    #[error("handshake serialization failed: {reason}")]
72    Serialization { reason: String },
73    #[error("wasm execution failed: {reason}")]
74    WasmTrap { reason: String },
75}
76
77impl HandshakeError {
78    pub(crate) fn from_protocol(error: HandshakeExecutionError) -> Self {
79        let reason = error.to_string();
80        match error {
81            HandshakeExecutionError::Validation(error) => map_wire_handshake_error(error, reason),
82            HandshakeExecutionError::Instantiate { reason } => {
83                HandshakeError::Instantiate { reason }
84            }
85            HandshakeExecutionError::MissingHandshakeExport => HandshakeError::MissingExport,
86            HandshakeExecutionError::SerializeOffer { reason }
87            | HandshakeExecutionError::DecodeAccept { reason } => {
88                HandshakeError::Serialization { reason }
89            }
90            HandshakeExecutionError::Call { reason } => map_call_error(reason),
91            HandshakeExecutionError::Timeout => HandshakeError::WasmTrap { reason },
92            HandshakeExecutionError::OutputTooLarge { .. } => {
93                HandshakeError::Serialization { reason }
94            }
95            HandshakeExecutionError::ImplementedFunctionCountExceeded { .. } => {
96                HandshakeError::InvalidIdentity {
97                    field: "implemented_functions",
98                    reason,
99                }
100            }
101            HandshakeExecutionError::ImplementedUnknownFunction { .. } => {
102                HandshakeError::InvalidIdentity {
103                    field: "implemented_functions",
104                    reason,
105                }
106            }
107            HandshakeExecutionError::SupportedUnknownFunction { .. }
108            | HandshakeExecutionError::SupportedVersionOutOfRange { .. } => {
109                HandshakeError::InvalidIdentity {
110                    field: "plugin_supported",
111                    reason,
112                }
113            }
114            HandshakeExecutionError::ChosenFunctionNotImplemented { function } => {
115                HandshakeError::FunctionMissing { name: function }
116            }
117            HandshakeExecutionError::ChosenVersionNotSupported { .. } => {
118                HandshakeError::InvalidIdentity {
119                    field: "chosen_versions",
120                    reason,
121                }
122            }
123            HandshakeExecutionError::DeclaredFunctionMissing { function } => {
124                HandshakeError::FunctionMissing { name: function }
125            }
126            HandshakeExecutionError::UndeclaredExport { .. } => HandshakeError::InvalidIdentity {
127                field: "implemented_functions",
128                reason,
129            },
130            _ => unreachable!(),
131        }
132    }
133
134    pub(crate) fn from_identity_error(error: IdentityReadError) -> Self {
135        let field = match &error {
136            IdentityReadError::WasmParseError(_) => "wasm",
137            IdentityReadError::MissingCustomSection
138            | IdentityReadError::DuplicateCustomSection
139            | IdentityReadError::SectionTooLarge { .. } => "custom_section",
140            IdentityReadError::MagicMismatch { .. } => "magic",
141            IdentityReadError::MalformedPayload(_) => "payload",
142            IdentityReadError::Validation(WireIdentityError::MagicMismatch) => "magic",
143            IdentityReadError::Validation(
144                WireIdentityError::PluginNameEmpty | WireIdentityError::PluginNameInvalid,
145            ) => "name",
146            IdentityReadError::Validation(
147                WireIdentityError::PluginVersionEmpty | WireIdentityError::PluginVersionTooLong,
148            ) => "version",
149            _ => unreachable!(),
150        };
151
152        HandshakeError::InvalidIdentity {
153            field,
154            reason: error.to_string(),
155        }
156    }
157}
158
159fn map_call_error(reason: String) -> HandshakeError {
160    if let Some(name) = reason
161        .strip_prefix("required capability '")
162        .and_then(|rest| rest.strip_suffix("' not available in host capabilities"))
163    {
164        return HandshakeError::MissingCapability {
165            name: name.to_owned(),
166        };
167    }
168
169    HandshakeError::WasmTrap { reason }
170}
171
172fn map_wire_handshake_error(error: WireHandshakeError, reason: String) -> HandshakeError {
173    match error {
174        WireHandshakeError::FunctionCountExceeded { .. }
175        | WireHandshakeError::FunctionVersionCountExceeded { .. }
176        | WireHandshakeError::VersionOutOfRange { .. } => HandshakeError::InvalidIdentity {
177            field: "function_versions",
178            reason,
179        },
180        WireHandshakeError::CapabilityCountExceeded { .. } => HandshakeError::InvalidIdentity {
181            field: "host_capabilities",
182            reason,
183        },
184        WireHandshakeError::ChosenVersionNotOffered { .. }
185        | WireHandshakeError::ChosenForUnknownFunction { .. } => HandshakeError::InvalidIdentity {
186            field: "chosen_versions",
187            reason,
188        },
189        WireHandshakeError::DowngradeAttempt {
190            function,
191            chosen,
192            max_intersection,
193        } => HandshakeError::DowngradeAttempt {
194            function,
195            host_offered_max: max_intersection,
196            plugin_chose: chosen,
197        },
198        WireHandshakeError::RequiredCapabilityUnavailable { capability } => {
199            HandshakeError::MissingCapability { name: capability }
200        }
201        WireHandshakeError::HandshakeSchemaVersionMismatch { .. } => {
202            HandshakeError::InvalidIdentity {
203                field: "handshake_schema_version",
204                reason,
205            }
206        }
207        WireHandshakeError::Canonical(_) => HandshakeError::Serialization { reason },
208    }
209}