cardinal_config/
lib.rs

1use crate::config::get_config_builder;
2use ::config::ConfigError;
3use derive_builder::Builder;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::collections::BTreeMap;
6
7pub mod config;
8
9#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
10pub struct HealthCheck {
11    pub path: String,
12    pub interval_ms: u64,
13    pub timeout_ms: u64,
14    pub expect_status: u16,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub enum MiddlewareType {
19    Inbound,
20    Outbound,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
24pub struct Middleware {
25    pub r#type: MiddlewareType,
26    pub name: String,
27}
28
29#[derive(Debug, Clone)]
30pub enum Plugin {
31    Builtin(BuiltinPlugin),
32    Wasm(WasmPluginConfig),
33}
34
35impl Plugin {
36    pub fn name(&self) -> &str {
37        match self {
38            Plugin::Builtin(builtin) => &builtin.name,
39            Plugin::Wasm(wasm) => &wasm.name,
40        }
41    }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
45pub struct BuiltinPlugin {
46    pub name: String,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
50pub struct WasmPluginConfig {
51    pub name: String,
52    pub path: String,
53    pub memory_name: Option<String>,
54    pub handle_name: Option<String>,
55}
56
57#[derive(Deserialize)]
58#[serde(untagged)]
59enum PluginSerde {
60    Name(String),
61    Builtin { builtin: BuiltinPlugin },
62    Wasm { wasm: WasmPluginConfig },
63}
64
65impl<'de> Deserialize<'de> for Plugin {
66    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
67    where
68        D: Deserializer<'de>,
69    {
70        match PluginSerde::deserialize(deserializer)? {
71            PluginSerde::Name(name) => Ok(Plugin::Builtin(BuiltinPlugin { name })),
72            PluginSerde::Builtin { builtin } => Ok(Plugin::Builtin(builtin)),
73            PluginSerde::Wasm { wasm } => Ok(Plugin::Wasm(wasm)),
74        }
75    }
76}
77
78impl Serialize for Plugin {
79    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
80    where
81        S: Serializer,
82    {
83        match self {
84            Plugin::Builtin(builtin) => {
85                #[derive(Serialize)]
86                struct Wrapper<'a> {
87                    builtin: &'a BuiltinPlugin,
88                }
89                Wrapper { builtin }.serialize(serializer)
90            }
91            Plugin::Wasm(wasm) => {
92                #[derive(Serialize)]
93                struct Wrapper<'a> {
94                    wasm: &'a WasmPluginConfig,
95                }
96                Wrapper { wasm }.serialize(serializer)
97            }
98        }
99    }
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
103pub struct Destination {
104    pub name: String,
105    pub url: String,
106    pub health_check: Option<HealthCheck>,
107    #[serde(default)]
108    pub routes: Vec<Route>,
109    #[serde(default)]
110    pub middleware: Vec<Middleware>,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
114pub struct ServerConfig {
115    pub address: String,
116    pub force_path_parameter: bool,
117    pub log_upstream_response: bool,
118    pub global_request_middleware: Vec<String>,
119    pub global_response_middleware: Vec<String>,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
123pub struct Route {
124    pub path: String,
125    pub method: String,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize, Default, Builder)]
129pub struct CardinalConfig {
130    pub server: ServerConfig,
131    pub destinations: BTreeMap<String, Destination>,
132    #[serde(default)]
133    pub plugins: Vec<Plugin>,
134}
135
136impl Default for ServerConfig {
137    fn default() -> Self {
138        ServerConfig {
139            address: "0.0.0.0:1704".into(),
140            force_path_parameter: true,
141            log_upstream_response: true,
142            global_response_middleware: vec![],
143            global_request_middleware: vec![],
144        }
145    }
146}
147
148pub fn load_config(paths: &[String]) -> Result<CardinalConfig, ConfigError> {
149    let builder = get_config_builder(paths)?;
150    let config: CardinalConfig = builder.build()?.try_deserialize()?;
151    validate_config(&config)?;
152
153    Ok(config)
154}
155
156pub fn validate_config(config: &CardinalConfig) -> Result<(), ConfigError> {
157    if config
158        .server
159        .address
160        .parse::<std::net::SocketAddr>()
161        .is_err()
162    {
163        return Err(ConfigError::Message(format!(
164            "Invalid server address: {}",
165            config.server.address
166        )));
167    }
168
169    let all_plugin_names = config
170        .plugins
171        .iter()
172        .map(|p| p.name())
173        .collect::<Vec<&str>>();
174
175    for middleware in config.destinations.values().flat_map(|d| &d.middleware) {
176        if !all_plugin_names.contains(&middleware.name.as_str()) {
177            return Err(ConfigError::Message(format!(
178                "Middleware {} not found. {0} must be included in the list of plugins.",
179                middleware.name
180            )));
181        }
182    }
183
184    for destination in config.destinations.values() {
185        for route in &destination.routes {
186            if !route.path.starts_with('/') {
187                return Err(ConfigError::Message(format!(
188                    "Route path {} must start with a '/'.",
189                    route.path
190                )));
191            }
192        }
193    }
194
195    for destination in config.destinations.values() {
196        for route in &destination.routes {
197            if !route.method.eq("GET")
198                && !route.method.eq("POST")
199                && !route.method.eq("PUT")
200                && !route.method.eq("DELETE")
201                && !route.method.eq("PATCH")
202                && !route.method.eq("HEAD")
203                && !route.method.eq("OPTIONS")
204            {
205                return Err(ConfigError::Message(format!(
206                    "Route method {} is not supported.",
207                    route.method
208                )));
209            }
210        }
211    }
212
213    Ok(())
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use serde_json::{json, to_value};
220
221    #[test]
222    fn serialize_builtin_plugin() {
223        let plugin = Plugin::Builtin(BuiltinPlugin {
224            name: "Logger".to_string(),
225        });
226
227        let val = to_value(&plugin).unwrap();
228
229        let expected = json!({
230            "builtin": {
231                "name": "Logger"
232            }
233        });
234
235        assert_eq!(val, expected);
236    }
237
238    #[test]
239    fn serialize_wasm_plugin() {
240        let wasm_cfg = WasmPluginConfig {
241            name: "RateLimit".to_string(),
242            path: "plugins/ratelimit.wasm".to_string(),
243            memory_name: None,
244            handle_name: None,
245        };
246        let plugin = Plugin::Wasm(wasm_cfg);
247
248        let val = to_value(&plugin).unwrap();
249
250        let expected = json!({
251            "wasm": {
252                "name": "RateLimit",
253                "path": "plugins/ratelimit.wasm",
254                "memory_name": null,
255                "handle_name": null
256            }
257        });
258
259        assert_eq!(val, expected);
260    }
261
262    #[test]
263    fn toml_builtin_plugin() {
264        let plugin = Plugin::Builtin(BuiltinPlugin {
265            name: "Logger".to_string(),
266        });
267
268        let toml_str = toml::to_string(&plugin).unwrap();
269
270        let expected = r#"[builtin]
271name = "Logger"
272"#;
273
274        assert_eq!(toml_str, expected);
275    }
276
277    #[test]
278    fn toml_wasm_plugin() {
279        let wasm_cfg = WasmPluginConfig {
280            name: "RateLimit".to_string(),
281            path: "plugins/ratelimit.wasm".to_string(),
282            memory_name: None,
283            handle_name: None,
284        };
285        let plugin = Plugin::Wasm(wasm_cfg);
286
287        let toml_str = toml::to_string(&plugin).unwrap();
288
289        // None fields are skipped
290        let expected = r#"[wasm]
291name = "RateLimit"
292path = "plugins/ratelimit.wasm"
293"#;
294
295        assert_eq!(toml_str, expected);
296    }
297}