Skip to main content

soli_proxy/config/
mod.rs

1pub mod serializer;
2
3use crate::auth::BasicAuth;
4use anyhow::Result;
5use arc_swap::ArcSwap;
6use notify::{RecommendedWatcher, RecursiveMode, Watcher};
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use url::Url;
14
15#[async_trait::async_trait]
16pub trait ConfigManagerTrait: Send + Sync {
17    async fn reload(&self) -> Result<()>;
18    fn get_config(&self) -> Arc<Config>;
19    fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()>;
20    fn add_route(&self, rule: ProxyRule) -> Result<()>;
21    fn remove_route(&self, index: usize) -> Result<()>;
22}
23
24#[derive(Deserialize, Default, Clone, Debug)]
25pub struct TomlConfig {
26    #[serde(default)]
27    pub server: ServerConfig,
28    #[serde(default)]
29    pub tls: TlsConfig,
30    pub letsencrypt: Option<LetsEncryptConfig>,
31    pub scripting: Option<ScriptingTomlConfig>,
32    pub admin: Option<AdminConfig>,
33    pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
34}
35
36#[derive(Deserialize, Serialize, Clone, Debug)]
37pub struct CircuitBreakerTomlConfig {
38    pub failure_threshold: Option<u32>,
39    pub recovery_timeout_secs: Option<u64>,
40    pub success_threshold: Option<u32>,
41    pub failure_status_codes: Option<Vec<u16>>,
42}
43
44#[derive(Deserialize, Serialize, Clone, Debug)]
45pub struct AdminConfig {
46    pub enabled: bool,
47    pub bind: String,
48    pub api_key: Option<String>,
49    #[serde(default)]
50    pub username: Option<String>,
51    #[serde(default)]
52    pub password_hash: Option<String>,
53}
54
55impl Default for AdminConfig {
56    fn default() -> Self {
57        let _ = dotenv::dotenv();
58        let username = std::env::var("ADMIN_USER").ok();
59        let password_hash = std::env::var("ADMIN_PASSWORD").ok();
60        Self {
61            enabled: false,
62            bind: "127.0.0.1:9090".to_string(),
63            api_key: None,
64            username,
65            password_hash,
66        }
67    }
68}
69
70#[derive(Deserialize, Serialize, Clone, Debug, Default)]
71pub struct ScriptingTomlConfig {
72    pub enabled: bool,
73    pub scripts_dir: Option<String>,
74    pub hook_timeout_ms: Option<u64>,
75}
76
77#[derive(Deserialize, Serialize, Clone, Debug)]
78pub struct ServerConfig {
79    pub bind: String,
80    pub https_port: u16,
81}
82
83impl Default for ServerConfig {
84    fn default() -> Self {
85        Self {
86            bind: "0.0.0.0:8080".to_string(),
87            https_port: 443,
88        }
89    }
90}
91
92#[derive(Deserialize, Serialize, Default, Clone, Debug)]
93pub struct TlsConfig {
94    pub mode: String,
95    pub cache_dir: String,
96}
97
98#[derive(Deserialize, Serialize, Clone, Debug)]
99pub struct LetsEncryptConfig {
100    pub staging: bool,
101    pub email: String,
102    pub terms_agreed: bool,
103}
104
105#[derive(Clone, Debug, Serialize)]
106pub struct Config {
107    pub server: ServerConfig,
108    pub tls: TlsConfig,
109    pub letsencrypt: Option<LetsEncryptConfig>,
110    pub scripting: ScriptingTomlConfig,
111    pub admin: AdminConfig,
112    pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
113    pub rules: Vec<ProxyRule>,
114    pub global_scripts: Vec<String>,
115}
116
117#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
118pub enum LoadBalancingStrategy {
119    #[default]
120    #[serde(rename = "round-robin")]
121    RoundRobin,
122    #[serde(rename = "weighted")]
123    Weighted,
124    #[serde(rename = "failover")]
125    Failover,
126}
127
128#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct ProxyRule {
130    pub matcher: RuleMatcher,
131    pub targets: Vec<Target>,
132    pub headers: Vec<HeaderRule>,
133    pub scripts: Vec<String>,
134    #[serde(default)]
135    pub auth: Vec<BasicAuth>,
136    #[serde(default)]
137    pub load_balancing: LoadBalancingStrategy,
138}
139
140#[derive(Clone, Debug)]
141pub enum RuleMatcher {
142    Default,
143    Prefix(String),
144    Regex(RegexMatcher),
145    Exact(String),
146    Domain(String),
147    DomainPath(String, String),
148}
149
150/// Wrapper around Regex that stores the original pattern for serialization
151#[derive(Clone, Debug)]
152pub struct RegexMatcher {
153    pub pattern: String,
154    pub regex: Regex,
155}
156
157impl RegexMatcher {
158    pub fn new(pattern: &str) -> Result<Self> {
159        Ok(Self {
160            pattern: pattern.to_string(),
161            regex: Regex::new(pattern)?,
162        })
163    }
164
165    pub fn is_match(&self, text: &str) -> bool {
166        self.regex.is_match(text)
167    }
168}
169
170impl Serialize for RuleMatcher {
171    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
172    where
173        S: serde::Serializer,
174    {
175        use serde::ser::SerializeMap;
176        let mut map = serializer.serialize_map(Some(2))?;
177        match self {
178            RuleMatcher::Default => {
179                map.serialize_entry("type", "default")?;
180            }
181            RuleMatcher::Prefix(v) => {
182                map.serialize_entry("type", "prefix")?;
183                map.serialize_entry("value", v)?;
184            }
185            RuleMatcher::Regex(rm) => {
186                map.serialize_entry("type", "regex")?;
187                map.serialize_entry("value", &rm.pattern)?;
188            }
189            RuleMatcher::Exact(v) => {
190                map.serialize_entry("type", "exact")?;
191                map.serialize_entry("value", v)?;
192            }
193            RuleMatcher::Domain(v) => {
194                map.serialize_entry("type", "domain")?;
195                map.serialize_entry("value", v)?;
196            }
197            RuleMatcher::DomainPath(d, p) => {
198                map.serialize_entry("type", "domain_path")?;
199                map.serialize_entry("domain", d)?;
200                map.serialize_entry("path", p)?;
201            }
202        }
203        map.end()
204    }
205}
206
207impl<'de> Deserialize<'de> for RuleMatcher {
208    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
209    where
210        D: serde::Deserializer<'de>,
211    {
212        use serde::de::Error;
213        let value: serde_json::Value = Deserialize::deserialize(deserializer)?;
214        let obj = value
215            .as_object()
216            .ok_or_else(|| D::Error::custom("expected object"))?;
217        let matcher_type = obj
218            .get("type")
219            .and_then(|v| v.as_str())
220            .ok_or_else(|| D::Error::custom("missing 'type' field"))?;
221
222        match matcher_type {
223            "default" => Ok(RuleMatcher::Default),
224            "exact" => {
225                let v = obj
226                    .get("value")
227                    .and_then(|v| v.as_str())
228                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
229                Ok(RuleMatcher::Exact(v.to_string()))
230            }
231            "prefix" => {
232                let v = obj
233                    .get("value")
234                    .and_then(|v| v.as_str())
235                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
236                Ok(RuleMatcher::Prefix(v.to_string()))
237            }
238            "regex" => {
239                let v = obj
240                    .get("value")
241                    .and_then(|v| v.as_str())
242                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
243                let rm = RegexMatcher::new(v)
244                    .map_err(|e| D::Error::custom(format!("invalid regex: {}", e)))?;
245                Ok(RuleMatcher::Regex(rm))
246            }
247            "domain" => {
248                let v = obj
249                    .get("value")
250                    .and_then(|v| v.as_str())
251                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
252                Ok(RuleMatcher::Domain(v.to_string()))
253            }
254            "domain_path" => {
255                let d = obj
256                    .get("domain")
257                    .and_then(|v| v.as_str())
258                    .ok_or_else(|| D::Error::custom("missing 'domain'"))?;
259                let p = obj
260                    .get("path")
261                    .and_then(|v| v.as_str())
262                    .ok_or_else(|| D::Error::custom("missing 'path'"))?;
263                Ok(RuleMatcher::DomainPath(d.to_string(), p.to_string()))
264            }
265            other => Err(D::Error::custom(format!("unknown matcher type: {}", other))),
266        }
267    }
268}
269
270#[derive(Clone, Debug, Serialize, Deserialize)]
271pub struct Target {
272    pub url: Url,
273    pub weight: u8,
274}
275
276#[derive(Clone, Debug, Serialize, Deserialize)]
277pub struct HeaderRule {
278    pub name: String,
279    pub value: String,
280}
281
282impl Config {
283    /// Extract unique domain names from Domain and DomainPath rules,
284    /// filtering out IPs and "localhost".
285    pub fn acme_domains(&self) -> Vec<String> {
286        let mut domains = Vec::new();
287        let mut seen = std::collections::HashSet::new();
288
289        for rule in &self.rules {
290            let domain = match &rule.matcher {
291                RuleMatcher::Domain(d) => Some(d.as_str()),
292                RuleMatcher::DomainPath(d, _) => Some(d.as_str()),
293                _ => None,
294            };
295
296            if let Some(d) = domain {
297                if d == "localhost" || d.parse::<std::net::IpAddr>().is_ok() {
298                    continue;
299                }
300                if seen.insert(d.to_string()) {
301                    domains.push(d.to_string());
302                }
303            }
304        }
305
306        domains
307    }
308}
309
310pub struct ConfigManager {
311    config: ArcSwap<Config>,
312    config_path: PathBuf,
313    _watcher: Option<RecommendedWatcher>,
314    suppress_watch: Arc<AtomicBool>,
315}
316
317impl Clone for ConfigManager {
318    fn clone(&self) -> Self {
319        Self {
320            config: ArcSwap::new(self.config.load().clone()),
321            config_path: self.config_path.clone(),
322            _watcher: None,
323            suppress_watch: self.suppress_watch.clone(),
324        }
325    }
326}
327
328impl ConfigManager {
329    pub fn new(config_path: &str) -> Result<Self> {
330        let path = PathBuf::from(config_path);
331        let config = Self::load_config(&path, &path)?;
332        Ok(Self {
333            config: ArcSwap::new(Arc::new(config)),
334            config_path: path,
335            _watcher: None,
336            suppress_watch: Arc::new(AtomicBool::new(false)),
337        })
338    }
339
340    pub fn config_path(&self) -> &Path {
341        &self.config_path
342    }
343
344    pub fn suppress_watch(&self) -> &Arc<AtomicBool> {
345        &self.suppress_watch
346    }
347
348    fn load_config(proxy_conf_path: &Path, config_path: &Path) -> Result<Config> {
349        let content = std::fs::read_to_string(proxy_conf_path).unwrap_or_default();
350        let (rules, global_scripts) = parse_proxy_config(&content)?;
351        let config_dir = config_path.parent().unwrap_or(Path::new("."));
352        let config_toml_path = config_dir.join("config.toml");
353        let toml_content = if config_toml_path.exists() {
354            std::fs::read_to_string(&config_toml_path).ok()
355        } else {
356            let default_config = r#"# Soli Proxy Configuration
357# Server settings
358[server]
359bind = "0.0.0.0:80"
360https_port = 443
361worker_threads = "auto"
362
363# TLS Configuration
364[tls]
365mode = "auto"
366cache_dir = "./certs"
367
368# Logging Configuration
369[logging]
370level = "info"
371format = "json"
372output = "stdout"
373include_request_body = true
374include_response_body = true
375
376# Metrics Configuration
377[metrics]
378enabled = true
379endpoint = "/metrics"
380
381# Health Check Configuration
382[health]
383enabled = true
384liveness_path = "/health/live"
385readiness_path = "/health/ready"
386
387# Limits Configuration
388[limits]
389max_connections = 10000
390max_request_size = "10MB"
391keep_alive_timeout = 30
392request_timeout = 60
393
394# Rate Limiting Configuration
395[rate_limiting]
396enabled = true
397strategy = "token_bucket"
398requests_per_second = 1000
399burst_size = 2000
400redis_url = "redis://localhost:6379"
401
402# Circuit Breaker Configuration
403[circuit_breaker]
404failure_threshold = 5
405recovery_timeout_secs = 30
406success_threshold = 2
407failure_status_codes = [502, 503, 504]
408
409# Admin REST API Configuration
410[admin]
411enabled = true
412bind = "0.0.0.0:9090"
413
414# Lua Scripting Configuration
415[scripting]
416enabled = false
417scripts_dir = "./scripts/lua"
418hook_timeout_ms = 10
419
420# Authentication Configuration
421[auth]
422enabled = false
423auth_type = "basic"
424realm = "Restricted"
425"#;
426            std::fs::write(&config_toml_path, default_config).ok();
427            Some(default_config.to_string())
428        };
429        let toml_config: TomlConfig = toml_content
430            .as_ref()
431            .and_then(|c| toml::from_str(c).ok())
432            .unwrap_or_default();
433
434        Ok(Config {
435            server: toml_config.server,
436            tls: toml_config.tls,
437            letsencrypt: toml_config.letsencrypt,
438            scripting: toml_config.scripting.unwrap_or_default(),
439            admin: toml_config.admin.unwrap_or_default(),
440            circuit_breaker: toml_config.circuit_breaker,
441            rules,
442            global_scripts,
443        })
444    }
445
446    pub fn get_config(&self) -> Arc<Config> {
447        self.config.load().clone()
448    }
449
450    pub fn start_watcher(&self) -> Result<()> {
451        // Ensure the file exists so the watcher has something to watch
452        if !self.config_path.exists() {
453            if let Some(parent) = self.config_path.parent() {
454                std::fs::create_dir_all(parent).ok();
455            }
456            std::fs::write(&self.config_path, "")?;
457        }
458
459        let (tx, mut rx) = mpsc::channel(1);
460        let config_path = self.config_path.clone();
461        let suppress = self.suppress_watch.clone();
462
463        let mut watcher = RecommendedWatcher::new(
464            move |res| {
465                let _ = tx.blocking_send(res);
466            },
467            notify::Config::default(),
468        )?;
469
470        watcher.watch(&config_path, RecursiveMode::NonRecursive)?;
471
472        tracing::info!("Watching config file: {}", config_path.display());
473
474        std::thread::spawn(move || {
475            while let Some(res) = rx.blocking_recv() {
476                match res {
477                    Ok(event) => {
478                        if event.kind.is_modify() {
479                            if suppress.swap(false, Ordering::SeqCst) {
480                                tracing::debug!(
481                                    "Suppressing file watcher reload (admin API write)"
482                                );
483                                continue;
484                            }
485                            tracing::info!("Config file changed, reloading...");
486                        }
487                    }
488                    Err(e) => tracing::error!("Watch error: {}", e),
489                }
490            }
491        });
492
493        Ok(())
494    }
495
496    pub async fn reload(&self) -> Result<()> {
497        let new_config = Self::load_config(&self.config_path, &self.config_path)?;
498        self.config.store(Arc::new(new_config));
499        tracing::info!("Configuration reloaded successfully");
500        Ok(())
501    }
502
503    /// Persist current rules to proxy.conf and swap in-memory config
504    fn persist_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
505        let content = serializer::serialize_proxy_conf(&rules, &global_scripts);
506        self.suppress_watch.store(true, Ordering::SeqCst);
507        std::fs::write(&self.config_path, &content)?;
508        let mut config = (*self.config.load().as_ref()).clone();
509        config.rules = rules;
510        config.global_scripts = global_scripts;
511        self.config.store(Arc::new(config));
512        tracing::info!("Configuration persisted to {}", self.config_path.display());
513        Ok(())
514    }
515
516    pub fn add_route(&self, rule: ProxyRule) -> Result<()> {
517        let cfg = self.get_config();
518        let mut rules = cfg.rules.clone();
519        rules.push(rule);
520        self.persist_rules(rules, cfg.global_scripts.clone())
521    }
522
523    pub fn update_route(&self, index: usize, rule: ProxyRule) -> Result<()> {
524        let cfg = self.get_config();
525        let mut rules = cfg.rules.clone();
526        if index >= rules.len() {
527            anyhow::bail!(
528                "Route index {} out of range (have {} routes)",
529                index,
530                rules.len()
531            );
532        }
533        rules[index] = rule;
534        self.persist_rules(rules, cfg.global_scripts.clone())
535    }
536
537    pub fn remove_route(&self, index: usize) -> Result<()> {
538        let cfg = self.get_config();
539        let mut rules = cfg.rules.clone();
540        if index >= rules.len() {
541            anyhow::bail!(
542                "Route index {} out of range (have {} routes)",
543                index,
544                rules.len()
545            );
546        }
547        rules.remove(index);
548        self.persist_rules(rules, cfg.global_scripts.clone())
549    }
550
551    pub fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
552        self.persist_rules(rules, global_scripts)
553    }
554}
555
556#[async_trait::async_trait]
557impl ConfigManagerTrait for ConfigManager {
558    async fn reload(&self) -> Result<()> {
559        self.reload().await
560    }
561
562    fn get_config(&self) -> Arc<Config> {
563        self.get_config()
564    }
565
566    fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
567        self.update_rules(rules, global_scripts)
568    }
569
570    fn add_route(&self, rule: ProxyRule) -> Result<()> {
571        self.add_route(rule)
572    }
573
574    fn remove_route(&self, index: usize) -> Result<()> {
575        self.remove_route(index)
576    }
577}
578
579/// Extract `@script:a.lua,b.lua` from a string, returning (remaining_str, scripts_vec).
580fn extract_scripts(s: &str) -> (&str, Vec<String>) {
581    if let Some(idx) = s.find("@script:") {
582        let before = s[..idx].trim();
583        let after = &s[idx + "@script:".len()..];
584        // Scripts are comma-separated, ending at whitespace or end-of-string
585        let script_part = after.split_whitespace().next().unwrap_or(after);
586        let scripts: Vec<String> = script_part
587            .split(',')
588            .map(|s| s.trim().to_string())
589            .filter(|s| !s.is_empty())
590            .collect();
591        (before, scripts)
592    } else {
593        (s, Vec::new())
594    }
595}
596
597/// Extract `@auth:user:hash` entries from a string, returning (remaining_str, auth_vec).
598/// Multiple @auth entries can appear: `@auth:user1:hash1 @auth:user2:hash2`
599/// Hash is everything after the second colon (bcrypt hashes start with $2a$, $2b$, $2y$)
600fn extract_auth(s: &str) -> (String, Vec<BasicAuth>) {
601    let mut auth_entries = Vec::new();
602    let mut remaining = s.to_string();
603
604    while let Some(idx) = remaining.find("@auth:") {
605        // Keep the part BEFORE @auth:
606        let before = &remaining[..idx];
607        let after = &remaining[idx + "@auth:".len()..];
608
609        // Find the end of this auth entry (whitespace or end of string)
610        let end_idx = after
611            .find(|c: char| c.is_whitespace())
612            .unwrap_or(after.len());
613
614        let auth_part = &after[..end_idx];
615
616        // Parse username:hash - hash is everything after the first colon
617        if let Some((username, hash)) = auth_part.split_once(':') {
618            if !username.is_empty() && !hash.is_empty() {
619                auth_entries.push(BasicAuth {
620                    username: username.to_string(),
621                    hash: hash.to_string(),
622                });
623            }
624        }
625
626        // Continue with the part BEFORE this @auth, plus any remaining after it
627        let rest = &after[end_idx..];
628        remaining = if rest.is_empty() {
629            before.to_string()
630        } else {
631            format!("{}{}", before, rest)
632        };
633    }
634
635    (remaining.trim().to_string(), auth_entries)
636}
637
638/// Extract `@lb:strategy` from a string, returning (remaining_str, strategy).
639/// Example: `@lb:round-robin` or `@lb:weighted` or `@lb:failover`
640fn extract_load_balancing(s: &str) -> (String, LoadBalancingStrategy) {
641    if let Some(idx) = s.find("@lb:") {
642        let before = &s[..idx];
643        let after = &s[idx + "@lb:".len()..];
644
645        let end_idx = after
646            .find(|c: char| c.is_whitespace())
647            .unwrap_or(after.len());
648
649        let strategy_str = &after[..end_idx];
650        let strategy = match strategy_str {
651            "round-robin" => LoadBalancingStrategy::RoundRobin,
652            "weighted" => LoadBalancingStrategy::Weighted,
653            "failover" => LoadBalancingStrategy::Failover,
654            _ => LoadBalancingStrategy::default(),
655        };
656
657        let rest = &after[end_idx..];
658        let remaining = if rest.is_empty() {
659            before.to_string()
660        } else {
661            format!("{}{}", before, rest)
662        };
663
664        (remaining.trim().to_string(), strategy)
665    } else {
666        (s.to_string(), LoadBalancingStrategy::default())
667    }
668}
669
670fn parse_proxy_config(content: &str) -> Result<(Vec<ProxyRule>, Vec<String>)> {
671    let mut rules = Vec::new();
672    let mut global_scripts = Vec::new();
673
674    // Join continuation lines (backslash at end of line)
675    let mut joined_lines: Vec<String> = Vec::new();
676    for line in content.lines() {
677        if let Some(current) = joined_lines.last_mut() {
678            if current.ends_with('\\') {
679                current.pop(); // remove the backslash
680                current.push_str(line.trim());
681                continue;
682            }
683        }
684        joined_lines.push(line.to_string());
685    }
686
687    for line in &joined_lines {
688        let trimmed = line.trim();
689        if trimmed.is_empty() || trimmed.starts_with('#') {
690            continue;
691        }
692
693        // Handle [global] @script:cors.lua,logging.lua
694        if trimmed.starts_with("[global]") {
695            let rest = trimmed.strip_prefix("[global]").unwrap().trim();
696            let (_, scripts) = extract_scripts(rest);
697            global_scripts.extend(scripts);
698            continue;
699        }
700
701        if let Some((source, target_str)) = trimmed.split_once("->") {
702            let source = source.trim();
703            // Extract @script: from the target side
704            let (target_str, route_scripts) = extract_scripts(target_str.trim());
705            // Extract @auth: entries from the target side
706            let (target_str, auth_entries) = extract_auth(target_str);
707            // Extract @lb: load balancing strategy
708            let (target_str, load_balancing) = extract_load_balancing(&target_str);
709
710            let matcher = if source == "default" || source == "*" {
711                RuleMatcher::Default
712            } else if let Some(pattern) = source.strip_prefix("~") {
713                RuleMatcher::Regex(RegexMatcher::new(pattern)?)
714            } else if !source.starts_with('/')
715                && (source.contains('.') || source.parse::<std::net::IpAddr>().is_ok())
716            {
717                if let Some((domain, path)) = source.split_once('/') {
718                    if path.is_empty() || path == "*" {
719                        RuleMatcher::Domain(domain.to_string())
720                    } else if path.ends_with("/*") {
721                        RuleMatcher::DomainPath(
722                            domain.to_string(),
723                            path.trim_end_matches('*').to_string(),
724                        )
725                    } else {
726                        RuleMatcher::DomainPath(domain.to_string(), path.to_string())
727                    }
728                } else {
729                    RuleMatcher::Domain(source.to_string())
730                }
731            } else if source.ends_with("/*") {
732                RuleMatcher::Prefix(source.trim_end_matches('*').to_string())
733            } else {
734                RuleMatcher::Exact(source.to_string())
735            };
736
737            let targets: Vec<Target> = target_str
738                .split(',')
739                .map(|t| {
740                    Ok(Target {
741                        url: Url::parse(t.trim())?,
742                        weight: 100,
743                    })
744                })
745                .collect::<Result<Vec<_>>>()?;
746
747            rules.push(ProxyRule {
748                matcher,
749                targets,
750                headers: vec![],
751                scripts: route_scripts,
752                auth: auth_entries,
753                load_balancing,
754            });
755        }
756    }
757
758    Ok((rules, global_scripts))
759}
760
761#[cfg(test)]
762mod tests {
763    use super::*;
764
765    #[test]
766    fn test_backslash_continuation_joins_lines() {
767        let config = "/api/* -> http://backend1:8080, \\\n          http://backend2:8080\n";
768        let (rules, _) = parse_proxy_config(config).unwrap();
769        assert_eq!(rules.len(), 1);
770        assert_eq!(rules[0].targets.len(), 2);
771        assert_eq!(rules[0].targets[0].url.as_str(), "http://backend1:8080/");
772        assert_eq!(rules[0].targets[1].url.as_str(), "http://backend2:8080/");
773    }
774
775    #[test]
776    fn test_multiple_continuation_lines() {
777        let config = "/api/* -> http://backend1:8080, \\\n\
778                       http://backend2:8080, \\\n\
779                       http://backend3:8080\n";
780        let (rules, _) = parse_proxy_config(config).unwrap();
781        assert_eq!(rules.len(), 1);
782        assert_eq!(rules[0].targets.len(), 3);
783        assert_eq!(rules[0].targets[2].url.as_str(), "http://backend3:8080/");
784    }
785
786    #[test]
787    fn test_backslash_mid_line_not_continuation() {
788        let config = "/path -> http://localhost:8080\n\
789                       ~^/foo\\dbar$ -> http://localhost:9090\n";
790        let (rules, _) = parse_proxy_config(config).unwrap();
791        assert_eq!(rules.len(), 2);
792    }
793
794    #[test]
795    fn test_continuation_trims_whitespace() {
796        let config = "/api/* -> http://a:8080,   \\\n   http://b:8080,  \\\n   http://c:8080\n";
797        let (rules, _) = parse_proxy_config(config).unwrap();
798        assert_eq!(rules.len(), 1);
799        assert_eq!(rules[0].targets.len(), 3);
800    }
801
802    #[test]
803    fn test_continuation_with_scripts() {
804        let config = "/api/* -> http://a:8080, \\\n\
805                       http://b:8080 @script:auth.lua\n";
806        let (rules, _) = parse_proxy_config(config).unwrap();
807        assert_eq!(rules.len(), 1);
808        assert_eq!(rules[0].targets.len(), 2);
809        assert_eq!(rules[0].scripts, vec!["auth.lua"]);
810    }
811
812    #[test]
813    fn test_no_continuation_normal_config() {
814        let config = "/api/* -> http://backend:8080\ndefault -> http://localhost:3000\n";
815        let (rules, _) = parse_proxy_config(config).unwrap();
816        assert_eq!(rules.len(), 2);
817    }
818
819    #[test]
820    fn test_auth_parsing() {
821        let config = r#"
822/db/* -> http://localhost:8080/ @auth:demo:$2b$12$YFlnIiACnSaAcxDWQlYjeedxq/3GvhvoGhRTYHMqLifJrETSqOZQa
823"#;
824        let (rules, _) = parse_proxy_config(config).unwrap();
825        assert_eq!(rules.len(), 1);
826        assert_eq!(rules[0].auth.len(), 1);
827        assert_eq!(rules[0].auth[0].username, "demo");
828        assert!(rules[0].auth[0].hash.starts_with("$2b$"));
829        assert_eq!(rules[0].targets.len(), 1);
830        assert_eq!(rules[0].targets[0].url.as_str(), "http://localhost:8080/");
831    }
832
833    #[test]
834    fn test_multiple_auth_users() {
835        let config = r#"
836secure.example.com -> http://localhost:9000/ @auth:admin:$2b$12$hash1 @auth:user:$2b$12$hash2
837"#;
838        let (rules, _) = parse_proxy_config(config).unwrap();
839        assert_eq!(rules.len(), 1);
840        assert_eq!(rules[0].auth.len(), 2);
841        assert_eq!(rules[0].auth[0].username, "admin");
842        assert_eq!(rules[0].auth[1].username, "user");
843    }
844
845    #[test]
846    fn test_load_balancing_round_robin() {
847        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:round-robin";
848        let (rules, _) = parse_proxy_config(config).unwrap();
849        assert_eq!(rules.len(), 1);
850        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
851        assert_eq!(rules[0].targets.len(), 2);
852    }
853
854    #[test]
855    fn test_load_balancing_weighted() {
856        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:weighted";
857        let (rules, _) = parse_proxy_config(config).unwrap();
858        assert_eq!(rules.len(), 1);
859        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Weighted);
860    }
861
862    #[test]
863    fn test_load_balancing_failover() {
864        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover";
865        let (rules, _) = parse_proxy_config(config).unwrap();
866        assert_eq!(rules.len(), 1);
867        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
868    }
869
870    #[test]
871    fn test_load_balancing_default_is_round_robin() {
872        let config = "/api/* -> http://b1:8080, http://b2:8080";
873        let (rules, _) = parse_proxy_config(config).unwrap();
874        assert_eq!(rules.len(), 1);
875        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
876    }
877
878    #[test]
879    fn test_load_balancing_with_scripts() {
880        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover @script:auth.lua";
881        let (rules, _) = parse_proxy_config(config).unwrap();
882        assert_eq!(rules.len(), 1);
883        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
884        assert_eq!(rules[0].scripts, vec!["auth.lua"]);
885    }
886
887    #[test]
888    fn test_load_balancing_unknown_strategy_defaults_to_round_robin() {
889        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:unknown";
890        let (rules, _) = parse_proxy_config(config).unwrap();
891        assert_eq!(rules.len(), 1);
892        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
893    }
894}