Skip to main content

soli_proxy/config/
mod.rs

1pub mod serializer;
2
3use crate::auth::BasicAuth;
4use anyhow::Result;
5use arc_swap::ArcSwap;
6use notify::{RecommendedWatcher, RecursiveMode, Watcher};
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use url::Url;
14
15#[async_trait::async_trait]
16pub trait ConfigManagerTrait: Send + Sync {
17    async fn reload(&self) -> Result<()>;
18    fn get_config(&self) -> Arc<Config>;
19    fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()>;
20    fn add_route(&self, rule: ProxyRule) -> Result<()>;
21    fn remove_route(&self, index: usize) -> Result<()>;
22}
23
24#[derive(Deserialize, Default, Clone, Debug)]
25pub struct TomlConfig {
26    #[serde(default)]
27    pub server: ServerConfig,
28    #[serde(default)]
29    pub tls: TlsConfig,
30    pub letsencrypt: Option<LetsEncryptConfig>,
31    pub scripting: Option<ScriptingTomlConfig>,
32    pub admin: Option<AdminConfig>,
33    pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
34}
35
36#[derive(Deserialize, Serialize, Clone, Debug)]
37pub struct CircuitBreakerTomlConfig {
38    pub failure_threshold: Option<u32>,
39    pub recovery_timeout_secs: Option<u64>,
40    pub success_threshold: Option<u32>,
41    pub failure_status_codes: Option<Vec<u16>>,
42}
43
44#[derive(Deserialize, Serialize, Clone, Debug)]
45pub struct AdminConfig {
46    pub enabled: bool,
47    pub bind: String,
48    pub api_key: Option<String>,
49}
50
51impl Default for AdminConfig {
52    fn default() -> Self {
53        Self {
54            enabled: false,
55            bind: "127.0.0.1:9090".to_string(),
56            api_key: None,
57        }
58    }
59}
60
61#[derive(Deserialize, Serialize, Clone, Debug, Default)]
62pub struct ScriptingTomlConfig {
63    pub enabled: bool,
64    pub scripts_dir: Option<String>,
65    pub hook_timeout_ms: Option<u64>,
66}
67
68#[derive(Deserialize, Serialize, Default, Clone, Debug)]
69pub struct ServerConfig {
70    pub bind: String,
71    pub https_port: u16,
72}
73
74#[derive(Deserialize, Serialize, Default, Clone, Debug)]
75pub struct TlsConfig {
76    pub mode: String,
77    pub cache_dir: String,
78}
79
80#[derive(Deserialize, Serialize, Clone, Debug)]
81pub struct LetsEncryptConfig {
82    pub staging: bool,
83    pub email: String,
84    pub terms_agreed: bool,
85}
86
87#[derive(Clone, Debug, Serialize)]
88pub struct Config {
89    pub server: ServerConfig,
90    pub tls: TlsConfig,
91    pub letsencrypt: Option<LetsEncryptConfig>,
92    pub scripting: ScriptingTomlConfig,
93    pub admin: AdminConfig,
94    pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
95    pub rules: Vec<ProxyRule>,
96    pub global_scripts: Vec<String>,
97}
98
99#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
100pub enum LoadBalancingStrategy {
101    #[default]
102    #[serde(rename = "round-robin")]
103    RoundRobin,
104    #[serde(rename = "weighted")]
105    Weighted,
106    #[serde(rename = "failover")]
107    Failover,
108}
109
110#[derive(Clone, Debug, Serialize, Deserialize)]
111pub struct ProxyRule {
112    pub matcher: RuleMatcher,
113    pub targets: Vec<Target>,
114    pub headers: Vec<HeaderRule>,
115    pub scripts: Vec<String>,
116    #[serde(default)]
117    pub auth: Vec<BasicAuth>,
118    #[serde(default)]
119    pub load_balancing: LoadBalancingStrategy,
120}
121
122#[derive(Clone, Debug)]
123pub enum RuleMatcher {
124    Default,
125    Prefix(String),
126    Regex(RegexMatcher),
127    Exact(String),
128    Domain(String),
129    DomainPath(String, String),
130}
131
132/// Wrapper around Regex that stores the original pattern for serialization
133#[derive(Clone, Debug)]
134pub struct RegexMatcher {
135    pub pattern: String,
136    pub regex: Regex,
137}
138
139impl RegexMatcher {
140    pub fn new(pattern: &str) -> Result<Self> {
141        Ok(Self {
142            pattern: pattern.to_string(),
143            regex: Regex::new(pattern)?,
144        })
145    }
146
147    pub fn is_match(&self, text: &str) -> bool {
148        self.regex.is_match(text)
149    }
150}
151
152impl Serialize for RuleMatcher {
153    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
154    where
155        S: serde::Serializer,
156    {
157        use serde::ser::SerializeMap;
158        let mut map = serializer.serialize_map(Some(2))?;
159        match self {
160            RuleMatcher::Default => {
161                map.serialize_entry("type", "default")?;
162            }
163            RuleMatcher::Prefix(v) => {
164                map.serialize_entry("type", "prefix")?;
165                map.serialize_entry("value", v)?;
166            }
167            RuleMatcher::Regex(rm) => {
168                map.serialize_entry("type", "regex")?;
169                map.serialize_entry("value", &rm.pattern)?;
170            }
171            RuleMatcher::Exact(v) => {
172                map.serialize_entry("type", "exact")?;
173                map.serialize_entry("value", v)?;
174            }
175            RuleMatcher::Domain(v) => {
176                map.serialize_entry("type", "domain")?;
177                map.serialize_entry("value", v)?;
178            }
179            RuleMatcher::DomainPath(d, p) => {
180                map.serialize_entry("type", "domain_path")?;
181                map.serialize_entry("domain", d)?;
182                map.serialize_entry("path", p)?;
183            }
184        }
185        map.end()
186    }
187}
188
189impl<'de> Deserialize<'de> for RuleMatcher {
190    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
191    where
192        D: serde::Deserializer<'de>,
193    {
194        use serde::de::Error;
195        let value: serde_json::Value = Deserialize::deserialize(deserializer)?;
196        let obj = value
197            .as_object()
198            .ok_or_else(|| D::Error::custom("expected object"))?;
199        let matcher_type = obj
200            .get("type")
201            .and_then(|v| v.as_str())
202            .ok_or_else(|| D::Error::custom("missing 'type' field"))?;
203
204        match matcher_type {
205            "default" => Ok(RuleMatcher::Default),
206            "exact" => {
207                let v = obj
208                    .get("value")
209                    .and_then(|v| v.as_str())
210                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
211                Ok(RuleMatcher::Exact(v.to_string()))
212            }
213            "prefix" => {
214                let v = obj
215                    .get("value")
216                    .and_then(|v| v.as_str())
217                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
218                Ok(RuleMatcher::Prefix(v.to_string()))
219            }
220            "regex" => {
221                let v = obj
222                    .get("value")
223                    .and_then(|v| v.as_str())
224                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
225                let rm = RegexMatcher::new(v)
226                    .map_err(|e| D::Error::custom(format!("invalid regex: {}", e)))?;
227                Ok(RuleMatcher::Regex(rm))
228            }
229            "domain" => {
230                let v = obj
231                    .get("value")
232                    .and_then(|v| v.as_str())
233                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
234                Ok(RuleMatcher::Domain(v.to_string()))
235            }
236            "domain_path" => {
237                let d = obj
238                    .get("domain")
239                    .and_then(|v| v.as_str())
240                    .ok_or_else(|| D::Error::custom("missing 'domain'"))?;
241                let p = obj
242                    .get("path")
243                    .and_then(|v| v.as_str())
244                    .ok_or_else(|| D::Error::custom("missing 'path'"))?;
245                Ok(RuleMatcher::DomainPath(d.to_string(), p.to_string()))
246            }
247            other => Err(D::Error::custom(format!("unknown matcher type: {}", other))),
248        }
249    }
250}
251
252#[derive(Clone, Debug, Serialize, Deserialize)]
253pub struct Target {
254    pub url: Url,
255    pub weight: u8,
256}
257
258#[derive(Clone, Debug, Serialize, Deserialize)]
259pub struct HeaderRule {
260    pub name: String,
261    pub value: String,
262}
263
264impl Config {
265    /// Extract unique domain names from Domain and DomainPath rules,
266    /// filtering out IPs and "localhost".
267    pub fn acme_domains(&self) -> Vec<String> {
268        let mut domains = Vec::new();
269        let mut seen = std::collections::HashSet::new();
270
271        for rule in &self.rules {
272            let domain = match &rule.matcher {
273                RuleMatcher::Domain(d) => Some(d.as_str()),
274                RuleMatcher::DomainPath(d, _) => Some(d.as_str()),
275                _ => None,
276            };
277
278            if let Some(d) = domain {
279                if d == "localhost" || d.parse::<std::net::IpAddr>().is_ok() {
280                    continue;
281                }
282                if seen.insert(d.to_string()) {
283                    domains.push(d.to_string());
284                }
285            }
286        }
287
288        domains
289    }
290}
291
292pub struct ConfigManager {
293    config: ArcSwap<Config>,
294    config_path: PathBuf,
295    _watcher: Option<RecommendedWatcher>,
296    suppress_watch: Arc<AtomicBool>,
297}
298
299impl Clone for ConfigManager {
300    fn clone(&self) -> Self {
301        Self {
302            config: ArcSwap::new(self.config.load().clone()),
303            config_path: self.config_path.clone(),
304            _watcher: None,
305            suppress_watch: self.suppress_watch.clone(),
306        }
307    }
308}
309
310impl ConfigManager {
311    pub fn new(config_path: &str) -> Result<Self> {
312        let path = PathBuf::from(config_path);
313        let config = Self::load_config(&path, &path)?;
314        Ok(Self {
315            config: ArcSwap::new(Arc::new(config)),
316            config_path: path,
317            _watcher: None,
318            suppress_watch: Arc::new(AtomicBool::new(false)),
319        })
320    }
321
322    pub fn config_path(&self) -> &Path {
323        &self.config_path
324    }
325
326    pub fn suppress_watch(&self) -> &Arc<AtomicBool> {
327        &self.suppress_watch
328    }
329
330    fn load_config(proxy_conf_path: &Path, config_path: &Path) -> Result<Config> {
331        let content = std::fs::read_to_string(proxy_conf_path).unwrap_or_default();
332        let (rules, global_scripts) = parse_proxy_config(&content)?;
333        let toml_content = std::fs::read_to_string(
334            config_path
335                .parent()
336                .unwrap_or(Path::new("."))
337                .join("config.toml"),
338        )
339        .ok();
340        let toml_config: TomlConfig = toml_content
341            .as_ref()
342            .and_then(|c| toml::from_str(c).ok())
343            .unwrap_or_default();
344
345        Ok(Config {
346            server: toml_config.server,
347            tls: toml_config.tls,
348            letsencrypt: toml_config.letsencrypt,
349            scripting: toml_config.scripting.unwrap_or_default(),
350            admin: toml_config.admin.unwrap_or_default(),
351            circuit_breaker: toml_config.circuit_breaker,
352            rules,
353            global_scripts,
354        })
355    }
356
357    pub fn get_config(&self) -> Arc<Config> {
358        self.config.load().clone()
359    }
360
361    pub fn start_watcher(&self) -> Result<()> {
362        // Ensure the file exists so the watcher has something to watch
363        if !self.config_path.exists() {
364            if let Some(parent) = self.config_path.parent() {
365                std::fs::create_dir_all(parent).ok();
366            }
367            std::fs::write(&self.config_path, "")?;
368        }
369
370        let (tx, mut rx) = mpsc::channel(1);
371        let config_path = self.config_path.clone();
372        let suppress = self.suppress_watch.clone();
373
374        let mut watcher = RecommendedWatcher::new(
375            move |res| {
376                let _ = tx.blocking_send(res);
377            },
378            notify::Config::default(),
379        )?;
380
381        watcher.watch(&config_path, RecursiveMode::NonRecursive)?;
382
383        tracing::info!("Watching config file: {}", config_path.display());
384
385        std::thread::spawn(move || {
386            while let Some(res) = rx.blocking_recv() {
387                match res {
388                    Ok(event) => {
389                        if event.kind.is_modify() {
390                            if suppress.swap(false, Ordering::SeqCst) {
391                                tracing::debug!(
392                                    "Suppressing file watcher reload (admin API write)"
393                                );
394                                continue;
395                            }
396                            tracing::info!("Config file changed, reloading...");
397                        }
398                    }
399                    Err(e) => tracing::error!("Watch error: {}", e),
400                }
401            }
402        });
403
404        Ok(())
405    }
406
407    pub async fn reload(&self) -> Result<()> {
408        let new_config = Self::load_config(&self.config_path, &self.config_path)?;
409        self.config.store(Arc::new(new_config));
410        tracing::info!("Configuration reloaded successfully");
411        Ok(())
412    }
413
414    /// Persist current rules to proxy.conf and swap in-memory config
415    fn persist_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
416        let content = serializer::serialize_proxy_conf(&rules, &global_scripts);
417        self.suppress_watch.store(true, Ordering::SeqCst);
418        std::fs::write(&self.config_path, &content)?;
419        let mut config = (*self.config.load().as_ref()).clone();
420        config.rules = rules;
421        config.global_scripts = global_scripts;
422        self.config.store(Arc::new(config));
423        tracing::info!("Configuration persisted to {}", self.config_path.display());
424        Ok(())
425    }
426
427    pub fn add_route(&self, rule: ProxyRule) -> Result<()> {
428        let cfg = self.get_config();
429        let mut rules = cfg.rules.clone();
430        rules.push(rule);
431        self.persist_rules(rules, cfg.global_scripts.clone())
432    }
433
434    pub fn update_route(&self, index: usize, rule: ProxyRule) -> Result<()> {
435        let cfg = self.get_config();
436        let mut rules = cfg.rules.clone();
437        if index >= rules.len() {
438            anyhow::bail!(
439                "Route index {} out of range (have {} routes)",
440                index,
441                rules.len()
442            );
443        }
444        rules[index] = rule;
445        self.persist_rules(rules, cfg.global_scripts.clone())
446    }
447
448    pub fn remove_route(&self, index: usize) -> Result<()> {
449        let cfg = self.get_config();
450        let mut rules = cfg.rules.clone();
451        if index >= rules.len() {
452            anyhow::bail!(
453                "Route index {} out of range (have {} routes)",
454                index,
455                rules.len()
456            );
457        }
458        rules.remove(index);
459        self.persist_rules(rules, cfg.global_scripts.clone())
460    }
461
462    pub fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
463        self.persist_rules(rules, global_scripts)
464    }
465}
466
467#[async_trait::async_trait]
468impl ConfigManagerTrait for ConfigManager {
469    async fn reload(&self) -> Result<()> {
470        self.reload().await
471    }
472
473    fn get_config(&self) -> Arc<Config> {
474        self.get_config()
475    }
476
477    fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
478        self.update_rules(rules, global_scripts)
479    }
480
481    fn add_route(&self, rule: ProxyRule) -> Result<()> {
482        self.add_route(rule)
483    }
484
485    fn remove_route(&self, index: usize) -> Result<()> {
486        self.remove_route(index)
487    }
488}
489
490/// Extract `@script:a.lua,b.lua` from a string, returning (remaining_str, scripts_vec).
491fn extract_scripts(s: &str) -> (&str, Vec<String>) {
492    if let Some(idx) = s.find("@script:") {
493        let before = s[..idx].trim();
494        let after = &s[idx + "@script:".len()..];
495        // Scripts are comma-separated, ending at whitespace or end-of-string
496        let script_part = after.split_whitespace().next().unwrap_or(after);
497        let scripts: Vec<String> = script_part
498            .split(',')
499            .map(|s| s.trim().to_string())
500            .filter(|s| !s.is_empty())
501            .collect();
502        (before, scripts)
503    } else {
504        (s, Vec::new())
505    }
506}
507
508/// Extract `@auth:user:hash` entries from a string, returning (remaining_str, auth_vec).
509/// Multiple @auth entries can appear: `@auth:user1:hash1 @auth:user2:hash2`
510/// Hash is everything after the second colon (bcrypt hashes start with $2a$, $2b$, $2y$)
511fn extract_auth(s: &str) -> (String, Vec<BasicAuth>) {
512    let mut auth_entries = Vec::new();
513    let mut remaining = s.to_string();
514
515    while let Some(idx) = remaining.find("@auth:") {
516        // Keep the part BEFORE @auth:
517        let before = &remaining[..idx];
518        let after = &remaining[idx + "@auth:".len()..];
519
520        // Find the end of this auth entry (whitespace or end of string)
521        let end_idx = after
522            .find(|c: char| c.is_whitespace())
523            .unwrap_or(after.len());
524
525        let auth_part = &after[..end_idx];
526
527        // Parse username:hash - hash is everything after the first colon
528        if let Some((username, hash)) = auth_part.split_once(':') {
529            if !username.is_empty() && !hash.is_empty() {
530                auth_entries.push(BasicAuth {
531                    username: username.to_string(),
532                    hash: hash.to_string(),
533                });
534            }
535        }
536
537        // Continue with the part BEFORE this @auth, plus any remaining after it
538        let rest = &after[end_idx..];
539        remaining = if rest.is_empty() {
540            before.to_string()
541        } else {
542            format!("{}{}", before, rest)
543        };
544    }
545
546    (remaining.trim().to_string(), auth_entries)
547}
548
549/// Extract `@lb:strategy` from a string, returning (remaining_str, strategy).
550/// Example: `@lb:round-robin` or `@lb:weighted` or `@lb:failover`
551fn extract_load_balancing(s: &str) -> (String, LoadBalancingStrategy) {
552    if let Some(idx) = s.find("@lb:") {
553        let before = &s[..idx];
554        let after = &s[idx + "@lb:".len()..];
555
556        let end_idx = after
557            .find(|c: char| c.is_whitespace())
558            .unwrap_or(after.len());
559
560        let strategy_str = &after[..end_idx];
561        let strategy = match strategy_str {
562            "round-robin" => LoadBalancingStrategy::RoundRobin,
563            "weighted" => LoadBalancingStrategy::Weighted,
564            "failover" => LoadBalancingStrategy::Failover,
565            _ => LoadBalancingStrategy::default(),
566        };
567
568        let rest = &after[end_idx..];
569        let remaining = if rest.is_empty() {
570            before.to_string()
571        } else {
572            format!("{}{}", before, rest)
573        };
574
575        (remaining.trim().to_string(), strategy)
576    } else {
577        (s.to_string(), LoadBalancingStrategy::default())
578    }
579}
580
581fn parse_proxy_config(content: &str) -> Result<(Vec<ProxyRule>, Vec<String>)> {
582    let mut rules = Vec::new();
583    let mut global_scripts = Vec::new();
584
585    // Join continuation lines (backslash at end of line)
586    let mut joined_lines: Vec<String> = Vec::new();
587    for line in content.lines() {
588        if let Some(current) = joined_lines.last_mut() {
589            if current.ends_with('\\') {
590                current.pop(); // remove the backslash
591                current.push_str(line.trim());
592                continue;
593            }
594        }
595        joined_lines.push(line.to_string());
596    }
597
598    for line in &joined_lines {
599        let trimmed = line.trim();
600        if trimmed.is_empty() || trimmed.starts_with('#') {
601            continue;
602        }
603
604        // Handle [global] @script:cors.lua,logging.lua
605        if trimmed.starts_with("[global]") {
606            let rest = trimmed.strip_prefix("[global]").unwrap().trim();
607            let (_, scripts) = extract_scripts(rest);
608            global_scripts.extend(scripts);
609            continue;
610        }
611
612        if let Some((source, target_str)) = trimmed.split_once("->") {
613            let source = source.trim();
614            // Extract @script: from the target side
615            let (target_str, route_scripts) = extract_scripts(target_str.trim());
616            // Extract @auth: entries from the target side
617            let (target_str, auth_entries) = extract_auth(target_str);
618            // Extract @lb: load balancing strategy
619            let (target_str, load_balancing) = extract_load_balancing(&target_str);
620
621            let matcher = if source == "default" || source == "*" {
622                RuleMatcher::Default
623            } else if let Some(pattern) = source.strip_prefix("~") {
624                RuleMatcher::Regex(RegexMatcher::new(pattern)?)
625            } else if !source.starts_with('/')
626                && (source.contains('.') || source.parse::<std::net::IpAddr>().is_ok())
627            {
628                if let Some((domain, path)) = source.split_once('/') {
629                    if path.is_empty() || path == "*" {
630                        RuleMatcher::Domain(domain.to_string())
631                    } else if path.ends_with("/*") {
632                        RuleMatcher::DomainPath(
633                            domain.to_string(),
634                            path.trim_end_matches('*').to_string(),
635                        )
636                    } else {
637                        RuleMatcher::DomainPath(domain.to_string(), path.to_string())
638                    }
639                } else {
640                    RuleMatcher::Domain(source.to_string())
641                }
642            } else if source.ends_with("/*") {
643                RuleMatcher::Prefix(source.trim_end_matches('*').to_string())
644            } else {
645                RuleMatcher::Exact(source.to_string())
646            };
647
648            let targets: Vec<Target> = target_str
649                .split(',')
650                .map(|t| {
651                    Ok(Target {
652                        url: Url::parse(t.trim())?,
653                        weight: 100,
654                    })
655                })
656                .collect::<Result<Vec<_>>>()?;
657
658            rules.push(ProxyRule {
659                matcher,
660                targets,
661                headers: vec![],
662                scripts: route_scripts,
663                auth: auth_entries,
664                load_balancing,
665            });
666        }
667    }
668
669    Ok((rules, global_scripts))
670}
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675
676    #[test]
677    fn test_backslash_continuation_joins_lines() {
678        let config = "/api/* -> http://backend1:8080, \\\n          http://backend2:8080\n";
679        let (rules, _) = parse_proxy_config(config).unwrap();
680        assert_eq!(rules.len(), 1);
681        assert_eq!(rules[0].targets.len(), 2);
682        assert_eq!(rules[0].targets[0].url.as_str(), "http://backend1:8080/");
683        assert_eq!(rules[0].targets[1].url.as_str(), "http://backend2:8080/");
684    }
685
686    #[test]
687    fn test_multiple_continuation_lines() {
688        let config = "/api/* -> http://backend1:8080, \\\n\
689                       http://backend2:8080, \\\n\
690                       http://backend3:8080\n";
691        let (rules, _) = parse_proxy_config(config).unwrap();
692        assert_eq!(rules.len(), 1);
693        assert_eq!(rules[0].targets.len(), 3);
694        assert_eq!(rules[0].targets[2].url.as_str(), "http://backend3:8080/");
695    }
696
697    #[test]
698    fn test_backslash_mid_line_not_continuation() {
699        let config = "/path -> http://localhost:8080\n\
700                       ~^/foo\\dbar$ -> http://localhost:9090\n";
701        let (rules, _) = parse_proxy_config(config).unwrap();
702        assert_eq!(rules.len(), 2);
703    }
704
705    #[test]
706    fn test_continuation_trims_whitespace() {
707        let config = "/api/* -> http://a:8080,   \\\n   http://b:8080,  \\\n   http://c:8080\n";
708        let (rules, _) = parse_proxy_config(config).unwrap();
709        assert_eq!(rules.len(), 1);
710        assert_eq!(rules[0].targets.len(), 3);
711    }
712
713    #[test]
714    fn test_continuation_with_scripts() {
715        let config = "/api/* -> http://a:8080, \\\n\
716                       http://b:8080 @script:auth.lua\n";
717        let (rules, _) = parse_proxy_config(config).unwrap();
718        assert_eq!(rules.len(), 1);
719        assert_eq!(rules[0].targets.len(), 2);
720        assert_eq!(rules[0].scripts, vec!["auth.lua"]);
721    }
722
723    #[test]
724    fn test_no_continuation_normal_config() {
725        let config = "/api/* -> http://backend:8080\ndefault -> http://localhost:3000\n";
726        let (rules, _) = parse_proxy_config(config).unwrap();
727        assert_eq!(rules.len(), 2);
728    }
729
730    #[test]
731    fn test_auth_parsing() {
732        let config = r#"
733/db/* -> http://localhost:8080/ @auth:demo:$2b$12$YFlnIiACnSaAcxDWQlYjeedxq/3GvhvoGhRTYHMqLifJrETSqOZQa
734"#;
735        let (rules, _) = parse_proxy_config(config).unwrap();
736        assert_eq!(rules.len(), 1);
737        assert_eq!(rules[0].auth.len(), 1);
738        assert_eq!(rules[0].auth[0].username, "demo");
739        assert!(rules[0].auth[0].hash.starts_with("$2b$"));
740        assert_eq!(rules[0].targets.len(), 1);
741        assert_eq!(rules[0].targets[0].url.as_str(), "http://localhost:8080/");
742    }
743
744    #[test]
745    fn test_multiple_auth_users() {
746        let config = r#"
747secure.example.com -> http://localhost:9000/ @auth:admin:$2b$12$hash1 @auth:user:$2b$12$hash2
748"#;
749        let (rules, _) = parse_proxy_config(config).unwrap();
750        assert_eq!(rules.len(), 1);
751        assert_eq!(rules[0].auth.len(), 2);
752        assert_eq!(rules[0].auth[0].username, "admin");
753        assert_eq!(rules[0].auth[1].username, "user");
754    }
755
756    #[test]
757    fn test_load_balancing_round_robin() {
758        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:round-robin";
759        let (rules, _) = parse_proxy_config(config).unwrap();
760        assert_eq!(rules.len(), 1);
761        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
762        assert_eq!(rules[0].targets.len(), 2);
763    }
764
765    #[test]
766    fn test_load_balancing_weighted() {
767        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:weighted";
768        let (rules, _) = parse_proxy_config(config).unwrap();
769        assert_eq!(rules.len(), 1);
770        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Weighted);
771    }
772
773    #[test]
774    fn test_load_balancing_failover() {
775        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover";
776        let (rules, _) = parse_proxy_config(config).unwrap();
777        assert_eq!(rules.len(), 1);
778        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
779    }
780
781    #[test]
782    fn test_load_balancing_default_is_round_robin() {
783        let config = "/api/* -> http://b1:8080, http://b2:8080";
784        let (rules, _) = parse_proxy_config(config).unwrap();
785        assert_eq!(rules.len(), 1);
786        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
787    }
788
789    #[test]
790    fn test_load_balancing_with_scripts() {
791        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover @script:auth.lua";
792        let (rules, _) = parse_proxy_config(config).unwrap();
793        assert_eq!(rules.len(), 1);
794        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
795        assert_eq!(rules[0].scripts, vec!["auth.lua"]);
796    }
797
798    #[test]
799    fn test_load_balancing_unknown_strategy_defaults_to_round_robin() {
800        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:unknown";
801        let (rules, _) = parse_proxy_config(config).unwrap();
802        assert_eq!(rules.len(), 1);
803        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
804    }
805}