Skip to main content

cc_lb_plugin_wire/
wire_function.rs

1use serde::{Deserialize, Serialize};
2
3/// Fallback behavior when a wire function fails at runtime.
4#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
5#[serde(rename_all = "snake_case")]
6pub enum FallbackPolicy {
7    /// Fail the entire request (used for critical functions like signing).
8    FailRequest,
9    /// Silently skip the function call (used for observability hooks).
10    SilentSkip,
11    /// Use the default/fallback response.
12    UseDefault,
13    /// Pass through to the next handler (used for error normalization).
14    PassThrough,
15}
16
17/// Type-level specification for a wire function exported by a plugin.
18pub trait WireFunction: 'static {
19    /// Name of the wire function (e.g., "sign", "observe").
20    const NAME: &'static str;
21
22    /// Compile-time fallback policy for this function. Must never be runtime-configurable
23    /// to prevent security incidents (e.g., skipping signing).
24    const FALLBACK: FallbackPolicy;
25
26    /// Wire protocol versions this function supports. Host advertises max version,
27    /// plugin chooses compatible version during handshake.
28    const SUPPORTED_VERSIONS: &'static [u32];
29
30    /// Request type for this wire function.
31    type Request: serde::Serialize + serde::de::DeserializeOwned;
32
33    /// Response type for this wire function.
34    type Response: serde::Serialize + serde::de::DeserializeOwned;
35
36    /// Dry-run request for self_check (skip-handler model: serialize/deserialize round-trip only).
37    fn dry_run_request() -> Self::Request;
38
39    /// Dry-run response for self_check (skip-handler model: serialize/deserialize round-trip only).
40    fn dry_run_response() -> Self::Response;
41}
42
43/// Plugin kind/slot type. Used by PDK macros and host-side plugin dispatch.
44/// Note: Kind is NOT embedded in the custom section (duck typing); it's inferred
45/// from declared functions or provided explicitly by PDK macro hints.
46#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum PluginKind {
49    /// Request/response dialect plugin (implements `shape`).
50    Dialect,
51    /// API key signing plugin (implements `build_signer`, `sign`).
52    SignerFactory,
53    /// Observability/logging plugin (implements `observe`).
54    Observability,
55}
56
57/// List of all valid wire function names in the system.
58pub fn all_wire_functions() -> &'static [&'static str] {
59    &[
60        "shape",
61        "normalize_error",
62        "build_signer",
63        "sign",
64        "on_unauthorized",
65        "observe",
66        "filter",
67    ]
68}
69
70/// Returns the set of wire functions REQUIRED for a plugin of the given kind.
71pub fn required_functions_for_kind(kind: PluginKind) -> &'static [&'static str] {
72    match kind {
73        PluginKind::Dialect => &["shape"],
74        PluginKind::SignerFactory => &["build_signer", "sign"],
75        PluginKind::Observability => &["observe"],
76    }
77}
78
79/// Returns the set of wire functions ALLOWED for a plugin of the given kind.
80/// A plugin may implement any subset of allowed functions; at minimum,
81/// it must implement all required functions.
82pub fn allowed_functions_for_kind(kind: PluginKind) -> &'static [&'static str] {
83    match kind {
84        PluginKind::Dialect => &["shape", "normalize_error"],
85        PluginKind::SignerFactory => &["build_signer", "sign", "on_unauthorized"],
86        PluginKind::Observability => &["observe"],
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn all_seven_functions_listed() {
96        assert_eq!(all_wire_functions().len(), 7);
97        let funcs = all_wire_functions();
98        assert!(funcs.contains(&"shape"));
99        assert!(funcs.contains(&"normalize_error"));
100        assert!(funcs.contains(&"build_signer"));
101        assert!(funcs.contains(&"sign"));
102        assert!(funcs.contains(&"on_unauthorized"));
103        assert!(funcs.contains(&"observe"));
104        assert!(funcs.contains(&"filter"));
105    }
106
107    #[test]
108    fn security_fallbacks_correct() {
109        // sign and build_signer must fail the request on error
110        // (signing cannot be skipped without breaking security)
111        // observe must silently skip
112        // normalize_error and on_unauthorized must pass through
113        // These are compile-time constants baked into each WireFunction impl.
114        // This test documents the expected mapping.
115        assert_eq!(FallbackPolicy::FailRequest, FallbackPolicy::FailRequest);
116        assert_eq!(FallbackPolicy::SilentSkip, FallbackPolicy::SilentSkip);
117        assert_eq!(FallbackPolicy::UseDefault, FallbackPolicy::UseDefault);
118        assert_eq!(FallbackPolicy::PassThrough, FallbackPolicy::PassThrough);
119    }
120
121    #[test]
122    fn required_subset_allowed_dialect() {
123        let required = required_functions_for_kind(PluginKind::Dialect);
124        let allowed = allowed_functions_for_kind(PluginKind::Dialect);
125        for req in required {
126            assert!(allowed.contains(req), "{} not in allowed", req);
127        }
128    }
129
130    #[test]
131    fn required_subset_allowed_signer_factory() {
132        let required = required_functions_for_kind(PluginKind::SignerFactory);
133        let allowed = allowed_functions_for_kind(PluginKind::SignerFactory);
134        for req in required {
135            assert!(allowed.contains(req), "{} not in allowed", req);
136        }
137    }
138
139    #[test]
140    fn required_subset_allowed_observability() {
141        let required = required_functions_for_kind(PluginKind::Observability);
142        let allowed = allowed_functions_for_kind(PluginKind::Observability);
143        for req in required {
144            assert!(allowed.contains(req), "{} not in allowed", req);
145        }
146    }
147
148    #[test]
149    fn fallback_policy_serde() {
150        let policies = [
151            FallbackPolicy::FailRequest,
152            FallbackPolicy::SilentSkip,
153            FallbackPolicy::UseDefault,
154            FallbackPolicy::PassThrough,
155        ];
156        for policy in &policies {
157            let json = serde_json::to_string(policy).unwrap();
158            let deserialized: FallbackPolicy = serde_json::from_str(&json).unwrap();
159            assert_eq!(*policy, deserialized);
160        }
161    }
162
163    #[test]
164    fn fallback_policy_snake_case_names() {
165        assert_eq!(
166            serde_json::to_string(&FallbackPolicy::FailRequest).unwrap(),
167            "\"fail_request\""
168        );
169        assert_eq!(
170            serde_json::to_string(&FallbackPolicy::SilentSkip).unwrap(),
171            "\"silent_skip\""
172        );
173        assert_eq!(
174            serde_json::to_string(&FallbackPolicy::UseDefault).unwrap(),
175            "\"use_default\""
176        );
177        assert_eq!(
178            serde_json::to_string(&FallbackPolicy::PassThrough).unwrap(),
179            "\"pass_through\""
180        );
181    }
182
183    #[test]
184    fn plugin_kind_serde() {
185        let kinds = [
186            PluginKind::Dialect,
187            PluginKind::SignerFactory,
188            PluginKind::Observability,
189        ];
190        for kind in &kinds {
191            let json = serde_json::to_string(kind).unwrap();
192            let deserialized: PluginKind = serde_json::from_str(&json).unwrap();
193            assert_eq!(*kind, deserialized);
194        }
195    }
196
197    #[test]
198    fn plugin_kind_snake_case_names() {
199        assert_eq!(
200            serde_json::to_string(&PluginKind::Dialect).unwrap(),
201            "\"dialect\""
202        );
203        assert_eq!(
204            serde_json::to_string(&PluginKind::SignerFactory).unwrap(),
205            "\"signer_factory\""
206        );
207        assert_eq!(
208            serde_json::to_string(&PluginKind::Observability).unwrap(),
209            "\"observability\""
210        );
211    }
212
213    #[test]
214    fn kind_allowed_and_required_nonempty() {
215        let kinds = [
216            PluginKind::Dialect,
217            PluginKind::SignerFactory,
218            PluginKind::Observability,
219        ];
220        for &kind in &kinds {
221            let required = required_functions_for_kind(kind);
222            let allowed = allowed_functions_for_kind(kind);
223            assert!(!required.is_empty(), "required empty for {:?}", kind);
224            assert!(!allowed.is_empty(), "allowed empty for {:?}", kind);
225            assert!(
226                required.len() <= allowed.len(),
227                "required > allowed for {:?}",
228                kind
229            );
230        }
231    }
232}