xitca_postgres/
config.rs

1//! Connection configuration. copy/paste from `tokio-postgres`
2
3use core::{fmt, iter, mem, str};
4
5use std::{
6    borrow::Cow,
7    path::{Path, PathBuf},
8};
9
10use super::{error::Error, session::TargetSessionAttrs};
11
12#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
13#[non_exhaustive]
14pub enum SslMode {
15    /// Do not use TLS.
16    #[cfg_attr(not(feature = "tls"), default)]
17    Disable,
18    /// Attempt to connect with TLS but allow sessions without.
19    #[cfg_attr(feature = "tls", default)]
20    Prefer,
21    /// Require the use of TLS.
22    Require,
23}
24
25/// TLS negotiation configuration
26#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
27#[non_exhaustive]
28pub enum SslNegotiation {
29    /// Use PostgreSQL SslRequest for Ssl negotiation
30    #[default]
31    Postgres,
32    /// Start Ssl handshake without negotiation, only works for PostgreSQL 17+
33    Direct,
34}
35
36/// A host specification.
37#[derive(Clone, Debug, Eq, PartialEq)]
38pub enum Host {
39    /// A TCP hostname.
40    Tcp(Box<str>),
41    Quic(Box<str>),
42    /// A Unix hostname.
43    Unix(PathBuf),
44}
45
46#[derive(Clone, Eq, PartialEq)]
47pub struct Config {
48    pub(crate) user: Option<Box<str>>,
49    pub(crate) password: Option<Box<[u8]>>,
50    pub(crate) dbname: Option<Box<str>>,
51    pub(crate) options: Option<Box<str>>,
52    pub(crate) application_name: Option<Box<str>>,
53    pub(crate) ssl_mode: SslMode,
54    pub(crate) ssl_negotiation: SslNegotiation,
55    pub(crate) host: Vec<Host>,
56    pub(crate) port: Vec<u16>,
57    target_session_attrs: TargetSessionAttrs,
58    tls_server_end_point: Option<Box<[u8]>>,
59}
60
61impl Default for Config {
62    fn default() -> Config {
63        Config::new()
64    }
65}
66
67impl Config {
68    /// Creates a new configuration.
69    pub fn new() -> Config {
70        Config {
71            user: None,
72            password: None,
73            dbname: None,
74            options: None,
75            application_name: None,
76            ssl_mode: SslMode::default(),
77            ssl_negotiation: SslNegotiation::Postgres,
78            host: Vec::new(),
79            port: Vec::new(),
80            target_session_attrs: TargetSessionAttrs::Any,
81            tls_server_end_point: None,
82        }
83    }
84
85    /// Sets the user to authenticate with.
86    ///
87    /// Required.
88    pub fn user(&mut self, user: &str) -> &mut Config {
89        self.user = Some(Box::from(user));
90        self
91    }
92
93    /// Gets the user to authenticate with, if one has been configured with
94    /// the `user` method.
95    pub fn get_user(&self) -> Option<&str> {
96        self.user.as_deref()
97    }
98
99    /// Sets the password to authenticate with.
100    pub fn password<T>(&mut self, password: T) -> &mut Config
101    where
102        T: AsRef<[u8]>,
103    {
104        self.password = Some(Box::from(password.as_ref()));
105        self
106    }
107
108    /// Gets the password to authenticate with, if one has been configured with
109    /// the `password` method.
110    pub fn get_password(&self) -> Option<&[u8]> {
111        self.password.as_deref()
112    }
113
114    /// Sets the name of the database to connect to.
115    ///
116    /// Defaults to the user.
117    pub fn dbname(&mut self, dbname: &str) -> &mut Config {
118        self.dbname = Some(Box::from(dbname));
119        self
120    }
121
122    /// Gets the name of the database to connect to, if one has been configured
123    /// with the `dbname` method.
124    pub fn get_dbname(&self) -> Option<&str> {
125        self.dbname.as_deref()
126    }
127
128    /// Sets command line options used to configure the server.
129    pub fn options(&mut self, options: &str) -> &mut Config {
130        self.options = Some(Box::from(options));
131        self
132    }
133
134    /// Gets the command line options used to configure the server, if the
135    /// options have been set with the `options` method.
136    pub fn get_options(&self) -> Option<&str> {
137        self.options.as_deref()
138    }
139
140    /// Sets the value of the `application_name` runtime parameter.
141    pub fn application_name(&mut self, application_name: &str) -> &mut Config {
142        self.application_name = Some(Box::from(application_name));
143        self
144    }
145
146    /// Gets the value of the `application_name` runtime parameter, if it has
147    /// been set with the `application_name` method.
148    pub fn get_application_name(&self) -> Option<&str> {
149        self.application_name.as_deref()
150    }
151
152    /// Sets the SSL configuration.
153    ///
154    /// Defaults to `prefer`.
155    pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
156        self.ssl_mode = ssl_mode;
157        self
158    }
159
160    /// Gets the SSL configuration.
161    pub fn get_ssl_mode(&self) -> SslMode {
162        self.ssl_mode
163    }
164
165    /// Sets the SSL negotiation method.
166    ///
167    /// Defaults to `postgres`.
168    pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config {
169        self.ssl_negotiation = ssl_negotiation;
170        self
171    }
172
173    /// Gets the SSL negotiation method.
174    pub fn get_ssl_negotiation(&self) -> SslNegotiation {
175        self.ssl_negotiation
176    }
177
178    pub fn host(&mut self, host: &str) -> &mut Config {
179        if host.starts_with('/') {
180            return self.host_path(host);
181        }
182
183        let host = Host::Tcp(Box::from(host));
184
185        self.host.push(host);
186        self
187    }
188
189    /// Adds a Unix socket host to the configuration.
190    ///
191    /// Unlike `host`, this method allows non-UTF8 paths.
192    pub fn host_path<T>(&mut self, host: T) -> &mut Config
193    where
194        T: AsRef<Path>,
195    {
196        self.host.push(Host::Unix(host.as_ref().to_path_buf()));
197        self
198    }
199
200    /// Gets the hosts that have been added to the configuration with `host`.
201    pub fn get_hosts(&self) -> &[Host] {
202        &self.host
203    }
204
205    /// Adds a port to the configuration.
206    ///
207    /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
208    /// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
209    /// as hosts.
210    pub fn port(&mut self, port: u16) -> &mut Config {
211        self.port.push(port);
212        self
213    }
214
215    /// Gets the ports that have been added to the configuration with `port`.
216    pub fn get_ports(&self) -> &[u16] {
217        &self.port
218    }
219
220    /// Sets the requirements of the session.
221    ///
222    /// This can be used to connect to the primary server in a clustered database rather than one of the read-only
223    /// secondary servers. Defaults to `Any`.
224    pub fn target_session_attrs(&mut self, target_session_attrs: TargetSessionAttrs) -> &mut Config {
225        self.target_session_attrs = target_session_attrs;
226        self
227    }
228
229    /// Gets the requirements of the session.
230    pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
231        self.target_session_attrs
232    }
233
234    /// change the remote peer's tls certificates. it's often coupled with [`Postgres::connect_io`] API for manual tls
235    /// session connecting and channel binding authentication.
236    /// # Examples
237    /// ```rust
238    /// use xitca_postgres::{Config, Postgres};
239    ///
240    /// // handle tls connection on your own.
241    /// async fn connect_io() {
242    ///     let mut cfg = Config::try_from("postgres://postgres:postgres@localhost/postgres").unwrap();
243    ///     
244    ///     // an imaginary function where you establish a tls connection to database on your own.
245    ///     // the established connection should be providing valid cert bytes.
246    ///     let (io, certs) = your_tls_connector().await;
247    ///
248    ///     // set cert bytes to configuration
249    ///     cfg.tls_server_end_point(certs);
250    ///
251    ///     // give xitca-postgres the config and established io and finish db session process.
252    ///     let _ = Postgres::new(cfg).connect_io(io).await;
253    /// }
254    ///
255    /// async fn your_tls_connector() -> (MyTlsStream, Vec<u8>) {
256    ///     todo!("your tls connecting logic lives here. the process can be async or not.")
257    /// }
258    ///
259    /// // a possible type representation of your manual tls connection to database
260    /// struct MyTlsStream;
261    ///
262    /// # use std::{io, pin::Pin, task::{Context, Poll}};
263    /// #
264    /// # use xitca_io::io::{AsyncIo, Interest, Ready};
265    /// #   
266    /// # impl AsyncIo for MyTlsStream {
267    /// #   async fn ready(&mut self, interest: Interest) -> io::Result<Ready> {
268    /// #       todo!()
269    /// #   }
270    /// #
271    /// #   fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
272    /// #       todo!()
273    /// #   }
274    /// #   
275    /// #   fn is_vectored_write(&self) -> bool {
276    /// #       false
277    /// #   }
278    /// #   
279    /// #   fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
280    /// #       Poll::Ready(Ok(()))
281    /// #   }
282    /// # }
283    /// #   
284    /// # impl io::Read for MyTlsStream {
285    /// #   fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
286    /// #       todo!()
287    /// #   }
288    /// # }   
289    /// #
290    /// # impl io::Write for MyTlsStream {
291    /// #   fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
292    /// #       todo!()
293    /// #   }
294    /// #   
295    /// #   fn flush(&mut self) -> io::Result<()> {
296    /// #       Ok(())
297    /// #   }
298    /// # }
299    /// ```
300    ///
301    /// [`Postgres::connect_io`]: crate::Postgres::connect_io
302    pub fn tls_server_end_point(&mut self, tls_server_end_point: impl AsRef<[u8]>) -> &mut Self {
303        self.tls_server_end_point = Some(Box::from(tls_server_end_point.as_ref()));
304        self
305    }
306
307    pub fn get_tls_server_end_point(&self) -> Option<&[u8]> {
308        self.tls_server_end_point.as_deref()
309    }
310
311    fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
312        match key {
313            "user" => {
314                self.user(value);
315            }
316            "password" => {
317                self.password(value);
318            }
319            "dbname" => {
320                self.dbname(value);
321            }
322            "options" => {
323                self.options(value);
324            }
325            "application_name" => {
326                self.application_name(value);
327            }
328            "sslmode" => {
329                let mode = match value {
330                    "disable" => SslMode::Disable,
331                    "prefer" => SslMode::Prefer,
332                    "require" => SslMode::Require,
333                    _ => return Err(Error::todo()),
334                };
335                self.ssl_mode(mode);
336            }
337            "sslnegotiation" => {
338                let mode = match value {
339                    "postgres" => SslNegotiation::Postgres,
340                    "direct" => SslNegotiation::Direct,
341                    _ => return Err(Error::todo()),
342                };
343                self.ssl_negotiation(mode);
344            }
345            "host" => {
346                for host in value.split(',') {
347                    self.host(host);
348                }
349            }
350            "port" => {
351                for port in value.split(',') {
352                    let port = if port.is_empty() {
353                        5432
354                    } else {
355                        port.parse().map_err(|_| Error::todo())?
356                    };
357                    self.port(port);
358                }
359            }
360            "target_session_attrs" => {
361                let target_session_attrs = match value {
362                    "any" => TargetSessionAttrs::Any,
363                    "read-write" => TargetSessionAttrs::ReadWrite,
364                    "read-only" => TargetSessionAttrs::ReadOnly,
365                    _ => return Err(Error::todo()),
366                };
367                self.target_session_attrs(target_session_attrs);
368            }
369            _ => {
370                return Err(Error::todo());
371            }
372        }
373
374        Ok(())
375    }
376}
377
378impl TryFrom<String> for Config {
379    type Error = Error;
380
381    fn try_from(s: String) -> Result<Self, Self::Error> {
382        Self::try_from(s.as_str())
383    }
384}
385
386impl TryFrom<&str> for Config {
387    type Error = Error;
388
389    fn try_from(s: &str) -> Result<Self, Self::Error> {
390        match UrlParser::parse(s)? {
391            Some(config) => Ok(config),
392            None => Parser::parse(s),
393        }
394    }
395}
396
397// Omit password from debug output
398impl fmt::Debug for Config {
399    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
400        struct Redaction {}
401        impl fmt::Debug for Redaction {
402            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
403                write!(f, "_")
404            }
405        }
406
407        f.debug_struct("Config")
408            .field("user", &self.user)
409            .field("password", &self.password.as_ref().map(|_| Redaction {}))
410            .field("dbname", &self.dbname)
411            .field("options", &self.options)
412            .field("application_name", &self.application_name)
413            .field("host", &self.host)
414            .field("port", &self.port)
415            .field("target_session_attrs", &self.target_session_attrs)
416            .finish()
417    }
418}
419
420struct Parser<'a> {
421    s: &'a str,
422    it: iter::Peekable<str::CharIndices<'a>>,
423}
424
425impl<'a> Parser<'a> {
426    fn parse(s: &'a str) -> Result<Config, Error> {
427        let mut parser = Parser {
428            s,
429            it: s.char_indices().peekable(),
430        };
431
432        let mut config = Config::new();
433
434        while let Some((key, value)) = parser.parameter()? {
435            config.param(key, &value)?;
436        }
437
438        Ok(config)
439    }
440
441    fn skip_ws(&mut self) {
442        self.take_while(char::is_whitespace);
443    }
444
445    fn take_while<F>(&mut self, f: F) -> &'a str
446    where
447        F: Fn(char) -> bool,
448    {
449        let start = match self.it.peek() {
450            Some(&(i, _)) => i,
451            None => return "",
452        };
453
454        loop {
455            match self.it.peek() {
456                Some(&(_, c)) if f(c) => {
457                    self.it.next();
458                }
459                Some(&(i, _)) => return &self.s[start..i],
460                None => return &self.s[start..],
461            }
462        }
463    }
464
465    fn eat(&mut self, target: char) -> Result<(), Error> {
466        match self.it.next() {
467            Some((_, c)) if c == target => Ok(()),
468            Some((i, c)) => {
469                let _m = format!("unexpected character at byte {i}: expected `{target}` but got `{c}`");
470                Err(Error::todo())
471            }
472            None => Err(Error::todo()),
473        }
474    }
475
476    fn eat_if(&mut self, target: char) -> bool {
477        match self.it.peek() {
478            Some(&(_, c)) if c == target => {
479                self.it.next();
480                true
481            }
482            _ => false,
483        }
484    }
485
486    fn keyword(&mut self) -> Option<&'a str> {
487        let s = self.take_while(|c| match c {
488            c if c.is_whitespace() => false,
489            '=' => false,
490            _ => true,
491        });
492
493        if s.is_empty() { None } else { Some(s) }
494    }
495
496    fn value(&mut self) -> Result<String, Error> {
497        let value = if self.eat_if('\'') {
498            let value = self.quoted_value()?;
499            self.eat('\'')?;
500            value
501        } else {
502            self.simple_value()?
503        };
504
505        Ok(value)
506    }
507
508    fn simple_value(&mut self) -> Result<String, Error> {
509        let mut value = String::new();
510
511        while let Some(&(_, c)) = self.it.peek() {
512            if c.is_whitespace() {
513                break;
514            }
515
516            self.it.next();
517            if c == '\\' {
518                if let Some((_, c2)) = self.it.next() {
519                    value.push(c2);
520                }
521            } else {
522                value.push(c);
523            }
524        }
525
526        if value.is_empty() {
527            return Err(Error::todo());
528        }
529
530        Ok(value)
531    }
532
533    fn quoted_value(&mut self) -> Result<String, Error> {
534        let mut value = String::new();
535
536        while let Some(&(_, c)) = self.it.peek() {
537            if c == '\'' {
538                return Ok(value);
539            }
540
541            self.it.next();
542            if c == '\\' {
543                if let Some((_, c2)) = self.it.next() {
544                    value.push(c2);
545                }
546            } else {
547                value.push(c);
548            }
549        }
550
551        Err(Error::todo())
552    }
553
554    fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
555        self.skip_ws();
556        let keyword = match self.keyword() {
557            Some(keyword) => keyword,
558            None => return Ok(None),
559        };
560        self.skip_ws();
561        self.eat('=')?;
562        self.skip_ws();
563        let value = self.value()?;
564
565        Ok(Some((keyword, value)))
566    }
567}
568
569// This is a pretty sloppy "URL" parser, but it matches the behavior of libpq, where things really aren't very strict
570struct UrlParser<'a> {
571    s: &'a str,
572    config: Config,
573}
574
575impl<'a> UrlParser<'a> {
576    fn parse(s: &'a str) -> Result<Option<Config>, Error> {
577        let s = match Self::remove_url_prefix(s) {
578            Some(s) => s,
579            None => return Ok(None),
580        };
581
582        let mut parser = UrlParser {
583            s,
584            config: Config::new(),
585        };
586
587        parser.parse_credentials()?;
588        parser.parse_host()?;
589        parser.parse_path()?;
590        parser.parse_params()?;
591
592        Ok(Some(parser.config))
593    }
594
595    fn remove_url_prefix(s: &str) -> Option<&str> {
596        for prefix in &["postgres://", "postgresql://"] {
597            if let Some(stripped) = s.strip_prefix(prefix) {
598                return Some(stripped);
599            }
600        }
601
602        None
603    }
604
605    fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
606        match self.s.find(end) {
607            Some(pos) => {
608                let (head, tail) = self.s.split_at(pos);
609                self.s = tail;
610                Some(head)
611            }
612            None => None,
613        }
614    }
615
616    fn take_all(&mut self) -> &'a str {
617        mem::take(&mut self.s)
618    }
619
620    fn eat_byte(&mut self) {
621        self.s = &self.s[1..];
622    }
623
624    fn parse_credentials(&mut self) -> Result<(), Error> {
625        let creds = match self.take_until(&['@']) {
626            Some(creds) => creds,
627            None => return Ok(()),
628        };
629        self.eat_byte();
630
631        let mut it = creds.splitn(2, ':');
632        let user = self.decode(it.next().unwrap())?;
633        self.config.user(&user);
634
635        if let Some(password) = it.next() {
636            let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
637            self.config.password(password);
638        }
639
640        Ok(())
641    }
642
643    fn parse_host(&mut self) -> Result<(), Error> {
644        let host = match self.take_until(&['/', '?']) {
645            Some(host) => host,
646            None => self.take_all(),
647        };
648
649        if host.is_empty() {
650            return Ok(());
651        }
652
653        for chunk in host.split(',') {
654            let (host, port) = if chunk.starts_with('[') {
655                let idx = match chunk.find(']') {
656                    Some(idx) => idx,
657                    None => return Err(Error::todo()),
658                };
659
660                let host = &chunk[1..idx];
661                let remaining = &chunk[idx + 1..];
662                let port = if let Some(port) = remaining.strip_prefix(':') {
663                    Some(port)
664                } else if remaining.is_empty() {
665                    None
666                } else {
667                    return Err(Error::todo());
668                };
669
670                (host, port)
671            } else {
672                let mut it = chunk.splitn(2, ':');
673                (it.next().unwrap(), it.next())
674            };
675
676            self.host_param(host)?;
677            let port = self.decode(port.unwrap_or("5432"))?;
678            self.config.param("port", &port)?;
679        }
680
681        Ok(())
682    }
683
684    fn parse_path(&mut self) -> Result<(), Error> {
685        if !self.s.starts_with('/') {
686            return Ok(());
687        }
688        self.eat_byte();
689
690        let dbname = match self.take_until(&['?']) {
691            Some(dbname) => dbname,
692            None => self.take_all(),
693        };
694
695        if !dbname.is_empty() {
696            self.config.dbname(&self.decode(dbname)?);
697        }
698
699        Ok(())
700    }
701
702    fn parse_params(&mut self) -> Result<(), Error> {
703        if !self.s.starts_with('?') {
704            return Ok(());
705        }
706        self.eat_byte();
707
708        while !self.s.is_empty() {
709            let key = match self.take_until(&['=']) {
710                Some(key) => self.decode(key)?,
711                None => return Err(Error::todo()),
712            };
713            self.eat_byte();
714
715            let value = match self.take_until(&['&']) {
716                Some(value) => {
717                    self.eat_byte();
718                    value
719                }
720                None => self.take_all(),
721            };
722
723            if key == "host" {
724                self.host_param(value)?;
725            } else {
726                let value = self.decode(value)?;
727                self.config.param(&key, &value)?;
728            }
729        }
730
731        Ok(())
732    }
733
734    fn host_param(&mut self, s: &str) -> Result<(), Error> {
735        let s = self.decode(s)?;
736        self.config.param("host", &s)
737    }
738
739    fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
740        percent_encoding::percent_decode(s.as_bytes())
741            .decode_utf8()
742            .map_err(|_| Error::todo())
743    }
744}