cc_lb_plugin_wire/
wire_function.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
5#[serde(rename_all = "snake_case")]
6pub enum FallbackPolicy {
7 FailRequest,
9 SilentSkip,
11 UseDefault,
13 PassThrough,
15}
16
17pub trait WireFunction: 'static {
19 const NAME: &'static str;
21
22 const FALLBACK: FallbackPolicy;
25
26 const SUPPORTED_VERSIONS: &'static [u32];
29
30 type Request: serde::Serialize + serde::de::DeserializeOwned;
32
33 type Response: serde::Serialize + serde::de::DeserializeOwned;
35
36 fn dry_run_request() -> Self::Request;
38
39 fn dry_run_response() -> Self::Response;
41}
42
43#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum PluginKind {
49 Dialect,
51 SignerFactory,
53 Observability,
55}
56
57pub fn all_wire_functions() -> &'static [&'static str] {
59 &[
60 "shape",
61 "normalize_error",
62 "build_signer",
63 "sign",
64 "on_unauthorized",
65 "observe",
66 "filter",
67 ]
68}
69
70pub fn required_functions_for_kind(kind: PluginKind) -> &'static [&'static str] {
72 match kind {
73 PluginKind::Dialect => &["shape"],
74 PluginKind::SignerFactory => &["build_signer", "sign"],
75 PluginKind::Observability => &["observe"],
76 }
77}
78
79pub fn allowed_functions_for_kind(kind: PluginKind) -> &'static [&'static str] {
83 match kind {
84 PluginKind::Dialect => &["shape", "normalize_error"],
85 PluginKind::SignerFactory => &["build_signer", "sign", "on_unauthorized"],
86 PluginKind::Observability => &["observe"],
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn all_seven_functions_listed() {
96 assert_eq!(all_wire_functions().len(), 7);
97 let funcs = all_wire_functions();
98 assert!(funcs.contains(&"shape"));
99 assert!(funcs.contains(&"normalize_error"));
100 assert!(funcs.contains(&"build_signer"));
101 assert!(funcs.contains(&"sign"));
102 assert!(funcs.contains(&"on_unauthorized"));
103 assert!(funcs.contains(&"observe"));
104 assert!(funcs.contains(&"filter"));
105 }
106
107 #[test]
108 fn security_fallbacks_correct() {
109 assert_eq!(FallbackPolicy::FailRequest, FallbackPolicy::FailRequest);
116 assert_eq!(FallbackPolicy::SilentSkip, FallbackPolicy::SilentSkip);
117 assert_eq!(FallbackPolicy::UseDefault, FallbackPolicy::UseDefault);
118 assert_eq!(FallbackPolicy::PassThrough, FallbackPolicy::PassThrough);
119 }
120
121 #[test]
122 fn required_subset_allowed_dialect() {
123 let required = required_functions_for_kind(PluginKind::Dialect);
124 let allowed = allowed_functions_for_kind(PluginKind::Dialect);
125 for req in required {
126 assert!(allowed.contains(req), "{} not in allowed", req);
127 }
128 }
129
130 #[test]
131 fn required_subset_allowed_signer_factory() {
132 let required = required_functions_for_kind(PluginKind::SignerFactory);
133 let allowed = allowed_functions_for_kind(PluginKind::SignerFactory);
134 for req in required {
135 assert!(allowed.contains(req), "{} not in allowed", req);
136 }
137 }
138
139 #[test]
140 fn required_subset_allowed_observability() {
141 let required = required_functions_for_kind(PluginKind::Observability);
142 let allowed = allowed_functions_for_kind(PluginKind::Observability);
143 for req in required {
144 assert!(allowed.contains(req), "{} not in allowed", req);
145 }
146 }
147
148 #[test]
149 fn fallback_policy_serde() {
150 let policies = [
151 FallbackPolicy::FailRequest,
152 FallbackPolicy::SilentSkip,
153 FallbackPolicy::UseDefault,
154 FallbackPolicy::PassThrough,
155 ];
156 for policy in &policies {
157 let json = serde_json::to_string(policy).unwrap();
158 let deserialized: FallbackPolicy = serde_json::from_str(&json).unwrap();
159 assert_eq!(*policy, deserialized);
160 }
161 }
162
163 #[test]
164 fn fallback_policy_snake_case_names() {
165 assert_eq!(
166 serde_json::to_string(&FallbackPolicy::FailRequest).unwrap(),
167 "\"fail_request\""
168 );
169 assert_eq!(
170 serde_json::to_string(&FallbackPolicy::SilentSkip).unwrap(),
171 "\"silent_skip\""
172 );
173 assert_eq!(
174 serde_json::to_string(&FallbackPolicy::UseDefault).unwrap(),
175 "\"use_default\""
176 );
177 assert_eq!(
178 serde_json::to_string(&FallbackPolicy::PassThrough).unwrap(),
179 "\"pass_through\""
180 );
181 }
182
183 #[test]
184 fn plugin_kind_serde() {
185 let kinds = [
186 PluginKind::Dialect,
187 PluginKind::SignerFactory,
188 PluginKind::Observability,
189 ];
190 for kind in &kinds {
191 let json = serde_json::to_string(kind).unwrap();
192 let deserialized: PluginKind = serde_json::from_str(&json).unwrap();
193 assert_eq!(*kind, deserialized);
194 }
195 }
196
197 #[test]
198 fn plugin_kind_snake_case_names() {
199 assert_eq!(
200 serde_json::to_string(&PluginKind::Dialect).unwrap(),
201 "\"dialect\""
202 );
203 assert_eq!(
204 serde_json::to_string(&PluginKind::SignerFactory).unwrap(),
205 "\"signer_factory\""
206 );
207 assert_eq!(
208 serde_json::to_string(&PluginKind::Observability).unwrap(),
209 "\"observability\""
210 );
211 }
212
213 #[test]
214 fn kind_allowed_and_required_nonempty() {
215 let kinds = [
216 PluginKind::Dialect,
217 PluginKind::SignerFactory,
218 PluginKind::Observability,
219 ];
220 for &kind in &kinds {
221 let required = required_functions_for_kind(kind);
222 let allowed = allowed_functions_for_kind(kind);
223 assert!(!required.is_empty(), "required empty for {:?}", kind);
224 assert!(!allowed.is_empty(), "allowed empty for {:?}", kind);
225 assert!(
226 required.len() <= allowed.len(),
227 "required > allowed for {:?}",
228 kind
229 );
230 }
231 }
232}