1pub mod serializer;
2
3use crate::auth::BasicAuth;
4use anyhow::Result;
5use arc_swap::ArcSwap;
6use notify::{RecommendedWatcher, RecursiveMode, Watcher};
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use url::Url;
14
15#[async_trait::async_trait]
16pub trait ConfigManagerTrait: Send + Sync {
17 async fn reload(&self) -> Result<()>;
18 fn get_config(&self) -> Arc<Config>;
19 fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()>;
20 fn add_route(&self, rule: ProxyRule) -> Result<()>;
21 fn remove_route(&self, index: usize) -> Result<()>;
22}
23
24#[derive(Deserialize, Default, Clone, Debug)]
25pub struct TomlConfig {
26 #[serde(default)]
27 pub server: ServerConfig,
28 #[serde(default)]
29 pub tls: TlsConfig,
30 pub letsencrypt: Option<LetsEncryptConfig>,
31 pub scripting: Option<ScriptingTomlConfig>,
32 pub admin: Option<AdminConfig>,
33 pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
34}
35
36#[derive(Deserialize, Serialize, Clone, Debug)]
37pub struct CircuitBreakerTomlConfig {
38 pub failure_threshold: Option<u32>,
39 pub recovery_timeout_secs: Option<u64>,
40 pub success_threshold: Option<u32>,
41 pub failure_status_codes: Option<Vec<u16>>,
42}
43
44#[derive(Deserialize, Serialize, Clone, Debug)]
45pub struct AdminConfig {
46 pub enabled: bool,
47 pub bind: String,
48 pub api_key: Option<String>,
49 #[serde(default)]
50 pub username: Option<String>,
51 #[serde(default)]
52 pub password_hash: Option<String>,
53}
54
55impl Default for AdminConfig {
56 fn default() -> Self {
57 let _ = dotenv::dotenv();
58 let username = std::env::var("ADMIN_USER").ok();
59 let password_hash = std::env::var("ADMIN_PASSWORD").ok();
60 Self {
61 enabled: false,
62 bind: "127.0.0.1:9090".to_string(),
63 api_key: None,
64 username,
65 password_hash,
66 }
67 }
68}
69
70#[derive(Deserialize, Serialize, Clone, Debug, Default)]
71pub struct ScriptingTomlConfig {
72 pub enabled: bool,
73 pub scripts_dir: Option<String>,
74 pub hook_timeout_ms: Option<u64>,
75}
76
77#[derive(Deserialize, Serialize, Clone, Debug)]
78pub struct ServerConfig {
79 pub bind: String,
80 pub https_port: u16,
81}
82
83impl Default for ServerConfig {
84 fn default() -> Self {
85 Self {
86 bind: "0.0.0.0:8080".to_string(),
87 https_port: 443,
88 }
89 }
90}
91
92#[derive(Deserialize, Serialize, Default, Clone, Debug)]
93pub struct TlsConfig {
94 pub mode: String,
95 pub cache_dir: String,
96}
97
98#[derive(Deserialize, Serialize, Clone, Debug)]
99pub struct LetsEncryptConfig {
100 pub staging: bool,
101 pub email: String,
102 pub terms_agreed: bool,
103}
104
105#[derive(Clone, Debug, Serialize)]
106pub struct Config {
107 pub server: ServerConfig,
108 pub tls: TlsConfig,
109 pub letsencrypt: Option<LetsEncryptConfig>,
110 pub scripting: ScriptingTomlConfig,
111 pub admin: AdminConfig,
112 pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
113 pub rules: Vec<ProxyRule>,
114 pub global_scripts: Vec<String>,
115}
116
117#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
118pub enum LoadBalancingStrategy {
119 #[default]
120 #[serde(rename = "round-robin")]
121 RoundRobin,
122 #[serde(rename = "weighted")]
123 Weighted,
124 #[serde(rename = "failover")]
125 Failover,
126}
127
128#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct ProxyRule {
130 pub matcher: RuleMatcher,
131 pub targets: Vec<Target>,
132 pub headers: Vec<HeaderRule>,
133 pub scripts: Vec<String>,
134 #[serde(default)]
135 pub auth: Vec<BasicAuth>,
136 #[serde(default)]
137 pub load_balancing: LoadBalancingStrategy,
138}
139
140#[derive(Clone, Debug)]
141pub enum RuleMatcher {
142 Default,
143 Prefix(String),
144 Regex(RegexMatcher),
145 Exact(String),
146 Domain(String),
147 DomainPath(String, String),
148}
149
150#[derive(Clone, Debug)]
152pub struct RegexMatcher {
153 pub pattern: String,
154 pub regex: Regex,
155}
156
157impl RegexMatcher {
158 pub fn new(pattern: &str) -> Result<Self> {
159 Ok(Self {
160 pattern: pattern.to_string(),
161 regex: Regex::new(pattern)?,
162 })
163 }
164
165 pub fn is_match(&self, text: &str) -> bool {
166 self.regex.is_match(text)
167 }
168}
169
170impl Serialize for RuleMatcher {
171 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
172 where
173 S: serde::Serializer,
174 {
175 use serde::ser::SerializeMap;
176 let mut map = serializer.serialize_map(Some(2))?;
177 match self {
178 RuleMatcher::Default => {
179 map.serialize_entry("type", "default")?;
180 }
181 RuleMatcher::Prefix(v) => {
182 map.serialize_entry("type", "prefix")?;
183 map.serialize_entry("value", v)?;
184 }
185 RuleMatcher::Regex(rm) => {
186 map.serialize_entry("type", "regex")?;
187 map.serialize_entry("value", &rm.pattern)?;
188 }
189 RuleMatcher::Exact(v) => {
190 map.serialize_entry("type", "exact")?;
191 map.serialize_entry("value", v)?;
192 }
193 RuleMatcher::Domain(v) => {
194 map.serialize_entry("type", "domain")?;
195 map.serialize_entry("value", v)?;
196 }
197 RuleMatcher::DomainPath(d, p) => {
198 map.serialize_entry("type", "domain_path")?;
199 map.serialize_entry("domain", d)?;
200 map.serialize_entry("path", p)?;
201 }
202 }
203 map.end()
204 }
205}
206
207impl<'de> Deserialize<'de> for RuleMatcher {
208 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
209 where
210 D: serde::Deserializer<'de>,
211 {
212 use serde::de::Error;
213 let value: serde_json::Value = Deserialize::deserialize(deserializer)?;
214 let obj = value
215 .as_object()
216 .ok_or_else(|| D::Error::custom("expected object"))?;
217 let matcher_type = obj
218 .get("type")
219 .and_then(|v| v.as_str())
220 .ok_or_else(|| D::Error::custom("missing 'type' field"))?;
221
222 match matcher_type {
223 "default" => Ok(RuleMatcher::Default),
224 "exact" => {
225 let v = obj
226 .get("value")
227 .and_then(|v| v.as_str())
228 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
229 Ok(RuleMatcher::Exact(v.to_string()))
230 }
231 "prefix" => {
232 let v = obj
233 .get("value")
234 .and_then(|v| v.as_str())
235 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
236 Ok(RuleMatcher::Prefix(v.to_string()))
237 }
238 "regex" => {
239 let v = obj
240 .get("value")
241 .and_then(|v| v.as_str())
242 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
243 let rm = RegexMatcher::new(v)
244 .map_err(|e| D::Error::custom(format!("invalid regex: {}", e)))?;
245 Ok(RuleMatcher::Regex(rm))
246 }
247 "domain" => {
248 let v = obj
249 .get("value")
250 .and_then(|v| v.as_str())
251 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
252 Ok(RuleMatcher::Domain(v.to_string()))
253 }
254 "domain_path" => {
255 let d = obj
256 .get("domain")
257 .and_then(|v| v.as_str())
258 .ok_or_else(|| D::Error::custom("missing 'domain'"))?;
259 let p = obj
260 .get("path")
261 .and_then(|v| v.as_str())
262 .ok_or_else(|| D::Error::custom("missing 'path'"))?;
263 Ok(RuleMatcher::DomainPath(d.to_string(), p.to_string()))
264 }
265 other => Err(D::Error::custom(format!("unknown matcher type: {}", other))),
266 }
267 }
268}
269
270#[derive(Clone, Debug, Serialize, Deserialize)]
271pub struct Target {
272 pub url: Url,
273 pub weight: u8,
274}
275
276#[derive(Clone, Debug, Serialize, Deserialize)]
277pub struct HeaderRule {
278 pub name: String,
279 pub value: String,
280}
281
282impl Config {
283 pub fn acme_domains(&self) -> Vec<String> {
286 let mut domains = Vec::new();
287 let mut seen = std::collections::HashSet::new();
288
289 for rule in &self.rules {
290 let domain = match &rule.matcher {
291 RuleMatcher::Domain(d) => Some(d.as_str()),
292 RuleMatcher::DomainPath(d, _) => Some(d.as_str()),
293 _ => None,
294 };
295
296 if let Some(d) = domain {
297 if d == "localhost" || d.parse::<std::net::IpAddr>().is_ok() {
298 continue;
299 }
300 if seen.insert(d.to_string()) {
301 domains.push(d.to_string());
302 }
303 }
304 }
305
306 domains
307 }
308}
309
310pub struct ConfigManager {
311 config: ArcSwap<Config>,
312 config_path: PathBuf,
313 _watcher: Option<RecommendedWatcher>,
314 suppress_watch: Arc<AtomicBool>,
315}
316
317impl Clone for ConfigManager {
318 fn clone(&self) -> Self {
319 Self {
320 config: ArcSwap::new(self.config.load().clone()),
321 config_path: self.config_path.clone(),
322 _watcher: None,
323 suppress_watch: self.suppress_watch.clone(),
324 }
325 }
326}
327
328impl ConfigManager {
329 pub fn new(config_path: &str) -> Result<Self> {
330 let path = PathBuf::from(config_path);
331 let config = Self::load_config(&path, &path)?;
332 Ok(Self {
333 config: ArcSwap::new(Arc::new(config)),
334 config_path: path,
335 _watcher: None,
336 suppress_watch: Arc::new(AtomicBool::new(false)),
337 })
338 }
339
340 pub fn config_path(&self) -> &Path {
341 &self.config_path
342 }
343
344 pub fn suppress_watch(&self) -> &Arc<AtomicBool> {
345 &self.suppress_watch
346 }
347
348 fn load_config(proxy_conf_path: &Path, config_path: &Path) -> Result<Config> {
349 let content = std::fs::read_to_string(proxy_conf_path).unwrap_or_default();
350 let (rules, global_scripts) = parse_proxy_config(&content)?;
351 let config_dir = config_path.parent().unwrap_or(Path::new("."));
352 let config_toml_path = config_dir.join("config.toml");
353 let toml_content = if config_toml_path.exists() {
354 std::fs::read_to_string(&config_toml_path).ok()
355 } else {
356 let default_config = r#"# Soli Proxy Configuration
357# Server settings
358[server]
359bind = "0.0.0.0:80"
360https_port = 443
361worker_threads = "auto"
362
363# TLS Configuration
364[tls]
365mode = "auto"
366cache_dir = "./certs"
367
368# Logging Configuration
369[logging]
370level = "info"
371format = "json"
372output = "stdout"
373include_request_body = true
374include_response_body = true
375
376# Metrics Configuration
377[metrics]
378enabled = true
379endpoint = "/metrics"
380
381# Health Check Configuration
382[health]
383enabled = true
384liveness_path = "/health/live"
385readiness_path = "/health/ready"
386
387# Limits Configuration
388[limits]
389max_connections = 10000
390max_request_size = "10MB"
391keep_alive_timeout = 30
392request_timeout = 60
393
394# Rate Limiting Configuration
395[rate_limiting]
396enabled = true
397strategy = "token_bucket"
398requests_per_second = 1000
399burst_size = 2000
400redis_url = "redis://localhost:6379"
401
402# Circuit Breaker Configuration
403[circuit_breaker]
404failure_threshold = 5
405recovery_timeout_secs = 30
406success_threshold = 2
407failure_status_codes = [502, 503, 504]
408
409# Admin REST API Configuration
410[admin]
411enabled = true
412bind = "0.0.0.0:9090"
413
414# Lua Scripting Configuration
415[scripting]
416enabled = false
417scripts_dir = "./scripts/lua"
418hook_timeout_ms = 10
419
420# Authentication Configuration
421[auth]
422enabled = false
423auth_type = "basic"
424realm = "Restricted"
425"#;
426 std::fs::write(&config_toml_path, default_config).ok();
427 Some(default_config.to_string())
428 };
429 let toml_config: TomlConfig = toml_content
430 .as_ref()
431 .and_then(|c| toml::from_str(c).ok())
432 .unwrap_or_default();
433
434 Ok(Config {
435 server: toml_config.server,
436 tls: toml_config.tls,
437 letsencrypt: toml_config.letsencrypt,
438 scripting: toml_config.scripting.unwrap_or_default(),
439 admin: toml_config.admin.unwrap_or_default(),
440 circuit_breaker: toml_config.circuit_breaker,
441 rules,
442 global_scripts,
443 })
444 }
445
446 pub fn get_config(&self) -> Arc<Config> {
447 self.config.load().clone()
448 }
449
450 pub fn start_watcher(&self) -> Result<()> {
451 if !self.config_path.exists() {
453 if let Some(parent) = self.config_path.parent() {
454 std::fs::create_dir_all(parent).ok();
455 }
456 std::fs::write(&self.config_path, "")?;
457 }
458
459 let (tx, mut rx) = mpsc::channel(1);
460 let config_path = self.config_path.clone();
461 let suppress = self.suppress_watch.clone();
462
463 let mut watcher = RecommendedWatcher::new(
464 move |res| {
465 let _ = tx.blocking_send(res);
466 },
467 notify::Config::default(),
468 )?;
469
470 watcher.watch(&config_path, RecursiveMode::NonRecursive)?;
471
472 tracing::info!("Watching config file: {}", config_path.display());
473
474 std::thread::spawn(move || {
475 while let Some(res) = rx.blocking_recv() {
476 match res {
477 Ok(event) => {
478 if event.kind.is_modify() {
479 if suppress.swap(false, Ordering::SeqCst) {
480 tracing::debug!(
481 "Suppressing file watcher reload (admin API write)"
482 );
483 continue;
484 }
485 tracing::info!("Config file changed, reloading...");
486 }
487 }
488 Err(e) => tracing::error!("Watch error: {}", e),
489 }
490 }
491 });
492
493 Ok(())
494 }
495
496 pub async fn reload(&self) -> Result<()> {
497 let new_config = Self::load_config(&self.config_path, &self.config_path)?;
498 self.config.store(Arc::new(new_config));
499 tracing::info!("Configuration reloaded successfully");
500 Ok(())
501 }
502
503 fn persist_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
505 let content = serializer::serialize_proxy_conf(&rules, &global_scripts);
506 self.suppress_watch.store(true, Ordering::SeqCst);
507 std::fs::write(&self.config_path, &content)?;
508 let mut config = (*self.config.load().as_ref()).clone();
509 config.rules = rules;
510 config.global_scripts = global_scripts;
511 self.config.store(Arc::new(config));
512 tracing::info!("Configuration persisted to {}", self.config_path.display());
513 Ok(())
514 }
515
516 pub fn add_route(&self, rule: ProxyRule) -> Result<()> {
517 let cfg = self.get_config();
518 let mut rules = cfg.rules.clone();
519 rules.push(rule);
520 self.persist_rules(rules, cfg.global_scripts.clone())
521 }
522
523 pub fn update_route(&self, index: usize, rule: ProxyRule) -> Result<()> {
524 let cfg = self.get_config();
525 let mut rules = cfg.rules.clone();
526 if index >= rules.len() {
527 anyhow::bail!(
528 "Route index {} out of range (have {} routes)",
529 index,
530 rules.len()
531 );
532 }
533 rules[index] = rule;
534 self.persist_rules(rules, cfg.global_scripts.clone())
535 }
536
537 pub fn remove_route(&self, index: usize) -> Result<()> {
538 let cfg = self.get_config();
539 let mut rules = cfg.rules.clone();
540 if index >= rules.len() {
541 anyhow::bail!(
542 "Route index {} out of range (have {} routes)",
543 index,
544 rules.len()
545 );
546 }
547 rules.remove(index);
548 self.persist_rules(rules, cfg.global_scripts.clone())
549 }
550
551 pub fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
552 self.persist_rules(rules, global_scripts)
553 }
554}
555
556#[async_trait::async_trait]
557impl ConfigManagerTrait for ConfigManager {
558 async fn reload(&self) -> Result<()> {
559 self.reload().await
560 }
561
562 fn get_config(&self) -> Arc<Config> {
563 self.get_config()
564 }
565
566 fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
567 self.update_rules(rules, global_scripts)
568 }
569
570 fn add_route(&self, rule: ProxyRule) -> Result<()> {
571 self.add_route(rule)
572 }
573
574 fn remove_route(&self, index: usize) -> Result<()> {
575 self.remove_route(index)
576 }
577}
578
579fn extract_scripts(s: &str) -> (&str, Vec<String>) {
581 if let Some(idx) = s.find("@script:") {
582 let before = s[..idx].trim();
583 let after = &s[idx + "@script:".len()..];
584 let script_part = after.split_whitespace().next().unwrap_or(after);
586 let scripts: Vec<String> = script_part
587 .split(',')
588 .map(|s| s.trim().to_string())
589 .filter(|s| !s.is_empty())
590 .collect();
591 (before, scripts)
592 } else {
593 (s, Vec::new())
594 }
595}
596
597fn extract_auth(s: &str) -> (String, Vec<BasicAuth>) {
601 let mut auth_entries = Vec::new();
602 let mut remaining = s.to_string();
603
604 while let Some(idx) = remaining.find("@auth:") {
605 let before = &remaining[..idx];
607 let after = &remaining[idx + "@auth:".len()..];
608
609 let end_idx = after
611 .find(|c: char| c.is_whitespace())
612 .unwrap_or(after.len());
613
614 let auth_part = &after[..end_idx];
615
616 if let Some((username, hash)) = auth_part.split_once(':') {
618 if !username.is_empty() && !hash.is_empty() {
619 auth_entries.push(BasicAuth {
620 username: username.to_string(),
621 hash: hash.to_string(),
622 });
623 }
624 }
625
626 let rest = &after[end_idx..];
628 remaining = if rest.is_empty() {
629 before.to_string()
630 } else {
631 format!("{}{}", before, rest)
632 };
633 }
634
635 (remaining.trim().to_string(), auth_entries)
636}
637
638fn extract_load_balancing(s: &str) -> (String, LoadBalancingStrategy) {
641 if let Some(idx) = s.find("@lb:") {
642 let before = &s[..idx];
643 let after = &s[idx + "@lb:".len()..];
644
645 let end_idx = after
646 .find(|c: char| c.is_whitespace())
647 .unwrap_or(after.len());
648
649 let strategy_str = &after[..end_idx];
650 let strategy = match strategy_str {
651 "round-robin" => LoadBalancingStrategy::RoundRobin,
652 "weighted" => LoadBalancingStrategy::Weighted,
653 "failover" => LoadBalancingStrategy::Failover,
654 _ => LoadBalancingStrategy::default(),
655 };
656
657 let rest = &after[end_idx..];
658 let remaining = if rest.is_empty() {
659 before.to_string()
660 } else {
661 format!("{}{}", before, rest)
662 };
663
664 (remaining.trim().to_string(), strategy)
665 } else {
666 (s.to_string(), LoadBalancingStrategy::default())
667 }
668}
669
670fn parse_proxy_config(content: &str) -> Result<(Vec<ProxyRule>, Vec<String>)> {
671 let mut rules = Vec::new();
672 let mut global_scripts = Vec::new();
673
674 let mut joined_lines: Vec<String> = Vec::new();
676 for line in content.lines() {
677 if let Some(current) = joined_lines.last_mut() {
678 if current.ends_with('\\') {
679 current.pop(); current.push_str(line.trim());
681 continue;
682 }
683 }
684 joined_lines.push(line.to_string());
685 }
686
687 for line in &joined_lines {
688 let trimmed = line.trim();
689 if trimmed.is_empty() || trimmed.starts_with('#') {
690 continue;
691 }
692
693 if trimmed.starts_with("[global]") {
695 let rest = trimmed.strip_prefix("[global]").unwrap().trim();
696 let (_, scripts) = extract_scripts(rest);
697 global_scripts.extend(scripts);
698 continue;
699 }
700
701 if let Some((source, target_str)) = trimmed.split_once("->") {
702 let source = source.trim();
703 let (target_str, route_scripts) = extract_scripts(target_str.trim());
705 let (target_str, auth_entries) = extract_auth(target_str);
707 let (target_str, load_balancing) = extract_load_balancing(&target_str);
709
710 let matcher = if source == "default" || source == "*" {
711 RuleMatcher::Default
712 } else if let Some(pattern) = source.strip_prefix("~") {
713 RuleMatcher::Regex(RegexMatcher::new(pattern)?)
714 } else if !source.starts_with('/')
715 && (source.contains('.') || source.parse::<std::net::IpAddr>().is_ok())
716 {
717 if let Some((domain, path)) = source.split_once('/') {
718 if path.is_empty() || path == "*" {
719 RuleMatcher::Domain(domain.to_string())
720 } else if path.ends_with("/*") {
721 RuleMatcher::DomainPath(
722 domain.to_string(),
723 path.trim_end_matches('*').to_string(),
724 )
725 } else {
726 RuleMatcher::DomainPath(domain.to_string(), path.to_string())
727 }
728 } else {
729 RuleMatcher::Domain(source.to_string())
730 }
731 } else if source.ends_with("/*") {
732 RuleMatcher::Prefix(source.trim_end_matches('*').to_string())
733 } else {
734 RuleMatcher::Exact(source.to_string())
735 };
736
737 let targets: Vec<Target> = target_str
738 .split(',')
739 .map(|t| {
740 Ok(Target {
741 url: Url::parse(t.trim())?,
742 weight: 100,
743 })
744 })
745 .collect::<Result<Vec<_>>>()?;
746
747 rules.push(ProxyRule {
748 matcher,
749 targets,
750 headers: vec![],
751 scripts: route_scripts,
752 auth: auth_entries,
753 load_balancing,
754 });
755 }
756 }
757
758 Ok((rules, global_scripts))
759}
760
761#[cfg(test)]
762mod tests {
763 use super::*;
764
765 #[test]
766 fn test_backslash_continuation_joins_lines() {
767 let config = "/api/* -> http://backend1:8080, \\\n http://backend2:8080\n";
768 let (rules, _) = parse_proxy_config(config).unwrap();
769 assert_eq!(rules.len(), 1);
770 assert_eq!(rules[0].targets.len(), 2);
771 assert_eq!(rules[0].targets[0].url.as_str(), "http://backend1:8080/");
772 assert_eq!(rules[0].targets[1].url.as_str(), "http://backend2:8080/");
773 }
774
775 #[test]
776 fn test_multiple_continuation_lines() {
777 let config = "/api/* -> http://backend1:8080, \\\n\
778 http://backend2:8080, \\\n\
779 http://backend3:8080\n";
780 let (rules, _) = parse_proxy_config(config).unwrap();
781 assert_eq!(rules.len(), 1);
782 assert_eq!(rules[0].targets.len(), 3);
783 assert_eq!(rules[0].targets[2].url.as_str(), "http://backend3:8080/");
784 }
785
786 #[test]
787 fn test_backslash_mid_line_not_continuation() {
788 let config = "/path -> http://localhost:8080\n\
789 ~^/foo\\dbar$ -> http://localhost:9090\n";
790 let (rules, _) = parse_proxy_config(config).unwrap();
791 assert_eq!(rules.len(), 2);
792 }
793
794 #[test]
795 fn test_continuation_trims_whitespace() {
796 let config = "/api/* -> http://a:8080, \\\n http://b:8080, \\\n http://c:8080\n";
797 let (rules, _) = parse_proxy_config(config).unwrap();
798 assert_eq!(rules.len(), 1);
799 assert_eq!(rules[0].targets.len(), 3);
800 }
801
802 #[test]
803 fn test_continuation_with_scripts() {
804 let config = "/api/* -> http://a:8080, \\\n\
805 http://b:8080 @script:auth.lua\n";
806 let (rules, _) = parse_proxy_config(config).unwrap();
807 assert_eq!(rules.len(), 1);
808 assert_eq!(rules[0].targets.len(), 2);
809 assert_eq!(rules[0].scripts, vec!["auth.lua"]);
810 }
811
812 #[test]
813 fn test_no_continuation_normal_config() {
814 let config = "/api/* -> http://backend:8080\ndefault -> http://localhost:3000\n";
815 let (rules, _) = parse_proxy_config(config).unwrap();
816 assert_eq!(rules.len(), 2);
817 }
818
819 #[test]
820 fn test_auth_parsing() {
821 let config = r#"
822/db/* -> http://localhost:8080/ @auth:demo:$2b$12$YFlnIiACnSaAcxDWQlYjeedxq/3GvhvoGhRTYHMqLifJrETSqOZQa
823"#;
824 let (rules, _) = parse_proxy_config(config).unwrap();
825 assert_eq!(rules.len(), 1);
826 assert_eq!(rules[0].auth.len(), 1);
827 assert_eq!(rules[0].auth[0].username, "demo");
828 assert!(rules[0].auth[0].hash.starts_with("$2b$"));
829 assert_eq!(rules[0].targets.len(), 1);
830 assert_eq!(rules[0].targets[0].url.as_str(), "http://localhost:8080/");
831 }
832
833 #[test]
834 fn test_multiple_auth_users() {
835 let config = r#"
836secure.example.com -> http://localhost:9000/ @auth:admin:$2b$12$hash1 @auth:user:$2b$12$hash2
837"#;
838 let (rules, _) = parse_proxy_config(config).unwrap();
839 assert_eq!(rules.len(), 1);
840 assert_eq!(rules[0].auth.len(), 2);
841 assert_eq!(rules[0].auth[0].username, "admin");
842 assert_eq!(rules[0].auth[1].username, "user");
843 }
844
845 #[test]
846 fn test_load_balancing_round_robin() {
847 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:round-robin";
848 let (rules, _) = parse_proxy_config(config).unwrap();
849 assert_eq!(rules.len(), 1);
850 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
851 assert_eq!(rules[0].targets.len(), 2);
852 }
853
854 #[test]
855 fn test_load_balancing_weighted() {
856 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:weighted";
857 let (rules, _) = parse_proxy_config(config).unwrap();
858 assert_eq!(rules.len(), 1);
859 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Weighted);
860 }
861
862 #[test]
863 fn test_load_balancing_failover() {
864 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover";
865 let (rules, _) = parse_proxy_config(config).unwrap();
866 assert_eq!(rules.len(), 1);
867 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
868 }
869
870 #[test]
871 fn test_load_balancing_default_is_round_robin() {
872 let config = "/api/* -> http://b1:8080, http://b2:8080";
873 let (rules, _) = parse_proxy_config(config).unwrap();
874 assert_eq!(rules.len(), 1);
875 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
876 }
877
878 #[test]
879 fn test_load_balancing_with_scripts() {
880 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover @script:auth.lua";
881 let (rules, _) = parse_proxy_config(config).unwrap();
882 assert_eq!(rules.len(), 1);
883 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
884 assert_eq!(rules[0].scripts, vec!["auth.lua"]);
885 }
886
887 #[test]
888 fn test_load_balancing_unknown_strategy_defaults_to_round_robin() {
889 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:unknown";
890 let (rules, _) = parse_proxy_config(config).unwrap();
891 assert_eq!(rules.len(), 1);
892 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
893 }
894}