Skip to main content

sentinel_driver/
config.rs

1use std::path::PathBuf;
2use std::time::Duration;
3
4use crate::error::{Error, Result};
5
6/// TLS mode for the connection.
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum SslMode {
9    /// No TLS. Connections are unencrypted.
10    Disable,
11    /// Try TLS, fall back to plaintext if server doesn't support it.
12    #[default]
13    Prefer,
14    /// Require TLS. Fail if server doesn't support it.
15    Require,
16    /// Require TLS and verify the server certificate.
17    VerifyCa,
18    /// Require TLS, verify certificate, and verify hostname matches.
19    VerifyFull,
20}
21
22/// Connection configuration for sentinel-driver.
23///
24/// # Connection String
25///
26/// ```text
27/// postgres://user:password@host:port/database?sslmode=prefer&application_name=myapp
28/// ```
29///
30/// # Builder
31///
32/// ```rust,no_run
33/// use sentinel_driver::Config;
34///
35/// let config = Config::builder()
36///     .host("localhost")
37///     .port(5432)
38///     .database("mydb")
39///     .user("postgres")
40///     .password("secret")
41///     .build();
42/// ```
43#[derive(Clone)]
44pub struct Config {
45    pub(crate) hosts: Vec<(String, u16)>,
46    pub(crate) database: String,
47    pub(crate) user: String,
48    pub(crate) password: Option<String>,
49    pub(crate) ssl_mode: SslMode,
50    pub(crate) application_name: Option<String>,
51    pub(crate) connect_timeout: Duration,
52    pub(crate) statement_timeout: Option<Duration>,
53    pub(crate) keepalive: Option<Duration>,
54    pub(crate) keepalive_idle: Option<Duration>,
55    pub(crate) target_session_attrs: TargetSessionAttrs,
56    pub(crate) extra_float_digits: Option<i32>,
57    pub(crate) load_balance_hosts: LoadBalanceHosts,
58    /// Path to client certificate file for certificate authentication.
59    pub(crate) ssl_client_cert: Option<std::path::PathBuf>,
60    /// Path to client private key file for certificate authentication.
61    pub(crate) ssl_client_key: Option<std::path::PathBuf>,
62    /// Use direct TLS connection (PG 17+) — skip SSLRequest negotiation.
63    pub(crate) ssl_direct: bool,
64    /// Enable SCRAM-SHA-256 channel binding (SCRAM-PLUS) when TLS is active.
65    pub(crate) channel_binding: ChannelBinding,
66    /// Pluggable instrumentation hook; `None` uses a no-op default.
67    pub(crate) instrumentation: Option<std::sync::Arc<dyn crate::Instrumentation>>,
68}
69
70/// Channel binding preference for SCRAM authentication.
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
72pub enum ChannelBinding {
73    /// Use channel binding if available (default).
74    #[default]
75    Prefer,
76    /// Require channel binding — fail if server doesn't support it.
77    Require,
78    /// Disable channel binding.
79    Disable,
80}
81
82/// Target session attributes for connection validation.
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
84pub enum TargetSessionAttrs {
85    /// Any server is acceptable.
86    #[default]
87    Any,
88    /// Only accept read-write servers (primary).
89    ReadWrite,
90    /// Only accept read-only servers (replica).
91    ReadOnly,
92}
93
94/// Load balancing strategy for multi-host connections.
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
96pub enum LoadBalanceHosts {
97    /// Try hosts in order (default).
98    #[default]
99    Disable,
100    /// Shuffle hosts before trying.
101    Random,
102}
103
104impl Config {
105    /// Parse a PostgreSQL connection string.
106    ///
107    /// Supported formats:
108    /// - `postgres://user:password@host:port/database?param=value`
109    /// - `postgresql://user:password@host:port/database?param=value`
110    pub fn parse(s: &str) -> Result<Self> {
111        let s = s.trim();
112
113        let without_scheme = s
114            .strip_prefix("postgres://")
115            .or_else(|| s.strip_prefix("postgresql://"))
116            .ok_or_else(|| {
117                Error::Config(
118                    "connection string must start with postgres:// or postgresql://".into(),
119                )
120            })?;
121
122        let (userinfo, rest) = match without_scheme.split_once('@') {
123            Some((ui, rest)) => (Some(ui), rest),
124            None => (None, without_scheme),
125        };
126
127        let (user, password) = match userinfo {
128            Some(ui) => match ui.split_once(':') {
129                Some((u, p)) => (percent_decode(u)?, Some(percent_decode(p)?)),
130                None => (percent_decode(ui)?, None),
131            },
132            None => (String::new(), None),
133        };
134
135        // Split host:port from database?params
136        let (hostport, db_and_params) = match rest.split_once('/') {
137            Some((hp, rest)) => (hp, Some(rest)),
138            None => (rest, None),
139        };
140
141        // Parse comma-separated host:port pairs
142        let mut hosts: Vec<(String, u16)> = Vec::new();
143        if hostport.is_empty() {
144            // Empty host — will be set via ?host= parameter (Unix socket)
145        } else {
146            for entry in hostport.split(',') {
147                let (h, p) = match entry.rsplit_once(':') {
148                    Some((h, p)) => {
149                        let port: u16 = p
150                            .parse()
151                            .map_err(|_| Error::Config(format!("invalid port: {p}")))?;
152                        (h.to_string(), port)
153                    }
154                    None => (entry.to_string(), 5432),
155                };
156                hosts.push((h, p));
157            }
158        }
159
160        let (database, params_str) = match db_and_params {
161            Some(dp) => match dp.split_once('?') {
162                Some((db, params)) => (percent_decode(db)?, Some(params.to_string())),
163                None => (percent_decode(dp)?, None),
164            },
165            None => (String::new(), None),
166        };
167
168        let mut config = ConfigBuilder::new();
169        for (h, p) in &hosts {
170            config = config.host_port(h.clone(), *p);
171        }
172        config = config.database(database).user(user);
173
174        if let Some(pw) = password {
175            config = config.password(pw);
176        }
177
178        // Parse query parameters
179        if let Some(params) = params_str {
180            for param in params.split('&') {
181                let (key, value) = param
182                    .split_once('=')
183                    .ok_or_else(|| Error::Config(format!("invalid parameter: {param}")))?;
184                let value = percent_decode(value)?;
185
186                match key {
187                    "sslmode" => {
188                        config = config.ssl_mode(match value.as_str() {
189                            "disable" => SslMode::Disable,
190                            "prefer" => SslMode::Prefer,
191                            "require" => SslMode::Require,
192                            "verify-ca" => SslMode::VerifyCa,
193                            "verify-full" => SslMode::VerifyFull,
194                            _ => return Err(Error::Config(format!("invalid sslmode: {value}"))),
195                        });
196                    }
197                    "application_name" => {
198                        config = config.application_name(value);
199                    }
200                    "connect_timeout" => {
201                        let secs: u64 = value.parse().map_err(|_| {
202                            Error::Config(format!("invalid connect_timeout: {value}"))
203                        })?;
204                        config = config.connect_timeout(Duration::from_secs(secs));
205                    }
206                    "statement_timeout" => {
207                        let secs: u64 = value.parse().map_err(|_| {
208                            Error::Config(format!("invalid statement_timeout: {value}"))
209                        })?;
210                        config = config.statement_timeout(Duration::from_secs(secs));
211                    }
212                    "target_session_attrs" => {
213                        config = config.target_session_attrs(match value.as_str() {
214                            "any" => TargetSessionAttrs::Any,
215                            "read-write" => TargetSessionAttrs::ReadWrite,
216                            "read-only" => TargetSessionAttrs::ReadOnly,
217                            _ => {
218                                return Err(Error::Config(format!(
219                                    "invalid target_session_attrs: {value}"
220                                )))
221                            }
222                        });
223                    }
224                    "sslcert" => {
225                        config = config.ssl_client_cert(PathBuf::from(value));
226                    }
227                    "sslkey" => {
228                        config = config.ssl_client_key(PathBuf::from(value));
229                    }
230                    "ssldirect" | "sslnegotiation" => {
231                        let direct = match value.as_str() {
232                            "true" | "direct" => true,
233                            "false" | "postgres" => false,
234                            _ => return Err(Error::Config(format!("invalid {key}: {value}"))),
235                        };
236                        config = config.ssl_direct(direct);
237                    }
238                    "channel_binding" => {
239                        config = config.channel_binding(match value.as_str() {
240                            "prefer" => ChannelBinding::Prefer,
241                            "require" => ChannelBinding::Require,
242                            "disable" => ChannelBinding::Disable,
243                            _ => {
244                                return Err(Error::Config(format!(
245                                    "invalid channel_binding: {value}"
246                                )))
247                            }
248                        });
249                    }
250                    "load_balance_hosts" => {
251                        config = config.load_balance_hosts(match value.as_str() {
252                            "disable" => LoadBalanceHosts::Disable,
253                            "random" => LoadBalanceHosts::Random,
254                            _ => {
255                                return Err(Error::Config(format!(
256                                    "invalid load_balance_hosts: {value}"
257                                )))
258                            }
259                        });
260                    }
261                    "host" => {
262                        // Support ?host=/var/run/postgresql for Unix sockets
263                        config = config.host_port(value, 5432);
264                    }
265                    _ => {
266                        // Ignore unknown parameters for forward compatibility
267                    }
268                }
269            }
270        }
271
272        Ok(config.build())
273    }
274
275    /// Create a new builder for `Config`.
276    pub fn builder() -> ConfigBuilder {
277        ConfigBuilder::new()
278    }
279
280    // Accessor methods
281
282    /// Returns the first host (for backward compatibility and single-host use).
283    pub fn host(&self) -> &str {
284        self.hosts.first().map_or("localhost", |(h, _)| h.as_str())
285    }
286
287    /// Returns the first port (for backward compatibility and single-host use).
288    pub fn port(&self) -> u16 {
289        self.hosts.first().map_or(5432, |(_, p)| *p)
290    }
291
292    /// Returns all configured host/port pairs.
293    pub fn hosts(&self) -> &[(String, u16)] {
294        &self.hosts
295    }
296
297    /// Load balancing strategy for multi-host connections.
298    pub fn load_balance_hosts(&self) -> LoadBalanceHosts {
299        self.load_balance_hosts
300    }
301
302    /// Target session attributes for connection routing.
303    pub fn target_session_attrs(&self) -> TargetSessionAttrs {
304        self.target_session_attrs
305    }
306
307    pub fn database(&self) -> &str {
308        &self.database
309    }
310
311    pub fn user(&self) -> &str {
312        &self.user
313    }
314
315    pub fn password(&self) -> Option<&str> {
316        self.password.as_deref()
317    }
318
319    pub fn ssl_mode(&self) -> SslMode {
320        self.ssl_mode
321    }
322
323    pub fn application_name(&self) -> Option<&str> {
324        self.application_name.as_deref()
325    }
326
327    pub fn connect_timeout(&self) -> Duration {
328        self.connect_timeout
329    }
330
331    pub fn statement_timeout(&self) -> Option<Duration> {
332        self.statement_timeout
333    }
334
335    /// Path to client certificate for certificate authentication.
336    pub fn ssl_client_cert(&self) -> Option<&std::path::Path> {
337        self.ssl_client_cert.as_deref()
338    }
339
340    /// Path to client private key for certificate authentication.
341    pub fn ssl_client_key(&self) -> Option<&std::path::Path> {
342        self.ssl_client_key.as_deref()
343    }
344
345    /// Whether direct TLS (PG 17+) is enabled.
346    pub fn ssl_direct(&self) -> bool {
347        self.ssl_direct
348    }
349
350    /// Channel binding preference for SCRAM authentication.
351    pub fn channel_binding(&self) -> ChannelBinding {
352        self.channel_binding
353    }
354
355    /// Install an `Instrumentation` impl. Inherited by every `Connection` and
356    /// `Pool` built from this `Config`.
357    pub fn with_instrumentation(
358        mut self,
359        instr: std::sync::Arc<dyn crate::Instrumentation>,
360    ) -> Self {
361        self.instrumentation = Some(instr);
362        self
363    }
364}
365
366impl std::fmt::Debug for Config {
367    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368        f.debug_struct("Config")
369            .field("hosts", &self.hosts)
370            .field("database", &self.database)
371            .field("user", &self.user)
372            .field("password", &self.password.as_ref().map(|_| "..."))
373            .field("ssl_mode", &self.ssl_mode)
374            .field("application_name", &self.application_name)
375            .field("connect_timeout", &self.connect_timeout)
376            .field("statement_timeout", &self.statement_timeout)
377            .field("_keepalive", &self.keepalive)
378            .field("_keepalive_idle", &self.keepalive_idle)
379            .field("_extra_float_digits", &self.extra_float_digits)
380            .field("target_session_attrs", &self.target_session_attrs)
381            .field("load_balance_hosts", &self.load_balance_hosts)
382            .field("ssl_client_cert", &self.ssl_client_cert)
383            .field("ssl_client_key", &self.ssl_client_key)
384            .field("ssl_direct", &self.ssl_direct)
385            .field("channel_binding", &self.channel_binding)
386            .field(
387                "instrumentation",
388                &self.instrumentation.as_ref().map(|_| "..."),
389            )
390            .finish()
391    }
392}
393
394/// Builder for [`Config`].
395#[derive(Debug, Clone)]
396pub struct ConfigBuilder {
397    hosts: Vec<(String, u16)>,
398    default_port: u16,
399    database: String,
400    user: String,
401    password: Option<String>,
402    ssl_mode: SslMode,
403    application_name: Option<String>,
404    connect_timeout: Duration,
405    statement_timeout: Option<Duration>,
406    keepalive: Option<Duration>,
407    keepalive_idle: Option<Duration>,
408    target_session_attrs: TargetSessionAttrs,
409    extra_float_digits: Option<i32>,
410    load_balance_hosts: LoadBalanceHosts,
411    ssl_client_cert: Option<PathBuf>,
412    ssl_client_key: Option<PathBuf>,
413    ssl_direct: bool,
414    channel_binding: ChannelBinding,
415}
416
417impl ConfigBuilder {
418    fn new() -> Self {
419        Self {
420            hosts: Vec::new(),
421            default_port: 5432,
422            database: String::new(),
423            user: String::new(),
424            password: None,
425            ssl_mode: SslMode::default(),
426            application_name: None,
427            connect_timeout: Duration::from_secs(10),
428            statement_timeout: None,
429            keepalive: Some(Duration::from_secs(60)),
430            keepalive_idle: None,
431            target_session_attrs: TargetSessionAttrs::default(),
432            extra_float_digits: Some(3),
433            load_balance_hosts: LoadBalanceHosts::default(),
434            ssl_client_cert: None,
435            ssl_client_key: None,
436            ssl_direct: false,
437            channel_binding: ChannelBinding::default(),
438        }
439    }
440
441    /// Append a host with the current default port.
442    pub fn host(mut self, host: impl Into<String>) -> Self {
443        self.hosts.push((host.into(), self.default_port));
444        self
445    }
446
447    /// Append a host with a specific port.
448    pub fn host_port(mut self, host: impl Into<String>, port: u16) -> Self {
449        self.hosts.push((host.into(), port));
450        self
451    }
452
453    /// Set the default port for subsequent `.host()` calls and update
454    /// any hosts that still have the old default port.
455    pub fn port(mut self, port: u16) -> Self {
456        let old_default = self.default_port;
457        self.default_port = port;
458        for (_, p) in &mut self.hosts {
459            if *p == old_default {
460                *p = port;
461            }
462        }
463        self
464    }
465
466    pub fn load_balance_hosts(mut self, strategy: LoadBalanceHosts) -> Self {
467        self.load_balance_hosts = strategy;
468        self
469    }
470
471    pub fn database(mut self, database: impl Into<String>) -> Self {
472        self.database = database.into();
473        self
474    }
475
476    pub fn user(mut self, user: impl Into<String>) -> Self {
477        self.user = user.into();
478        self
479    }
480
481    pub fn password(mut self, password: impl Into<String>) -> Self {
482        self.password = Some(password.into());
483        self
484    }
485
486    pub fn ssl_mode(mut self, ssl_mode: SslMode) -> Self {
487        self.ssl_mode = ssl_mode;
488        self
489    }
490
491    pub fn application_name(mut self, name: impl Into<String>) -> Self {
492        self.application_name = Some(name.into());
493        self
494    }
495
496    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
497        self.connect_timeout = timeout;
498        self
499    }
500
501    pub fn statement_timeout(mut self, timeout: Duration) -> Self {
502        self.statement_timeout = Some(timeout);
503        self
504    }
505
506    pub fn keepalive(mut self, interval: Duration) -> Self {
507        self.keepalive = Some(interval);
508        self
509    }
510
511    pub fn target_session_attrs(mut self, attrs: TargetSessionAttrs) -> Self {
512        self.target_session_attrs = attrs;
513        self
514    }
515
516    /// Set the path to the client certificate file for certificate authentication.
517    pub fn ssl_client_cert(mut self, path: impl Into<PathBuf>) -> Self {
518        self.ssl_client_cert = Some(path.into());
519        self
520    }
521
522    /// Set the path to the client private key file for certificate authentication.
523    pub fn ssl_client_key(mut self, path: impl Into<PathBuf>) -> Self {
524        self.ssl_client_key = Some(path.into());
525        self
526    }
527
528    /// Enable direct TLS connection (PG 17+), skipping SSLRequest negotiation.
529    pub fn ssl_direct(mut self, direct: bool) -> Self {
530        self.ssl_direct = direct;
531        self
532    }
533
534    /// Set the channel binding preference for SCRAM authentication.
535    pub fn channel_binding(mut self, binding: ChannelBinding) -> Self {
536        self.channel_binding = binding;
537        self
538    }
539
540    /// Build the final `Config`.
541    pub fn build(self) -> Config {
542        let hosts = if self.hosts.is_empty() {
543            vec![("localhost".to_string(), self.default_port)]
544        } else {
545            self.hosts
546        };
547        Config {
548            hosts,
549            database: self.database,
550            user: self.user,
551            password: self.password,
552            ssl_mode: self.ssl_mode,
553            application_name: self.application_name,
554            connect_timeout: self.connect_timeout,
555            statement_timeout: self.statement_timeout,
556            keepalive: self.keepalive,
557            keepalive_idle: self.keepalive_idle,
558            target_session_attrs: self.target_session_attrs,
559            extra_float_digits: self.extra_float_digits,
560            load_balance_hosts: self.load_balance_hosts,
561            ssl_client_cert: self.ssl_client_cert,
562            ssl_client_key: self.ssl_client_key,
563            ssl_direct: self.ssl_direct,
564            channel_binding: self.channel_binding,
565            instrumentation: None,
566        }
567    }
568}
569
570/// Percent-decode a URL component.
571fn percent_decode(s: &str) -> Result<String> {
572    let mut result = String::with_capacity(s.len());
573    let mut chars = s.as_bytes().iter();
574
575    while let Some(&b) = chars.next() {
576        if b == b'%' {
577            let hi = chars
578                .next()
579                .ok_or_else(|| Error::Config("incomplete percent encoding".into()))?;
580            let lo = chars
581                .next()
582                .ok_or_else(|| Error::Config("incomplete percent encoding".into()))?;
583            let byte = hex_digit(*hi)? << 4 | hex_digit(*lo)?;
584            result.push(byte as char);
585        } else {
586            result.push(b as char);
587        }
588    }
589
590    Ok(result)
591}
592
593fn hex_digit(b: u8) -> Result<u8> {
594    match b {
595        b'0'..=b'9' => Ok(b - b'0'),
596        b'a'..=b'f' => Ok(b - b'a' + 10),
597        b'A'..=b'F' => Ok(b - b'A' + 10),
598        _ => Err(Error::Config(format!("invalid hex digit: {}", b as char))),
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605
606    #[test]
607    fn config_builder_build_populates_all_fields() {
608        let cfg = ConfigBuilder::new()
609            .host("localhost".to_string())
610            .port(5432)
611            .database("test".to_string())
612            .user("postgres".to_string())
613            .password("secret".to_string())
614            .application_name("test_app".to_string())
615            .connect_timeout(std::time::Duration::from_secs(10))
616            .channel_binding(ChannelBinding::Prefer)
617            .build();
618        assert_eq!(cfg.user, "postgres");
619        assert_eq!(cfg.database, "test");
620        assert_eq!(cfg.application_name.as_deref(), Some("test_app"));
621        assert_eq!(cfg.channel_binding, ChannelBinding::Prefer);
622        // keepalive defaults to Some(60s) in ConfigBuilder::new()
623        assert_eq!(cfg.keepalive, Some(std::time::Duration::from_secs(60)));
624        assert!(cfg.keepalive_idle.is_none());
625        // extra_float_digits defaults to Some(3) in ConfigBuilder::new()
626        assert_eq!(cfg.extra_float_digits, Some(3));
627        assert!(cfg.instrumentation.is_none());
628    }
629
630    #[test]
631    fn channel_binding_accessor() {
632        let cfg = ConfigBuilder::new()
633            .channel_binding(ChannelBinding::Require)
634            .build();
635        assert_eq!(cfg.channel_binding(), ChannelBinding::Require);
636    }
637
638    #[test]
639    fn with_instrumentation_sets_field() {
640        struct NoOp;
641        impl crate::Instrumentation for NoOp {
642            fn on_event(&self, _: &crate::Event<'_>) {}
643        }
644        let cfg = ConfigBuilder::new()
645            .build()
646            .with_instrumentation(std::sync::Arc::new(NoOp));
647        assert!(cfg.instrumentation.is_some());
648    }
649
650    #[test]
651    fn debug_redacts_password() {
652        let cfg = ConfigBuilder::new()
653            .password("super_secret".to_string())
654            .build();
655        let s = format!("{cfg:?}");
656        assert!(!s.contains("super_secret"));
657        assert!(s.contains("password"));
658    }
659}