1use crate::config::get_config_builder;
2use ::config::ConfigError;
3use derive_builder::Builder;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::collections::BTreeMap;
6use ts_rs::TS;
7
8pub mod config;
9
10#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
11#[ts(export)]
12pub struct HealthCheck {
13 pub path: String,
14 pub interval_ms: u64,
15 pub timeout_ms: u64,
16 pub expect_status: u16,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
20#[ts(export)]
21pub enum MiddlewareType {
22 Inbound,
23 Outbound,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
27#[ts(export)]
28pub struct Middleware {
29 pub r#type: MiddlewareType,
30 pub name: String,
31}
32
33#[derive(Debug, Clone, TS)]
34#[ts(export)]
35pub enum Plugin {
36 Builtin(BuiltinPlugin),
37 Wasm(WasmPluginConfig),
38}
39
40impl Plugin {
41 pub fn name(&self) -> &str {
42 match self {
43 Plugin::Builtin(builtin) => &builtin.name,
44 Plugin::Wasm(wasm) => &wasm.name,
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
50#[ts(export)]
51pub struct BuiltinPlugin {
52 pub name: String,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
56#[ts(export)]
57pub struct WasmPluginConfig {
58 pub name: String,
59 pub path: String,
60 pub memory_name: Option<String>,
61 pub handle_name: Option<String>,
62}
63
64#[derive(Deserialize, TS)]
65#[serde(untagged)]
66#[ts(export)]
67enum PluginSerde {
68 Name(String),
69 Builtin { builtin: BuiltinPlugin },
70 Wasm { wasm: WasmPluginConfig },
71}
72
73impl<'de> Deserialize<'de> for Plugin {
74 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
75 where
76 D: Deserializer<'de>,
77 {
78 match PluginSerde::deserialize(deserializer)? {
79 PluginSerde::Name(name) => Ok(Plugin::Builtin(BuiltinPlugin { name })),
80 PluginSerde::Builtin { builtin } => Ok(Plugin::Builtin(builtin)),
81 PluginSerde::Wasm { wasm } => Ok(Plugin::Wasm(wasm)),
82 }
83 }
84}
85
86impl Serialize for Plugin {
87 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
88 where
89 S: Serializer,
90 {
91 match self {
92 Plugin::Builtin(builtin) => {
93 #[derive(Serialize)]
94 struct Wrapper<'a> {
95 builtin: &'a BuiltinPlugin,
96 }
97 Wrapper { builtin }.serialize(serializer)
98 }
99 Plugin::Wasm(wasm) => {
100 #[derive(Serialize)]
101 struct Wrapper<'a> {
102 wasm: &'a WasmPluginConfig,
103 }
104 Wrapper { wasm }.serialize(serializer)
105 }
106 }
107 }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, TS)]
111#[serde(untagged)]
112#[ts(export)]
113pub enum DestinationMatchValue {
114 String(String),
115 Regex { regex: String },
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Builder, TS)]
119#[ts(export)]
120pub struct DestinationMatch {
121 pub host: Option<DestinationMatchValue>, pub path_prefix: Option<DestinationMatchValue>, pub path_exact: Option<String>,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Builder, TS, Default)]
127#[ts(export)]
128pub struct DestinationTimeouts {
129 pub connect: Option<u64>,
130 pub read: Option<u64>,
131 pub write: Option<u64>,
132 pub idle: Option<u64>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
136#[ts(export)]
137pub enum DestinationRetryBackoffType {
138 Exponential,
139 Linear,
140 None,
141}
142
143impl Default for DestinationRetryBackoffType {
144 fn default() -> Self {
145 DestinationRetryBackoffType::None
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Builder, TS, Default)]
150#[ts(export)]
151pub struct DestinationRetry {
152 pub max_attempts: u64,
153 pub interval_ms: u64,
154 pub backoff_type: DestinationRetryBackoffType,
155 pub max_interval: Option<u64>,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
159#[ts(export)]
160pub struct Destination {
161 pub name: String,
162 pub url: String,
163 pub health_check: Option<HealthCheck>,
164 #[serde(default)]
165 pub default: bool,
166 #[serde(default)]
167 pub r#match: Option<Vec<DestinationMatch>>,
168 #[serde(default)]
169 pub routes: Vec<Route>,
170 #[serde(default)]
171 pub middleware: Vec<Middleware>,
172 #[serde(default)]
173 pub timeout: Option<DestinationTimeouts>,
174 #[serde(default)]
175 pub retry: Option<DestinationRetry>,
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
179#[ts(export)]
180pub struct ServerConfig {
181 pub address: String,
182 pub force_path_parameter: bool,
183 pub log_upstream_response: bool,
184 pub global_request_middleware: Vec<String>,
185 pub global_response_middleware: Vec<String>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize, Builder, TS)]
189#[ts(export)]
190pub struct Route {
191 pub path: String,
192 pub method: String,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize, Default, Builder, TS)]
196#[ts(export)]
197pub struct CardinalConfig {
198 pub server: ServerConfig,
199 pub destinations: BTreeMap<String, Destination>,
200 #[serde(default)]
201 pub plugins: Vec<Plugin>,
202}
203
204impl Default for ServerConfig {
205 fn default() -> Self {
206 ServerConfig {
207 address: "0.0.0.0:1704".into(),
208 force_path_parameter: true,
209 log_upstream_response: true,
210 global_response_middleware: vec![],
211 global_request_middleware: vec![],
212 }
213 }
214}
215
216pub fn load_config(paths: &[String]) -> Result<CardinalConfig, ConfigError> {
217 let builder = get_config_builder(paths)?;
218 let config: CardinalConfig = builder.build()?.try_deserialize()?;
219 validate_config(&config)?;
220
221 Ok(config)
222}
223
224pub fn validate_config(config: &CardinalConfig) -> Result<(), ConfigError> {
225 if config
226 .server
227 .address
228 .parse::<std::net::SocketAddr>()
229 .is_err()
230 {
231 return Err(ConfigError::Message(format!(
232 "Invalid server address: {}",
233 config.server.address
234 )));
235 }
236
237 let all_plugin_names = config
238 .plugins
239 .iter()
240 .map(|p| p.name())
241 .collect::<Vec<&str>>();
242
243 for middleware in config.destinations.values().flat_map(|d| &d.middleware) {
244 if !all_plugin_names.contains(&middleware.name.as_str()) {
245 return Err(ConfigError::Message(format!(
246 "Middleware {} not found. {0} must be included in the list of plugins.",
247 middleware.name
248 )));
249 }
250 }
251
252 for destination in config.destinations.values() {
253 for route in &destination.routes {
254 if !route.path.starts_with('/') {
255 return Err(ConfigError::Message(format!(
256 "Route path {} must start with a '/'.",
257 route.path
258 )));
259 }
260 }
261 }
262
263 for destination in config.destinations.values() {
264 for route in &destination.routes {
265 if !route.method.eq("GET")
266 && !route.method.eq("POST")
267 && !route.method.eq("PUT")
268 && !route.method.eq("DELETE")
269 && !route.method.eq("PATCH")
270 && !route.method.eq("HEAD")
271 && !route.method.eq("OPTIONS")
272 {
273 return Err(ConfigError::Message(format!(
274 "Route method {} is not supported.",
275 route.method
276 )));
277 }
278 }
279 }
280
281 Ok(())
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use serde::{Deserialize, Serialize};
288 use serde_json::{json, to_value};
289
290 #[test]
291 fn serialize_builtin_plugin() {
292 let plugin = Plugin::Builtin(BuiltinPlugin {
293 name: "Logger".to_string(),
294 });
295
296 let val = to_value(&plugin).unwrap();
297
298 let expected = json!({
299 "builtin": {
300 "name": "Logger"
301 }
302 });
303
304 assert_eq!(val, expected);
305 }
306
307 #[test]
308 fn serialize_wasm_plugin() {
309 let wasm_cfg = WasmPluginConfig {
310 name: "RateLimit".to_string(),
311 path: "plugins/ratelimit.wasm".to_string(),
312 memory_name: None,
313 handle_name: None,
314 };
315 let plugin = Plugin::Wasm(wasm_cfg);
316
317 let val = to_value(&plugin).unwrap();
318
319 let expected = json!({
320 "wasm": {
321 "name": "RateLimit",
322 "path": "plugins/ratelimit.wasm",
323 "memory_name": null,
324 "handle_name": null
325 }
326 });
327
328 assert_eq!(val, expected);
329 }
330
331 #[test]
332 fn toml_builtin_plugin() {
333 let plugin = Plugin::Builtin(BuiltinPlugin {
334 name: "Logger".to_string(),
335 });
336
337 let toml_str = toml::to_string(&plugin).unwrap();
338
339 let expected = r#"[builtin]
340name = "Logger"
341"#;
342
343 assert_eq!(toml_str, expected);
344 }
345
346 #[test]
347 fn toml_wasm_plugin() {
348 let wasm_cfg = WasmPluginConfig {
349 name: "RateLimit".to_string(),
350 path: "plugins/ratelimit.wasm".to_string(),
351 memory_name: None,
352 handle_name: None,
353 };
354 let plugin = Plugin::Wasm(wasm_cfg);
355
356 let toml_str = toml::to_string(&plugin).unwrap();
357
358 let expected = r#"[wasm]
360name = "RateLimit"
361path = "plugins/ratelimit.wasm"
362"#;
363
364 assert_eq!(toml_str, expected);
365 }
366
367 #[test]
368 fn destination_match_value_string_roundtrip_json() {
369 let value = DestinationMatchValue::String("api.example.com".to_string());
370 let serialized = to_value(&value).unwrap();
371
372 assert_eq!(serialized, json!("api.example.com"));
373
374 let from_string: DestinationMatchValue =
375 serde_json::from_value(json!("api.example.com")).unwrap();
376 assert_eq!(from_string, value);
377 }
378
379 #[test]
380 fn destination_match_value_regex_roundtrip_json() {
381 let value = DestinationMatchValue::Regex {
382 regex: "^api\\.".to_string(),
383 };
384 let serialized = to_value(&value).unwrap();
385
386 assert_eq!(serialized, json!({"regex": "^api\\."}));
387
388 let decoded: DestinationMatchValue =
389 serde_json::from_value(json!({"regex": "^api\\."})).unwrap();
390 assert_eq!(decoded, value);
391 }
392
393 #[test]
394 fn destination_match_value_string_roundtrip_toml() {
395 let value = DestinationMatchValue::String("billing".to_string());
396 #[derive(Serialize, Deserialize, Debug, PartialEq)]
397 struct Wrapper {
398 value: DestinationMatchValue,
399 }
400
401 let toml_encoded = toml::to_string(&Wrapper {
402 value: value.clone(),
403 })
404 .unwrap();
405 assert_eq!(toml_encoded, "value = \"billing\"\n");
406
407 let decoded: Wrapper = toml::from_str(&toml_encoded).unwrap();
408 assert_eq!(decoded.value, value);
409 }
410
411 #[test]
412 fn destination_match_value_regex_roundtrip_toml() {
413 let value = DestinationMatchValue::Regex {
414 regex: "^/billing".to_string(),
415 };
416 #[derive(Serialize, Deserialize, Debug, PartialEq)]
417 struct Wrapper {
418 value: DestinationMatchValue,
419 }
420
421 let toml_encoded = toml::to_string(&Wrapper {
422 value: value.clone(),
423 })
424 .unwrap();
425 assert_eq!(toml_encoded, "[value]\nregex = \"^/billing\"\n");
426
427 let decoded: Wrapper = toml::from_str(&toml_encoded).unwrap();
428 assert_eq!(decoded.value, value);
429 }
430
431 #[test]
432 fn destination_struct_match_variants() {
433 let string_toml = r#"
434name = "customer_service"
435url = "https://svc.internal/api"
436
437[[match]]
438host = "support.example.com"
439path_prefix = "/helpdesk"
440"#;
441
442 let customer: Destination = toml::from_str(string_toml).unwrap();
443 let matcher = customer
444 .r#match
445 .as_ref()
446 .and_then(|entries| entries.first())
447 .expect("expected match section");
448 assert_eq!(
449 matcher.host,
450 Some(DestinationMatchValue::String("support.example.com".into()))
451 );
452 assert_eq!(
453 matcher.path_prefix,
454 Some(DestinationMatchValue::String("/helpdesk".into()))
455 );
456 assert_eq!(matcher.path_exact, None);
457
458 let regex_toml = r#"
459name = "billing"
460url = "https://billing.internal"
461
462[[match]]
463host = { regex = '^api\.(eu|us)\.example\.com$' }
464path_prefix = { regex = '^/billing/(v\d+)/' }
465"#;
466
467 let billing: Destination = toml::from_str(regex_toml).unwrap();
468 let matcher = billing
469 .r#match
470 .as_ref()
471 .and_then(|entries| entries.first())
472 .expect("expected match section");
473 assert_eq!(
474 matcher.host,
475 Some(DestinationMatchValue::Regex {
476 regex: r"^api\.(eu|us)\.example\.com$".into()
477 })
478 );
479 assert_eq!(
480 matcher.path_prefix,
481 Some(DestinationMatchValue::Regex {
482 regex: r"^/billing/(v\d+)/".into()
483 })
484 );
485 assert_eq!(matcher.path_exact, None);
486 }
487
488 #[test]
489 fn destination_match_toml_mixed_variants() {
490 #[derive(Serialize, Deserialize, Debug, PartialEq)]
491 struct ConfigHarness {
492 destinations: BTreeMap<String, DestinationHarness>,
493 }
494
495 #[derive(Serialize, Deserialize, Debug, PartialEq)]
496 struct DestinationHarness {
497 name: String,
498 url: String,
499 #[serde(rename = "match")]
500 matcher: Option<Vec<DestinationMatch>>,
501 }
502
503 impl DestinationHarness {
504 fn first_match(&self) -> &DestinationMatch {
505 self.matcher
506 .as_ref()
507 .and_then(|entries| entries.first())
508 .expect("matcher section present")
509 }
510
511 fn match_count(&self) -> usize {
512 self.matcher.as_ref().map(|m| m.len()).unwrap_or(0)
513 }
514 }
515
516 let toml_source = r#"
517[destinations.customer_service]
518name = "customer_service"
519url = "https://svc.internal/api"
520
521[[destinations.customer_service.match]]
522host = "support.example.com"
523path_prefix = "/helpdesk"
524
525[[destinations.customer_service.match]]
526host = "support.example.com"
527path_prefix = { regex = '^/support' }
528
529[destinations.billing]
530name = "billing"
531url = "https://billing.internal"
532
533[[destinations.billing.match]]
534host = { regex = '^api\.(eu|us)\.example\.com$' }
535path_prefix = { regex = '^/billing/(v\d+)/' }
536"#;
537
538 let parsed: ConfigHarness = toml::from_str(toml_source).unwrap();
539
540 let customer = parsed.destinations.get("customer_service").unwrap();
541 assert_eq!(customer.match_count(), 2);
542 let customer_match = customer.first_match();
543 assert_eq!(
544 customer_match.host,
545 Some(DestinationMatchValue::String("support.example.com".into()))
546 );
547 assert_eq!(
548 customer_match.path_prefix,
549 Some(DestinationMatchValue::String("/helpdesk".into()))
550 );
551
552 let customer_matches = customer.matcher.as_ref().unwrap();
553 let second = &customer_matches[1];
554 assert_eq!(
555 second.path_prefix,
556 Some(DestinationMatchValue::Regex {
557 regex: String::from("^/support"),
558 })
559 );
560
561 let billing = parsed.destinations.get("billing").unwrap();
562 assert_eq!(billing.match_count(), 1);
563 let billing_match = billing.first_match();
564 assert_eq!(
565 billing_match.host,
566 Some(DestinationMatchValue::Regex {
567 regex: r"^api\.(eu|us)\.example\.com$".into()
568 })
569 );
570 assert_eq!(
571 billing_match.path_prefix,
572 Some(DestinationMatchValue::Regex {
573 regex: r"^/billing/(v\d+)/".into()
574 })
575 );
576
577 let serialized = toml::to_string(&parsed).unwrap();
578 let reparsed: ConfigHarness = toml::from_str(&serialized).unwrap();
579 assert_eq!(reparsed, parsed);
580 }
581
582 #[test]
583 fn destination_match_allows_empty_array() {
584 #[derive(Serialize, Deserialize, Debug, PartialEq)]
585 struct ConfigHarness {
586 destinations: BTreeMap<String, DestinationHarness>,
587 }
588
589 #[derive(Serialize, Deserialize, Debug, PartialEq)]
590 struct DestinationHarness {
591 name: String,
592 url: String,
593 #[serde(rename = "match")]
594 matcher: Option<Vec<DestinationMatch>>,
595 }
596
597 let toml_source = r#"
598[destinations.empty]
599name = "empty"
600url = "https://empty.internal"
601match = []
602"#;
603
604 let parsed: ConfigHarness = toml::from_str(toml_source).unwrap();
605 let destination = parsed.destinations.get("empty").unwrap();
606 assert!(destination
607 .matcher
608 .as_ref()
609 .map(|entries| entries.is_empty())
610 .unwrap_or(false));
611
612 let serialized = toml::to_string(&parsed).unwrap();
613 let reparsed: ConfigHarness = toml::from_str(&serialized).unwrap();
614 assert_eq!(reparsed, parsed);
615 }
616}