1pub mod serializer;
2
3use anyhow::Result;
4use arc_swap::ArcSwap;
5use notify::{RecommendedWatcher, RecursiveMode, Watcher};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use std::path::{Path, PathBuf};
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::Arc;
11use tokio::sync::mpsc;
12use url::Url;
13
14#[async_trait::async_trait]
15pub trait ConfigManagerTrait: Send + Sync {
16 async fn reload(&self) -> Result<()>;
17 fn get_config(&self) -> Arc<Config>;
18 fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()>;
19 fn add_route(&self, rule: ProxyRule) -> Result<()>;
20 fn remove_route(&self, index: usize) -> Result<()>;
21}
22
23#[derive(Deserialize, Default, Clone, Debug)]
24pub struct TomlConfig {
25 #[serde(default)]
26 pub server: ServerConfig,
27 #[serde(default)]
28 pub tls: TlsConfig,
29 pub letsencrypt: Option<LetsEncryptConfig>,
30 pub scripting: Option<ScriptingTomlConfig>,
31 pub admin: Option<AdminConfig>,
32 pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
33}
34
35#[derive(Deserialize, Serialize, Clone, Debug)]
36pub struct CircuitBreakerTomlConfig {
37 pub failure_threshold: Option<u32>,
38 pub recovery_timeout_secs: Option<u64>,
39 pub success_threshold: Option<u32>,
40 pub failure_status_codes: Option<Vec<u16>>,
41}
42
43#[derive(Deserialize, Serialize, Clone, Debug)]
44pub struct AdminConfig {
45 pub enabled: bool,
46 pub bind: String,
47 pub api_key: Option<String>,
48}
49
50impl Default for AdminConfig {
51 fn default() -> Self {
52 Self {
53 enabled: false,
54 bind: "127.0.0.1:9090".to_string(),
55 api_key: None,
56 }
57 }
58}
59
60#[derive(Deserialize, Serialize, Clone, Debug, Default)]
61pub struct ScriptingTomlConfig {
62 pub enabled: bool,
63 pub scripts_dir: Option<String>,
64 pub hook_timeout_ms: Option<u64>,
65}
66
67#[derive(Deserialize, Serialize, Default, Clone, Debug)]
68pub struct ServerConfig {
69 pub bind: String,
70 pub https_port: u16,
71}
72
73#[derive(Deserialize, Serialize, Default, Clone, Debug)]
74pub struct TlsConfig {
75 pub mode: String,
76 pub cache_dir: String,
77}
78
79#[derive(Deserialize, Serialize, Clone, Debug)]
80pub struct LetsEncryptConfig {
81 pub staging: bool,
82 pub email: String,
83 pub terms_agreed: bool,
84}
85
86#[derive(Clone, Debug, Serialize)]
87pub struct Config {
88 pub server: ServerConfig,
89 pub tls: TlsConfig,
90 pub letsencrypt: Option<LetsEncryptConfig>,
91 pub scripting: ScriptingTomlConfig,
92 pub admin: AdminConfig,
93 pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
94 pub rules: Vec<ProxyRule>,
95 pub global_scripts: Vec<String>,
96}
97
98#[derive(Clone, Debug, Serialize, Deserialize)]
99pub struct ProxyRule {
100 pub matcher: RuleMatcher,
101 pub targets: Vec<Target>,
102 pub headers: Vec<HeaderRule>,
103 pub scripts: Vec<String>,
104}
105
106#[derive(Clone, Debug)]
107pub enum RuleMatcher {
108 Default,
109 Prefix(String),
110 Regex(RegexMatcher),
111 Exact(String),
112 Domain(String),
113 DomainPath(String, String),
114}
115
116#[derive(Clone, Debug)]
118pub struct RegexMatcher {
119 pub pattern: String,
120 pub regex: Regex,
121}
122
123impl RegexMatcher {
124 pub fn new(pattern: &str) -> Result<Self> {
125 Ok(Self {
126 pattern: pattern.to_string(),
127 regex: Regex::new(pattern)?,
128 })
129 }
130
131 pub fn is_match(&self, text: &str) -> bool {
132 self.regex.is_match(text)
133 }
134}
135
136impl Serialize for RuleMatcher {
137 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
138 where
139 S: serde::Serializer,
140 {
141 use serde::ser::SerializeMap;
142 let mut map = serializer.serialize_map(Some(2))?;
143 match self {
144 RuleMatcher::Default => {
145 map.serialize_entry("type", "default")?;
146 }
147 RuleMatcher::Prefix(v) => {
148 map.serialize_entry("type", "prefix")?;
149 map.serialize_entry("value", v)?;
150 }
151 RuleMatcher::Regex(rm) => {
152 map.serialize_entry("type", "regex")?;
153 map.serialize_entry("value", &rm.pattern)?;
154 }
155 RuleMatcher::Exact(v) => {
156 map.serialize_entry("type", "exact")?;
157 map.serialize_entry("value", v)?;
158 }
159 RuleMatcher::Domain(v) => {
160 map.serialize_entry("type", "domain")?;
161 map.serialize_entry("value", v)?;
162 }
163 RuleMatcher::DomainPath(d, p) => {
164 map.serialize_entry("type", "domain_path")?;
165 map.serialize_entry("domain", d)?;
166 map.serialize_entry("path", p)?;
167 }
168 }
169 map.end()
170 }
171}
172
173impl<'de> Deserialize<'de> for RuleMatcher {
174 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
175 where
176 D: serde::Deserializer<'de>,
177 {
178 use serde::de::Error;
179 let value: serde_json::Value = Deserialize::deserialize(deserializer)?;
180 let obj = value
181 .as_object()
182 .ok_or_else(|| D::Error::custom("expected object"))?;
183 let matcher_type = obj
184 .get("type")
185 .and_then(|v| v.as_str())
186 .ok_or_else(|| D::Error::custom("missing 'type' field"))?;
187
188 match matcher_type {
189 "default" => Ok(RuleMatcher::Default),
190 "exact" => {
191 let v = obj
192 .get("value")
193 .and_then(|v| v.as_str())
194 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
195 Ok(RuleMatcher::Exact(v.to_string()))
196 }
197 "prefix" => {
198 let v = obj
199 .get("value")
200 .and_then(|v| v.as_str())
201 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
202 Ok(RuleMatcher::Prefix(v.to_string()))
203 }
204 "regex" => {
205 let v = obj
206 .get("value")
207 .and_then(|v| v.as_str())
208 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
209 let rm = RegexMatcher::new(v)
210 .map_err(|e| D::Error::custom(format!("invalid regex: {}", e)))?;
211 Ok(RuleMatcher::Regex(rm))
212 }
213 "domain" => {
214 let v = obj
215 .get("value")
216 .and_then(|v| v.as_str())
217 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
218 Ok(RuleMatcher::Domain(v.to_string()))
219 }
220 "domain_path" => {
221 let d = obj
222 .get("domain")
223 .and_then(|v| v.as_str())
224 .ok_or_else(|| D::Error::custom("missing 'domain'"))?;
225 let p = obj
226 .get("path")
227 .and_then(|v| v.as_str())
228 .ok_or_else(|| D::Error::custom("missing 'path'"))?;
229 Ok(RuleMatcher::DomainPath(d.to_string(), p.to_string()))
230 }
231 other => Err(D::Error::custom(format!("unknown matcher type: {}", other))),
232 }
233 }
234}
235
236#[derive(Clone, Debug, Serialize, Deserialize)]
237pub struct Target {
238 pub url: Url,
239 pub weight: u8,
240}
241
242#[derive(Clone, Debug, Serialize, Deserialize)]
243pub struct HeaderRule {
244 pub name: String,
245 pub value: String,
246}
247
248impl Config {
249 pub fn acme_domains(&self) -> Vec<String> {
252 let mut domains = Vec::new();
253 let mut seen = std::collections::HashSet::new();
254
255 for rule in &self.rules {
256 let domain = match &rule.matcher {
257 RuleMatcher::Domain(d) => Some(d.as_str()),
258 RuleMatcher::DomainPath(d, _) => Some(d.as_str()),
259 _ => None,
260 };
261
262 if let Some(d) = domain {
263 if d == "localhost" || d.parse::<std::net::IpAddr>().is_ok() {
264 continue;
265 }
266 if seen.insert(d.to_string()) {
267 domains.push(d.to_string());
268 }
269 }
270 }
271
272 domains
273 }
274}
275
276pub struct ConfigManager {
277 config: ArcSwap<Config>,
278 config_path: PathBuf,
279 _watcher: Option<RecommendedWatcher>,
280 suppress_watch: Arc<AtomicBool>,
281}
282
283impl Clone for ConfigManager {
284 fn clone(&self) -> Self {
285 Self {
286 config: ArcSwap::new(self.config.load().clone()),
287 config_path: self.config_path.clone(),
288 _watcher: None,
289 suppress_watch: self.suppress_watch.clone(),
290 }
291 }
292}
293
294impl ConfigManager {
295 pub fn new(config_path: &str) -> Result<Self> {
296 let path = PathBuf::from(config_path);
297 let config = Self::load_config(&path, &path)?;
298 Ok(Self {
299 config: ArcSwap::new(Arc::new(config)),
300 config_path: path,
301 _watcher: None,
302 suppress_watch: Arc::new(AtomicBool::new(false)),
303 })
304 }
305
306 pub fn config_path(&self) -> &Path {
307 &self.config_path
308 }
309
310 pub fn suppress_watch(&self) -> &Arc<AtomicBool> {
311 &self.suppress_watch
312 }
313
314 fn load_config(proxy_conf_path: &Path, config_path: &Path) -> Result<Config> {
315 let content = std::fs::read_to_string(proxy_conf_path).unwrap_or_default();
316 let (rules, global_scripts) = parse_proxy_config(&content)?;
317 let toml_content = std::fs::read_to_string(
318 config_path
319 .parent()
320 .unwrap_or(Path::new("."))
321 .join("config.toml"),
322 )
323 .ok();
324 let toml_config: TomlConfig = toml_content
325 .as_ref()
326 .and_then(|c| toml::from_str(c).ok())
327 .unwrap_or_default();
328
329 Ok(Config {
330 server: toml_config.server,
331 tls: toml_config.tls,
332 letsencrypt: toml_config.letsencrypt,
333 scripting: toml_config.scripting.unwrap_or_default(),
334 admin: toml_config.admin.unwrap_or_default(),
335 circuit_breaker: toml_config.circuit_breaker,
336 rules,
337 global_scripts,
338 })
339 }
340
341 pub fn get_config(&self) -> Arc<Config> {
342 self.config.load().clone()
343 }
344
345 pub fn start_watcher(&self) -> Result<()> {
346 if !self.config_path.exists() {
348 if let Some(parent) = self.config_path.parent() {
349 std::fs::create_dir_all(parent).ok();
350 }
351 std::fs::write(&self.config_path, "")?;
352 }
353
354 let (tx, mut rx) = mpsc::channel(1);
355 let config_path = self.config_path.clone();
356 let suppress = self.suppress_watch.clone();
357
358 let mut watcher = RecommendedWatcher::new(
359 move |res| {
360 let _ = tx.blocking_send(res);
361 },
362 notify::Config::default(),
363 )?;
364
365 watcher.watch(&config_path, RecursiveMode::NonRecursive)?;
366
367 tracing::info!("Watching config file: {}", config_path.display());
368
369 std::thread::spawn(move || {
370 while let Some(res) = rx.blocking_recv() {
371 match res {
372 Ok(event) => {
373 if event.kind.is_modify() {
374 if suppress.swap(false, Ordering::SeqCst) {
375 tracing::debug!(
376 "Suppressing file watcher reload (admin API write)"
377 );
378 continue;
379 }
380 tracing::info!("Config file changed, reloading...");
381 }
382 }
383 Err(e) => tracing::error!("Watch error: {}", e),
384 }
385 }
386 });
387
388 Ok(())
389 }
390
391 pub async fn reload(&self) -> Result<()> {
392 let new_config = Self::load_config(&self.config_path, &self.config_path)?;
393 self.config.store(Arc::new(new_config));
394 tracing::info!("Configuration reloaded successfully");
395 Ok(())
396 }
397
398 fn persist_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
400 let content = serializer::serialize_proxy_conf(&rules, &global_scripts);
401 self.suppress_watch.store(true, Ordering::SeqCst);
402 std::fs::write(&self.config_path, &content)?;
403 let mut config = (*self.config.load().as_ref()).clone();
404 config.rules = rules;
405 config.global_scripts = global_scripts;
406 self.config.store(Arc::new(config));
407 tracing::info!("Configuration persisted to {}", self.config_path.display());
408 Ok(())
409 }
410
411 pub fn add_route(&self, rule: ProxyRule) -> Result<()> {
412 let cfg = self.get_config();
413 let mut rules = cfg.rules.clone();
414 rules.push(rule);
415 self.persist_rules(rules, cfg.global_scripts.clone())
416 }
417
418 pub fn update_route(&self, index: usize, rule: ProxyRule) -> Result<()> {
419 let cfg = self.get_config();
420 let mut rules = cfg.rules.clone();
421 if index >= rules.len() {
422 anyhow::bail!(
423 "Route index {} out of range (have {} routes)",
424 index,
425 rules.len()
426 );
427 }
428 rules[index] = rule;
429 self.persist_rules(rules, cfg.global_scripts.clone())
430 }
431
432 pub fn remove_route(&self, index: usize) -> Result<()> {
433 let cfg = self.get_config();
434 let mut rules = cfg.rules.clone();
435 if index >= rules.len() {
436 anyhow::bail!(
437 "Route index {} out of range (have {} routes)",
438 index,
439 rules.len()
440 );
441 }
442 rules.remove(index);
443 self.persist_rules(rules, cfg.global_scripts.clone())
444 }
445
446 pub fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
447 self.persist_rules(rules, global_scripts)
448 }
449}
450
451#[async_trait::async_trait]
452impl ConfigManagerTrait for ConfigManager {
453 async fn reload(&self) -> Result<()> {
454 self.reload().await
455 }
456
457 fn get_config(&self) -> Arc<Config> {
458 self.get_config()
459 }
460
461 fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
462 self.update_rules(rules, global_scripts)
463 }
464
465 fn add_route(&self, rule: ProxyRule) -> Result<()> {
466 self.add_route(rule)
467 }
468
469 fn remove_route(&self, index: usize) -> Result<()> {
470 self.remove_route(index)
471 }
472}
473
474fn extract_scripts(s: &str) -> (&str, Vec<String>) {
476 if let Some(idx) = s.find("@script:") {
477 let before = s[..idx].trim();
478 let after = &s[idx + "@script:".len()..];
479 let script_part = after.split_whitespace().next().unwrap_or(after);
481 let scripts: Vec<String> = script_part
482 .split(',')
483 .map(|s| s.trim().to_string())
484 .filter(|s| !s.is_empty())
485 .collect();
486 (before, scripts)
487 } else {
488 (s, Vec::new())
489 }
490}
491
492fn parse_proxy_config(content: &str) -> Result<(Vec<ProxyRule>, Vec<String>)> {
493 let mut rules = Vec::new();
494 let mut global_scripts = Vec::new();
495
496 let mut joined_lines: Vec<String> = Vec::new();
498 for line in content.lines() {
499 if let Some(current) = joined_lines.last_mut() {
500 if current.ends_with('\\') {
501 current.pop(); current.push_str(line.trim());
503 continue;
504 }
505 }
506 joined_lines.push(line.to_string());
507 }
508
509 for line in &joined_lines {
510 let trimmed = line.trim();
511 if trimmed.is_empty() || trimmed.starts_with('#') {
512 continue;
513 }
514
515 if trimmed.starts_with("[global]") {
517 let rest = trimmed.strip_prefix("[global]").unwrap().trim();
518 let (_, scripts) = extract_scripts(rest);
519 global_scripts.extend(scripts);
520 continue;
521 }
522
523 if let Some((source, target_str)) = trimmed.split_once("->") {
524 let source = source.trim();
525 let (target_str, route_scripts) = extract_scripts(target_str.trim());
527
528 let matcher = if source == "default" || source == "*" {
529 RuleMatcher::Default
530 } else if let Some(pattern) = source.strip_prefix("~") {
531 RuleMatcher::Regex(RegexMatcher::new(pattern)?)
532 } else if !source.starts_with('/')
533 && (source.contains('.') || source.parse::<std::net::IpAddr>().is_ok())
534 {
535 if let Some((domain, path)) = source.split_once('/') {
536 if path.is_empty() || path == "*" {
537 RuleMatcher::Domain(domain.to_string())
538 } else if path.ends_with("/*") {
539 RuleMatcher::DomainPath(
540 domain.to_string(),
541 path.trim_end_matches('*').to_string(),
542 )
543 } else {
544 RuleMatcher::DomainPath(domain.to_string(), path.to_string())
545 }
546 } else {
547 RuleMatcher::Domain(source.to_string())
548 }
549 } else if source.ends_with("/*") {
550 RuleMatcher::Prefix(source.trim_end_matches('*').to_string())
551 } else {
552 RuleMatcher::Exact(source.to_string())
553 };
554
555 let targets: Vec<Target> = target_str
556 .split(',')
557 .map(|t| {
558 Ok(Target {
559 url: Url::parse(t.trim())?,
560 weight: 100,
561 })
562 })
563 .collect::<Result<Vec<_>>>()?;
564
565 rules.push(ProxyRule {
566 matcher,
567 targets,
568 headers: vec![],
569 scripts: route_scripts,
570 });
571 }
572 }
573
574 Ok((rules, global_scripts))
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580
581 #[test]
582 fn test_backslash_continuation_joins_lines() {
583 let config = "/api/* -> http://backend1:8080, \\\n http://backend2:8080\n";
584 let (rules, _) = parse_proxy_config(config).unwrap();
585 assert_eq!(rules.len(), 1);
586 assert_eq!(rules[0].targets.len(), 2);
587 assert_eq!(rules[0].targets[0].url.as_str(), "http://backend1:8080/");
588 assert_eq!(rules[0].targets[1].url.as_str(), "http://backend2:8080/");
589 }
590
591 #[test]
592 fn test_multiple_continuation_lines() {
593 let config = "/api/* -> http://backend1:8080, \\\n\
594 http://backend2:8080, \\\n\
595 http://backend3:8080\n";
596 let (rules, _) = parse_proxy_config(config).unwrap();
597 assert_eq!(rules.len(), 1);
598 assert_eq!(rules[0].targets.len(), 3);
599 assert_eq!(rules[0].targets[2].url.as_str(), "http://backend3:8080/");
600 }
601
602 #[test]
603 fn test_backslash_mid_line_not_continuation() {
604 let config = "/path -> http://localhost:8080\n\
605 ~^/foo\\dbar$ -> http://localhost:9090\n";
606 let (rules, _) = parse_proxy_config(config).unwrap();
607 assert_eq!(rules.len(), 2);
608 }
609
610 #[test]
611 fn test_continuation_trims_whitespace() {
612 let config = "/api/* -> http://a:8080, \\\n http://b:8080, \\\n http://c:8080\n";
613 let (rules, _) = parse_proxy_config(config).unwrap();
614 assert_eq!(rules.len(), 1);
615 assert_eq!(rules[0].targets.len(), 3);
616 }
617
618 #[test]
619 fn test_continuation_with_scripts() {
620 let config = "/api/* -> http://a:8080, \\\n\
621 http://b:8080 @script:auth.lua\n";
622 let (rules, _) = parse_proxy_config(config).unwrap();
623 assert_eq!(rules.len(), 1);
624 assert_eq!(rules[0].targets.len(), 2);
625 assert_eq!(rules[0].scripts, vec!["auth.lua"]);
626 }
627
628 #[test]
629 fn test_no_continuation_normal_config() {
630 let config = "/api/* -> http://backend:8080\ndefault -> http://localhost:3000\n";
631 let (rules, _) = parse_proxy_config(config).unwrap();
632 assert_eq!(rules.len(), 2);
633 }
634}