1pub mod serializer;
2
3use crate::auth::BasicAuth;
4use anyhow::Result;
5use arc_swap::ArcSwap;
6use notify::{RecommendedWatcher, RecursiveMode, Watcher};
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use url::Url;
14
15#[async_trait::async_trait]
16pub trait ConfigManagerTrait: Send + Sync {
17 async fn reload(&self) -> Result<()>;
18 fn get_config(&self) -> Arc<Config>;
19 fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()>;
20 fn add_route(&self, rule: ProxyRule) -> Result<()>;
21 fn remove_route(&self, index: usize) -> Result<()>;
22}
23
24#[derive(Deserialize, Default, Clone, Debug)]
25pub struct TomlConfig {
26 #[serde(default)]
27 pub server: ServerConfig,
28 #[serde(default)]
29 pub tls: TlsConfig,
30 pub letsencrypt: Option<LetsEncryptConfig>,
31 pub scripting: Option<ScriptingTomlConfig>,
32 pub admin: Option<AdminConfig>,
33 pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
34}
35
36#[derive(Deserialize, Serialize, Clone, Debug)]
37pub struct CircuitBreakerTomlConfig {
38 pub failure_threshold: Option<u32>,
39 pub recovery_timeout_secs: Option<u64>,
40 pub success_threshold: Option<u32>,
41 pub failure_status_codes: Option<Vec<u16>>,
42}
43
44#[derive(Deserialize, Serialize, Clone, Debug)]
45pub struct AdminConfig {
46 pub enabled: bool,
47 pub bind: String,
48 pub api_key: Option<String>,
49}
50
51impl Default for AdminConfig {
52 fn default() -> Self {
53 Self {
54 enabled: false,
55 bind: "127.0.0.1:9090".to_string(),
56 api_key: None,
57 }
58 }
59}
60
61#[derive(Deserialize, Serialize, Clone, Debug, Default)]
62pub struct ScriptingTomlConfig {
63 pub enabled: bool,
64 pub scripts_dir: Option<String>,
65 pub hook_timeout_ms: Option<u64>,
66}
67
68#[derive(Deserialize, Serialize, Clone, Debug)]
69pub struct ServerConfig {
70 pub bind: String,
71 pub https_port: u16,
72}
73
74impl Default for ServerConfig {
75 fn default() -> Self {
76 Self {
77 bind: "0.0.0.0:8080".to_string(),
78 https_port: 443,
79 }
80 }
81}
82
83#[derive(Deserialize, Serialize, Default, Clone, Debug)]
84pub struct TlsConfig {
85 pub mode: String,
86 pub cache_dir: String,
87}
88
89#[derive(Deserialize, Serialize, Clone, Debug)]
90pub struct LetsEncryptConfig {
91 pub staging: bool,
92 pub email: String,
93 pub terms_agreed: bool,
94}
95
96#[derive(Clone, Debug, Serialize)]
97pub struct Config {
98 pub server: ServerConfig,
99 pub tls: TlsConfig,
100 pub letsencrypt: Option<LetsEncryptConfig>,
101 pub scripting: ScriptingTomlConfig,
102 pub admin: AdminConfig,
103 pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
104 pub rules: Vec<ProxyRule>,
105 pub global_scripts: Vec<String>,
106}
107
108#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
109pub enum LoadBalancingStrategy {
110 #[default]
111 #[serde(rename = "round-robin")]
112 RoundRobin,
113 #[serde(rename = "weighted")]
114 Weighted,
115 #[serde(rename = "failover")]
116 Failover,
117}
118
119#[derive(Clone, Debug, Serialize, Deserialize)]
120pub struct ProxyRule {
121 pub matcher: RuleMatcher,
122 pub targets: Vec<Target>,
123 pub headers: Vec<HeaderRule>,
124 pub scripts: Vec<String>,
125 #[serde(default)]
126 pub auth: Vec<BasicAuth>,
127 #[serde(default)]
128 pub load_balancing: LoadBalancingStrategy,
129}
130
131#[derive(Clone, Debug)]
132pub enum RuleMatcher {
133 Default,
134 Prefix(String),
135 Regex(RegexMatcher),
136 Exact(String),
137 Domain(String),
138 DomainPath(String, String),
139}
140
141#[derive(Clone, Debug)]
143pub struct RegexMatcher {
144 pub pattern: String,
145 pub regex: Regex,
146}
147
148impl RegexMatcher {
149 pub fn new(pattern: &str) -> Result<Self> {
150 Ok(Self {
151 pattern: pattern.to_string(),
152 regex: Regex::new(pattern)?,
153 })
154 }
155
156 pub fn is_match(&self, text: &str) -> bool {
157 self.regex.is_match(text)
158 }
159}
160
161impl Serialize for RuleMatcher {
162 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
163 where
164 S: serde::Serializer,
165 {
166 use serde::ser::SerializeMap;
167 let mut map = serializer.serialize_map(Some(2))?;
168 match self {
169 RuleMatcher::Default => {
170 map.serialize_entry("type", "default")?;
171 }
172 RuleMatcher::Prefix(v) => {
173 map.serialize_entry("type", "prefix")?;
174 map.serialize_entry("value", v)?;
175 }
176 RuleMatcher::Regex(rm) => {
177 map.serialize_entry("type", "regex")?;
178 map.serialize_entry("value", &rm.pattern)?;
179 }
180 RuleMatcher::Exact(v) => {
181 map.serialize_entry("type", "exact")?;
182 map.serialize_entry("value", v)?;
183 }
184 RuleMatcher::Domain(v) => {
185 map.serialize_entry("type", "domain")?;
186 map.serialize_entry("value", v)?;
187 }
188 RuleMatcher::DomainPath(d, p) => {
189 map.serialize_entry("type", "domain_path")?;
190 map.serialize_entry("domain", d)?;
191 map.serialize_entry("path", p)?;
192 }
193 }
194 map.end()
195 }
196}
197
198impl<'de> Deserialize<'de> for RuleMatcher {
199 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
200 where
201 D: serde::Deserializer<'de>,
202 {
203 use serde::de::Error;
204 let value: serde_json::Value = Deserialize::deserialize(deserializer)?;
205 let obj = value
206 .as_object()
207 .ok_or_else(|| D::Error::custom("expected object"))?;
208 let matcher_type = obj
209 .get("type")
210 .and_then(|v| v.as_str())
211 .ok_or_else(|| D::Error::custom("missing 'type' field"))?;
212
213 match matcher_type {
214 "default" => Ok(RuleMatcher::Default),
215 "exact" => {
216 let v = obj
217 .get("value")
218 .and_then(|v| v.as_str())
219 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
220 Ok(RuleMatcher::Exact(v.to_string()))
221 }
222 "prefix" => {
223 let v = obj
224 .get("value")
225 .and_then(|v| v.as_str())
226 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
227 Ok(RuleMatcher::Prefix(v.to_string()))
228 }
229 "regex" => {
230 let v = obj
231 .get("value")
232 .and_then(|v| v.as_str())
233 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
234 let rm = RegexMatcher::new(v)
235 .map_err(|e| D::Error::custom(format!("invalid regex: {}", e)))?;
236 Ok(RuleMatcher::Regex(rm))
237 }
238 "domain" => {
239 let v = obj
240 .get("value")
241 .and_then(|v| v.as_str())
242 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
243 Ok(RuleMatcher::Domain(v.to_string()))
244 }
245 "domain_path" => {
246 let d = obj
247 .get("domain")
248 .and_then(|v| v.as_str())
249 .ok_or_else(|| D::Error::custom("missing 'domain'"))?;
250 let p = obj
251 .get("path")
252 .and_then(|v| v.as_str())
253 .ok_or_else(|| D::Error::custom("missing 'path'"))?;
254 Ok(RuleMatcher::DomainPath(d.to_string(), p.to_string()))
255 }
256 other => Err(D::Error::custom(format!("unknown matcher type: {}", other))),
257 }
258 }
259}
260
261#[derive(Clone, Debug, Serialize, Deserialize)]
262pub struct Target {
263 pub url: Url,
264 pub weight: u8,
265}
266
267#[derive(Clone, Debug, Serialize, Deserialize)]
268pub struct HeaderRule {
269 pub name: String,
270 pub value: String,
271}
272
273impl Config {
274 pub fn acme_domains(&self) -> Vec<String> {
277 let mut domains = Vec::new();
278 let mut seen = std::collections::HashSet::new();
279
280 for rule in &self.rules {
281 let domain = match &rule.matcher {
282 RuleMatcher::Domain(d) => Some(d.as_str()),
283 RuleMatcher::DomainPath(d, _) => Some(d.as_str()),
284 _ => None,
285 };
286
287 if let Some(d) = domain {
288 if d == "localhost" || d.parse::<std::net::IpAddr>().is_ok() {
289 continue;
290 }
291 if seen.insert(d.to_string()) {
292 domains.push(d.to_string());
293 }
294 }
295 }
296
297 domains
298 }
299}
300
301pub struct ConfigManager {
302 config: ArcSwap<Config>,
303 config_path: PathBuf,
304 _watcher: Option<RecommendedWatcher>,
305 suppress_watch: Arc<AtomicBool>,
306}
307
308impl Clone for ConfigManager {
309 fn clone(&self) -> Self {
310 Self {
311 config: ArcSwap::new(self.config.load().clone()),
312 config_path: self.config_path.clone(),
313 _watcher: None,
314 suppress_watch: self.suppress_watch.clone(),
315 }
316 }
317}
318
319impl ConfigManager {
320 pub fn new(config_path: &str) -> Result<Self> {
321 let path = PathBuf::from(config_path);
322 let config = Self::load_config(&path, &path)?;
323 Ok(Self {
324 config: ArcSwap::new(Arc::new(config)),
325 config_path: path,
326 _watcher: None,
327 suppress_watch: Arc::new(AtomicBool::new(false)),
328 })
329 }
330
331 pub fn config_path(&self) -> &Path {
332 &self.config_path
333 }
334
335 pub fn suppress_watch(&self) -> &Arc<AtomicBool> {
336 &self.suppress_watch
337 }
338
339 fn load_config(proxy_conf_path: &Path, config_path: &Path) -> Result<Config> {
340 let content = std::fs::read_to_string(proxy_conf_path).unwrap_or_default();
341 let (rules, global_scripts) = parse_proxy_config(&content)?;
342 let toml_content = std::fs::read_to_string(
343 config_path
344 .parent()
345 .unwrap_or(Path::new("."))
346 .join("config.toml"),
347 )
348 .ok();
349 let toml_config: TomlConfig = toml_content
350 .as_ref()
351 .and_then(|c| toml::from_str(c).ok())
352 .unwrap_or_default();
353
354 Ok(Config {
355 server: toml_config.server,
356 tls: toml_config.tls,
357 letsencrypt: toml_config.letsencrypt,
358 scripting: toml_config.scripting.unwrap_or_default(),
359 admin: toml_config.admin.unwrap_or_default(),
360 circuit_breaker: toml_config.circuit_breaker,
361 rules,
362 global_scripts,
363 })
364 }
365
366 pub fn get_config(&self) -> Arc<Config> {
367 self.config.load().clone()
368 }
369
370 pub fn start_watcher(&self) -> Result<()> {
371 if !self.config_path.exists() {
373 if let Some(parent) = self.config_path.parent() {
374 std::fs::create_dir_all(parent).ok();
375 }
376 std::fs::write(&self.config_path, "")?;
377 }
378
379 let (tx, mut rx) = mpsc::channel(1);
380 let config_path = self.config_path.clone();
381 let suppress = self.suppress_watch.clone();
382
383 let mut watcher = RecommendedWatcher::new(
384 move |res| {
385 let _ = tx.blocking_send(res);
386 },
387 notify::Config::default(),
388 )?;
389
390 watcher.watch(&config_path, RecursiveMode::NonRecursive)?;
391
392 tracing::info!("Watching config file: {}", config_path.display());
393
394 std::thread::spawn(move || {
395 while let Some(res) = rx.blocking_recv() {
396 match res {
397 Ok(event) => {
398 if event.kind.is_modify() {
399 if suppress.swap(false, Ordering::SeqCst) {
400 tracing::debug!(
401 "Suppressing file watcher reload (admin API write)"
402 );
403 continue;
404 }
405 tracing::info!("Config file changed, reloading...");
406 }
407 }
408 Err(e) => tracing::error!("Watch error: {}", e),
409 }
410 }
411 });
412
413 Ok(())
414 }
415
416 pub async fn reload(&self) -> Result<()> {
417 let new_config = Self::load_config(&self.config_path, &self.config_path)?;
418 self.config.store(Arc::new(new_config));
419 tracing::info!("Configuration reloaded successfully");
420 Ok(())
421 }
422
423 fn persist_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
425 let content = serializer::serialize_proxy_conf(&rules, &global_scripts);
426 self.suppress_watch.store(true, Ordering::SeqCst);
427 std::fs::write(&self.config_path, &content)?;
428 let mut config = (*self.config.load().as_ref()).clone();
429 config.rules = rules;
430 config.global_scripts = global_scripts;
431 self.config.store(Arc::new(config));
432 tracing::info!("Configuration persisted to {}", self.config_path.display());
433 Ok(())
434 }
435
436 pub fn add_route(&self, rule: ProxyRule) -> Result<()> {
437 let cfg = self.get_config();
438 let mut rules = cfg.rules.clone();
439 rules.push(rule);
440 self.persist_rules(rules, cfg.global_scripts.clone())
441 }
442
443 pub fn update_route(&self, index: usize, rule: ProxyRule) -> Result<()> {
444 let cfg = self.get_config();
445 let mut rules = cfg.rules.clone();
446 if index >= rules.len() {
447 anyhow::bail!(
448 "Route index {} out of range (have {} routes)",
449 index,
450 rules.len()
451 );
452 }
453 rules[index] = rule;
454 self.persist_rules(rules, cfg.global_scripts.clone())
455 }
456
457 pub fn remove_route(&self, index: usize) -> Result<()> {
458 let cfg = self.get_config();
459 let mut rules = cfg.rules.clone();
460 if index >= rules.len() {
461 anyhow::bail!(
462 "Route index {} out of range (have {} routes)",
463 index,
464 rules.len()
465 );
466 }
467 rules.remove(index);
468 self.persist_rules(rules, cfg.global_scripts.clone())
469 }
470
471 pub fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
472 self.persist_rules(rules, global_scripts)
473 }
474}
475
476#[async_trait::async_trait]
477impl ConfigManagerTrait for ConfigManager {
478 async fn reload(&self) -> Result<()> {
479 self.reload().await
480 }
481
482 fn get_config(&self) -> Arc<Config> {
483 self.get_config()
484 }
485
486 fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
487 self.update_rules(rules, global_scripts)
488 }
489
490 fn add_route(&self, rule: ProxyRule) -> Result<()> {
491 self.add_route(rule)
492 }
493
494 fn remove_route(&self, index: usize) -> Result<()> {
495 self.remove_route(index)
496 }
497}
498
499fn extract_scripts(s: &str) -> (&str, Vec<String>) {
501 if let Some(idx) = s.find("@script:") {
502 let before = s[..idx].trim();
503 let after = &s[idx + "@script:".len()..];
504 let script_part = after.split_whitespace().next().unwrap_or(after);
506 let scripts: Vec<String> = script_part
507 .split(',')
508 .map(|s| s.trim().to_string())
509 .filter(|s| !s.is_empty())
510 .collect();
511 (before, scripts)
512 } else {
513 (s, Vec::new())
514 }
515}
516
517fn extract_auth(s: &str) -> (String, Vec<BasicAuth>) {
521 let mut auth_entries = Vec::new();
522 let mut remaining = s.to_string();
523
524 while let Some(idx) = remaining.find("@auth:") {
525 let before = &remaining[..idx];
527 let after = &remaining[idx + "@auth:".len()..];
528
529 let end_idx = after
531 .find(|c: char| c.is_whitespace())
532 .unwrap_or(after.len());
533
534 let auth_part = &after[..end_idx];
535
536 if let Some((username, hash)) = auth_part.split_once(':') {
538 if !username.is_empty() && !hash.is_empty() {
539 auth_entries.push(BasicAuth {
540 username: username.to_string(),
541 hash: hash.to_string(),
542 });
543 }
544 }
545
546 let rest = &after[end_idx..];
548 remaining = if rest.is_empty() {
549 before.to_string()
550 } else {
551 format!("{}{}", before, rest)
552 };
553 }
554
555 (remaining.trim().to_string(), auth_entries)
556}
557
558fn extract_load_balancing(s: &str) -> (String, LoadBalancingStrategy) {
561 if let Some(idx) = s.find("@lb:") {
562 let before = &s[..idx];
563 let after = &s[idx + "@lb:".len()..];
564
565 let end_idx = after
566 .find(|c: char| c.is_whitespace())
567 .unwrap_or(after.len());
568
569 let strategy_str = &after[..end_idx];
570 let strategy = match strategy_str {
571 "round-robin" => LoadBalancingStrategy::RoundRobin,
572 "weighted" => LoadBalancingStrategy::Weighted,
573 "failover" => LoadBalancingStrategy::Failover,
574 _ => LoadBalancingStrategy::default(),
575 };
576
577 let rest = &after[end_idx..];
578 let remaining = if rest.is_empty() {
579 before.to_string()
580 } else {
581 format!("{}{}", before, rest)
582 };
583
584 (remaining.trim().to_string(), strategy)
585 } else {
586 (s.to_string(), LoadBalancingStrategy::default())
587 }
588}
589
590fn parse_proxy_config(content: &str) -> Result<(Vec<ProxyRule>, Vec<String>)> {
591 let mut rules = Vec::new();
592 let mut global_scripts = Vec::new();
593
594 let mut joined_lines: Vec<String> = Vec::new();
596 for line in content.lines() {
597 if let Some(current) = joined_lines.last_mut() {
598 if current.ends_with('\\') {
599 current.pop(); current.push_str(line.trim());
601 continue;
602 }
603 }
604 joined_lines.push(line.to_string());
605 }
606
607 for line in &joined_lines {
608 let trimmed = line.trim();
609 if trimmed.is_empty() || trimmed.starts_with('#') {
610 continue;
611 }
612
613 if trimmed.starts_with("[global]") {
615 let rest = trimmed.strip_prefix("[global]").unwrap().trim();
616 let (_, scripts) = extract_scripts(rest);
617 global_scripts.extend(scripts);
618 continue;
619 }
620
621 if let Some((source, target_str)) = trimmed.split_once("->") {
622 let source = source.trim();
623 let (target_str, route_scripts) = extract_scripts(target_str.trim());
625 let (target_str, auth_entries) = extract_auth(target_str);
627 let (target_str, load_balancing) = extract_load_balancing(&target_str);
629
630 let matcher = if source == "default" || source == "*" {
631 RuleMatcher::Default
632 } else if let Some(pattern) = source.strip_prefix("~") {
633 RuleMatcher::Regex(RegexMatcher::new(pattern)?)
634 } else if !source.starts_with('/')
635 && (source.contains('.') || source.parse::<std::net::IpAddr>().is_ok())
636 {
637 if let Some((domain, path)) = source.split_once('/') {
638 if path.is_empty() || path == "*" {
639 RuleMatcher::Domain(domain.to_string())
640 } else if path.ends_with("/*") {
641 RuleMatcher::DomainPath(
642 domain.to_string(),
643 path.trim_end_matches('*').to_string(),
644 )
645 } else {
646 RuleMatcher::DomainPath(domain.to_string(), path.to_string())
647 }
648 } else {
649 RuleMatcher::Domain(source.to_string())
650 }
651 } else if source.ends_with("/*") {
652 RuleMatcher::Prefix(source.trim_end_matches('*').to_string())
653 } else {
654 RuleMatcher::Exact(source.to_string())
655 };
656
657 let targets: Vec<Target> = target_str
658 .split(',')
659 .map(|t| {
660 Ok(Target {
661 url: Url::parse(t.trim())?,
662 weight: 100,
663 })
664 })
665 .collect::<Result<Vec<_>>>()?;
666
667 rules.push(ProxyRule {
668 matcher,
669 targets,
670 headers: vec![],
671 scripts: route_scripts,
672 auth: auth_entries,
673 load_balancing,
674 });
675 }
676 }
677
678 Ok((rules, global_scripts))
679}
680
681#[cfg(test)]
682mod tests {
683 use super::*;
684
685 #[test]
686 fn test_backslash_continuation_joins_lines() {
687 let config = "/api/* -> http://backend1:8080, \\\n http://backend2:8080\n";
688 let (rules, _) = parse_proxy_config(config).unwrap();
689 assert_eq!(rules.len(), 1);
690 assert_eq!(rules[0].targets.len(), 2);
691 assert_eq!(rules[0].targets[0].url.as_str(), "http://backend1:8080/");
692 assert_eq!(rules[0].targets[1].url.as_str(), "http://backend2:8080/");
693 }
694
695 #[test]
696 fn test_multiple_continuation_lines() {
697 let config = "/api/* -> http://backend1:8080, \\\n\
698 http://backend2:8080, \\\n\
699 http://backend3:8080\n";
700 let (rules, _) = parse_proxy_config(config).unwrap();
701 assert_eq!(rules.len(), 1);
702 assert_eq!(rules[0].targets.len(), 3);
703 assert_eq!(rules[0].targets[2].url.as_str(), "http://backend3:8080/");
704 }
705
706 #[test]
707 fn test_backslash_mid_line_not_continuation() {
708 let config = "/path -> http://localhost:8080\n\
709 ~^/foo\\dbar$ -> http://localhost:9090\n";
710 let (rules, _) = parse_proxy_config(config).unwrap();
711 assert_eq!(rules.len(), 2);
712 }
713
714 #[test]
715 fn test_continuation_trims_whitespace() {
716 let config = "/api/* -> http://a:8080, \\\n http://b:8080, \\\n http://c:8080\n";
717 let (rules, _) = parse_proxy_config(config).unwrap();
718 assert_eq!(rules.len(), 1);
719 assert_eq!(rules[0].targets.len(), 3);
720 }
721
722 #[test]
723 fn test_continuation_with_scripts() {
724 let config = "/api/* -> http://a:8080, \\\n\
725 http://b:8080 @script:auth.lua\n";
726 let (rules, _) = parse_proxy_config(config).unwrap();
727 assert_eq!(rules.len(), 1);
728 assert_eq!(rules[0].targets.len(), 2);
729 assert_eq!(rules[0].scripts, vec!["auth.lua"]);
730 }
731
732 #[test]
733 fn test_no_continuation_normal_config() {
734 let config = "/api/* -> http://backend:8080\ndefault -> http://localhost:3000\n";
735 let (rules, _) = parse_proxy_config(config).unwrap();
736 assert_eq!(rules.len(), 2);
737 }
738
739 #[test]
740 fn test_auth_parsing() {
741 let config = r#"
742/db/* -> http://localhost:8080/ @auth:demo:$2b$12$YFlnIiACnSaAcxDWQlYjeedxq/3GvhvoGhRTYHMqLifJrETSqOZQa
743"#;
744 let (rules, _) = parse_proxy_config(config).unwrap();
745 assert_eq!(rules.len(), 1);
746 assert_eq!(rules[0].auth.len(), 1);
747 assert_eq!(rules[0].auth[0].username, "demo");
748 assert!(rules[0].auth[0].hash.starts_with("$2b$"));
749 assert_eq!(rules[0].targets.len(), 1);
750 assert_eq!(rules[0].targets[0].url.as_str(), "http://localhost:8080/");
751 }
752
753 #[test]
754 fn test_multiple_auth_users() {
755 let config = r#"
756secure.example.com -> http://localhost:9000/ @auth:admin:$2b$12$hash1 @auth:user:$2b$12$hash2
757"#;
758 let (rules, _) = parse_proxy_config(config).unwrap();
759 assert_eq!(rules.len(), 1);
760 assert_eq!(rules[0].auth.len(), 2);
761 assert_eq!(rules[0].auth[0].username, "admin");
762 assert_eq!(rules[0].auth[1].username, "user");
763 }
764
765 #[test]
766 fn test_load_balancing_round_robin() {
767 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:round-robin";
768 let (rules, _) = parse_proxy_config(config).unwrap();
769 assert_eq!(rules.len(), 1);
770 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
771 assert_eq!(rules[0].targets.len(), 2);
772 }
773
774 #[test]
775 fn test_load_balancing_weighted() {
776 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:weighted";
777 let (rules, _) = parse_proxy_config(config).unwrap();
778 assert_eq!(rules.len(), 1);
779 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Weighted);
780 }
781
782 #[test]
783 fn test_load_balancing_failover() {
784 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover";
785 let (rules, _) = parse_proxy_config(config).unwrap();
786 assert_eq!(rules.len(), 1);
787 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
788 }
789
790 #[test]
791 fn test_load_balancing_default_is_round_robin() {
792 let config = "/api/* -> http://b1:8080, http://b2:8080";
793 let (rules, _) = parse_proxy_config(config).unwrap();
794 assert_eq!(rules.len(), 1);
795 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
796 }
797
798 #[test]
799 fn test_load_balancing_with_scripts() {
800 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover @script:auth.lua";
801 let (rules, _) = parse_proxy_config(config).unwrap();
802 assert_eq!(rules.len(), 1);
803 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
804 assert_eq!(rules[0].scripts, vec!["auth.lua"]);
805 }
806
807 #[test]
808 fn test_load_balancing_unknown_strategy_defaults_to_round_robin() {
809 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:unknown";
810 let (rules, _) = parse_proxy_config(config).unwrap();
811 assert_eq!(rules.len(), 1);
812 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
813 }
814}