cardinal_config/
lib.rs

1use crate::config::get_config_builder;
2use ::config::ConfigError;
3use derive_builder::Builder;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::collections::BTreeMap;
6use ts_rs::TS;
7
8pub mod config;
9
10#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
11#[ts(export)]
12pub struct HealthCheck {
13    pub path: String,
14    pub interval_ms: u64,
15    pub timeout_ms: u64,
16    pub expect_status: u16,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
20#[ts(export)]
21pub enum MiddlewareType {
22    Inbound,
23    Outbound,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
27#[ts(export)]
28pub struct Middleware {
29    pub r#type: MiddlewareType,
30    pub name: String,
31}
32
33#[derive(Debug, Clone, TS)]
34#[ts(export)]
35pub enum Plugin {
36    Builtin(BuiltinPlugin),
37    Wasm(WasmPluginConfig),
38}
39
40impl Plugin {
41    pub fn name(&self) -> &str {
42        match self {
43            Plugin::Builtin(builtin) => &builtin.name,
44            Plugin::Wasm(wasm) => &wasm.name,
45        }
46    }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
50#[ts(export)]
51pub struct BuiltinPlugin {
52    pub name: String,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
56#[ts(export)]
57pub struct WasmPluginConfig {
58    pub name: String,
59    pub path: String,
60    pub memory_name: Option<String>,
61    pub handle_name: Option<String>,
62}
63
64#[derive(Deserialize, TS)]
65#[serde(untagged)]
66#[ts(export)]
67enum PluginSerde {
68    Name(String),
69    Builtin { builtin: BuiltinPlugin },
70    Wasm { wasm: WasmPluginConfig },
71}
72
73impl<'de> Deserialize<'de> for Plugin {
74    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
75    where
76        D: Deserializer<'de>,
77    {
78        match PluginSerde::deserialize(deserializer)? {
79            PluginSerde::Name(name) => Ok(Plugin::Builtin(BuiltinPlugin { name })),
80            PluginSerde::Builtin { builtin } => Ok(Plugin::Builtin(builtin)),
81            PluginSerde::Wasm { wasm } => Ok(Plugin::Wasm(wasm)),
82        }
83    }
84}
85
86impl Serialize for Plugin {
87    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
88    where
89        S: Serializer,
90    {
91        match self {
92            Plugin::Builtin(builtin) => {
93                #[derive(Serialize)]
94                struct Wrapper<'a> {
95                    builtin: &'a BuiltinPlugin,
96                }
97                Wrapper { builtin }.serialize(serializer)
98            }
99            Plugin::Wasm(wasm) => {
100                #[derive(Serialize)]
101                struct Wrapper<'a> {
102                    wasm: &'a WasmPluginConfig,
103                }
104                Wrapper { wasm }.serialize(serializer)
105            }
106        }
107    }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, TS)]
111#[serde(untagged)]
112#[ts(export)]
113pub enum DestinationMatchValue {
114    String(String),
115    Regex { regex: String },
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Builder, TS)]
119#[ts(export)]
120pub struct DestinationMatch {
121    pub host: Option<DestinationMatchValue>, // exact or wildcard “*.tenant.com”
122    pub path_prefix: Option<DestinationMatchValue>, // e.g. “/billing/”
123    pub path_exact: Option<String>,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Builder, TS, Default)]
127#[ts(export)]
128pub struct DestinationTimeouts {
129    pub connect: Option<u64>,
130    pub read: Option<u64>,
131    pub write: Option<u64>,
132    pub idle: Option<u64>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS, Default)]
136#[ts(export)]
137pub enum DestinationRetryBackoffType {
138    Exponential,
139    Linear,
140    #[default]
141    None,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Builder, TS, Default)]
145#[ts(export)]
146pub struct DestinationRetry {
147    pub max_attempts: u64,
148    pub interval_ms: u64,
149    pub backoff_type: DestinationRetryBackoffType,
150    pub max_interval: Option<u64>,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
154#[ts(export)]
155pub struct Destination {
156    pub name: String,
157    pub url: String,
158    pub health_check: Option<HealthCheck>,
159    #[serde(default)]
160    pub default: bool,
161    #[serde(default)]
162    pub r#match: Option<Vec<DestinationMatch>>,
163    #[serde(default)]
164    pub routes: Vec<Route>,
165    #[serde(default)]
166    pub middleware: Vec<Middleware>,
167    #[serde(default)]
168    pub timeout: Option<DestinationTimeouts>,
169    #[serde(default)]
170    pub retry: Option<DestinationRetry>,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
174#[ts(export)]
175pub struct ServerConfig {
176    pub address: String,
177    pub force_path_parameter: bool,
178    pub log_upstream_response: bool,
179    pub global_request_middleware: Vec<String>,
180    pub global_response_middleware: Vec<String>,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
184#[ts(export)]
185pub struct Route {
186    pub path: String,
187    pub method: String,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize, Default, Builder, TS)]
191#[ts(export)]
192pub struct CardinalConfig {
193    pub server: ServerConfig,
194    pub destinations: BTreeMap<String, Destination>,
195    #[serde(default)]
196    pub plugins: Vec<Plugin>,
197}
198
199impl Default for ServerConfig {
200    fn default() -> Self {
201        ServerConfig {
202            address: "0.0.0.0:1704".into(),
203            force_path_parameter: true,
204            log_upstream_response: true,
205            global_response_middleware: vec![],
206            global_request_middleware: vec![],
207        }
208    }
209}
210
211pub fn load_config(paths: &[String]) -> Result<CardinalConfig, ConfigError> {
212    let builder = get_config_builder(paths)?;
213    let config: CardinalConfig = builder.build()?.try_deserialize()?;
214    validate_config(&config)?;
215
216    Ok(config)
217}
218
219pub fn validate_config(config: &CardinalConfig) -> Result<(), ConfigError> {
220    if config
221        .server
222        .address
223        .parse::<std::net::SocketAddr>()
224        .is_err()
225    {
226        return Err(ConfigError::Message(format!(
227            "Invalid server address: {}",
228            config.server.address
229        )));
230    }
231
232    let all_plugin_names = config
233        .plugins
234        .iter()
235        .map(|p| p.name())
236        .collect::<Vec<&str>>();
237
238    for middleware in config.destinations.values().flat_map(|d| &d.middleware) {
239        if !all_plugin_names.contains(&middleware.name.as_str()) {
240            return Err(ConfigError::Message(format!(
241                "Middleware {} not found. {0} must be included in the list of plugins.",
242                middleware.name
243            )));
244        }
245    }
246
247    for destination in config.destinations.values() {
248        for route in &destination.routes {
249            if !route.path.starts_with('/') {
250                return Err(ConfigError::Message(format!(
251                    "Route path {} must start with a '/'.",
252                    route.path
253                )));
254            }
255        }
256    }
257
258    for destination in config.destinations.values() {
259        for route in &destination.routes {
260            if !route.method.eq("GET")
261                && !route.method.eq("POST")
262                && !route.method.eq("PUT")
263                && !route.method.eq("DELETE")
264                && !route.method.eq("PATCH")
265                && !route.method.eq("HEAD")
266                && !route.method.eq("OPTIONS")
267            {
268                return Err(ConfigError::Message(format!(
269                    "Route method {} is not supported.",
270                    route.method
271                )));
272            }
273        }
274    }
275
276    Ok(())
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use serde::{Deserialize, Serialize};
283    use serde_json::{json, to_value};
284
285    #[test]
286    fn serialize_builtin_plugin() {
287        let plugin = Plugin::Builtin(BuiltinPlugin {
288            name: "Logger".to_string(),
289        });
290
291        let val = to_value(&plugin).unwrap();
292
293        let expected = json!({
294            "builtin": {
295                "name": "Logger"
296            }
297        });
298
299        assert_eq!(val, expected);
300    }
301
302    #[test]
303    fn serialize_wasm_plugin() {
304        let wasm_cfg = WasmPluginConfig {
305            name: "RateLimit".to_string(),
306            path: "plugins/ratelimit.wasm".to_string(),
307            memory_name: None,
308            handle_name: None,
309        };
310        let plugin = Plugin::Wasm(wasm_cfg);
311
312        let val = to_value(&plugin).unwrap();
313
314        let expected = json!({
315            "wasm": {
316                "name": "RateLimit",
317                "path": "plugins/ratelimit.wasm",
318                "memory_name": null,
319                "handle_name": null
320            }
321        });
322
323        assert_eq!(val, expected);
324    }
325
326    #[test]
327    fn toml_builtin_plugin() {
328        let plugin = Plugin::Builtin(BuiltinPlugin {
329            name: "Logger".to_string(),
330        });
331
332        let toml_str = toml::to_string(&plugin).unwrap();
333
334        let expected = r#"[builtin]
335name = "Logger"
336"#;
337
338        assert_eq!(toml_str, expected);
339    }
340
341    #[test]
342    fn toml_wasm_plugin() {
343        let wasm_cfg = WasmPluginConfig {
344            name: "RateLimit".to_string(),
345            path: "plugins/ratelimit.wasm".to_string(),
346            memory_name: None,
347            handle_name: None,
348        };
349        let plugin = Plugin::Wasm(wasm_cfg);
350
351        let toml_str = toml::to_string(&plugin).unwrap();
352
353        // None fields are skipped
354        let expected = r#"[wasm]
355name = "RateLimit"
356path = "plugins/ratelimit.wasm"
357"#;
358
359        assert_eq!(toml_str, expected);
360    }
361
362    #[test]
363    fn destination_match_value_string_roundtrip_json() {
364        let value = DestinationMatchValue::String("api.example.com".to_string());
365        let serialized = to_value(&value).unwrap();
366
367        assert_eq!(serialized, json!("api.example.com"));
368
369        let from_string: DestinationMatchValue =
370            serde_json::from_value(json!("api.example.com")).unwrap();
371        assert_eq!(from_string, value);
372    }
373
374    #[test]
375    fn destination_match_value_regex_roundtrip_json() {
376        let value = DestinationMatchValue::Regex {
377            regex: "^api\\.".to_string(),
378        };
379        let serialized = to_value(&value).unwrap();
380
381        assert_eq!(serialized, json!({"regex": "^api\\."}));
382
383        let decoded: DestinationMatchValue =
384            serde_json::from_value(json!({"regex": "^api\\."})).unwrap();
385        assert_eq!(decoded, value);
386    }
387
388    #[test]
389    fn destination_match_value_string_roundtrip_toml() {
390        let value = DestinationMatchValue::String("billing".to_string());
391        #[derive(Serialize, Deserialize, Debug, PartialEq)]
392        struct Wrapper {
393            value: DestinationMatchValue,
394        }
395
396        let toml_encoded = toml::to_string(&Wrapper {
397            value: value.clone(),
398        })
399        .unwrap();
400        assert_eq!(toml_encoded, "value = \"billing\"\n");
401
402        let decoded: Wrapper = toml::from_str(&toml_encoded).unwrap();
403        assert_eq!(decoded.value, value);
404    }
405
406    #[test]
407    fn destination_match_value_regex_roundtrip_toml() {
408        let value = DestinationMatchValue::Regex {
409            regex: "^/billing".to_string(),
410        };
411        #[derive(Serialize, Deserialize, Debug, PartialEq)]
412        struct Wrapper {
413            value: DestinationMatchValue,
414        }
415
416        let toml_encoded = toml::to_string(&Wrapper {
417            value: value.clone(),
418        })
419        .unwrap();
420        assert_eq!(toml_encoded, "[value]\nregex = \"^/billing\"\n");
421
422        let decoded: Wrapper = toml::from_str(&toml_encoded).unwrap();
423        assert_eq!(decoded.value, value);
424    }
425
426    #[test]
427    fn destination_struct_match_variants() {
428        let string_toml = r#"
429name = "customer_service"
430url = "https://svc.internal/api"
431
432[[match]]
433host = "support.example.com"
434path_prefix = "/helpdesk"
435"#;
436
437        let customer: Destination = toml::from_str(string_toml).unwrap();
438        let matcher = customer
439            .r#match
440            .as_ref()
441            .and_then(|entries| entries.first())
442            .expect("expected match section");
443        assert_eq!(
444            matcher.host,
445            Some(DestinationMatchValue::String("support.example.com".into()))
446        );
447        assert_eq!(
448            matcher.path_prefix,
449            Some(DestinationMatchValue::String("/helpdesk".into()))
450        );
451        assert_eq!(matcher.path_exact, None);
452
453        let regex_toml = r#"
454name = "billing"
455url = "https://billing.internal"
456
457[[match]]
458host = { regex = '^api\.(eu|us)\.example\.com$' }
459path_prefix = { regex = '^/billing/(v\d+)/' }
460"#;
461
462        let billing: Destination = toml::from_str(regex_toml).unwrap();
463        let matcher = billing
464            .r#match
465            .as_ref()
466            .and_then(|entries| entries.first())
467            .expect("expected match section");
468        assert_eq!(
469            matcher.host,
470            Some(DestinationMatchValue::Regex {
471                regex: r"^api\.(eu|us)\.example\.com$".into()
472            })
473        );
474        assert_eq!(
475            matcher.path_prefix,
476            Some(DestinationMatchValue::Regex {
477                regex: r"^/billing/(v\d+)/".into()
478            })
479        );
480        assert_eq!(matcher.path_exact, None);
481    }
482
483    #[test]
484    fn destination_match_toml_mixed_variants() {
485        #[derive(Serialize, Deserialize, Debug, PartialEq)]
486        struct ConfigHarness {
487            destinations: BTreeMap<String, DestinationHarness>,
488        }
489
490        #[derive(Serialize, Deserialize, Debug, PartialEq)]
491        struct DestinationHarness {
492            name: String,
493            url: String,
494            #[serde(rename = "match")]
495            matcher: Option<Vec<DestinationMatch>>,
496        }
497
498        impl DestinationHarness {
499            fn first_match(&self) -> &DestinationMatch {
500                self.matcher
501                    .as_ref()
502                    .and_then(|entries| entries.first())
503                    .expect("matcher section present")
504            }
505
506            fn match_count(&self) -> usize {
507                self.matcher.as_ref().map(|m| m.len()).unwrap_or(0)
508            }
509        }
510
511        let toml_source = r#"
512[destinations.customer_service]
513name = "customer_service"
514url = "https://svc.internal/api"
515
516[[destinations.customer_service.match]]
517host = "support.example.com"
518path_prefix = "/helpdesk"
519
520[[destinations.customer_service.match]]
521host = "support.example.com"
522path_prefix = { regex = '^/support' }
523
524[destinations.billing]
525name = "billing"
526url = "https://billing.internal"
527
528[[destinations.billing.match]]
529host = { regex = '^api\.(eu|us)\.example\.com$' }
530path_prefix = { regex = '^/billing/(v\d+)/' }
531"#;
532
533        let parsed: ConfigHarness = toml::from_str(toml_source).unwrap();
534
535        let customer = parsed.destinations.get("customer_service").unwrap();
536        assert_eq!(customer.match_count(), 2);
537        let customer_match = customer.first_match();
538        assert_eq!(
539            customer_match.host,
540            Some(DestinationMatchValue::String("support.example.com".into()))
541        );
542        assert_eq!(
543            customer_match.path_prefix,
544            Some(DestinationMatchValue::String("/helpdesk".into()))
545        );
546
547        let customer_matches = customer.matcher.as_ref().unwrap();
548        let second = &customer_matches[1];
549        assert_eq!(
550            second.path_prefix,
551            Some(DestinationMatchValue::Regex {
552                regex: String::from("^/support"),
553            })
554        );
555
556        let billing = parsed.destinations.get("billing").unwrap();
557        assert_eq!(billing.match_count(), 1);
558        let billing_match = billing.first_match();
559        assert_eq!(
560            billing_match.host,
561            Some(DestinationMatchValue::Regex {
562                regex: r"^api\.(eu|us)\.example\.com$".into()
563            })
564        );
565        assert_eq!(
566            billing_match.path_prefix,
567            Some(DestinationMatchValue::Regex {
568                regex: r"^/billing/(v\d+)/".into()
569            })
570        );
571
572        let serialized = toml::to_string(&parsed).unwrap();
573        let reparsed: ConfigHarness = toml::from_str(&serialized).unwrap();
574        assert_eq!(reparsed, parsed);
575    }
576
577    #[test]
578    fn destination_match_allows_empty_array() {
579        #[derive(Serialize, Deserialize, Debug, PartialEq)]
580        struct ConfigHarness {
581            destinations: BTreeMap<String, DestinationHarness>,
582        }
583
584        #[derive(Serialize, Deserialize, Debug, PartialEq)]
585        struct DestinationHarness {
586            name: String,
587            url: String,
588            #[serde(rename = "match")]
589            matcher: Option<Vec<DestinationMatch>>,
590        }
591
592        let toml_source = r#"
593[destinations.empty]
594name = "empty"
595url = "https://empty.internal"
596match = []
597"#;
598
599        let parsed: ConfigHarness = toml::from_str(toml_source).unwrap();
600        let destination = parsed.destinations.get("empty").unwrap();
601        assert!(destination
602            .matcher
603            .as_ref()
604            .map(|entries| entries.is_empty())
605            .unwrap_or(false));
606
607        let serialized = toml::to_string(&parsed).unwrap();
608        let reparsed: ConfigHarness = toml::from_str(&serialized).unwrap();
609        assert_eq!(reparsed, parsed);
610    }
611}