1pub mod serializer;
2
3use crate::auth::BasicAuth;
4use anyhow::Result;
5use arc_swap::ArcSwap;
6use notify::{RecommendedWatcher, RecursiveMode, Watcher};
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use url::Url;
14
15#[async_trait::async_trait]
16pub trait ConfigManagerTrait: Send + Sync {
17 async fn reload(&self) -> Result<()>;
18 fn get_config(&self) -> Arc<Config>;
19 fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()>;
20 fn add_route(&self, rule: ProxyRule) -> Result<()>;
21 fn remove_route(&self, index: usize) -> Result<()>;
22}
23
24#[derive(Deserialize, Default, Clone, Debug)]
25pub struct TomlConfig {
26 #[serde(default)]
27 pub server: ServerConfig,
28 #[serde(default)]
29 pub tls: TlsConfig,
30 pub letsencrypt: Option<LetsEncryptConfig>,
31 pub scripting: Option<ScriptingTomlConfig>,
32 pub admin: Option<AdminConfig>,
33 pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
34}
35
36#[derive(Deserialize, Serialize, Clone, Debug)]
37pub struct CircuitBreakerTomlConfig {
38 pub failure_threshold: Option<u32>,
39 pub recovery_timeout_secs: Option<u64>,
40 pub success_threshold: Option<u32>,
41 pub failure_status_codes: Option<Vec<u16>>,
42}
43
44#[derive(Deserialize, Serialize, Clone, Debug)]
45pub struct AdminConfig {
46 pub enabled: bool,
47 pub bind: String,
48 pub api_key: Option<String>,
49}
50
51impl Default for AdminConfig {
52 fn default() -> Self {
53 Self {
54 enabled: false,
55 bind: "127.0.0.1:9090".to_string(),
56 api_key: None,
57 }
58 }
59}
60
61#[derive(Deserialize, Serialize, Clone, Debug, Default)]
62pub struct ScriptingTomlConfig {
63 pub enabled: bool,
64 pub scripts_dir: Option<String>,
65 pub hook_timeout_ms: Option<u64>,
66}
67
68#[derive(Deserialize, Serialize, Default, Clone, Debug)]
69pub struct ServerConfig {
70 pub bind: String,
71 pub https_port: u16,
72}
73
74#[derive(Deserialize, Serialize, Default, Clone, Debug)]
75pub struct TlsConfig {
76 pub mode: String,
77 pub cache_dir: String,
78}
79
80#[derive(Deserialize, Serialize, Clone, Debug)]
81pub struct LetsEncryptConfig {
82 pub staging: bool,
83 pub email: String,
84 pub terms_agreed: bool,
85}
86
87#[derive(Clone, Debug, Serialize)]
88pub struct Config {
89 pub server: ServerConfig,
90 pub tls: TlsConfig,
91 pub letsencrypt: Option<LetsEncryptConfig>,
92 pub scripting: ScriptingTomlConfig,
93 pub admin: AdminConfig,
94 pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
95 pub rules: Vec<ProxyRule>,
96 pub global_scripts: Vec<String>,
97}
98
99#[derive(Clone, Debug, Serialize, Deserialize)]
100pub struct ProxyRule {
101 pub matcher: RuleMatcher,
102 pub targets: Vec<Target>,
103 pub headers: Vec<HeaderRule>,
104 pub scripts: Vec<String>,
105 #[serde(default)]
106 pub auth: Vec<BasicAuth>,
107}
108
109#[derive(Clone, Debug)]
110pub enum RuleMatcher {
111 Default,
112 Prefix(String),
113 Regex(RegexMatcher),
114 Exact(String),
115 Domain(String),
116 DomainPath(String, String),
117}
118
119#[derive(Clone, Debug)]
121pub struct RegexMatcher {
122 pub pattern: String,
123 pub regex: Regex,
124}
125
126impl RegexMatcher {
127 pub fn new(pattern: &str) -> Result<Self> {
128 Ok(Self {
129 pattern: pattern.to_string(),
130 regex: Regex::new(pattern)?,
131 })
132 }
133
134 pub fn is_match(&self, text: &str) -> bool {
135 self.regex.is_match(text)
136 }
137}
138
139impl Serialize for RuleMatcher {
140 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
141 where
142 S: serde::Serializer,
143 {
144 use serde::ser::SerializeMap;
145 let mut map = serializer.serialize_map(Some(2))?;
146 match self {
147 RuleMatcher::Default => {
148 map.serialize_entry("type", "default")?;
149 }
150 RuleMatcher::Prefix(v) => {
151 map.serialize_entry("type", "prefix")?;
152 map.serialize_entry("value", v)?;
153 }
154 RuleMatcher::Regex(rm) => {
155 map.serialize_entry("type", "regex")?;
156 map.serialize_entry("value", &rm.pattern)?;
157 }
158 RuleMatcher::Exact(v) => {
159 map.serialize_entry("type", "exact")?;
160 map.serialize_entry("value", v)?;
161 }
162 RuleMatcher::Domain(v) => {
163 map.serialize_entry("type", "domain")?;
164 map.serialize_entry("value", v)?;
165 }
166 RuleMatcher::DomainPath(d, p) => {
167 map.serialize_entry("type", "domain_path")?;
168 map.serialize_entry("domain", d)?;
169 map.serialize_entry("path", p)?;
170 }
171 }
172 map.end()
173 }
174}
175
176impl<'de> Deserialize<'de> for RuleMatcher {
177 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
178 where
179 D: serde::Deserializer<'de>,
180 {
181 use serde::de::Error;
182 let value: serde_json::Value = Deserialize::deserialize(deserializer)?;
183 let obj = value
184 .as_object()
185 .ok_or_else(|| D::Error::custom("expected object"))?;
186 let matcher_type = obj
187 .get("type")
188 .and_then(|v| v.as_str())
189 .ok_or_else(|| D::Error::custom("missing 'type' field"))?;
190
191 match matcher_type {
192 "default" => Ok(RuleMatcher::Default),
193 "exact" => {
194 let v = obj
195 .get("value")
196 .and_then(|v| v.as_str())
197 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
198 Ok(RuleMatcher::Exact(v.to_string()))
199 }
200 "prefix" => {
201 let v = obj
202 .get("value")
203 .and_then(|v| v.as_str())
204 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
205 Ok(RuleMatcher::Prefix(v.to_string()))
206 }
207 "regex" => {
208 let v = obj
209 .get("value")
210 .and_then(|v| v.as_str())
211 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
212 let rm = RegexMatcher::new(v)
213 .map_err(|e| D::Error::custom(format!("invalid regex: {}", e)))?;
214 Ok(RuleMatcher::Regex(rm))
215 }
216 "domain" => {
217 let v = obj
218 .get("value")
219 .and_then(|v| v.as_str())
220 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
221 Ok(RuleMatcher::Domain(v.to_string()))
222 }
223 "domain_path" => {
224 let d = obj
225 .get("domain")
226 .and_then(|v| v.as_str())
227 .ok_or_else(|| D::Error::custom("missing 'domain'"))?;
228 let p = obj
229 .get("path")
230 .and_then(|v| v.as_str())
231 .ok_or_else(|| D::Error::custom("missing 'path'"))?;
232 Ok(RuleMatcher::DomainPath(d.to_string(), p.to_string()))
233 }
234 other => Err(D::Error::custom(format!("unknown matcher type: {}", other))),
235 }
236 }
237}
238
239#[derive(Clone, Debug, Serialize, Deserialize)]
240pub struct Target {
241 pub url: Url,
242 pub weight: u8,
243}
244
245#[derive(Clone, Debug, Serialize, Deserialize)]
246pub struct HeaderRule {
247 pub name: String,
248 pub value: String,
249}
250
251impl Config {
252 pub fn acme_domains(&self) -> Vec<String> {
255 let mut domains = Vec::new();
256 let mut seen = std::collections::HashSet::new();
257
258 for rule in &self.rules {
259 let domain = match &rule.matcher {
260 RuleMatcher::Domain(d) => Some(d.as_str()),
261 RuleMatcher::DomainPath(d, _) => Some(d.as_str()),
262 _ => None,
263 };
264
265 if let Some(d) = domain {
266 if d == "localhost" || d.parse::<std::net::IpAddr>().is_ok() {
267 continue;
268 }
269 if seen.insert(d.to_string()) {
270 domains.push(d.to_string());
271 }
272 }
273 }
274
275 domains
276 }
277}
278
279pub struct ConfigManager {
280 config: ArcSwap<Config>,
281 config_path: PathBuf,
282 _watcher: Option<RecommendedWatcher>,
283 suppress_watch: Arc<AtomicBool>,
284}
285
286impl Clone for ConfigManager {
287 fn clone(&self) -> Self {
288 Self {
289 config: ArcSwap::new(self.config.load().clone()),
290 config_path: self.config_path.clone(),
291 _watcher: None,
292 suppress_watch: self.suppress_watch.clone(),
293 }
294 }
295}
296
297impl ConfigManager {
298 pub fn new(config_path: &str) -> Result<Self> {
299 let path = PathBuf::from(config_path);
300 let config = Self::load_config(&path, &path)?;
301 Ok(Self {
302 config: ArcSwap::new(Arc::new(config)),
303 config_path: path,
304 _watcher: None,
305 suppress_watch: Arc::new(AtomicBool::new(false)),
306 })
307 }
308
309 pub fn config_path(&self) -> &Path {
310 &self.config_path
311 }
312
313 pub fn suppress_watch(&self) -> &Arc<AtomicBool> {
314 &self.suppress_watch
315 }
316
317 fn load_config(proxy_conf_path: &Path, config_path: &Path) -> Result<Config> {
318 let content = std::fs::read_to_string(proxy_conf_path).unwrap_or_default();
319 let (rules, global_scripts) = parse_proxy_config(&content)?;
320 let toml_content = std::fs::read_to_string(
321 config_path
322 .parent()
323 .unwrap_or(Path::new("."))
324 .join("config.toml"),
325 )
326 .ok();
327 let toml_config: TomlConfig = toml_content
328 .as_ref()
329 .and_then(|c| toml::from_str(c).ok())
330 .unwrap_or_default();
331
332 Ok(Config {
333 server: toml_config.server,
334 tls: toml_config.tls,
335 letsencrypt: toml_config.letsencrypt,
336 scripting: toml_config.scripting.unwrap_or_default(),
337 admin: toml_config.admin.unwrap_or_default(),
338 circuit_breaker: toml_config.circuit_breaker,
339 rules,
340 global_scripts,
341 })
342 }
343
344 pub fn get_config(&self) -> Arc<Config> {
345 self.config.load().clone()
346 }
347
348 pub fn start_watcher(&self) -> Result<()> {
349 if !self.config_path.exists() {
351 if let Some(parent) = self.config_path.parent() {
352 std::fs::create_dir_all(parent).ok();
353 }
354 std::fs::write(&self.config_path, "")?;
355 }
356
357 let (tx, mut rx) = mpsc::channel(1);
358 let config_path = self.config_path.clone();
359 let suppress = self.suppress_watch.clone();
360
361 let mut watcher = RecommendedWatcher::new(
362 move |res| {
363 let _ = tx.blocking_send(res);
364 },
365 notify::Config::default(),
366 )?;
367
368 watcher.watch(&config_path, RecursiveMode::NonRecursive)?;
369
370 tracing::info!("Watching config file: {}", config_path.display());
371
372 std::thread::spawn(move || {
373 while let Some(res) = rx.blocking_recv() {
374 match res {
375 Ok(event) => {
376 if event.kind.is_modify() {
377 if suppress.swap(false, Ordering::SeqCst) {
378 tracing::debug!(
379 "Suppressing file watcher reload (admin API write)"
380 );
381 continue;
382 }
383 tracing::info!("Config file changed, reloading...");
384 }
385 }
386 Err(e) => tracing::error!("Watch error: {}", e),
387 }
388 }
389 });
390
391 Ok(())
392 }
393
394 pub async fn reload(&self) -> Result<()> {
395 let new_config = Self::load_config(&self.config_path, &self.config_path)?;
396 self.config.store(Arc::new(new_config));
397 tracing::info!("Configuration reloaded successfully");
398 Ok(())
399 }
400
401 fn persist_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
403 let content = serializer::serialize_proxy_conf(&rules, &global_scripts);
404 self.suppress_watch.store(true, Ordering::SeqCst);
405 std::fs::write(&self.config_path, &content)?;
406 let mut config = (*self.config.load().as_ref()).clone();
407 config.rules = rules;
408 config.global_scripts = global_scripts;
409 self.config.store(Arc::new(config));
410 tracing::info!("Configuration persisted to {}", self.config_path.display());
411 Ok(())
412 }
413
414 pub fn add_route(&self, rule: ProxyRule) -> Result<()> {
415 let cfg = self.get_config();
416 let mut rules = cfg.rules.clone();
417 rules.push(rule);
418 self.persist_rules(rules, cfg.global_scripts.clone())
419 }
420
421 pub fn update_route(&self, index: usize, rule: ProxyRule) -> Result<()> {
422 let cfg = self.get_config();
423 let mut rules = cfg.rules.clone();
424 if index >= rules.len() {
425 anyhow::bail!(
426 "Route index {} out of range (have {} routes)",
427 index,
428 rules.len()
429 );
430 }
431 rules[index] = rule;
432 self.persist_rules(rules, cfg.global_scripts.clone())
433 }
434
435 pub fn remove_route(&self, index: usize) -> Result<()> {
436 let cfg = self.get_config();
437 let mut rules = cfg.rules.clone();
438 if index >= rules.len() {
439 anyhow::bail!(
440 "Route index {} out of range (have {} routes)",
441 index,
442 rules.len()
443 );
444 }
445 rules.remove(index);
446 self.persist_rules(rules, cfg.global_scripts.clone())
447 }
448
449 pub fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
450 self.persist_rules(rules, global_scripts)
451 }
452}
453
454#[async_trait::async_trait]
455impl ConfigManagerTrait for ConfigManager {
456 async fn reload(&self) -> Result<()> {
457 self.reload().await
458 }
459
460 fn get_config(&self) -> Arc<Config> {
461 self.get_config()
462 }
463
464 fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
465 self.update_rules(rules, global_scripts)
466 }
467
468 fn add_route(&self, rule: ProxyRule) -> Result<()> {
469 self.add_route(rule)
470 }
471
472 fn remove_route(&self, index: usize) -> Result<()> {
473 self.remove_route(index)
474 }
475}
476
477fn extract_scripts(s: &str) -> (&str, Vec<String>) {
479 if let Some(idx) = s.find("@script:") {
480 let before = s[..idx].trim();
481 let after = &s[idx + "@script:".len()..];
482 let script_part = after.split_whitespace().next().unwrap_or(after);
484 let scripts: Vec<String> = script_part
485 .split(',')
486 .map(|s| s.trim().to_string())
487 .filter(|s| !s.is_empty())
488 .collect();
489 (before, scripts)
490 } else {
491 (s, Vec::new())
492 }
493}
494
495fn extract_auth(s: &str) -> (String, Vec<BasicAuth>) {
499 let mut auth_entries = Vec::new();
500 let mut remaining = s.to_string();
501
502 while let Some(idx) = remaining.find("@auth:") {
503 let before = &remaining[..idx];
505 let after = &remaining[idx + "@auth:".len()..];
506
507 let end_idx = after
509 .find(|c: char| c.is_whitespace())
510 .unwrap_or(after.len());
511
512 let auth_part = &after[..end_idx];
513
514 if let Some((username, hash)) = auth_part.split_once(':') {
516 if !username.is_empty() && !hash.is_empty() {
517 auth_entries.push(BasicAuth {
518 username: username.to_string(),
519 hash: hash.to_string(),
520 });
521 }
522 }
523
524 let rest = &after[end_idx..];
526 remaining = if rest.is_empty() {
527 before.to_string()
528 } else {
529 format!("{}{}", before, rest)
530 };
531 }
532
533 (remaining.trim().to_string(), auth_entries)
534}
535
536fn parse_proxy_config(content: &str) -> Result<(Vec<ProxyRule>, Vec<String>)> {
537 let mut rules = Vec::new();
538 let mut global_scripts = Vec::new();
539
540 let mut joined_lines: Vec<String> = Vec::new();
542 for line in content.lines() {
543 if let Some(current) = joined_lines.last_mut() {
544 if current.ends_with('\\') {
545 current.pop(); current.push_str(line.trim());
547 continue;
548 }
549 }
550 joined_lines.push(line.to_string());
551 }
552
553 for line in &joined_lines {
554 let trimmed = line.trim();
555 if trimmed.is_empty() || trimmed.starts_with('#') {
556 continue;
557 }
558
559 if trimmed.starts_with("[global]") {
561 let rest = trimmed.strip_prefix("[global]").unwrap().trim();
562 let (_, scripts) = extract_scripts(rest);
563 global_scripts.extend(scripts);
564 continue;
565 }
566
567 if let Some((source, target_str)) = trimmed.split_once("->") {
568 let source = source.trim();
569 let (target_str, route_scripts) = extract_scripts(target_str.trim());
571 let (target_str, auth_entries) = extract_auth(target_str);
573
574 let matcher = if source == "default" || source == "*" {
575 RuleMatcher::Default
576 } else if let Some(pattern) = source.strip_prefix("~") {
577 RuleMatcher::Regex(RegexMatcher::new(pattern)?)
578 } else if !source.starts_with('/')
579 && (source.contains('.') || source.parse::<std::net::IpAddr>().is_ok())
580 {
581 if let Some((domain, path)) = source.split_once('/') {
582 if path.is_empty() || path == "*" {
583 RuleMatcher::Domain(domain.to_string())
584 } else if path.ends_with("/*") {
585 RuleMatcher::DomainPath(
586 domain.to_string(),
587 path.trim_end_matches('*').to_string(),
588 )
589 } else {
590 RuleMatcher::DomainPath(domain.to_string(), path.to_string())
591 }
592 } else {
593 RuleMatcher::Domain(source.to_string())
594 }
595 } else if source.ends_with("/*") {
596 RuleMatcher::Prefix(source.trim_end_matches('*').to_string())
597 } else {
598 RuleMatcher::Exact(source.to_string())
599 };
600
601 let targets: Vec<Target> = target_str
602 .split(',')
603 .map(|t| {
604 Ok(Target {
605 url: Url::parse(t.trim())?,
606 weight: 100,
607 })
608 })
609 .collect::<Result<Vec<_>>>()?;
610
611 rules.push(ProxyRule {
612 matcher,
613 targets,
614 headers: vec![],
615 scripts: route_scripts,
616 auth: auth_entries,
617 });
618 }
619 }
620
621 Ok((rules, global_scripts))
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627
628 #[test]
629 fn test_backslash_continuation_joins_lines() {
630 let config = "/api/* -> http://backend1:8080, \\\n http://backend2:8080\n";
631 let (rules, _) = parse_proxy_config(config).unwrap();
632 assert_eq!(rules.len(), 1);
633 assert_eq!(rules[0].targets.len(), 2);
634 assert_eq!(rules[0].targets[0].url.as_str(), "http://backend1:8080/");
635 assert_eq!(rules[0].targets[1].url.as_str(), "http://backend2:8080/");
636 }
637
638 #[test]
639 fn test_multiple_continuation_lines() {
640 let config = "/api/* -> http://backend1:8080, \\\n\
641 http://backend2:8080, \\\n\
642 http://backend3:8080\n";
643 let (rules, _) = parse_proxy_config(config).unwrap();
644 assert_eq!(rules.len(), 1);
645 assert_eq!(rules[0].targets.len(), 3);
646 assert_eq!(rules[0].targets[2].url.as_str(), "http://backend3:8080/");
647 }
648
649 #[test]
650 fn test_backslash_mid_line_not_continuation() {
651 let config = "/path -> http://localhost:8080\n\
652 ~^/foo\\dbar$ -> http://localhost:9090\n";
653 let (rules, _) = parse_proxy_config(config).unwrap();
654 assert_eq!(rules.len(), 2);
655 }
656
657 #[test]
658 fn test_continuation_trims_whitespace() {
659 let config = "/api/* -> http://a:8080, \\\n http://b:8080, \\\n http://c:8080\n";
660 let (rules, _) = parse_proxy_config(config).unwrap();
661 assert_eq!(rules.len(), 1);
662 assert_eq!(rules[0].targets.len(), 3);
663 }
664
665 #[test]
666 fn test_continuation_with_scripts() {
667 let config = "/api/* -> http://a:8080, \\\n\
668 http://b:8080 @script:auth.lua\n";
669 let (rules, _) = parse_proxy_config(config).unwrap();
670 assert_eq!(rules.len(), 1);
671 assert_eq!(rules[0].targets.len(), 2);
672 assert_eq!(rules[0].scripts, vec!["auth.lua"]);
673 }
674
675 #[test]
676 fn test_no_continuation_normal_config() {
677 let config = "/api/* -> http://backend:8080\ndefault -> http://localhost:3000\n";
678 let (rules, _) = parse_proxy_config(config).unwrap();
679 assert_eq!(rules.len(), 2);
680 }
681
682 #[test]
683 fn test_auth_parsing() {
684 let config = r#"
685/db/* -> http://localhost:8080/ @auth:demo:$2b$12$YFlnIiACnSaAcxDWQlYjeedxq/3GvhvoGhRTYHMqLifJrETSqOZQa
686"#;
687 let (rules, _) = parse_proxy_config(config).unwrap();
688 assert_eq!(rules.len(), 1);
689 assert_eq!(rules[0].auth.len(), 1);
690 assert_eq!(rules[0].auth[0].username, "demo");
691 assert!(rules[0].auth[0].hash.starts_with("$2b$"));
692 assert_eq!(rules[0].targets.len(), 1);
693 assert_eq!(rules[0].targets[0].url.as_str(), "http://localhost:8080/");
694 }
695
696 #[test]
697 fn test_multiple_auth_users() {
698 let config = r#"
699secure.example.com -> http://localhost:9000/ @auth:admin:$2b$12$hash1 @auth:user:$2b$12$hash2
700"#;
701 let (rules, _) = parse_proxy_config(config).unwrap();
702 assert_eq!(rules.len(), 1);
703 assert_eq!(rules[0].auth.len(), 2);
704 assert_eq!(rules[0].auth[0].username, "admin");
705 assert_eq!(rules[0].auth[1].username, "user");
706 }
707}