1use crate::config::get_config_builder;
2use ::config::ConfigError;
3use derive_builder::Builder;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::collections::BTreeMap;
6
7pub mod config;
8
9#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
10pub struct HealthCheck {
11 pub path: String,
12 pub interval_ms: u64,
13 pub timeout_ms: u64,
14 pub expect_status: u16,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub enum MiddlewareType {
19 Inbound,
20 Outbound,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
24pub struct Middleware {
25 pub r#type: MiddlewareType,
26 pub name: String,
27}
28
29#[derive(Debug, Clone)]
30pub enum Plugin {
31 Builtin(BuiltinPlugin),
32 Wasm(WasmPluginConfig),
33}
34
35impl Plugin {
36 pub fn name(&self) -> &str {
37 match self {
38 Plugin::Builtin(builtin) => &builtin.name,
39 Plugin::Wasm(wasm) => &wasm.name,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
45pub struct BuiltinPlugin {
46 pub name: String,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
50pub struct WasmPluginConfig {
51 pub name: String,
52 pub path: String,
53 pub memory_name: Option<String>,
54 pub handle_name: Option<String>,
55}
56
57#[derive(Deserialize)]
58#[serde(untagged)]
59enum PluginSerde {
60 Name(String),
61 Builtin { builtin: BuiltinPlugin },
62 Wasm { wasm: WasmPluginConfig },
63}
64
65impl<'de> Deserialize<'de> for Plugin {
66 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
67 where
68 D: Deserializer<'de>,
69 {
70 match PluginSerde::deserialize(deserializer)? {
71 PluginSerde::Name(name) => Ok(Plugin::Builtin(BuiltinPlugin { name })),
72 PluginSerde::Builtin { builtin } => Ok(Plugin::Builtin(builtin)),
73 PluginSerde::Wasm { wasm } => Ok(Plugin::Wasm(wasm)),
74 }
75 }
76}
77
78impl Serialize for Plugin {
79 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
80 where
81 S: Serializer,
82 {
83 match self {
84 Plugin::Builtin(builtin) => {
85 #[derive(Serialize)]
86 struct Wrapper<'a> {
87 builtin: &'a BuiltinPlugin,
88 }
89 Wrapper { builtin }.serialize(serializer)
90 }
91 Plugin::Wasm(wasm) => {
92 #[derive(Serialize)]
93 struct Wrapper<'a> {
94 wasm: &'a WasmPluginConfig,
95 }
96 Wrapper { wasm }.serialize(serializer)
97 }
98 }
99 }
100}
101
102#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
103#[serde(untagged)]
104pub enum DestinationMatchValue {
105 String(String),
106 Regex { regex: String },
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Builder)]
110pub struct DestinationMatch {
111 pub host: Option<DestinationMatchValue>, pub path_prefix: Option<DestinationMatchValue>, pub path_exact: Option<String>,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
117pub struct Destination {
118 pub name: String,
119 pub url: String,
120 pub health_check: Option<HealthCheck>,
121 #[serde(default)]
122 pub default: bool,
123 #[serde(default)]
124 pub r#match: Option<DestinationMatch>,
125 #[serde(default)]
126 pub routes: Vec<Route>,
127 #[serde(default)]
128 pub middleware: Vec<Middleware>,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
132pub struct ServerConfig {
133 pub address: String,
134 pub force_path_parameter: bool,
135 pub log_upstream_response: bool,
136 pub global_request_middleware: Vec<String>,
137 pub global_response_middleware: Vec<String>,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
141pub struct Route {
142 pub path: String,
143 pub method: String,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize, Default, Builder)]
147pub struct CardinalConfig {
148 pub server: ServerConfig,
149 pub destinations: BTreeMap<String, Destination>,
150 #[serde(default)]
151 pub plugins: Vec<Plugin>,
152}
153
154impl Default for ServerConfig {
155 fn default() -> Self {
156 ServerConfig {
157 address: "0.0.0.0:1704".into(),
158 force_path_parameter: true,
159 log_upstream_response: true,
160 global_response_middleware: vec![],
161 global_request_middleware: vec![],
162 }
163 }
164}
165
166pub fn load_config(paths: &[String]) -> Result<CardinalConfig, ConfigError> {
167 let builder = get_config_builder(paths)?;
168 let config: CardinalConfig = builder.build()?.try_deserialize()?;
169 validate_config(&config)?;
170
171 Ok(config)
172}
173
174pub fn validate_config(config: &CardinalConfig) -> Result<(), ConfigError> {
175 if config
176 .server
177 .address
178 .parse::<std::net::SocketAddr>()
179 .is_err()
180 {
181 return Err(ConfigError::Message(format!(
182 "Invalid server address: {}",
183 config.server.address
184 )));
185 }
186
187 let all_plugin_names = config
188 .plugins
189 .iter()
190 .map(|p| p.name())
191 .collect::<Vec<&str>>();
192
193 for middleware in config.destinations.values().flat_map(|d| &d.middleware) {
194 if !all_plugin_names.contains(&middleware.name.as_str()) {
195 return Err(ConfigError::Message(format!(
196 "Middleware {} not found. {0} must be included in the list of plugins.",
197 middleware.name
198 )));
199 }
200 }
201
202 for destination in config.destinations.values() {
203 for route in &destination.routes {
204 if !route.path.starts_with('/') {
205 return Err(ConfigError::Message(format!(
206 "Route path {} must start with a '/'.",
207 route.path
208 )));
209 }
210 }
211 }
212
213 for destination in config.destinations.values() {
214 for route in &destination.routes {
215 if !route.method.eq("GET")
216 && !route.method.eq("POST")
217 && !route.method.eq("PUT")
218 && !route.method.eq("DELETE")
219 && !route.method.eq("PATCH")
220 && !route.method.eq("HEAD")
221 && !route.method.eq("OPTIONS")
222 {
223 return Err(ConfigError::Message(format!(
224 "Route method {} is not supported.",
225 route.method
226 )));
227 }
228 }
229 }
230
231 Ok(())
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use serde::{Deserialize, Serialize};
238 use serde_json::{json, to_value};
239
240 #[test]
241 fn serialize_builtin_plugin() {
242 let plugin = Plugin::Builtin(BuiltinPlugin {
243 name: "Logger".to_string(),
244 });
245
246 let val = to_value(&plugin).unwrap();
247
248 let expected = json!({
249 "builtin": {
250 "name": "Logger"
251 }
252 });
253
254 assert_eq!(val, expected);
255 }
256
257 #[test]
258 fn serialize_wasm_plugin() {
259 let wasm_cfg = WasmPluginConfig {
260 name: "RateLimit".to_string(),
261 path: "plugins/ratelimit.wasm".to_string(),
262 memory_name: None,
263 handle_name: None,
264 };
265 let plugin = Plugin::Wasm(wasm_cfg);
266
267 let val = to_value(&plugin).unwrap();
268
269 let expected = json!({
270 "wasm": {
271 "name": "RateLimit",
272 "path": "plugins/ratelimit.wasm",
273 "memory_name": null,
274 "handle_name": null
275 }
276 });
277
278 assert_eq!(val, expected);
279 }
280
281 #[test]
282 fn toml_builtin_plugin() {
283 let plugin = Plugin::Builtin(BuiltinPlugin {
284 name: "Logger".to_string(),
285 });
286
287 let toml_str = toml::to_string(&plugin).unwrap();
288
289 let expected = r#"[builtin]
290name = "Logger"
291"#;
292
293 assert_eq!(toml_str, expected);
294 }
295
296 #[test]
297 fn toml_wasm_plugin() {
298 let wasm_cfg = WasmPluginConfig {
299 name: "RateLimit".to_string(),
300 path: "plugins/ratelimit.wasm".to_string(),
301 memory_name: None,
302 handle_name: None,
303 };
304 let plugin = Plugin::Wasm(wasm_cfg);
305
306 let toml_str = toml::to_string(&plugin).unwrap();
307
308 let expected = r#"[wasm]
310name = "RateLimit"
311path = "plugins/ratelimit.wasm"
312"#;
313
314 assert_eq!(toml_str, expected);
315 }
316
317 #[test]
318 fn destination_match_value_string_roundtrip_json() {
319 let value = DestinationMatchValue::String("api.example.com".to_string());
320 let serialized = to_value(&value).unwrap();
321
322 assert_eq!(serialized, json!("api.example.com"));
323
324 let from_string: DestinationMatchValue =
325 serde_json::from_value(json!("api.example.com")).unwrap();
326 assert_eq!(from_string, value);
327 }
328
329 #[test]
330 fn destination_match_value_regex_roundtrip_json() {
331 let value = DestinationMatchValue::Regex {
332 regex: "^api\\.".to_string(),
333 };
334 let serialized = to_value(&value).unwrap();
335
336 assert_eq!(serialized, json!({"regex": "^api\\."}));
337
338 let decoded: DestinationMatchValue =
339 serde_json::from_value(json!({"regex": "^api\\."})).unwrap();
340 assert_eq!(decoded, value);
341 }
342
343 #[test]
344 fn destination_match_value_string_roundtrip_toml() {
345 let value = DestinationMatchValue::String("billing".to_string());
346 #[derive(Serialize, Deserialize, Debug, PartialEq)]
347 struct Wrapper {
348 value: DestinationMatchValue,
349 }
350
351 let toml_encoded = toml::to_string(&Wrapper {
352 value: value.clone(),
353 })
354 .unwrap();
355 assert_eq!(toml_encoded, "value = \"billing\"\n");
356
357 let decoded: Wrapper = toml::from_str(&toml_encoded).unwrap();
358 assert_eq!(decoded.value, value);
359 }
360
361 #[test]
362 fn destination_match_value_regex_roundtrip_toml() {
363 let value = DestinationMatchValue::Regex {
364 regex: "^/billing".to_string(),
365 };
366 #[derive(Serialize, Deserialize, Debug, PartialEq)]
367 struct Wrapper {
368 value: DestinationMatchValue,
369 }
370
371 let toml_encoded = toml::to_string(&Wrapper {
372 value: value.clone(),
373 })
374 .unwrap();
375 assert_eq!(toml_encoded, "[value]\nregex = \"^/billing\"\n");
376
377 let decoded: Wrapper = toml::from_str(&toml_encoded).unwrap();
378 assert_eq!(decoded.value, value);
379 }
380
381 #[test]
382 fn destination_struct_match_variants() {
383 let string_toml = r#"
384name = "customer_service"
385url = "https://svc.internal/api"
386
387[match]
388host = "support.example.com"
389path_prefix = "/helpdesk"
390"#;
391
392 let customer: Destination = toml::from_str(string_toml).unwrap();
393 let matcher = customer.r#match.as_ref().expect("expected match section");
394 assert_eq!(
395 matcher.host,
396 Some(DestinationMatchValue::String("support.example.com".into()))
397 );
398 assert_eq!(
399 matcher.path_prefix,
400 Some(DestinationMatchValue::String("/helpdesk".into()))
401 );
402 assert_eq!(matcher.path_exact, None);
403
404 let regex_toml = r#"
405name = "billing"
406url = "https://billing.internal"
407
408[match]
409host = { regex = '^api\.(eu|us)\.example\.com$' }
410path_prefix = { regex = '^/billing/(v\d+)/' }
411"#;
412
413 let billing: Destination = toml::from_str(regex_toml).unwrap();
414 let matcher = billing.r#match.as_ref().expect("expected match section");
415 assert_eq!(
416 matcher.host,
417 Some(DestinationMatchValue::Regex {
418 regex: r"^api\.(eu|us)\.example\.com$".into()
419 })
420 );
421 assert_eq!(
422 matcher.path_prefix,
423 Some(DestinationMatchValue::Regex {
424 regex: r"^/billing/(v\d+)/".into()
425 })
426 );
427 assert_eq!(matcher.path_exact, None);
428 }
429
430 #[test]
431 fn destination_match_toml_mixed_variants() {
432 #[derive(Serialize, Deserialize, Debug, PartialEq)]
433 struct ConfigHarness {
434 destinations: BTreeMap<String, DestinationHarness>,
435 }
436
437 #[derive(Serialize, Deserialize, Debug, PartialEq)]
438 struct DestinationHarness {
439 name: String,
440 url: String,
441 #[serde(rename = "match")]
442 matcher: Option<DestinationMatch>,
443 }
444
445 impl DestinationHarness {
446 fn matcher(&self) -> &DestinationMatch {
447 self.matcher.as_ref().expect("matcher section present")
448 }
449 }
450
451 let toml_source = r#"
452[destinations.customer_service]
453name = "customer_service"
454url = "https://svc.internal/api"
455
456[destinations.customer_service.match]
457host = "support.example.com"
458path_prefix = "/helpdesk"
459
460[destinations.billing]
461name = "billing"
462url = "https://billing.internal"
463
464[destinations.billing.match]
465host = { regex = '^api\.(eu|us)\.example\.com$' }
466path_prefix = { regex = '^/billing/(v\d+)/' }
467"#;
468
469 let parsed: ConfigHarness = toml::from_str(toml_source).unwrap();
470
471 let customer = parsed
472 .destinations
473 .get("customer_service")
474 .unwrap()
475 .matcher();
476 assert_eq!(
477 customer.host,
478 Some(DestinationMatchValue::String("support.example.com".into()))
479 );
480 assert_eq!(
481 customer.path_prefix,
482 Some(DestinationMatchValue::String("/helpdesk".into()))
483 );
484
485 let billing = parsed.destinations.get("billing").unwrap().matcher();
486 assert_eq!(
487 billing.host,
488 Some(DestinationMatchValue::Regex {
489 regex: r"^api\.(eu|us)\.example\.com$".into()
490 })
491 );
492 assert_eq!(
493 billing.path_prefix,
494 Some(DestinationMatchValue::Regex {
495 regex: r"^/billing/(v\d+)/".into()
496 })
497 );
498
499 let serialized = toml::to_string(&parsed).unwrap();
500 let reparsed: ConfigHarness = toml::from_str(&serialized).unwrap();
501 assert_eq!(reparsed, parsed);
502 }
503}