1#[cfg(feature = "runtime")]
4use crate::connect::connect;
5use crate::connect_raw::connect_raw;
6#[cfg(feature = "runtime")]
7use crate::tls::MakeTlsConnect;
8use crate::tls::TlsConnect;
9#[cfg(feature = "runtime")]
10use crate::Socket;
11use crate::{Client, Connection, Error};
12use std::borrow::Cow;
13#[cfg(all(unix, not(madsim)))]
14use std::ffi::OsStr;
15#[cfg(all(unix, not(madsim)))]
16use std::os::unix::ffi::OsStrExt;
17#[cfg(all(unix, not(madsim)))]
18use std::path::{Path, PathBuf};
19use std::str;
20use std::str::FromStr;
21use std::time::Duration;
22use std::{error, fmt, iter, mem};
23use tokio::io::{AsyncRead, AsyncWrite};
24
25#[derive(Debug, Copy, Clone, PartialEq)]
27#[non_exhaustive]
28pub enum TargetSessionAttrs {
29 Any,
31 ReadWrite,
33}
34
35#[derive(Debug, Copy, Clone, PartialEq)]
37#[non_exhaustive]
38pub enum SslMode {
39 Disable,
41 Prefer,
43 Require,
45}
46
47#[derive(Debug, Copy, Clone, PartialEq)]
49#[non_exhaustive]
50pub enum ChannelBinding {
51 Disable,
53 Prefer,
55 Require,
57}
58
59#[derive(Debug, Clone, PartialEq)]
61pub enum Host {
62 Tcp(String),
64 #[cfg(all(unix, not(madsim)))]
68 Unix(PathBuf),
69}
70
71#[derive(PartialEq, Clone)]
148pub struct Config {
149 pub(crate) user: Option<String>,
150 pub(crate) password: Option<Vec<u8>>,
151 pub(crate) dbname: Option<String>,
152 pub(crate) options: Option<String>,
153 pub(crate) application_name: Option<String>,
154 pub(crate) ssl_mode: SslMode,
155 pub(crate) host: Vec<Host>,
156 pub(crate) port: Vec<u16>,
157 pub(crate) connect_timeout: Option<Duration>,
158 pub(crate) keepalives: bool,
159 pub(crate) keepalives_idle: Duration,
160 pub(crate) target_session_attrs: TargetSessionAttrs,
161 pub(crate) channel_binding: ChannelBinding,
162}
163
164impl Default for Config {
165 fn default() -> Config {
166 Config::new()
167 }
168}
169
170impl Config {
171 pub fn new() -> Config {
173 Config {
174 user: None,
175 password: None,
176 dbname: None,
177 options: None,
178 application_name: None,
179 ssl_mode: SslMode::Prefer,
180 host: vec![],
181 port: vec![],
182 connect_timeout: None,
183 keepalives: true,
184 keepalives_idle: Duration::from_secs(2 * 60 * 60),
185 target_session_attrs: TargetSessionAttrs::Any,
186 channel_binding: ChannelBinding::Prefer,
187 }
188 }
189
190 pub fn user(&mut self, user: &str) -> &mut Config {
194 self.user = Some(user.to_string());
195 self
196 }
197
198 pub fn get_user(&self) -> Option<&str> {
201 self.user.as_deref()
202 }
203
204 pub fn password<T>(&mut self, password: T) -> &mut Config
206 where
207 T: AsRef<[u8]>,
208 {
209 self.password = Some(password.as_ref().to_vec());
210 self
211 }
212
213 pub fn get_password(&self) -> Option<&[u8]> {
216 self.password.as_deref()
217 }
218
219 pub fn dbname(&mut self, dbname: &str) -> &mut Config {
223 self.dbname = Some(dbname.to_string());
224 self
225 }
226
227 pub fn get_dbname(&self) -> Option<&str> {
230 self.dbname.as_deref()
231 }
232
233 pub fn options(&mut self, options: &str) -> &mut Config {
235 self.options = Some(options.to_string());
236 self
237 }
238
239 pub fn get_options(&self) -> Option<&str> {
242 self.options.as_deref()
243 }
244
245 pub fn application_name(&mut self, application_name: &str) -> &mut Config {
247 self.application_name = Some(application_name.to_string());
248 self
249 }
250
251 pub fn get_application_name(&self) -> Option<&str> {
254 self.application_name.as_deref()
255 }
256
257 pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
261 self.ssl_mode = ssl_mode;
262 self
263 }
264
265 pub fn get_ssl_mode(&self) -> SslMode {
267 self.ssl_mode
268 }
269
270 pub fn host(&mut self, host: &str) -> &mut Config {
275 #[cfg(all(unix, not(madsim)))]
276 {
277 if host.starts_with('/') {
278 return self.host_path(host);
279 }
280 }
281
282 self.host.push(Host::Tcp(host.to_string()));
283 self
284 }
285
286 pub fn get_hosts(&self) -> &[Host] {
288 &self.host
289 }
290
291 #[cfg(all(unix, not(madsim)))]
295 pub fn host_path<T>(&mut self, host: T) -> &mut Config
296 where
297 T: AsRef<Path>,
298 {
299 self.host.push(Host::Unix(host.as_ref().to_path_buf()));
300 self
301 }
302
303 pub fn port(&mut self, port: u16) -> &mut Config {
309 self.port.push(port);
310 self
311 }
312
313 pub fn get_ports(&self) -> &[u16] {
315 &self.port
316 }
317
318 pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
323 self.connect_timeout = Some(connect_timeout);
324 self
325 }
326
327 pub fn get_connect_timeout(&self) -> Option<&Duration> {
330 self.connect_timeout.as_ref()
331 }
332
333 pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
337 self.keepalives = keepalives;
338 self
339 }
340
341 pub fn get_keepalives(&self) -> bool {
343 self.keepalives
344 }
345
346 pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
350 self.keepalives_idle = keepalives_idle;
351 self
352 }
353
354 pub fn get_keepalives_idle(&self) -> Duration {
357 self.keepalives_idle
358 }
359
360 pub fn target_session_attrs(
365 &mut self,
366 target_session_attrs: TargetSessionAttrs,
367 ) -> &mut Config {
368 self.target_session_attrs = target_session_attrs;
369 self
370 }
371
372 pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
374 self.target_session_attrs
375 }
376
377 pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
381 self.channel_binding = channel_binding;
382 self
383 }
384
385 pub fn get_channel_binding(&self) -> ChannelBinding {
387 self.channel_binding
388 }
389
390 fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
391 match key {
392 "user" => {
393 self.user(value);
394 }
395 "password" => {
396 self.password(value);
397 }
398 "dbname" => {
399 self.dbname(value);
400 }
401 "options" => {
402 self.options(value);
403 }
404 "application_name" => {
405 self.application_name(value);
406 }
407 "sslmode" => {
408 let mode = match value {
409 "disable" => SslMode::Disable,
410 "prefer" => SslMode::Prefer,
411 "require" => SslMode::Require,
412 _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
413 };
414 self.ssl_mode(mode);
415 }
416 "host" => {
417 for host in value.split(',') {
418 self.host(host);
419 }
420 }
421 "port" => {
422 for port in value.split(',') {
423 let port = if port.is_empty() {
424 5432
425 } else {
426 port.parse()
427 .map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))?
428 };
429 self.port(port);
430 }
431 }
432 "connect_timeout" => {
433 let timeout = value
434 .parse::<i64>()
435 .map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?;
436 if timeout > 0 {
437 self.connect_timeout(Duration::from_secs(timeout as u64));
438 }
439 }
440 "keepalives" => {
441 let keepalives = value
442 .parse::<u64>()
443 .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives"))))?;
444 self.keepalives(keepalives != 0);
445 }
446 "keepalives_idle" => {
447 let keepalives_idle = value
448 .parse::<i64>()
449 .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives_idle"))))?;
450 if keepalives_idle > 0 {
451 self.keepalives_idle(Duration::from_secs(keepalives_idle as u64));
452 }
453 }
454 "target_session_attrs" => {
455 let target_session_attrs = match &*value {
456 "any" => TargetSessionAttrs::Any,
457 "read-write" => TargetSessionAttrs::ReadWrite,
458 _ => {
459 return Err(Error::config_parse(Box::new(InvalidValue(
460 "target_session_attrs",
461 ))));
462 }
463 };
464 self.target_session_attrs(target_session_attrs);
465 }
466 "channel_binding" => {
467 let channel_binding = match value {
468 "disable" => ChannelBinding::Disable,
469 "prefer" => ChannelBinding::Prefer,
470 "require" => ChannelBinding::Require,
471 _ => {
472 return Err(Error::config_parse(Box::new(InvalidValue(
473 "channel_binding",
474 ))))
475 }
476 };
477 self.channel_binding(channel_binding);
478 }
479 key => {
480 return Err(Error::config_parse(Box::new(UnknownOption(
481 key.to_string(),
482 ))));
483 }
484 }
485
486 Ok(())
487 }
488
489 #[cfg(feature = "runtime")]
493 pub async fn connect<T>(&self, tls: T) -> Result<(Client, Connection<Socket, T::Stream>), Error>
494 where
495 T: MakeTlsConnect<Socket>,
496 {
497 connect(tls, self).await
498 }
499
500 pub async fn connect_raw<S, T>(
504 &self,
505 stream: S,
506 tls: T,
507 ) -> Result<(Client, Connection<S, T::Stream>), Error>
508 where
509 S: AsyncRead + AsyncWrite + Unpin,
510 T: TlsConnect<S>,
511 {
512 connect_raw(stream, tls, self).await
513 }
514}
515
516impl FromStr for Config {
517 type Err = Error;
518
519 fn from_str(s: &str) -> Result<Config, Error> {
520 match UrlParser::parse(s)? {
521 Some(config) => Ok(config),
522 None => Parser::parse(s),
523 }
524 }
525}
526
527impl fmt::Debug for Config {
529 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
530 struct Redaction {}
531 impl fmt::Debug for Redaction {
532 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
533 write!(f, "_")
534 }
535 }
536
537 f.debug_struct("Config")
538 .field("user", &self.user)
539 .field("password", &self.password.as_ref().map(|_| Redaction {}))
540 .field("dbname", &self.dbname)
541 .field("options", &self.options)
542 .field("application_name", &self.application_name)
543 .field("ssl_mode", &self.ssl_mode)
544 .field("host", &self.host)
545 .field("port", &self.port)
546 .field("connect_timeout", &self.connect_timeout)
547 .field("keepalives", &self.keepalives)
548 .field("keepalives_idle", &self.keepalives_idle)
549 .field("target_session_attrs", &self.target_session_attrs)
550 .field("channel_binding", &self.channel_binding)
551 .finish()
552 }
553}
554
555#[derive(Debug)]
556struct UnknownOption(String);
557
558impl fmt::Display for UnknownOption {
559 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
560 write!(fmt, "unknown option `{}`", self.0)
561 }
562}
563
564impl error::Error for UnknownOption {}
565
566#[derive(Debug)]
567struct InvalidValue(&'static str);
568
569impl fmt::Display for InvalidValue {
570 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
571 write!(fmt, "invalid value for option `{}`", self.0)
572 }
573}
574
575impl error::Error for InvalidValue {}
576
577struct Parser<'a> {
578 s: &'a str,
579 it: iter::Peekable<str::CharIndices<'a>>,
580}
581
582impl<'a> Parser<'a> {
583 fn parse(s: &'a str) -> Result<Config, Error> {
584 let mut parser = Parser {
585 s,
586 it: s.char_indices().peekable(),
587 };
588
589 let mut config = Config::new();
590
591 while let Some((key, value)) = parser.parameter()? {
592 config.param(key, &value)?;
593 }
594
595 Ok(config)
596 }
597
598 fn skip_ws(&mut self) {
599 self.take_while(char::is_whitespace);
600 }
601
602 fn take_while<F>(&mut self, f: F) -> &'a str
603 where
604 F: Fn(char) -> bool,
605 {
606 let start = match self.it.peek() {
607 Some(&(i, _)) => i,
608 None => return "",
609 };
610
611 loop {
612 match self.it.peek() {
613 Some(&(_, c)) if f(c) => {
614 self.it.next();
615 }
616 Some(&(i, _)) => return &self.s[start..i],
617 None => return &self.s[start..],
618 }
619 }
620 }
621
622 fn eat(&mut self, target: char) -> Result<(), Error> {
623 match self.it.next() {
624 Some((_, c)) if c == target => Ok(()),
625 Some((i, c)) => {
626 let m = format!(
627 "unexpected character at byte {}: expected `{}` but got `{}`",
628 i, target, c
629 );
630 Err(Error::config_parse(m.into()))
631 }
632 None => Err(Error::config_parse("unexpected EOF".into())),
633 }
634 }
635
636 fn eat_if(&mut self, target: char) -> bool {
637 match self.it.peek() {
638 Some(&(_, c)) if c == target => {
639 self.it.next();
640 true
641 }
642 _ => false,
643 }
644 }
645
646 fn keyword(&mut self) -> Option<&'a str> {
647 let s = self.take_while(|c| match c {
648 c if c.is_whitespace() => false,
649 '=' => false,
650 _ => true,
651 });
652
653 if s.is_empty() {
654 None
655 } else {
656 Some(s)
657 }
658 }
659
660 fn value(&mut self) -> Result<String, Error> {
661 let value = if self.eat_if('\'') {
662 let value = self.quoted_value()?;
663 self.eat('\'')?;
664 value
665 } else {
666 self.simple_value()?
667 };
668
669 Ok(value)
670 }
671
672 fn simple_value(&mut self) -> Result<String, Error> {
673 let mut value = String::new();
674
675 while let Some(&(_, c)) = self.it.peek() {
676 if c.is_whitespace() {
677 break;
678 }
679
680 self.it.next();
681 if c == '\\' {
682 if let Some((_, c2)) = self.it.next() {
683 value.push(c2);
684 }
685 } else {
686 value.push(c);
687 }
688 }
689
690 if value.is_empty() {
691 return Err(Error::config_parse("unexpected EOF".into()));
692 }
693
694 Ok(value)
695 }
696
697 fn quoted_value(&mut self) -> Result<String, Error> {
698 let mut value = String::new();
699
700 while let Some(&(_, c)) = self.it.peek() {
701 if c == '\'' {
702 return Ok(value);
703 }
704
705 self.it.next();
706 if c == '\\' {
707 if let Some((_, c2)) = self.it.next() {
708 value.push(c2);
709 }
710 } else {
711 value.push(c);
712 }
713 }
714
715 Err(Error::config_parse(
716 "unterminated quoted connection parameter value".into(),
717 ))
718 }
719
720 fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
721 self.skip_ws();
722 let keyword = match self.keyword() {
723 Some(keyword) => keyword,
724 None => return Ok(None),
725 };
726 self.skip_ws();
727 self.eat('=')?;
728 self.skip_ws();
729 let value = self.value()?;
730
731 Ok(Some((keyword, value)))
732 }
733}
734
735struct UrlParser<'a> {
737 s: &'a str,
738 config: Config,
739}
740
741impl<'a> UrlParser<'a> {
742 fn parse(s: &'a str) -> Result<Option<Config>, Error> {
743 let s = match Self::remove_url_prefix(s) {
744 Some(s) => s,
745 None => return Ok(None),
746 };
747
748 let mut parser = UrlParser {
749 s,
750 config: Config::new(),
751 };
752
753 parser.parse_credentials()?;
754 parser.parse_host()?;
755 parser.parse_path()?;
756 parser.parse_params()?;
757
758 Ok(Some(parser.config))
759 }
760
761 fn remove_url_prefix(s: &str) -> Option<&str> {
762 for prefix in &["postgres://", "postgresql://"] {
763 if let Some(stripped) = s.strip_prefix(prefix) {
764 return Some(stripped);
765 }
766 }
767
768 None
769 }
770
771 fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
772 match self.s.find(end) {
773 Some(pos) => {
774 let (head, tail) = self.s.split_at(pos);
775 self.s = tail;
776 Some(head)
777 }
778 None => None,
779 }
780 }
781
782 fn take_all(&mut self) -> &'a str {
783 mem::take(&mut self.s)
784 }
785
786 fn eat_byte(&mut self) {
787 self.s = &self.s[1..];
788 }
789
790 fn parse_credentials(&mut self) -> Result<(), Error> {
791 let creds = match self.take_until(&['@']) {
792 Some(creds) => creds,
793 None => return Ok(()),
794 };
795 self.eat_byte();
796
797 let mut it = creds.splitn(2, ':');
798 let user = self.decode(it.next().unwrap())?;
799 self.config.user(&user);
800
801 if let Some(password) = it.next() {
802 let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
803 self.config.password(password);
804 }
805
806 Ok(())
807 }
808
809 fn parse_host(&mut self) -> Result<(), Error> {
810 let host = match self.take_until(&['/', '?']) {
811 Some(host) => host,
812 None => self.take_all(),
813 };
814
815 if host.is_empty() {
816 return Ok(());
817 }
818
819 for chunk in host.split(',') {
820 let (host, port) = if chunk.starts_with('[') {
821 let idx = match chunk.find(']') {
822 Some(idx) => idx,
823 None => return Err(Error::config_parse(InvalidValue("host").into())),
824 };
825
826 let host = &chunk[1..idx];
827 let remaining = &chunk[idx + 1..];
828 let port = if let Some(port) = remaining.strip_prefix(':') {
829 Some(port)
830 } else if remaining.is_empty() {
831 None
832 } else {
833 return Err(Error::config_parse(InvalidValue("host").into()));
834 };
835
836 (host, port)
837 } else {
838 let mut it = chunk.splitn(2, ':');
839 (it.next().unwrap(), it.next())
840 };
841
842 self.host_param(host)?;
843 let port = self.decode(port.unwrap_or("5432"))?;
844 self.config.param("port", &port)?;
845 }
846
847 Ok(())
848 }
849
850 fn parse_path(&mut self) -> Result<(), Error> {
851 if !self.s.starts_with('/') {
852 return Ok(());
853 }
854 self.eat_byte();
855
856 let dbname = match self.take_until(&['?']) {
857 Some(dbname) => dbname,
858 None => self.take_all(),
859 };
860
861 if !dbname.is_empty() {
862 self.config.dbname(&self.decode(dbname)?);
863 }
864
865 Ok(())
866 }
867
868 fn parse_params(&mut self) -> Result<(), Error> {
869 if !self.s.starts_with('?') {
870 return Ok(());
871 }
872 self.eat_byte();
873
874 while !self.s.is_empty() {
875 let key = match self.take_until(&['=']) {
876 Some(key) => self.decode(key)?,
877 None => return Err(Error::config_parse("unterminated parameter".into())),
878 };
879 self.eat_byte();
880
881 let value = match self.take_until(&['&']) {
882 Some(value) => {
883 self.eat_byte();
884 value
885 }
886 None => self.take_all(),
887 };
888
889 if key == "host" {
890 self.host_param(value)?;
891 } else {
892 let value = self.decode(value)?;
893 self.config.param(&key, &value)?;
894 }
895 }
896
897 Ok(())
898 }
899
900 #[cfg(all(unix, not(madsim)))]
901 fn host_param(&mut self, s: &str) -> Result<(), Error> {
902 let decoded = Cow::from(percent_encoding::percent_decode(s.as_bytes()));
903 if decoded.get(0) == Some(&b'/') {
904 self.config.host_path(OsStr::from_bytes(&decoded));
905 } else {
906 let decoded = str::from_utf8(&decoded).map_err(|e| Error::config_parse(Box::new(e)))?;
907 self.config.host(decoded);
908 }
909
910 Ok(())
911 }
912
913 #[cfg(not(all(unix, not(madsim))))]
914 fn host_param(&mut self, s: &str) -> Result<(), Error> {
915 let s = self.decode(s)?;
916 self.config.param("host", &s)
917 }
918
919 fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
920 percent_encoding::percent_decode(s.as_bytes())
921 .decode_utf8()
922 .map_err(|e| Error::config_parse(e.into()))
923 }
924}