1use super::defaults;
6use crate::types::{
7 CacheCapacity, HostName, MaxConnections, MaxErrors, Port, ServerName, ThreadCount,
8 duration_serde, option_duration_serde,
9};
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)]
15#[serde(rename_all = "lowercase")]
16pub enum RoutingMode {
17 Stateful,
19 PerCommand,
21 Hybrid,
23}
24
25impl Default for RoutingMode {
26 fn default() -> Self {
30 Self::Hybrid
31 }
32}
33
34impl RoutingMode {
35 #[must_use]
37 pub const fn supports_per_command_routing(&self) -> bool {
38 matches!(self, Self::PerCommand | Self::Hybrid)
39 }
40
41 #[must_use]
43 pub const fn supports_stateful_commands(&self) -> bool {
44 matches!(self, Self::Stateful | Self::Hybrid)
45 }
46
47 #[must_use]
49 pub const fn as_str(&self) -> &'static str {
50 match self {
51 Self::Stateful => "stateful 1:1 mode",
52 Self::PerCommand => "per-command routing mode (stateless)",
53 Self::Hybrid => "hybrid routing mode",
54 }
55 }
56}
57
58impl std::fmt::Display for RoutingMode {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 f.write_str(self.as_str())
61 }
62}
63
64#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
66pub struct Config {
67 #[serde(default)]
69 pub servers: Vec<Server>,
70 #[serde(default)]
72 pub proxy: Proxy,
73 #[serde(default)]
75 pub health_check: HealthCheck,
76 #[serde(skip_serializing_if = "Option::is_none")]
78 pub cache: Option<Cache>,
79 #[serde(default)]
81 pub client_auth: ClientAuth,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86#[serde(default)]
87pub struct Proxy {
88 pub host: String,
90 pub port: Port,
92 pub threads: ThreadCount,
94}
95
96impl Proxy {
97 pub const DEFAULT_HOST: &'static str = "0.0.0.0";
99}
100
101impl Default for Proxy {
102 fn default() -> Self {
103 Self {
104 host: Self::DEFAULT_HOST.to_string(),
105 port: Port::default(),
106 threads: ThreadCount::default(),
107 }
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
113pub struct Cache {
114 #[serde(default = "super::defaults::cache_max_capacity")]
116 pub max_capacity: CacheCapacity,
117 #[serde(with = "duration_serde", default = "super::defaults::cache_ttl")]
119 pub ttl: Duration,
120}
121
122impl Default for Cache {
123 fn default() -> Self {
124 Self {
125 max_capacity: defaults::cache_max_capacity(),
126 ttl: defaults::cache_ttl(),
127 }
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
133pub struct HealthCheck {
134 #[serde(
136 with = "duration_serde",
137 default = "super::defaults::health_check_interval"
138 )]
139 pub interval: Duration,
140 #[serde(
142 with = "duration_serde",
143 default = "super::defaults::health_check_timeout"
144 )]
145 pub timeout: Duration,
146 #[serde(default = "super::defaults::unhealthy_threshold")]
148 pub unhealthy_threshold: MaxErrors,
149}
150
151impl Default for HealthCheck {
152 fn default() -> Self {
153 Self {
154 interval: super::defaults::health_check_interval(),
155 timeout: super::defaults::health_check_timeout(),
156 unhealthy_threshold: super::defaults::unhealthy_threshold(),
157 }
158 }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
163pub struct ClientAuth {
164 #[serde(skip_serializing_if = "Option::is_none")]
167 pub username: Option<String>,
168 #[serde(skip_serializing_if = "Option::is_none")]
171 pub password: Option<String>,
172 #[serde(skip_serializing_if = "Option::is_none")]
174 pub greeting: Option<String>,
175 #[serde(default, skip_serializing_if = "Vec::is_empty")]
177 pub users: Vec<UserCredentials>,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
182pub struct UserCredentials {
183 pub username: String,
184 pub password: String,
185}
186
187impl ClientAuth {
188 pub fn is_enabled(&self) -> bool {
190 (!self.users.is_empty()) || (self.username.is_some() && self.password.is_some())
192 }
193
194 pub fn all_users(&self) -> Vec<(&str, &str)> {
196 let mut users = Vec::new();
197
198 if let (Some(u), Some(p)) = (&self.username, &self.password) {
200 users.push((u.as_str(), p.as_str()));
201 }
202
203 for user in &self.users {
205 users.push((user.username.as_str(), user.password.as_str()));
206 }
207
208 users
209 }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
214pub struct Server {
215 pub host: HostName,
216 pub port: Port,
217 pub name: ServerName,
218 #[serde(skip_serializing_if = "Option::is_none")]
219 pub username: Option<String>,
220 #[serde(skip_serializing_if = "Option::is_none")]
221 pub password: Option<String>,
222 #[serde(default = "super::defaults::max_connections")]
224 pub max_connections: MaxConnections,
225
226 #[serde(default)]
228 pub use_tls: bool,
229 #[serde(default = "super::defaults::tls_verify_cert")]
231 pub tls_verify_cert: bool,
232 #[serde(skip_serializing_if = "Option::is_none")]
234 pub tls_cert_path: Option<String>,
235 #[serde(
238 with = "option_duration_serde",
239 default,
240 skip_serializing_if = "Option::is_none"
241 )]
242 pub connection_keepalive: Option<Duration>,
243 #[serde(default = "super::defaults::health_check_max_per_cycle")]
246 pub health_check_max_per_cycle: usize,
247 #[serde(
250 with = "duration_serde",
251 default = "super::defaults::health_check_pool_timeout"
252 )]
253 pub health_check_pool_timeout: Duration,
254}
255
256pub struct ServerBuilder {
282 host: String,
283 port: u16,
284 name: Option<String>,
285 username: Option<String>,
286 password: Option<String>,
287 max_connections: Option<usize>,
288 use_tls: bool,
289 tls_verify_cert: bool,
290 tls_cert_path: Option<String>,
291 connection_keepalive: Option<Duration>,
292 health_check_max_per_cycle: Option<usize>,
293 health_check_pool_timeout: Option<Duration>,
294}
295
296impl ServerBuilder {
297 #[must_use]
303 pub fn new(host: impl Into<String>, port: u16) -> Self {
304 Self {
305 host: host.into(),
306 port,
307 name: None,
308 username: None,
309 password: None,
310 max_connections: None,
311 use_tls: false,
312 tls_verify_cert: true, tls_cert_path: None,
314 connection_keepalive: None,
315 health_check_max_per_cycle: None,
316 health_check_pool_timeout: None,
317 }
318 }
319
320 #[must_use]
322 pub fn name(mut self, name: impl Into<String>) -> Self {
323 self.name = Some(name.into());
324 self
325 }
326
327 #[must_use]
329 pub fn username(mut self, username: impl Into<String>) -> Self {
330 self.username = Some(username.into());
331 self
332 }
333
334 #[must_use]
336 pub fn password(mut self, password: impl Into<String>) -> Self {
337 self.password = Some(password.into());
338 self
339 }
340
341 #[must_use]
343 pub fn max_connections(mut self, max: usize) -> Self {
344 self.max_connections = Some(max);
345 self
346 }
347
348 #[must_use]
350 pub fn use_tls(mut self, enabled: bool) -> Self {
351 self.use_tls = enabled;
352 self
353 }
354
355 #[must_use]
357 pub fn tls_verify_cert(mut self, verify: bool) -> Self {
358 self.tls_verify_cert = verify;
359 self
360 }
361
362 #[must_use]
364 pub fn tls_cert_path(mut self, path: impl Into<String>) -> Self {
365 self.tls_cert_path = Some(path.into());
366 self
367 }
368
369 #[must_use]
371 pub fn connection_keepalive(mut self, interval: Duration) -> Self {
372 self.connection_keepalive = Some(interval);
373 self
374 }
375
376 #[must_use]
378 pub fn health_check_max_per_cycle(mut self, max: usize) -> Self {
379 self.health_check_max_per_cycle = Some(max);
380 self
381 }
382
383 #[must_use]
385 pub fn health_check_pool_timeout(mut self, timeout: Duration) -> Self {
386 self.health_check_pool_timeout = Some(timeout);
387 self
388 }
389
390 pub fn build(self) -> Result<Server, anyhow::Error> {
400 use crate::types::{HostName, MaxConnections, Port, ServerName};
401
402 let host = HostName::new(self.host.clone())?;
403
404 let port = Port::new(self.port)
405 .ok_or_else(|| anyhow::anyhow!("Invalid port: {} (must be 1-65535)", self.port))?;
406
407 let name_str = self
408 .name
409 .unwrap_or_else(|| format!("{}:{}", self.host, self.port));
410 let name = ServerName::new(name_str)?;
411
412 let max_connections = if let Some(max) = self.max_connections {
413 MaxConnections::new(max)
414 .ok_or_else(|| anyhow::anyhow!("Invalid max_connections: {} (must be > 0)", max))?
415 } else {
416 super::defaults::max_connections()
417 };
418
419 let health_check_max_per_cycle = self
420 .health_check_max_per_cycle
421 .unwrap_or_else(super::defaults::health_check_max_per_cycle);
422
423 let health_check_pool_timeout = self
424 .health_check_pool_timeout
425 .unwrap_or_else(super::defaults::health_check_pool_timeout);
426
427 Ok(Server {
428 host,
429 port,
430 name,
431 username: self.username,
432 password: self.password,
433 max_connections,
434 use_tls: self.use_tls,
435 tls_verify_cert: self.tls_verify_cert,
436 tls_cert_path: self.tls_cert_path,
437 connection_keepalive: self.connection_keepalive,
438 health_check_max_per_cycle,
439 health_check_pool_timeout,
440 })
441 }
442}
443
444impl Server {
445 #[must_use]
459 pub fn builder(host: impl Into<String>, port: u16) -> ServerBuilder {
460 ServerBuilder::new(host, port)
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467
468 #[test]
470 fn test_routing_mode_default() {
471 assert_eq!(RoutingMode::default(), RoutingMode::Hybrid);
472 }
473
474 #[test]
475 fn test_routing_mode_supports_per_command() {
476 assert!(RoutingMode::PerCommand.supports_per_command_routing());
477 assert!(RoutingMode::Hybrid.supports_per_command_routing());
478 assert!(!RoutingMode::Stateful.supports_per_command_routing());
479 }
480
481 #[test]
482 fn test_routing_mode_supports_stateful() {
483 assert!(RoutingMode::Stateful.supports_stateful_commands());
484 assert!(RoutingMode::Hybrid.supports_stateful_commands());
485 assert!(!RoutingMode::PerCommand.supports_stateful_commands());
486 }
487
488 #[test]
489 fn test_routing_mode_as_str() {
490 assert_eq!(RoutingMode::Stateful.as_str(), "stateful 1:1 mode");
491 assert_eq!(
492 RoutingMode::PerCommand.as_str(),
493 "per-command routing mode (stateless)"
494 );
495 assert_eq!(RoutingMode::Hybrid.as_str(), "hybrid routing mode");
496 }
497
498 #[test]
499 fn test_routing_mode_display() {
500 assert_eq!(RoutingMode::Stateful.to_string(), "stateful 1:1 mode");
501 assert_eq!(RoutingMode::Hybrid.to_string(), "hybrid routing mode");
502 }
503
504 #[test]
506 fn test_proxy_default() {
507 let proxy = Proxy::default();
508 assert_eq!(proxy.host, "0.0.0.0");
509 assert_eq!(proxy.port.get(), 8119);
510 }
511
512 #[test]
513 fn test_proxy_default_host_constant() {
514 assert_eq!(Proxy::DEFAULT_HOST, "0.0.0.0");
515 }
516
517 #[test]
519 fn test_cache_default() {
520 let cache = Cache::default();
521 assert_eq!(cache.max_capacity.get(), 10000);
522 assert_eq!(cache.ttl, Duration::from_secs(3600));
523 }
524
525 #[test]
527 fn test_health_check_default() {
528 let hc = HealthCheck::default();
529 assert_eq!(hc.interval, Duration::from_secs(30));
530 assert_eq!(hc.timeout, Duration::from_secs(5));
531 assert_eq!(hc.unhealthy_threshold.get(), 3);
532 }
533
534 #[test]
536 fn test_client_auth_is_enabled_legacy() {
537 let mut auth = ClientAuth::default();
538 assert!(!auth.is_enabled());
539
540 auth.username = Some("user".to_string());
541 auth.password = Some("pass".to_string());
542 assert!(auth.is_enabled());
543 }
544
545 #[test]
546 fn test_client_auth_is_enabled_multi_user() {
547 let mut auth = ClientAuth::default();
548 auth.users.push(UserCredentials {
549 username: "alice".to_string(),
550 password: "secret".to_string(),
551 });
552 assert!(auth.is_enabled());
553 }
554
555 #[test]
556 fn test_client_auth_all_users_legacy() {
557 let mut auth = ClientAuth::default();
558 auth.username = Some("user".to_string());
559 auth.password = Some("pass".to_string());
560
561 let users = auth.all_users();
562 assert_eq!(users.len(), 1);
563 assert_eq!(users[0], ("user", "pass"));
564 }
565
566 #[test]
567 fn test_client_auth_all_users_multi() {
568 let mut auth = ClientAuth::default();
569 auth.users.push(UserCredentials {
570 username: "alice".to_string(),
571 password: "alice_pw".to_string(),
572 });
573 auth.users.push(UserCredentials {
574 username: "bob".to_string(),
575 password: "bob_pw".to_string(),
576 });
577
578 let users = auth.all_users();
579 assert_eq!(users.len(), 2);
580 assert_eq!(users[0], ("alice", "alice_pw"));
581 assert_eq!(users[1], ("bob", "bob_pw"));
582 }
583
584 #[test]
585 fn test_client_auth_all_users_combined() {
586 let mut auth = ClientAuth::default();
587 auth.username = Some("legacy".to_string());
588 auth.password = Some("legacy_pw".to_string());
589 auth.users.push(UserCredentials {
590 username: "alice".to_string(),
591 password: "alice_pw".to_string(),
592 });
593
594 let users = auth.all_users();
595 assert_eq!(users.len(), 2);
596 assert_eq!(users[0], ("legacy", "legacy_pw"));
597 assert_eq!(users[1], ("alice", "alice_pw"));
598 }
599
600 #[test]
602 fn test_server_builder_minimal() {
603 let server = Server::builder("news.example.com", 119).build().unwrap();
604
605 assert_eq!(server.host.as_str(), "news.example.com");
606 assert_eq!(server.port.get(), 119);
607 assert_eq!(server.name.as_str(), "news.example.com:119");
608 assert_eq!(server.max_connections.get(), 10);
609 assert!(!server.use_tls);
610 assert!(server.tls_verify_cert); }
612
613 #[test]
614 fn test_server_builder_with_name() {
615 let server = Server::builder("localhost", 119)
616 .name("Test Server")
617 .build()
618 .unwrap();
619
620 assert_eq!(server.name.as_str(), "Test Server");
621 }
622
623 #[test]
624 fn test_server_builder_with_auth() {
625 let server = Server::builder("news.example.com", 119)
626 .username("testuser")
627 .password("testpass")
628 .build()
629 .unwrap();
630
631 assert_eq!(server.username.as_ref().unwrap(), "testuser");
632 assert_eq!(server.password.as_ref().unwrap(), "testpass");
633 }
634
635 #[test]
636 fn test_server_builder_with_max_connections() {
637 let server = Server::builder("localhost", 119)
638 .max_connections(20)
639 .build()
640 .unwrap();
641
642 assert_eq!(server.max_connections.get(), 20);
643 }
644
645 #[test]
646 fn test_server_builder_with_tls() {
647 let server = Server::builder("secure.example.com", 563)
648 .use_tls(true)
649 .tls_verify_cert(false)
650 .tls_cert_path("/path/to/cert.pem")
651 .build()
652 .unwrap();
653
654 assert!(server.use_tls);
655 assert!(!server.tls_verify_cert);
656 assert_eq!(server.tls_cert_path.as_ref().unwrap(), "/path/to/cert.pem");
657 }
658
659 #[test]
660 fn test_server_builder_with_keepalive() {
661 let keepalive = Duration::from_secs(300);
662 let server = Server::builder("localhost", 119)
663 .connection_keepalive(keepalive)
664 .build()
665 .unwrap();
666
667 assert_eq!(server.connection_keepalive, Some(keepalive));
668 }
669
670 #[test]
671 fn test_server_builder_with_health_check_settings() {
672 let timeout = Duration::from_millis(500);
673 let server = Server::builder("localhost", 119)
674 .health_check_max_per_cycle(5)
675 .health_check_pool_timeout(timeout)
676 .build()
677 .unwrap();
678
679 assert_eq!(server.health_check_max_per_cycle, 5);
680 assert_eq!(server.health_check_pool_timeout, timeout);
681 }
682
683 #[test]
684 fn test_server_builder_chaining() {
685 let server = Server::builder("news.example.com", 563)
686 .name("Production Server")
687 .username("admin")
688 .password("secret")
689 .max_connections(25)
690 .use_tls(true)
691 .tls_verify_cert(true)
692 .build()
693 .unwrap();
694
695 assert_eq!(server.name.as_str(), "Production Server");
696 assert_eq!(server.max_connections.get(), 25);
697 assert!(server.use_tls);
698 }
699
700 #[test]
701 fn test_server_builder_invalid_host() {
702 let result = Server::builder("", 119).build();
703 assert!(result.is_err());
704 }
705
706 #[test]
707 fn test_server_builder_invalid_port() {
708 let result = Server::builder("localhost", 0).build();
709 assert!(result.is_err());
710 }
711
712 #[test]
713 fn test_server_builder_invalid_max_connections() {
714 let result = Server::builder("localhost", 119).max_connections(0).build();
715 assert!(result.is_err());
716 }
717
718 #[test]
719 fn test_server_builder_from_server_method() {
720 let builder = Server::builder("localhost", 119);
721 let server = builder.build().unwrap();
722 assert_eq!(server.host.as_str(), "localhost");
723 }
724
725 #[test]
727 fn test_config_default() {
728 let config = Config::default();
729 assert!(config.servers.is_empty());
730 assert_eq!(config.proxy.host, "0.0.0.0");
731 assert!(config.cache.is_none());
732 assert!(!config.client_auth.is_enabled());
733 }
734}