1use super::defaults;
6use crate::types::{
7 CacheCapacity, HostName, MaxConnections, MaxErrors, Port, ServerName, ThreadCount,
8 duration_serde, option_duration_serde,
9};
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)]
15#[serde(rename_all = "lowercase")]
16pub enum RoutingMode {
17 Stateful,
19 PerCommand,
21 Hybrid,
23}
24
25impl Default for RoutingMode {
26 fn default() -> Self {
30 Self::Hybrid
31 }
32}
33
34impl RoutingMode {
35 #[must_use]
37 pub const fn supports_per_command_routing(&self) -> bool {
38 matches!(self, Self::PerCommand | Self::Hybrid)
39 }
40
41 #[must_use]
43 pub const fn supports_stateful_commands(&self) -> bool {
44 matches!(self, Self::Stateful | Self::Hybrid)
45 }
46
47 #[must_use]
49 pub const fn short_name(&self) -> &'static str {
50 match self {
51 Self::Stateful => "stateful",
52 Self::PerCommand => "per-command",
53 Self::Hybrid => "hybrid",
54 }
55 }
56
57 #[must_use]
59 pub const fn as_str(&self) -> &'static str {
60 match self {
61 Self::Stateful => "stateful 1:1 mode",
62 Self::PerCommand => "per-command routing mode (stateless)",
63 Self::Hybrid => "hybrid routing mode",
64 }
65 }
66}
67
68impl std::fmt::Display for RoutingMode {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 f.write_str(self.as_str())
71 }
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)]
76#[serde(rename_all = "kebab-case")]
77pub enum BackendSelectionStrategy {
78 WeightedRoundRobin,
80 LeastLoaded,
82}
83
84impl Default for BackendSelectionStrategy {
85 fn default() -> Self {
87 Self::LeastLoaded
88 }
89}
90
91impl BackendSelectionStrategy {
92 #[must_use]
94 pub const fn as_str(&self) -> &'static str {
95 match self {
96 Self::WeightedRoundRobin => "weighted round-robin",
97 Self::LeastLoaded => "least-loaded",
98 }
99 }
100}
101
102impl std::fmt::Display for BackendSelectionStrategy {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.write_str(self.as_str())
105 }
106}
107
108#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
110pub struct Config {
111 #[serde(default)]
113 pub servers: Vec<Server>,
114 #[serde(default)]
116 pub proxy: Proxy,
117 #[serde(default)]
119 pub health_check: HealthCheck,
120 #[serde(skip_serializing_if = "Option::is_none")]
122 pub cache: Option<Cache>,
123 #[serde(default)]
125 pub client_auth: ClientAuth,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
130#[serde(default)]
131pub struct Proxy {
132 pub host: String,
134 pub port: Port,
136 pub threads: ThreadCount,
138 pub backend_selection: BackendSelectionStrategy,
140 pub validate_yenc: bool,
142}
143
144impl Proxy {
145 pub const DEFAULT_HOST: &'static str = "0.0.0.0";
147}
148
149impl Default for Proxy {
150 fn default() -> Self {
151 Self {
152 host: Self::DEFAULT_HOST.to_string(),
153 port: Port::default(),
154 threads: ThreadCount::default(),
155 backend_selection: BackendSelectionStrategy::default(),
156 validate_yenc: true,
157 }
158 }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
163pub struct Cache {
164 #[serde(default = "super::defaults::cache_max_capacity")]
172 pub max_capacity: CacheCapacity,
173 #[serde(with = "duration_serde", default = "super::defaults::cache_ttl")]
175 pub ttl: Duration,
176 #[serde(default = "super::defaults::cache_articles")]
187 pub cache_articles: bool,
188 #[serde(default = "super::defaults::adaptive_precheck")]
202 pub adaptive_precheck: bool,
203
204 #[serde(default, skip_serializing_if = "Option::is_none")]
209 pub disk: Option<DiskCache>,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
220pub struct DiskCache {
221 #[serde(default = "super::defaults::disk_cache_path")]
226 pub path: std::path::PathBuf,
227
228 #[serde(default = "super::defaults::disk_cache_capacity")]
235 pub capacity: CacheCapacity,
236
237 #[serde(default = "super::defaults::disk_cache_compression")]
242 pub compression: bool,
243
244 #[serde(default = "super::defaults::disk_cache_shards")]
248 pub shards: usize,
249}
250
251impl Default for DiskCache {
252 fn default() -> Self {
253 Self {
254 path: defaults::disk_cache_path(),
255 capacity: defaults::disk_cache_capacity(),
256 compression: defaults::disk_cache_compression(),
257 shards: defaults::disk_cache_shards(),
258 }
259 }
260}
261
262impl Default for Cache {
263 fn default() -> Self {
264 Self {
265 max_capacity: defaults::cache_max_capacity(),
266 ttl: defaults::cache_ttl(),
267 cache_articles: defaults::cache_articles(),
268 adaptive_precheck: defaults::adaptive_precheck(),
269 disk: None,
270 }
271 }
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
276pub struct HealthCheck {
277 #[serde(
279 with = "duration_serde",
280 default = "super::defaults::health_check_interval"
281 )]
282 pub interval: Duration,
283 #[serde(
285 with = "duration_serde",
286 default = "super::defaults::health_check_timeout"
287 )]
288 pub timeout: Duration,
289 #[serde(default = "super::defaults::unhealthy_threshold")]
291 pub unhealthy_threshold: MaxErrors,
292}
293
294impl Default for HealthCheck {
295 fn default() -> Self {
296 Self {
297 interval: super::defaults::health_check_interval(),
298 timeout: super::defaults::health_check_timeout(),
299 unhealthy_threshold: super::defaults::unhealthy_threshold(),
300 }
301 }
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
306pub struct ClientAuth {
307 #[serde(skip_serializing_if = "Option::is_none")]
309 pub greeting: Option<String>,
310 #[serde(default, skip_serializing_if = "Vec::is_empty")]
312 pub users: Vec<UserCredentials>,
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
317pub struct UserCredentials {
318 pub username: String,
319 pub password: String,
320}
321
322impl ClientAuth {
323 pub fn is_enabled(&self) -> bool {
325 !self.users.is_empty()
326 }
327
328 pub fn all_users(&self) -> Vec<(&str, &str)> {
330 self.users
331 .iter()
332 .map(|user| (user.username.as_str(), user.password.as_str()))
333 .collect()
334 }
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
339pub struct Server {
340 pub host: HostName,
341 pub port: Port,
342 pub name: ServerName,
343 #[serde(skip_serializing_if = "Option::is_none")]
344 pub username: Option<String>,
345 #[serde(skip_serializing_if = "Option::is_none")]
346 pub password: Option<String>,
347 #[serde(default = "super::defaults::max_connections")]
349 pub max_connections: MaxConnections,
350
351 #[serde(default)]
353 pub use_tls: bool,
354 #[serde(default = "super::defaults::tls_verify_cert")]
356 pub tls_verify_cert: bool,
357 #[serde(skip_serializing_if = "Option::is_none")]
359 pub tls_cert_path: Option<String>,
360 #[serde(
363 with = "option_duration_serde",
364 default,
365 skip_serializing_if = "Option::is_none"
366 )]
367 pub connection_keepalive: Option<Duration>,
368 #[serde(default = "super::defaults::health_check_max_per_cycle")]
371 pub health_check_max_per_cycle: usize,
372 #[serde(
375 with = "duration_serde",
376 default = "super::defaults::health_check_pool_timeout"
377 )]
378 pub health_check_pool_timeout: Duration,
379 #[serde(default)]
382 pub tier: u8,
383}
384
385pub struct ServerBuilder {
412 host: String,
413 port: Port,
414 name: Option<String>,
415 username: Option<String>,
416 password: Option<String>,
417 max_connections: Option<MaxConnections>,
418 use_tls: bool,
419 tls_verify_cert: bool,
420 tls_cert_path: Option<String>,
421 connection_keepalive: Option<Duration>,
422 health_check_max_per_cycle: Option<usize>,
423 health_check_pool_timeout: Option<Duration>,
424 tier: u8,
425}
426
427impl ServerBuilder {
428 #[must_use]
434 pub fn new(host: impl Into<String>, port: Port) -> Self {
435 Self {
436 host: host.into(),
437 port,
438 name: None,
439 username: None,
440 password: None,
441 max_connections: None,
442 use_tls: false,
443 tls_verify_cert: true, tls_cert_path: None,
445 connection_keepalive: None,
446 health_check_max_per_cycle: None,
447 health_check_pool_timeout: None,
448 tier: 0,
449 }
450 }
451
452 #[must_use]
454 pub fn name(mut self, name: impl Into<String>) -> Self {
455 self.name = Some(name.into());
456 self
457 }
458
459 #[must_use]
461 pub fn username(mut self, username: impl Into<String>) -> Self {
462 self.username = Some(username.into());
463 self
464 }
465
466 #[must_use]
468 pub fn password(mut self, password: impl Into<String>) -> Self {
469 self.password = Some(password.into());
470 self
471 }
472
473 #[must_use]
475 pub fn max_connections(mut self, max: MaxConnections) -> Self {
476 self.max_connections = Some(max);
477 self
478 }
479
480 #[must_use]
482 pub fn use_tls(mut self, enabled: bool) -> Self {
483 self.use_tls = enabled;
484 self
485 }
486
487 #[must_use]
489 pub fn tls_verify_cert(mut self, verify: bool) -> Self {
490 self.tls_verify_cert = verify;
491 self
492 }
493
494 #[must_use]
496 pub fn tls_cert_path(mut self, path: impl Into<String>) -> Self {
497 self.tls_cert_path = Some(path.into());
498 self
499 }
500
501 #[must_use]
503 pub fn connection_keepalive(mut self, interval: Duration) -> Self {
504 self.connection_keepalive = Some(interval);
505 self
506 }
507
508 #[must_use]
510 pub fn health_check_max_per_cycle(mut self, max: usize) -> Self {
511 self.health_check_max_per_cycle = Some(max);
512 self
513 }
514
515 #[must_use]
517 pub fn health_check_pool_timeout(mut self, timeout: Duration) -> Self {
518 self.health_check_pool_timeout = Some(timeout);
519 self
520 }
521
522 #[must_use]
524 pub fn tier(mut self, tier: u8) -> Self {
525 self.tier = tier;
526 self
527 }
528
529 pub fn build(self) -> Result<Server, anyhow::Error> {
539 use crate::types::{HostName, ServerName};
540
541 let host = HostName::try_new(self.host.clone())?;
542 let port = self.port; let name_str = self
544 .name
545 .unwrap_or_else(|| format!("{}:{}", self.host, self.port.get()));
546 let name = ServerName::try_new(name_str)?;
547
548 let max_connections = self
549 .max_connections
550 .unwrap_or_else(super::defaults::max_connections);
551
552 let health_check_max_per_cycle = self
553 .health_check_max_per_cycle
554 .unwrap_or_else(super::defaults::health_check_max_per_cycle);
555
556 let health_check_pool_timeout = self
557 .health_check_pool_timeout
558 .unwrap_or_else(super::defaults::health_check_pool_timeout);
559
560 Ok(Server {
561 host,
562 port,
563 name,
564 username: self.username,
565 password: self.password,
566 max_connections,
567 use_tls: self.use_tls,
568 tls_verify_cert: self.tls_verify_cert,
569 tls_cert_path: self.tls_cert_path,
570 connection_keepalive: self.connection_keepalive,
571 health_check_max_per_cycle,
572 health_check_pool_timeout,
573 tier: self.tier,
574 })
575 }
576}
577
578impl Server {
579 #[must_use]
594 pub fn builder(host: impl Into<String>, port: Port) -> ServerBuilder {
595 ServerBuilder::new(host, port)
596 }
597}
598
599#[cfg(test)]
600mod tests {
601 use super::*;
602
603 #[test]
605 fn test_routing_mode_default() {
606 assert_eq!(RoutingMode::default(), RoutingMode::Hybrid);
607 }
608
609 #[test]
610 fn test_routing_mode_supports_per_command() {
611 assert!(RoutingMode::PerCommand.supports_per_command_routing());
612 assert!(RoutingMode::Hybrid.supports_per_command_routing());
613 assert!(!RoutingMode::Stateful.supports_per_command_routing());
614 }
615
616 #[test]
617 fn test_routing_mode_supports_stateful() {
618 assert!(RoutingMode::Stateful.supports_stateful_commands());
619 assert!(RoutingMode::Hybrid.supports_stateful_commands());
620 assert!(!RoutingMode::PerCommand.supports_stateful_commands());
621 }
622
623 #[test]
624 fn test_routing_mode_as_str() {
625 assert_eq!(RoutingMode::Stateful.as_str(), "stateful 1:1 mode");
626 assert_eq!(
627 RoutingMode::PerCommand.as_str(),
628 "per-command routing mode (stateless)"
629 );
630 assert_eq!(RoutingMode::Hybrid.as_str(), "hybrid routing mode");
631 }
632
633 #[test]
634 fn test_routing_mode_display() {
635 assert_eq!(RoutingMode::Stateful.to_string(), "stateful 1:1 mode");
636 assert_eq!(RoutingMode::Hybrid.to_string(), "hybrid routing mode");
637 }
638
639 #[test]
641 fn test_proxy_default() {
642 let proxy = Proxy::default();
643 assert_eq!(proxy.host, "0.0.0.0");
644 assert_eq!(proxy.port.get(), 8119);
645 }
646
647 #[test]
648 fn test_proxy_default_host_constant() {
649 assert_eq!(Proxy::DEFAULT_HOST, "0.0.0.0");
650 }
651
652 #[test]
654 fn test_cache_default() {
655 let cache = Cache::default();
656 assert_eq!(cache.max_capacity.get(), 64 * 1024 * 1024); assert_eq!(cache.ttl, Duration::from_secs(3600));
658 }
659
660 #[test]
662 fn test_health_check_default() {
663 let hc = HealthCheck::default();
664 assert_eq!(hc.interval, Duration::from_secs(30));
665 assert_eq!(hc.timeout, Duration::from_secs(5));
666 assert_eq!(hc.unhealthy_threshold.get(), 3);
667 }
668
669 #[test]
671 fn test_client_auth_is_enabled() {
672 let mut auth = ClientAuth::default();
673 assert!(!auth.is_enabled());
674
675 auth.users.push(UserCredentials {
676 username: "user".to_string(),
677 password: "pass".to_string(),
678 });
679 assert!(auth.is_enabled());
680 }
681
682 #[test]
683 fn test_client_auth_is_enabled_multi_user() {
684 let mut auth = ClientAuth::default();
685 auth.users.push(UserCredentials {
686 username: "alice".to_string(),
687 password: "secret".to_string(),
688 });
689 assert!(auth.is_enabled());
690 }
691
692 #[test]
693 fn test_client_auth_all_users_single() {
694 let mut auth = ClientAuth::default();
695 auth.users.push(UserCredentials {
696 username: "user".to_string(),
697 password: "pass".to_string(),
698 });
699
700 let users = auth.all_users();
701 assert_eq!(users.len(), 1);
702 assert_eq!(users[0], ("user", "pass"));
703 }
704
705 #[test]
706 fn test_client_auth_all_users_multi() {
707 let mut auth = ClientAuth::default();
708 auth.users.push(UserCredentials {
709 username: "alice".to_string(),
710 password: "alice_pw".to_string(),
711 });
712 auth.users.push(UserCredentials {
713 username: "bob".to_string(),
714 password: "bob_pw".to_string(),
715 });
716
717 let users = auth.all_users();
718 assert_eq!(users.len(), 2);
719 assert_eq!(users[0], ("alice", "alice_pw"));
720 assert_eq!(users[1], ("bob", "bob_pw"));
721 }
722
723 #[test]
725 fn test_server_builder_minimal() {
726 let server = Server::builder("news.example.com", Port::try_new(119).unwrap())
727 .build()
728 .unwrap();
729
730 assert_eq!(server.host.as_str(), "news.example.com");
731 assert_eq!(server.port.get(), 119);
732 assert_eq!(server.name.as_str(), "news.example.com:119");
733 assert_eq!(server.max_connections.get(), 10);
734 assert!(!server.use_tls);
735 assert!(server.tls_verify_cert); }
737
738 #[test]
739 fn test_server_builder_with_name() {
740 let server = Server::builder("localhost", Port::try_new(119).unwrap())
741 .name("Test Server")
742 .build()
743 .unwrap();
744
745 assert_eq!(server.name.as_str(), "Test Server");
746 }
747
748 #[test]
749 fn test_server_builder_with_auth() {
750 let server = Server::builder("news.example.com", Port::try_new(119).unwrap())
751 .username("testuser")
752 .password("testpass")
753 .build()
754 .unwrap();
755
756 assert_eq!(server.username.as_ref().unwrap(), "testuser");
757 assert_eq!(server.password.as_ref().unwrap(), "testpass");
758 }
759
760 #[test]
761 fn test_server_builder_with_max_connections() {
762 let server = Server::builder("localhost", Port::try_new(119).unwrap())
763 .max_connections(MaxConnections::try_new(20).unwrap())
764 .build()
765 .unwrap();
766
767 assert_eq!(server.max_connections.get(), 20);
768 }
769
770 #[test]
771 fn test_server_builder_with_tls() {
772 let server = Server::builder("secure.example.com", Port::try_new(563).unwrap())
773 .use_tls(true)
774 .tls_verify_cert(false)
775 .tls_cert_path("/path/to/cert.pem")
776 .build()
777 .unwrap();
778
779 assert!(server.use_tls);
780 assert!(!server.tls_verify_cert);
781 assert_eq!(server.tls_cert_path.as_ref().unwrap(), "/path/to/cert.pem");
782 }
783
784 #[test]
785 fn test_server_builder_with_keepalive() {
786 let keepalive = Duration::from_secs(300);
787 let server = Server::builder("localhost", Port::try_new(119).unwrap())
788 .connection_keepalive(keepalive)
789 .build()
790 .unwrap();
791
792 assert_eq!(server.connection_keepalive, Some(keepalive));
793 }
794
795 #[test]
796 fn test_server_builder_with_health_check_settings() {
797 let timeout = Duration::from_millis(500);
798 let server = Server::builder("localhost", Port::try_new(119).unwrap())
799 .health_check_max_per_cycle(5)
800 .health_check_pool_timeout(timeout)
801 .build()
802 .unwrap();
803
804 assert_eq!(server.health_check_max_per_cycle, 5);
805 assert_eq!(server.health_check_pool_timeout, timeout);
806 }
807
808 #[test]
809 fn test_server_builder_chaining() {
810 let server = Server::builder("news.example.com", Port::try_new(563).unwrap())
811 .name("Production Server")
812 .username("admin")
813 .password("secret")
814 .max_connections(MaxConnections::try_new(25).unwrap())
815 .use_tls(true)
816 .tls_verify_cert(true)
817 .build()
818 .unwrap();
819
820 assert_eq!(server.name.as_str(), "Production Server");
821 assert_eq!(server.max_connections.get(), 25);
822 assert!(server.use_tls);
823 }
824
825 #[test]
827 fn test_config_default() {
828 let config = Config::default();
829 assert!(config.servers.is_empty());
830 assert_eq!(config.proxy.host, "0.0.0.0");
831 assert!(config.cache.is_none());
832 assert!(!config.client_auth.is_enabled());
833 }
834}