cardinal_config/
lib.rs

1use crate::config::get_config_builder;
2use ::config::ConfigError;
3use derive_builder::Builder;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::collections::BTreeMap;
6
7pub mod config;
8
9#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
10pub struct HealthCheck {
11    pub path: String,
12    pub interval_ms: u64,
13    pub timeout_ms: u64,
14    pub expect_status: u16,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub enum MiddlewareType {
19    Inbound,
20    Outbound,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
24pub struct Middleware {
25    pub r#type: MiddlewareType,
26    pub name: String,
27}
28
29#[derive(Debug, Clone)]
30pub enum Plugin {
31    Builtin(BuiltinPlugin),
32    Wasm(WasmPluginConfig),
33}
34
35impl Plugin {
36    pub fn name(&self) -> &str {
37        match self {
38            Plugin::Builtin(builtin) => &builtin.name,
39            Plugin::Wasm(wasm) => &wasm.name,
40        }
41    }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
45pub struct BuiltinPlugin {
46    pub name: String,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
50pub struct WasmPluginConfig {
51    pub name: String,
52    pub path: String,
53    pub memory_name: Option<String>,
54    pub handle_name: Option<String>,
55}
56
57#[derive(Deserialize)]
58#[serde(untagged)]
59enum PluginSerde {
60    Name(String),
61    Builtin { builtin: BuiltinPlugin },
62    Wasm { wasm: WasmPluginConfig },
63}
64
65impl<'de> Deserialize<'de> for Plugin {
66    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
67    where
68        D: Deserializer<'de>,
69    {
70        match PluginSerde::deserialize(deserializer)? {
71            PluginSerde::Name(name) => Ok(Plugin::Builtin(BuiltinPlugin { name })),
72            PluginSerde::Builtin { builtin } => Ok(Plugin::Builtin(builtin)),
73            PluginSerde::Wasm { wasm } => Ok(Plugin::Wasm(wasm)),
74        }
75    }
76}
77
78impl Serialize for Plugin {
79    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
80    where
81        S: Serializer,
82    {
83        match self {
84            Plugin::Builtin(builtin) => {
85                #[derive(Serialize)]
86                struct Wrapper<'a> {
87                    builtin: &'a BuiltinPlugin,
88                }
89                Wrapper { builtin }.serialize(serializer)
90            }
91            Plugin::Wasm(wasm) => {
92                #[derive(Serialize)]
93                struct Wrapper<'a> {
94                    wasm: &'a WasmPluginConfig,
95                }
96                Wrapper { wasm }.serialize(serializer)
97            }
98        }
99    }
100}
101
102#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
103#[serde(untagged)]
104pub enum DestinationMatchValue {
105    String(String),
106    Regex { regex: String },
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Builder)]
110pub struct DestinationMatch {
111    pub host: Option<DestinationMatchValue>, // exact or wildcard “*.tenant.com”
112    pub path_prefix: Option<DestinationMatchValue>, // e.g. “/billing/”
113    pub path_exact: Option<String>,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
117pub struct Destination {
118    pub name: String,
119    pub url: String,
120    pub health_check: Option<HealthCheck>,
121    #[serde(default)]
122    pub default: bool,
123    #[serde(default)]
124    pub r#match: Option<DestinationMatch>,
125    #[serde(default)]
126    pub routes: Vec<Route>,
127    #[serde(default)]
128    pub middleware: Vec<Middleware>,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
132pub struct ServerConfig {
133    pub address: String,
134    pub force_path_parameter: bool,
135    pub log_upstream_response: bool,
136    pub global_request_middleware: Vec<String>,
137    pub global_response_middleware: Vec<String>,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
141pub struct Route {
142    pub path: String,
143    pub method: String,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize, Default, Builder)]
147pub struct CardinalConfig {
148    pub server: ServerConfig,
149    pub destinations: BTreeMap<String, Destination>,
150    #[serde(default)]
151    pub plugins: Vec<Plugin>,
152}
153
154impl Default for ServerConfig {
155    fn default() -> Self {
156        ServerConfig {
157            address: "0.0.0.0:1704".into(),
158            force_path_parameter: true,
159            log_upstream_response: true,
160            global_response_middleware: vec![],
161            global_request_middleware: vec![],
162        }
163    }
164}
165
166pub fn load_config(paths: &[String]) -> Result<CardinalConfig, ConfigError> {
167    let builder = get_config_builder(paths)?;
168    let config: CardinalConfig = builder.build()?.try_deserialize()?;
169    validate_config(&config)?;
170
171    Ok(config)
172}
173
174pub fn validate_config(config: &CardinalConfig) -> Result<(), ConfigError> {
175    if config
176        .server
177        .address
178        .parse::<std::net::SocketAddr>()
179        .is_err()
180    {
181        return Err(ConfigError::Message(format!(
182            "Invalid server address: {}",
183            config.server.address
184        )));
185    }
186
187    let all_plugin_names = config
188        .plugins
189        .iter()
190        .map(|p| p.name())
191        .collect::<Vec<&str>>();
192
193    for middleware in config.destinations.values().flat_map(|d| &d.middleware) {
194        if !all_plugin_names.contains(&middleware.name.as_str()) {
195            return Err(ConfigError::Message(format!(
196                "Middleware {} not found. {0} must be included in the list of plugins.",
197                middleware.name
198            )));
199        }
200    }
201
202    for destination in config.destinations.values() {
203        for route in &destination.routes {
204            if !route.path.starts_with('/') {
205                return Err(ConfigError::Message(format!(
206                    "Route path {} must start with a '/'.",
207                    route.path
208                )));
209            }
210        }
211    }
212
213    for destination in config.destinations.values() {
214        for route in &destination.routes {
215            if !route.method.eq("GET")
216                && !route.method.eq("POST")
217                && !route.method.eq("PUT")
218                && !route.method.eq("DELETE")
219                && !route.method.eq("PATCH")
220                && !route.method.eq("HEAD")
221                && !route.method.eq("OPTIONS")
222            {
223                return Err(ConfigError::Message(format!(
224                    "Route method {} is not supported.",
225                    route.method
226                )));
227            }
228        }
229    }
230
231    Ok(())
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use serde::{Deserialize, Serialize};
238    use serde_json::{json, to_value};
239
240    #[test]
241    fn serialize_builtin_plugin() {
242        let plugin = Plugin::Builtin(BuiltinPlugin {
243            name: "Logger".to_string(),
244        });
245
246        let val = to_value(&plugin).unwrap();
247
248        let expected = json!({
249            "builtin": {
250                "name": "Logger"
251            }
252        });
253
254        assert_eq!(val, expected);
255    }
256
257    #[test]
258    fn serialize_wasm_plugin() {
259        let wasm_cfg = WasmPluginConfig {
260            name: "RateLimit".to_string(),
261            path: "plugins/ratelimit.wasm".to_string(),
262            memory_name: None,
263            handle_name: None,
264        };
265        let plugin = Plugin::Wasm(wasm_cfg);
266
267        let val = to_value(&plugin).unwrap();
268
269        let expected = json!({
270            "wasm": {
271                "name": "RateLimit",
272                "path": "plugins/ratelimit.wasm",
273                "memory_name": null,
274                "handle_name": null
275            }
276        });
277
278        assert_eq!(val, expected);
279    }
280
281    #[test]
282    fn toml_builtin_plugin() {
283        let plugin = Plugin::Builtin(BuiltinPlugin {
284            name: "Logger".to_string(),
285        });
286
287        let toml_str = toml::to_string(&plugin).unwrap();
288
289        let expected = r#"[builtin]
290name = "Logger"
291"#;
292
293        assert_eq!(toml_str, expected);
294    }
295
296    #[test]
297    fn toml_wasm_plugin() {
298        let wasm_cfg = WasmPluginConfig {
299            name: "RateLimit".to_string(),
300            path: "plugins/ratelimit.wasm".to_string(),
301            memory_name: None,
302            handle_name: None,
303        };
304        let plugin = Plugin::Wasm(wasm_cfg);
305
306        let toml_str = toml::to_string(&plugin).unwrap();
307
308        // None fields are skipped
309        let expected = r#"[wasm]
310name = "RateLimit"
311path = "plugins/ratelimit.wasm"
312"#;
313
314        assert_eq!(toml_str, expected);
315    }
316
317    #[test]
318    fn destination_match_value_string_roundtrip_json() {
319        let value = DestinationMatchValue::String("api.example.com".to_string());
320        let serialized = to_value(&value).unwrap();
321
322        assert_eq!(serialized, json!("api.example.com"));
323
324        let from_string: DestinationMatchValue =
325            serde_json::from_value(json!("api.example.com")).unwrap();
326        assert_eq!(from_string, value);
327    }
328
329    #[test]
330    fn destination_match_value_regex_roundtrip_json() {
331        let value = DestinationMatchValue::Regex {
332            regex: "^api\\.".to_string(),
333        };
334        let serialized = to_value(&value).unwrap();
335
336        assert_eq!(serialized, json!({"regex": "^api\\."}));
337
338        let decoded: DestinationMatchValue =
339            serde_json::from_value(json!({"regex": "^api\\."})).unwrap();
340        assert_eq!(decoded, value);
341    }
342
343    #[test]
344    fn destination_match_value_string_roundtrip_toml() {
345        let value = DestinationMatchValue::String("billing".to_string());
346        #[derive(Serialize, Deserialize, Debug, PartialEq)]
347        struct Wrapper {
348            value: DestinationMatchValue,
349        }
350
351        let toml_encoded = toml::to_string(&Wrapper {
352            value: value.clone(),
353        })
354        .unwrap();
355        assert_eq!(toml_encoded, "value = \"billing\"\n");
356
357        let decoded: Wrapper = toml::from_str(&toml_encoded).unwrap();
358        assert_eq!(decoded.value, value);
359    }
360
361    #[test]
362    fn destination_match_value_regex_roundtrip_toml() {
363        let value = DestinationMatchValue::Regex {
364            regex: "^/billing".to_string(),
365        };
366        #[derive(Serialize, Deserialize, Debug, PartialEq)]
367        struct Wrapper {
368            value: DestinationMatchValue,
369        }
370
371        let toml_encoded = toml::to_string(&Wrapper {
372            value: value.clone(),
373        })
374        .unwrap();
375        assert_eq!(toml_encoded, "[value]\nregex = \"^/billing\"\n");
376
377        let decoded: Wrapper = toml::from_str(&toml_encoded).unwrap();
378        assert_eq!(decoded.value, value);
379    }
380
381    #[test]
382    fn destination_struct_match_variants() {
383        let string_toml = r#"
384name = "customer_service"
385url = "https://svc.internal/api"
386
387[match]
388host = "support.example.com"
389path_prefix = "/helpdesk"
390"#;
391
392        let customer: Destination = toml::from_str(string_toml).unwrap();
393        let matcher = customer.r#match.as_ref().expect("expected match section");
394        assert_eq!(
395            matcher.host,
396            Some(DestinationMatchValue::String("support.example.com".into()))
397        );
398        assert_eq!(
399            matcher.path_prefix,
400            Some(DestinationMatchValue::String("/helpdesk".into()))
401        );
402        assert_eq!(matcher.path_exact, None);
403
404        let regex_toml = r#"
405name = "billing"
406url = "https://billing.internal"
407
408[match]
409host = { regex = '^api\.(eu|us)\.example\.com$' }
410path_prefix = { regex = '^/billing/(v\d+)/' }
411"#;
412
413        let billing: Destination = toml::from_str(regex_toml).unwrap();
414        let matcher = billing.r#match.as_ref().expect("expected match section");
415        assert_eq!(
416            matcher.host,
417            Some(DestinationMatchValue::Regex {
418                regex: r"^api\.(eu|us)\.example\.com$".into()
419            })
420        );
421        assert_eq!(
422            matcher.path_prefix,
423            Some(DestinationMatchValue::Regex {
424                regex: r"^/billing/(v\d+)/".into()
425            })
426        );
427        assert_eq!(matcher.path_exact, None);
428    }
429
430    #[test]
431    fn destination_match_toml_mixed_variants() {
432        #[derive(Serialize, Deserialize, Debug, PartialEq)]
433        struct ConfigHarness {
434            destinations: BTreeMap<String, DestinationHarness>,
435        }
436
437        #[derive(Serialize, Deserialize, Debug, PartialEq)]
438        struct DestinationHarness {
439            name: String,
440            url: String,
441            #[serde(rename = "match")]
442            matcher: Option<DestinationMatch>,
443        }
444
445        impl DestinationHarness {
446            fn matcher(&self) -> &DestinationMatch {
447                self.matcher.as_ref().expect("matcher section present")
448            }
449        }
450
451        let toml_source = r#"
452[destinations.customer_service]
453name = "customer_service"
454url = "https://svc.internal/api"
455
456[destinations.customer_service.match]
457host = "support.example.com"
458path_prefix = "/helpdesk"
459
460[destinations.billing]
461name = "billing"
462url = "https://billing.internal"
463
464[destinations.billing.match]
465host = { regex = '^api\.(eu|us)\.example\.com$' }
466path_prefix = { regex = '^/billing/(v\d+)/' }
467"#;
468
469        let parsed: ConfigHarness = toml::from_str(toml_source).unwrap();
470
471        let customer = parsed
472            .destinations
473            .get("customer_service")
474            .unwrap()
475            .matcher();
476        assert_eq!(
477            customer.host,
478            Some(DestinationMatchValue::String("support.example.com".into()))
479        );
480        assert_eq!(
481            customer.path_prefix,
482            Some(DestinationMatchValue::String("/helpdesk".into()))
483        );
484
485        let billing = parsed.destinations.get("billing").unwrap().matcher();
486        assert_eq!(
487            billing.host,
488            Some(DestinationMatchValue::Regex {
489                regex: r"^api\.(eu|us)\.example\.com$".into()
490            })
491        );
492        assert_eq!(
493            billing.path_prefix,
494            Some(DestinationMatchValue::Regex {
495                regex: r"^/billing/(v\d+)/".into()
496            })
497        );
498
499        let serialized = toml::to_string(&parsed).unwrap();
500        let reparsed: ConfigHarness = toml::from_str(&serialized).unwrap();
501        assert_eq!(reparsed, parsed);
502    }
503}