1#[cfg(feature = "runtime")]
4use crate::connect::connect;
5use crate::connect_raw::connect_raw;
6#[cfg(feature = "runtime")]
7use crate::tls::MakeTlsConnect;
8use crate::tls::TlsConnect;
9#[cfg(feature = "runtime")]
10use crate::Socket;
11use crate::{Client, Connection, Error};
12use std::borrow::Cow;
13#[cfg(unix)]
14use std::ffi::OsStr;
15#[cfg(unix)]
16use std::os::unix::ffi::OsStrExt;
17#[cfg(unix)]
18use std::path::{Path, PathBuf};
19use std::str;
20use std::str::FromStr;
21use std::time::Duration;
22use std::{error, fmt, iter, mem};
23use tokio::io::{AsyncRead, AsyncWrite};
24
25#[derive(Debug, Copy, Clone, PartialEq)]
27#[non_exhaustive]
28pub enum TargetSessionAttrs {
29 Any,
31 ReadWrite,
33}
34
35#[derive(Debug, Copy, Clone, PartialEq)]
37#[non_exhaustive]
38pub enum SslMode {
39 Disable,
41 Prefer,
43 Require,
45}
46
47#[derive(Debug, Copy, Clone, PartialEq)]
49#[non_exhaustive]
50pub enum ChannelBinding {
51 Disable,
53 Prefer,
55 Require,
57}
58
59#[derive(Debug, Clone, PartialEq)]
61pub enum Host {
62 Tcp(String),
64 #[cfg(unix)]
68 Unix(PathBuf),
69}
70
71#[derive(PartialEq, Clone)]
149pub struct Config {
150 pub(crate) user: Option<String>,
151 pub(crate) password: Option<Vec<u8>>,
152 pub(crate) dbname: Option<String>,
153 pub(crate) options: Option<String>,
154 pub(crate) application_name: Option<String>,
155 pub(crate) ssl_mode: SslMode,
156 pub(crate) host: Vec<Host>,
157 pub(crate) port: Vec<u16>,
158 pub(crate) connect_timeout: Option<Duration>,
159 pub(crate) keepalives: bool,
160 pub(crate) keepalives_idle: Duration,
161 pub(crate) target_session_attrs: TargetSessionAttrs,
162 pub(crate) channel_binding: ChannelBinding,
163}
164
165impl Default for Config {
166 fn default() -> Config {
167 Config::new()
168 }
169}
170
171impl Config {
172 pub fn new() -> Config {
174 Config {
175 user: None,
176 password: None,
177 dbname: None,
178 options: None,
179 application_name: None,
180 ssl_mode: SslMode::Prefer,
181 host: vec![],
182 port: vec![],
183 connect_timeout: None,
184 keepalives: true,
185 keepalives_idle: Duration::from_secs(2 * 60 * 60),
186 target_session_attrs: TargetSessionAttrs::Any,
187 channel_binding: ChannelBinding::Prefer,
188 }
189 }
190
191 pub fn user(&mut self, user: &str) -> &mut Config {
195 self.user = Some(user.to_string());
196 self
197 }
198
199 pub fn get_user(&self) -> Option<&str> {
202 self.user.as_deref()
203 }
204
205 pub fn password<T>(&mut self, password: T) -> &mut Config
207 where
208 T: AsRef<[u8]>,
209 {
210 self.password = Some(password.as_ref().to_vec());
211 self
212 }
213
214 pub fn get_password(&self) -> Option<&[u8]> {
217 self.password.as_deref()
218 }
219
220 pub fn dbname(&mut self, dbname: &str) -> &mut Config {
224 self.dbname = Some(dbname.to_string());
225 self
226 }
227
228 pub fn get_dbname(&self) -> Option<&str> {
231 self.dbname.as_deref()
232 }
233
234 pub fn options(&mut self, options: &str) -> &mut Config {
236 self.options = Some(options.to_string());
237 self
238 }
239
240 pub fn get_options(&self) -> Option<&str> {
243 self.options.as_deref()
244 }
245
246 pub fn application_name(&mut self, application_name: &str) -> &mut Config {
248 self.application_name = Some(application_name.to_string());
249 self
250 }
251
252 pub fn get_application_name(&self) -> Option<&str> {
255 self.application_name.as_deref()
256 }
257
258 pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
262 self.ssl_mode = ssl_mode;
263 self
264 }
265
266 pub fn get_ssl_mode(&self) -> SslMode {
268 self.ssl_mode
269 }
270
271 pub fn host(&mut self, host: &str) -> &mut Config {
276 #[cfg(unix)]
277 {
278 if host.starts_with('/') {
279 return self.host_path(host);
280 }
281 }
282
283 self.host.push(Host::Tcp(host.to_string()));
284 self
285 }
286
287 pub fn get_hosts(&self) -> &[Host] {
289 &self.host
290 }
291
292 #[cfg(unix)]
296 pub fn host_path<T>(&mut self, host: T) -> &mut Config
297 where
298 T: AsRef<Path>,
299 {
300 self.host.push(Host::Unix(host.as_ref().to_path_buf()));
301 self
302 }
303
304 pub fn port(&mut self, port: u16) -> &mut Config {
310 self.port.push(port);
311 self
312 }
313
314 pub fn get_ports(&self) -> &[u16] {
316 &self.port
317 }
318
319 pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
324 self.connect_timeout = Some(connect_timeout);
325 self
326 }
327
328 pub fn get_connect_timeout(&self) -> Option<&Duration> {
331 self.connect_timeout.as_ref()
332 }
333
334 pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
338 self.keepalives = keepalives;
339 self
340 }
341
342 pub fn get_keepalives(&self) -> bool {
344 self.keepalives
345 }
346
347 pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
351 self.keepalives_idle = keepalives_idle;
352 self
353 }
354
355 pub fn get_keepalives_idle(&self) -> Duration {
358 self.keepalives_idle
359 }
360
361 pub fn target_session_attrs(
366 &mut self,
367 target_session_attrs: TargetSessionAttrs,
368 ) -> &mut Config {
369 self.target_session_attrs = target_session_attrs;
370 self
371 }
372
373 pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
375 self.target_session_attrs
376 }
377
378 pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
382 self.channel_binding = channel_binding;
383 self
384 }
385
386 pub fn get_channel_binding(&self) -> ChannelBinding {
388 self.channel_binding
389 }
390
391 fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
392 match key {
393 "user" => {
394 self.user(value);
395 }
396 "password" => {
397 self.password(value);
398 }
399 "dbname" => {
400 self.dbname(value);
401 }
402 "options" => {
403 self.options(value);
404 }
405 "application_name" => {
406 self.application_name(value);
407 }
408 "sslmode" => {
409 let mode = match value {
410 "disable" => SslMode::Disable,
411 "prefer" => SslMode::Prefer,
412 "require" => SslMode::Require,
413 _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
414 };
415 self.ssl_mode(mode);
416 }
417 "host" => {
418 for host in value.split(',') {
419 self.host(host);
420 }
421 }
422 "port" => {
423 for port in value.split(',') {
424 let port = if port.is_empty() {
425 5432
426 } else {
427 port.parse()
428 .map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))?
429 };
430 self.port(port);
431 }
432 }
433 "connect_timeout" => {
434 let timeout = value
435 .parse::<i64>()
436 .map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?;
437 if timeout > 0 {
438 self.connect_timeout(Duration::from_secs(timeout as u64));
439 }
440 }
441 "keepalives" => {
442 let keepalives = value
443 .parse::<u64>()
444 .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives"))))?;
445 self.keepalives(keepalives != 0);
446 }
447 "keepalives_idle" => {
448 let keepalives_idle = value
449 .parse::<i64>()
450 .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives_idle"))))?;
451 if keepalives_idle > 0 {
452 self.keepalives_idle(Duration::from_secs(keepalives_idle as u64));
453 }
454 }
455 "target_session_attrs" => {
456 let target_session_attrs = match &*value {
457 "any" => TargetSessionAttrs::Any,
458 "read-write" => TargetSessionAttrs::ReadWrite,
459 _ => {
460 return Err(Error::config_parse(Box::new(InvalidValue(
461 "target_session_attrs",
462 ))));
463 }
464 };
465 self.target_session_attrs(target_session_attrs);
466 }
467 "channel_binding" => {
468 let channel_binding = match value {
469 "disable" => ChannelBinding::Disable,
470 "prefer" => ChannelBinding::Prefer,
471 "require" => ChannelBinding::Require,
472 _ => {
473 return Err(Error::config_parse(Box::new(InvalidValue(
474 "channel_binding",
475 ))))
476 }
477 };
478 self.channel_binding(channel_binding);
479 }
480 key => {
481 return Err(Error::config_parse(Box::new(UnknownOption(
482 key.to_string(),
483 ))));
484 }
485 }
486
487 Ok(())
488 }
489
490 #[cfg(feature = "runtime")]
494 pub async fn connect<T>(&self, tls: T) -> Result<(Client, Connection<Socket, T::Stream>), Error>
495 where
496 T: MakeTlsConnect<Socket>,
497 {
498 connect(tls, self).await
499 }
500
501 pub async fn connect_raw<S, T>(
505 &self,
506 stream: S,
507 tls: T,
508 ) -> Result<(Client, Connection<S, T::Stream>), Error>
509 where
510 S: AsyncRead + AsyncWrite + Unpin,
511 T: TlsConnect<S>,
512 {
513 connect_raw(stream, tls, self).await
514 }
515}
516
517impl FromStr for Config {
518 type Err = Error;
519
520 fn from_str(s: &str) -> Result<Config, Error> {
521 match UrlParser::parse(s)? {
522 Some(config) => Ok(config),
523 None => Parser::parse(s),
524 }
525 }
526}
527
528impl fmt::Debug for Config {
530 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
531 struct Redaction {}
532 impl fmt::Debug for Redaction {
533 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
534 write!(f, "_")
535 }
536 }
537
538 f.debug_struct("Config")
539 .field("user", &self.user)
540 .field("password", &self.password.as_ref().map(|_| Redaction {}))
541 .field("dbname", &self.dbname)
542 .field("options", &self.options)
543 .field("application_name", &self.application_name)
544 .field("ssl_mode", &self.ssl_mode)
545 .field("host", &self.host)
546 .field("port", &self.port)
547 .field("connect_timeout", &self.connect_timeout)
548 .field("keepalives", &self.keepalives)
549 .field("keepalives_idle", &self.keepalives_idle)
550 .field("target_session_attrs", &self.target_session_attrs)
551 .field("channel_binding", &self.channel_binding)
552 .finish()
553 }
554}
555
556#[derive(Debug)]
557struct UnknownOption(String);
558
559impl fmt::Display for UnknownOption {
560 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
561 write!(fmt, "unknown option `{}`", self.0)
562 }
563}
564
565impl error::Error for UnknownOption {}
566
567#[derive(Debug)]
568struct InvalidValue(&'static str);
569
570impl fmt::Display for InvalidValue {
571 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
572 write!(fmt, "invalid value for option `{}`", self.0)
573 }
574}
575
576impl error::Error for InvalidValue {}
577
578struct Parser<'a> {
579 s: &'a str,
580 it: iter::Peekable<str::CharIndices<'a>>,
581}
582
583impl<'a> Parser<'a> {
584 fn parse(s: &'a str) -> Result<Config, Error> {
585 let mut parser = Parser {
586 s,
587 it: s.char_indices().peekable(),
588 };
589
590 let mut config = Config::new();
591
592 while let Some((key, value)) = parser.parameter()? {
593 config.param(key, &value)?;
594 }
595
596 Ok(config)
597 }
598
599 fn skip_ws(&mut self) {
600 self.take_while(char::is_whitespace);
601 }
602
603 fn take_while<F>(&mut self, f: F) -> &'a str
604 where
605 F: Fn(char) -> bool,
606 {
607 let start = match self.it.peek() {
608 Some(&(i, _)) => i,
609 None => return "",
610 };
611
612 loop {
613 match self.it.peek() {
614 Some(&(_, c)) if f(c) => {
615 self.it.next();
616 }
617 Some(&(i, _)) => return &self.s[start..i],
618 None => return &self.s[start..],
619 }
620 }
621 }
622
623 fn eat(&mut self, target: char) -> Result<(), Error> {
624 match self.it.next() {
625 Some((_, c)) if c == target => Ok(()),
626 Some((i, c)) => {
627 let m = format!(
628 "unexpected character at byte {}: expected `{}` but got `{}`",
629 i, target, c
630 );
631 Err(Error::config_parse(m.into()))
632 }
633 None => Err(Error::config_parse("unexpected EOF".into())),
634 }
635 }
636
637 fn eat_if(&mut self, target: char) -> bool {
638 match self.it.peek() {
639 Some(&(_, c)) if c == target => {
640 self.it.next();
641 true
642 }
643 _ => false,
644 }
645 }
646
647 fn keyword(&mut self) -> Option<&'a str> {
648 let s = self.take_while(|c| match c {
649 c if c.is_whitespace() => false,
650 '=' => false,
651 _ => true,
652 });
653
654 if s.is_empty() {
655 None
656 } else {
657 Some(s)
658 }
659 }
660
661 fn value(&mut self) -> Result<String, Error> {
662 let value = if self.eat_if('\'') {
663 let value = self.quoted_value()?;
664 self.eat('\'')?;
665 value
666 } else {
667 self.simple_value()?
668 };
669
670 Ok(value)
671 }
672
673 fn simple_value(&mut self) -> Result<String, Error> {
674 let mut value = String::new();
675
676 while let Some(&(_, c)) = self.it.peek() {
677 if c.is_whitespace() {
678 break;
679 }
680
681 self.it.next();
682 if c == '\\' {
683 if let Some((_, c2)) = self.it.next() {
684 value.push(c2);
685 }
686 } else {
687 value.push(c);
688 }
689 }
690
691 if value.is_empty() {
692 return Err(Error::config_parse("unexpected EOF".into()));
693 }
694
695 Ok(value)
696 }
697
698 fn quoted_value(&mut self) -> Result<String, Error> {
699 let mut value = String::new();
700
701 while let Some(&(_, c)) = self.it.peek() {
702 if c == '\'' {
703 return Ok(value);
704 }
705
706 self.it.next();
707 if c == '\\' {
708 if let Some((_, c2)) = self.it.next() {
709 value.push(c2);
710 }
711 } else {
712 value.push(c);
713 }
714 }
715
716 Err(Error::config_parse(
717 "unterminated quoted connection parameter value".into(),
718 ))
719 }
720
721 fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
722 self.skip_ws();
723 let keyword = match self.keyword() {
724 Some(keyword) => keyword,
725 None => return Ok(None),
726 };
727 self.skip_ws();
728 self.eat('=')?;
729 self.skip_ws();
730 let value = self.value()?;
731
732 Ok(Some((keyword, value)))
733 }
734}
735
736struct UrlParser<'a> {
738 s: &'a str,
739 config: Config,
740}
741
742impl<'a> UrlParser<'a> {
743 fn parse(s: &'a str) -> Result<Option<Config>, Error> {
744 let s = match Self::remove_url_prefix(s) {
745 Some(s) => s,
746 None => return Ok(None),
747 };
748
749 let mut parser = UrlParser {
750 s,
751 config: Config::new(),
752 };
753
754 parser.parse_credentials()?;
755 parser.parse_host()?;
756 parser.parse_path()?;
757 parser.parse_params()?;
758
759 Ok(Some(parser.config))
760 }
761
762 fn remove_url_prefix(s: &str) -> Option<&str> {
763 for prefix in &["postgres://", "postgresql://", "opengauss://"] {
764 if let Some(stripped) = s.strip_prefix(prefix) {
765 return Some(stripped);
766 }
767 }
768
769 None
770 }
771
772 fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
773 match self.s.find(end) {
774 Some(pos) => {
775 let (head, tail) = self.s.split_at(pos);
776 self.s = tail;
777 Some(head)
778 }
779 None => None,
780 }
781 }
782
783 fn take_all(&mut self) -> &'a str {
784 mem::replace(&mut self.s, "")
785 }
786
787 fn eat_byte(&mut self) {
788 self.s = &self.s[1..];
789 }
790
791 fn parse_credentials(&mut self) -> Result<(), Error> {
792 let creds = match self.take_until(&['@']) {
793 Some(creds) => creds,
794 None => return Ok(()),
795 };
796 self.eat_byte();
797
798 let mut it = creds.splitn(2, ':');
799 let user = self.decode(it.next().unwrap())?;
800 self.config.user(&user);
801
802 if let Some(password) = it.next() {
803 let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
804 self.config.password(password);
805 }
806
807 Ok(())
808 }
809
810 fn parse_host(&mut self) -> Result<(), Error> {
811 let host = match self.take_until(&['/', '?']) {
812 Some(host) => host,
813 None => self.take_all(),
814 };
815
816 if host.is_empty() {
817 return Ok(());
818 }
819
820 for chunk in host.split(',') {
821 let (host, port) = if chunk.starts_with('[') {
822 let idx = match chunk.find(']') {
823 Some(idx) => idx,
824 None => return Err(Error::config_parse(InvalidValue("host").into())),
825 };
826
827 let host = &chunk[1..idx];
828 let remaining = &chunk[idx + 1..];
829 let port = if let Some(port) = remaining.strip_prefix(':') {
830 Some(port)
831 } else if remaining.is_empty() {
832 None
833 } else {
834 return Err(Error::config_parse(InvalidValue("host").into()));
835 };
836
837 (host, port)
838 } else {
839 let mut it = chunk.splitn(2, ':');
840 (it.next().unwrap(), it.next())
841 };
842
843 self.host_param(host)?;
844 let port = self.decode(port.unwrap_or("5432"))?;
845 self.config.param("port", &port)?;
846 }
847
848 Ok(())
849 }
850
851 fn parse_path(&mut self) -> Result<(), Error> {
852 if !self.s.starts_with('/') {
853 return Ok(());
854 }
855 self.eat_byte();
856
857 let dbname = match self.take_until(&['?']) {
858 Some(dbname) => dbname,
859 None => self.take_all(),
860 };
861
862 if !dbname.is_empty() {
863 self.config.dbname(&self.decode(dbname)?);
864 }
865
866 Ok(())
867 }
868
869 fn parse_params(&mut self) -> Result<(), Error> {
870 if !self.s.starts_with('?') {
871 return Ok(());
872 }
873 self.eat_byte();
874
875 while !self.s.is_empty() {
876 let key = match self.take_until(&['=']) {
877 Some(key) => self.decode(key)?,
878 None => return Err(Error::config_parse("unterminated parameter".into())),
879 };
880 self.eat_byte();
881
882 let value = match self.take_until(&['&']) {
883 Some(value) => {
884 self.eat_byte();
885 value
886 }
887 None => self.take_all(),
888 };
889
890 if key == "host" {
891 self.host_param(value)?;
892 } else {
893 let value = self.decode(value)?;
894 self.config.param(&key, &value)?;
895 }
896 }
897
898 Ok(())
899 }
900
901 #[cfg(unix)]
902 fn host_param(&mut self, s: &str) -> Result<(), Error> {
903 let decoded = Cow::from(percent_encoding::percent_decode(s.as_bytes()));
904 if decoded.get(0) == Some(&b'/') {
905 self.config.host_path(OsStr::from_bytes(&decoded));
906 } else {
907 let decoded = str::from_utf8(&decoded).map_err(|e| Error::config_parse(Box::new(e)))?;
908 self.config.host(decoded);
909 }
910
911 Ok(())
912 }
913
914 #[cfg(not(unix))]
915 fn host_param(&mut self, s: &str) -> Result<(), Error> {
916 let s = self.decode(s)?;
917 self.config.param("host", &s)
918 }
919
920 fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
921 percent_encoding::percent_decode(s.as_bytes())
922 .decode_utf8()
923 .map_err(|e| Error::config_parse(e.into()))
924 }
925}