use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FallbackPolicy {
FailRequest,
SilentSkip,
UseDefault,
PassThrough,
}
pub trait WireFunction: 'static {
const NAME: &'static str;
const FALLBACK: FallbackPolicy;
const SUPPORTED_VERSIONS: &'static [u32];
type Request: serde::Serialize + serde::de::DeserializeOwned;
type Response: serde::Serialize + serde::de::DeserializeOwned;
fn dry_run_request() -> Self::Request;
fn dry_run_response() -> Self::Response;
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PluginKind {
Dialect,
SignerFactory,
Observability,
}
pub fn all_wire_functions() -> &'static [&'static str] {
&[
"shape",
"normalize_error",
"build_signer",
"sign",
"on_unauthorized",
"observe",
"filter",
]
}
pub fn required_functions_for_kind(kind: PluginKind) -> &'static [&'static str] {
match kind {
PluginKind::Dialect => &["shape"],
PluginKind::SignerFactory => &["build_signer", "sign"],
PluginKind::Observability => &["observe"],
}
}
pub fn allowed_functions_for_kind(kind: PluginKind) -> &'static [&'static str] {
match kind {
PluginKind::Dialect => &["shape", "normalize_error"],
PluginKind::SignerFactory => &["build_signer", "sign", "on_unauthorized"],
PluginKind::Observability => &["observe"],
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn all_seven_functions_listed() {
assert_eq!(all_wire_functions().len(), 7);
let funcs = all_wire_functions();
assert!(funcs.contains(&"shape"));
assert!(funcs.contains(&"normalize_error"));
assert!(funcs.contains(&"build_signer"));
assert!(funcs.contains(&"sign"));
assert!(funcs.contains(&"on_unauthorized"));
assert!(funcs.contains(&"observe"));
assert!(funcs.contains(&"filter"));
}
#[test]
fn security_fallbacks_correct() {
assert_eq!(FallbackPolicy::FailRequest, FallbackPolicy::FailRequest);
assert_eq!(FallbackPolicy::SilentSkip, FallbackPolicy::SilentSkip);
assert_eq!(FallbackPolicy::UseDefault, FallbackPolicy::UseDefault);
assert_eq!(FallbackPolicy::PassThrough, FallbackPolicy::PassThrough);
}
#[test]
fn required_subset_allowed_dialect() {
let required = required_functions_for_kind(PluginKind::Dialect);
let allowed = allowed_functions_for_kind(PluginKind::Dialect);
for req in required {
assert!(allowed.contains(req), "{} not in allowed", req);
}
}
#[test]
fn required_subset_allowed_signer_factory() {
let required = required_functions_for_kind(PluginKind::SignerFactory);
let allowed = allowed_functions_for_kind(PluginKind::SignerFactory);
for req in required {
assert!(allowed.contains(req), "{} not in allowed", req);
}
}
#[test]
fn required_subset_allowed_observability() {
let required = required_functions_for_kind(PluginKind::Observability);
let allowed = allowed_functions_for_kind(PluginKind::Observability);
for req in required {
assert!(allowed.contains(req), "{} not in allowed", req);
}
}
#[test]
fn fallback_policy_serde() {
let policies = [
FallbackPolicy::FailRequest,
FallbackPolicy::SilentSkip,
FallbackPolicy::UseDefault,
FallbackPolicy::PassThrough,
];
for policy in &policies {
let json = serde_json::to_string(policy).unwrap();
let deserialized: FallbackPolicy = serde_json::from_str(&json).unwrap();
assert_eq!(*policy, deserialized);
}
}
#[test]
fn fallback_policy_snake_case_names() {
assert_eq!(
serde_json::to_string(&FallbackPolicy::FailRequest).unwrap(),
"\"fail_request\""
);
assert_eq!(
serde_json::to_string(&FallbackPolicy::SilentSkip).unwrap(),
"\"silent_skip\""
);
assert_eq!(
serde_json::to_string(&FallbackPolicy::UseDefault).unwrap(),
"\"use_default\""
);
assert_eq!(
serde_json::to_string(&FallbackPolicy::PassThrough).unwrap(),
"\"pass_through\""
);
}
#[test]
fn plugin_kind_serde() {
let kinds = [
PluginKind::Dialect,
PluginKind::SignerFactory,
PluginKind::Observability,
];
for kind in &kinds {
let json = serde_json::to_string(kind).unwrap();
let deserialized: PluginKind = serde_json::from_str(&json).unwrap();
assert_eq!(*kind, deserialized);
}
}
#[test]
fn plugin_kind_snake_case_names() {
assert_eq!(
serde_json::to_string(&PluginKind::Dialect).unwrap(),
"\"dialect\""
);
assert_eq!(
serde_json::to_string(&PluginKind::SignerFactory).unwrap(),
"\"signer_factory\""
);
assert_eq!(
serde_json::to_string(&PluginKind::Observability).unwrap(),
"\"observability\""
);
}
#[test]
fn kind_allowed_and_required_nonempty() {
let kinds = [
PluginKind::Dialect,
PluginKind::SignerFactory,
PluginKind::Observability,
];
for &kind in &kinds {
let required = required_functions_for_kind(kind);
let allowed = allowed_functions_for_kind(kind);
assert!(!required.is_empty(), "required empty for {:?}", kind);
assert!(!allowed.is_empty(), "allowed empty for {:?}", kind);
assert!(
required.len() <= allowed.len(),
"required > allowed for {:?}",
kind
);
}
}
}