cardinal_config/
lib.rs

1use crate::config::get_config_builder;
2use ::config::ConfigError;
3use derive_builder::Builder;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::collections::BTreeMap;
6use ts_rs::TS;
7
8pub mod config;
9
10#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
11#[ts(export)]
12pub struct HealthCheck {
13    pub path: String,
14    pub interval_ms: u64,
15    pub timeout_ms: u64,
16    pub expect_status: u16,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
20#[ts(export)]
21pub enum MiddlewareType {
22    Inbound,
23    Outbound,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
27#[ts(export)]
28pub struct Middleware {
29    pub r#type: MiddlewareType,
30    pub name: String,
31}
32
33#[derive(Debug, Clone, TS)]
34#[ts(export)]
35pub enum Plugin {
36    Builtin(BuiltinPlugin),
37    Wasm(WasmPluginConfig),
38}
39
40impl Plugin {
41    pub fn name(&self) -> &str {
42        match self {
43            Plugin::Builtin(builtin) => &builtin.name,
44            Plugin::Wasm(wasm) => &wasm.name,
45        }
46    }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
50#[ts(export)]
51pub struct BuiltinPlugin {
52    pub name: String,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
56#[ts(export)]
57pub struct WasmPluginConfig {
58    pub name: String,
59    pub path: String,
60    pub memory_name: Option<String>,
61    pub handle_name: Option<String>,
62}
63
64#[derive(Deserialize, TS)]
65#[serde(untagged)]
66#[ts(export)]
67enum PluginSerde {
68    Name(String),
69    Builtin { builtin: BuiltinPlugin },
70    Wasm { wasm: WasmPluginConfig },
71}
72
73impl<'de> Deserialize<'de> for Plugin {
74    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
75    where
76        D: Deserializer<'de>,
77    {
78        match PluginSerde::deserialize(deserializer)? {
79            PluginSerde::Name(name) => Ok(Plugin::Builtin(BuiltinPlugin { name })),
80            PluginSerde::Builtin { builtin } => Ok(Plugin::Builtin(builtin)),
81            PluginSerde::Wasm { wasm } => Ok(Plugin::Wasm(wasm)),
82        }
83    }
84}
85
86impl Serialize for Plugin {
87    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
88    where
89        S: Serializer,
90    {
91        match self {
92            Plugin::Builtin(builtin) => {
93                #[derive(Serialize)]
94                struct Wrapper<'a> {
95                    builtin: &'a BuiltinPlugin,
96                }
97                Wrapper { builtin }.serialize(serializer)
98            }
99            Plugin::Wasm(wasm) => {
100                #[derive(Serialize)]
101                struct Wrapper<'a> {
102                    wasm: &'a WasmPluginConfig,
103                }
104                Wrapper { wasm }.serialize(serializer)
105            }
106        }
107    }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, TS)]
111#[serde(untagged)]
112#[ts(export)]
113pub enum DestinationMatchValue {
114    String(String),
115    Regex { regex: String },
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Builder, TS)]
119#[ts(export)]
120pub struct DestinationMatch {
121    pub host: Option<DestinationMatchValue>, // exact or wildcard “*.tenant.com”
122    pub path_prefix: Option<DestinationMatchValue>, // e.g. “/billing/”
123    pub path_exact: Option<String>,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Builder, TS, Default)]
127#[ts(export)]
128pub struct DestinationTimeouts {
129    pub connect: Option<u64>,
130    pub read: Option<u64>,
131    pub write: Option<u64>,
132    pub idle: Option<u64>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
136#[ts(export)]
137pub enum DestinationRetryBackoffType {
138    Exponential,
139    Linear,
140    None,
141}
142
143impl Default for DestinationRetryBackoffType {
144    fn default() -> Self {
145        DestinationRetryBackoffType::None
146    }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Builder, TS, Default)]
150#[ts(export)]
151pub struct DestinationRetry {
152    pub max_attempts: u64,
153    pub interval_ms: u64,
154    pub backoff_type: DestinationRetryBackoffType,
155    pub max_interval: Option<u64>,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
159#[ts(export)]
160pub struct Destination {
161    pub name: String,
162    pub url: String,
163    pub health_check: Option<HealthCheck>,
164    #[serde(default)]
165    pub default: bool,
166    #[serde(default)]
167    pub r#match: Option<Vec<DestinationMatch>>,
168    #[serde(default)]
169    pub routes: Vec<Route>,
170    #[serde(default)]
171    pub middleware: Vec<Middleware>,
172    #[serde(default)]
173    pub timeout: Option<DestinationTimeouts>,
174    #[serde(default)]
175    pub retry: Option<DestinationRetry>,
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
179#[ts(export)]
180pub struct ServerConfig {
181    pub address: String,
182    pub force_path_parameter: bool,
183    pub log_upstream_response: bool,
184    pub global_request_middleware: Vec<String>,
185    pub global_response_middleware: Vec<String>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
189#[ts(export)]
190pub struct Route {
191    pub path: String,
192    pub method: String,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize, Default, Builder, TS)]
196#[ts(export)]
197pub struct CardinalConfig {
198    pub server: ServerConfig,
199    pub destinations: BTreeMap<String, Destination>,
200    #[serde(default)]
201    pub plugins: Vec<Plugin>,
202}
203
204impl Default for ServerConfig {
205    fn default() -> Self {
206        ServerConfig {
207            address: "0.0.0.0:1704".into(),
208            force_path_parameter: true,
209            log_upstream_response: true,
210            global_response_middleware: vec![],
211            global_request_middleware: vec![],
212        }
213    }
214}
215
216pub fn load_config(paths: &[String]) -> Result<CardinalConfig, ConfigError> {
217    let builder = get_config_builder(paths)?;
218    let config: CardinalConfig = builder.build()?.try_deserialize()?;
219    validate_config(&config)?;
220
221    Ok(config)
222}
223
224pub fn validate_config(config: &CardinalConfig) -> Result<(), ConfigError> {
225    if config
226        .server
227        .address
228        .parse::<std::net::SocketAddr>()
229        .is_err()
230    {
231        return Err(ConfigError::Message(format!(
232            "Invalid server address: {}",
233            config.server.address
234        )));
235    }
236
237    let all_plugin_names = config
238        .plugins
239        .iter()
240        .map(|p| p.name())
241        .collect::<Vec<&str>>();
242
243    for middleware in config.destinations.values().flat_map(|d| &d.middleware) {
244        if !all_plugin_names.contains(&middleware.name.as_str()) {
245            return Err(ConfigError::Message(format!(
246                "Middleware {} not found. {0} must be included in the list of plugins.",
247                middleware.name
248            )));
249        }
250    }
251
252    for destination in config.destinations.values() {
253        for route in &destination.routes {
254            if !route.path.starts_with('/') {
255                return Err(ConfigError::Message(format!(
256                    "Route path {} must start with a '/'.",
257                    route.path
258                )));
259            }
260        }
261    }
262
263    for destination in config.destinations.values() {
264        for route in &destination.routes {
265            if !route.method.eq("GET")
266                && !route.method.eq("POST")
267                && !route.method.eq("PUT")
268                && !route.method.eq("DELETE")
269                && !route.method.eq("PATCH")
270                && !route.method.eq("HEAD")
271                && !route.method.eq("OPTIONS")
272            {
273                return Err(ConfigError::Message(format!(
274                    "Route method {} is not supported.",
275                    route.method
276                )));
277            }
278        }
279    }
280
281    Ok(())
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use serde::{Deserialize, Serialize};
288    use serde_json::{json, to_value};
289
290    #[test]
291    fn serialize_builtin_plugin() {
292        let plugin = Plugin::Builtin(BuiltinPlugin {
293            name: "Logger".to_string(),
294        });
295
296        let val = to_value(&plugin).unwrap();
297
298        let expected = json!({
299            "builtin": {
300                "name": "Logger"
301            }
302        });
303
304        assert_eq!(val, expected);
305    }
306
307    #[test]
308    fn serialize_wasm_plugin() {
309        let wasm_cfg = WasmPluginConfig {
310            name: "RateLimit".to_string(),
311            path: "plugins/ratelimit.wasm".to_string(),
312            memory_name: None,
313            handle_name: None,
314        };
315        let plugin = Plugin::Wasm(wasm_cfg);
316
317        let val = to_value(&plugin).unwrap();
318
319        let expected = json!({
320            "wasm": {
321                "name": "RateLimit",
322                "path": "plugins/ratelimit.wasm",
323                "memory_name": null,
324                "handle_name": null
325            }
326        });
327
328        assert_eq!(val, expected);
329    }
330
331    #[test]
332    fn toml_builtin_plugin() {
333        let plugin = Plugin::Builtin(BuiltinPlugin {
334            name: "Logger".to_string(),
335        });
336
337        let toml_str = toml::to_string(&plugin).unwrap();
338
339        let expected = r#"[builtin]
340name = "Logger"
341"#;
342
343        assert_eq!(toml_str, expected);
344    }
345
346    #[test]
347    fn toml_wasm_plugin() {
348        let wasm_cfg = WasmPluginConfig {
349            name: "RateLimit".to_string(),
350            path: "plugins/ratelimit.wasm".to_string(),
351            memory_name: None,
352            handle_name: None,
353        };
354        let plugin = Plugin::Wasm(wasm_cfg);
355
356        let toml_str = toml::to_string(&plugin).unwrap();
357
358        // None fields are skipped
359        let expected = r#"[wasm]
360name = "RateLimit"
361path = "plugins/ratelimit.wasm"
362"#;
363
364        assert_eq!(toml_str, expected);
365    }
366
367    #[test]
368    fn destination_match_value_string_roundtrip_json() {
369        let value = DestinationMatchValue::String("api.example.com".to_string());
370        let serialized = to_value(&value).unwrap();
371
372        assert_eq!(serialized, json!("api.example.com"));
373
374        let from_string: DestinationMatchValue =
375            serde_json::from_value(json!("api.example.com")).unwrap();
376        assert_eq!(from_string, value);
377    }
378
379    #[test]
380    fn destination_match_value_regex_roundtrip_json() {
381        let value = DestinationMatchValue::Regex {
382            regex: "^api\\.".to_string(),
383        };
384        let serialized = to_value(&value).unwrap();
385
386        assert_eq!(serialized, json!({"regex": "^api\\."}));
387
388        let decoded: DestinationMatchValue =
389            serde_json::from_value(json!({"regex": "^api\\."})).unwrap();
390        assert_eq!(decoded, value);
391    }
392
393    #[test]
394    fn destination_match_value_string_roundtrip_toml() {
395        let value = DestinationMatchValue::String("billing".to_string());
396        #[derive(Serialize, Deserialize, Debug, PartialEq)]
397        struct Wrapper {
398            value: DestinationMatchValue,
399        }
400
401        let toml_encoded = toml::to_string(&Wrapper {
402            value: value.clone(),
403        })
404        .unwrap();
405        assert_eq!(toml_encoded, "value = \"billing\"\n");
406
407        let decoded: Wrapper = toml::from_str(&toml_encoded).unwrap();
408        assert_eq!(decoded.value, value);
409    }
410
411    #[test]
412    fn destination_match_value_regex_roundtrip_toml() {
413        let value = DestinationMatchValue::Regex {
414            regex: "^/billing".to_string(),
415        };
416        #[derive(Serialize, Deserialize, Debug, PartialEq)]
417        struct Wrapper {
418            value: DestinationMatchValue,
419        }
420
421        let toml_encoded = toml::to_string(&Wrapper {
422            value: value.clone(),
423        })
424        .unwrap();
425        assert_eq!(toml_encoded, "[value]\nregex = \"^/billing\"\n");
426
427        let decoded: Wrapper = toml::from_str(&toml_encoded).unwrap();
428        assert_eq!(decoded.value, value);
429    }
430
431    #[test]
432    fn destination_struct_match_variants() {
433        let string_toml = r#"
434name = "customer_service"
435url = "https://svc.internal/api"
436
437[[match]]
438host = "support.example.com"
439path_prefix = "/helpdesk"
440"#;
441
442        let customer: Destination = toml::from_str(string_toml).unwrap();
443        let matcher = customer
444            .r#match
445            .as_ref()
446            .and_then(|entries| entries.first())
447            .expect("expected match section");
448        assert_eq!(
449            matcher.host,
450            Some(DestinationMatchValue::String("support.example.com".into()))
451        );
452        assert_eq!(
453            matcher.path_prefix,
454            Some(DestinationMatchValue::String("/helpdesk".into()))
455        );
456        assert_eq!(matcher.path_exact, None);
457
458        let regex_toml = r#"
459name = "billing"
460url = "https://billing.internal"
461
462[[match]]
463host = { regex = '^api\.(eu|us)\.example\.com$' }
464path_prefix = { regex = '^/billing/(v\d+)/' }
465"#;
466
467        let billing: Destination = toml::from_str(regex_toml).unwrap();
468        let matcher = billing
469            .r#match
470            .as_ref()
471            .and_then(|entries| entries.first())
472            .expect("expected match section");
473        assert_eq!(
474            matcher.host,
475            Some(DestinationMatchValue::Regex {
476                regex: r"^api\.(eu|us)\.example\.com$".into()
477            })
478        );
479        assert_eq!(
480            matcher.path_prefix,
481            Some(DestinationMatchValue::Regex {
482                regex: r"^/billing/(v\d+)/".into()
483            })
484        );
485        assert_eq!(matcher.path_exact, None);
486    }
487
488    #[test]
489    fn destination_match_toml_mixed_variants() {
490        #[derive(Serialize, Deserialize, Debug, PartialEq)]
491        struct ConfigHarness {
492            destinations: BTreeMap<String, DestinationHarness>,
493        }
494
495        #[derive(Serialize, Deserialize, Debug, PartialEq)]
496        struct DestinationHarness {
497            name: String,
498            url: String,
499            #[serde(rename = "match")]
500            matcher: Option<Vec<DestinationMatch>>,
501        }
502
503        impl DestinationHarness {
504            fn first_match(&self) -> &DestinationMatch {
505                self.matcher
506                    .as_ref()
507                    .and_then(|entries| entries.first())
508                    .expect("matcher section present")
509            }
510
511            fn match_count(&self) -> usize {
512                self.matcher.as_ref().map(|m| m.len()).unwrap_or(0)
513            }
514        }
515
516        let toml_source = r#"
517[destinations.customer_service]
518name = "customer_service"
519url = "https://svc.internal/api"
520
521[[destinations.customer_service.match]]
522host = "support.example.com"
523path_prefix = "/helpdesk"
524
525[[destinations.customer_service.match]]
526host = "support.example.com"
527path_prefix = { regex = '^/support' }
528
529[destinations.billing]
530name = "billing"
531url = "https://billing.internal"
532
533[[destinations.billing.match]]
534host = { regex = '^api\.(eu|us)\.example\.com$' }
535path_prefix = { regex = '^/billing/(v\d+)/' }
536"#;
537
538        let parsed: ConfigHarness = toml::from_str(toml_source).unwrap();
539
540        let customer = parsed.destinations.get("customer_service").unwrap();
541        assert_eq!(customer.match_count(), 2);
542        let customer_match = customer.first_match();
543        assert_eq!(
544            customer_match.host,
545            Some(DestinationMatchValue::String("support.example.com".into()))
546        );
547        assert_eq!(
548            customer_match.path_prefix,
549            Some(DestinationMatchValue::String("/helpdesk".into()))
550        );
551
552        let customer_matches = customer.matcher.as_ref().unwrap();
553        let second = &customer_matches[1];
554        assert_eq!(
555            second.path_prefix,
556            Some(DestinationMatchValue::Regex {
557                regex: String::from("^/support"),
558            })
559        );
560
561        let billing = parsed.destinations.get("billing").unwrap();
562        assert_eq!(billing.match_count(), 1);
563        let billing_match = billing.first_match();
564        assert_eq!(
565            billing_match.host,
566            Some(DestinationMatchValue::Regex {
567                regex: r"^api\.(eu|us)\.example\.com$".into()
568            })
569        );
570        assert_eq!(
571            billing_match.path_prefix,
572            Some(DestinationMatchValue::Regex {
573                regex: r"^/billing/(v\d+)/".into()
574            })
575        );
576
577        let serialized = toml::to_string(&parsed).unwrap();
578        let reparsed: ConfigHarness = toml::from_str(&serialized).unwrap();
579        assert_eq!(reparsed, parsed);
580    }
581
582    #[test]
583    fn destination_match_allows_empty_array() {
584        #[derive(Serialize, Deserialize, Debug, PartialEq)]
585        struct ConfigHarness {
586            destinations: BTreeMap<String, DestinationHarness>,
587        }
588
589        #[derive(Serialize, Deserialize, Debug, PartialEq)]
590        struct DestinationHarness {
591            name: String,
592            url: String,
593            #[serde(rename = "match")]
594            matcher: Option<Vec<DestinationMatch>>,
595        }
596
597        let toml_source = r#"
598[destinations.empty]
599name = "empty"
600url = "https://empty.internal"
601match = []
602"#;
603
604        let parsed: ConfigHarness = toml::from_str(toml_source).unwrap();
605        let destination = parsed.destinations.get("empty").unwrap();
606        assert!(destination
607            .matcher
608            .as_ref()
609            .map(|entries| entries.is_empty())
610            .unwrap_or(false));
611
612        let serialized = toml::to_string(&parsed).unwrap();
613        let reparsed: ConfigHarness = toml::from_str(&serialized).unwrap();
614        assert_eq!(reparsed, parsed);
615    }
616}