1use core::{fmt, iter, mem, str};
4
5use std::{
6 borrow::Cow,
7 path::{Path, PathBuf},
8};
9
10use super::{error::Error, session::TargetSessionAttrs};
11
12#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
13#[non_exhaustive]
14pub enum SslMode {
15 Disable,
17 #[default]
19 Prefer,
20 Require,
22}
23
24#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
26#[non_exhaustive]
27pub enum SslNegotiation {
28 #[default]
30 Postgres,
31 Direct,
33}
34
35#[derive(Clone, Debug, Eq, PartialEq)]
37pub enum Host {
38 Tcp(Box<str>),
40 Quic(Box<str>),
41 Unix(PathBuf),
43}
44
45#[derive(Clone, Eq, PartialEq)]
46pub struct Config {
47 pub(crate) user: Option<Box<str>>,
48 pub(crate) password: Option<Box<[u8]>>,
49 pub(crate) dbname: Option<Box<str>>,
50 pub(crate) options: Option<Box<str>>,
51 pub(crate) application_name: Option<Box<str>>,
52 pub(crate) ssl_mode: SslMode,
53 pub(crate) ssl_negotiation: SslNegotiation,
54 pub(crate) host: Vec<Host>,
55 pub(crate) port: Vec<u16>,
56 target_session_attrs: TargetSessionAttrs,
57 tls_server_end_point: Option<Box<[u8]>>,
58}
59
60impl Default for Config {
61 fn default() -> Config {
62 Config::new()
63 }
64}
65
66impl Config {
67 pub const fn new() -> Config {
69 Config {
70 user: None,
71 password: None,
72 dbname: None,
73 options: None,
74 application_name: None,
75 ssl_mode: SslMode::Prefer,
76 ssl_negotiation: SslNegotiation::Postgres,
77 host: Vec::new(),
78 port: Vec::new(),
79 target_session_attrs: TargetSessionAttrs::Any,
80 tls_server_end_point: None,
81 }
82 }
83
84 pub fn user(&mut self, user: &str) -> &mut Config {
88 self.user = Some(Box::from(user));
89 self
90 }
91
92 pub fn get_user(&self) -> Option<&str> {
95 self.user.as_deref()
96 }
97
98 pub fn password<T>(&mut self, password: T) -> &mut Config
100 where
101 T: AsRef<[u8]>,
102 {
103 self.password = Some(Box::from(password.as_ref()));
104 self
105 }
106
107 pub fn get_password(&self) -> Option<&[u8]> {
110 self.password.as_deref()
111 }
112
113 pub fn dbname(&mut self, dbname: &str) -> &mut Config {
117 self.dbname = Some(Box::from(dbname));
118 self
119 }
120
121 pub fn get_dbname(&self) -> Option<&str> {
124 self.dbname.as_deref()
125 }
126
127 pub fn options(&mut self, options: &str) -> &mut Config {
129 self.options = Some(Box::from(options));
130 self
131 }
132
133 pub fn get_options(&self) -> Option<&str> {
136 self.options.as_deref()
137 }
138
139 pub fn application_name(&mut self, application_name: &str) -> &mut Config {
141 self.application_name = Some(Box::from(application_name));
142 self
143 }
144
145 pub fn get_application_name(&self) -> Option<&str> {
148 self.application_name.as_deref()
149 }
150
151 pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
155 self.ssl_mode = ssl_mode;
156 self
157 }
158
159 pub fn get_ssl_mode(&self) -> SslMode {
161 self.ssl_mode
162 }
163
164 pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config {
168 self.ssl_negotiation = ssl_negotiation;
169 self
170 }
171
172 pub fn get_ssl_negotiation(&self) -> SslNegotiation {
174 self.ssl_negotiation
175 }
176
177 pub fn host(&mut self, host: &str) -> &mut Config {
178 if host.starts_with('/') {
179 return self.host_path(host);
180 }
181
182 let host = Host::Tcp(Box::from(host));
183
184 self.host.push(host);
185 self
186 }
187
188 pub fn host_path<T>(&mut self, host: T) -> &mut Config
192 where
193 T: AsRef<Path>,
194 {
195 self.host.push(Host::Unix(host.as_ref().to_path_buf()));
196 self
197 }
198
199 pub fn get_hosts(&self) -> &[Host] {
201 &self.host
202 }
203
204 pub fn port(&mut self, port: u16) -> &mut Config {
210 self.port.push(port);
211 self
212 }
213
214 pub fn get_ports(&self) -> &[u16] {
216 &self.port
217 }
218
219 pub fn target_session_attrs(&mut self, target_session_attrs: TargetSessionAttrs) -> &mut Config {
224 self.target_session_attrs = target_session_attrs;
225 self
226 }
227
228 pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
230 self.target_session_attrs
231 }
232
233 pub fn tls_server_end_point(&mut self, tls_server_end_point: impl AsRef<[u8]>) -> &mut Self {
302 self.tls_server_end_point = Some(Box::from(tls_server_end_point.as_ref()));
303 self
304 }
305
306 pub fn get_tls_server_end_point(&self) -> Option<&[u8]> {
307 self.tls_server_end_point.as_deref()
308 }
309
310 fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
311 match key {
312 "user" => {
313 self.user(value);
314 }
315 "password" => {
316 self.password(value);
317 }
318 "dbname" => {
319 self.dbname(value);
320 }
321 "options" => {
322 self.options(value);
323 }
324 "application_name" => {
325 self.application_name(value);
326 }
327 "sslmode" => {
328 let mode = match value {
329 "disable" => SslMode::Disable,
330 "prefer" => SslMode::Prefer,
331 "require" => SslMode::Require,
332 _ => return Err(Error::todo()),
333 };
334 self.ssl_mode(mode);
335 }
336 "sslnegotiation" => {
337 let mode = match value {
338 "postgres" => SslNegotiation::Postgres,
339 "direct" => SslNegotiation::Direct,
340 _ => return Err(Error::todo()),
341 };
342 self.ssl_negotiation(mode);
343 }
344 "host" => {
345 for host in value.split(',') {
346 self.host(host);
347 }
348 }
349 "port" => {
350 for port in value.split(',') {
351 let port = if port.is_empty() {
352 5432
353 } else {
354 port.parse().map_err(|_| Error::todo())?
355 };
356 self.port(port);
357 }
358 }
359 "target_session_attrs" => {
360 let target_session_attrs = match value {
361 "any" => TargetSessionAttrs::Any,
362 "read-write" => TargetSessionAttrs::ReadWrite,
363 "read-only" => TargetSessionAttrs::ReadOnly,
364 _ => return Err(Error::todo()),
365 };
366 self.target_session_attrs(target_session_attrs);
367 }
368 _ => {
369 return Err(Error::todo());
370 }
371 }
372
373 Ok(())
374 }
375}
376
377impl TryFrom<String> for Config {
378 type Error = Error;
379
380 fn try_from(s: String) -> Result<Self, Self::Error> {
381 Self::try_from(s.as_str())
382 }
383}
384
385impl TryFrom<&str> for Config {
386 type Error = Error;
387
388 fn try_from(s: &str) -> Result<Self, Self::Error> {
389 match UrlParser::parse(s)? {
390 Some(config) => Ok(config),
391 None => Parser::parse(s),
392 }
393 }
394}
395
396impl fmt::Debug for Config {
398 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399 struct Redaction {}
400 impl fmt::Debug for Redaction {
401 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
402 write!(f, "_")
403 }
404 }
405
406 f.debug_struct("Config")
407 .field("user", &self.user)
408 .field("password", &self.password.as_ref().map(|_| Redaction {}))
409 .field("dbname", &self.dbname)
410 .field("options", &self.options)
411 .field("application_name", &self.application_name)
412 .field("host", &self.host)
413 .field("port", &self.port)
414 .field("target_session_attrs", &self.target_session_attrs)
415 .finish()
416 }
417}
418
419struct Parser<'a> {
420 s: &'a str,
421 it: iter::Peekable<str::CharIndices<'a>>,
422}
423
424impl<'a> Parser<'a> {
425 fn parse(s: &'a str) -> Result<Config, Error> {
426 let mut parser = Parser {
427 s,
428 it: s.char_indices().peekable(),
429 };
430
431 let mut config = Config::new();
432
433 while let Some((key, value)) = parser.parameter()? {
434 config.param(key, &value)?;
435 }
436
437 Ok(config)
438 }
439
440 fn skip_ws(&mut self) {
441 self.take_while(char::is_whitespace);
442 }
443
444 fn take_while<F>(&mut self, f: F) -> &'a str
445 where
446 F: Fn(char) -> bool,
447 {
448 let start = match self.it.peek() {
449 Some(&(i, _)) => i,
450 None => return "",
451 };
452
453 loop {
454 match self.it.peek() {
455 Some(&(_, c)) if f(c) => {
456 self.it.next();
457 }
458 Some(&(i, _)) => return &self.s[start..i],
459 None => return &self.s[start..],
460 }
461 }
462 }
463
464 fn eat(&mut self, target: char) -> Result<(), Error> {
465 match self.it.next() {
466 Some((_, c)) if c == target => Ok(()),
467 Some((i, c)) => {
468 let _m = format!("unexpected character at byte {i}: expected `{target}` but got `{c}`");
469 Err(Error::todo())
470 }
471 None => Err(Error::todo()),
472 }
473 }
474
475 fn eat_if(&mut self, target: char) -> bool {
476 match self.it.peek() {
477 Some(&(_, c)) if c == target => {
478 self.it.next();
479 true
480 }
481 _ => false,
482 }
483 }
484
485 fn keyword(&mut self) -> Option<&'a str> {
486 let s = self.take_while(|c| match c {
487 c if c.is_whitespace() => false,
488 '=' => false,
489 _ => true,
490 });
491
492 if s.is_empty() {
493 None
494 } else {
495 Some(s)
496 }
497 }
498
499 fn value(&mut self) -> Result<String, Error> {
500 let value = if self.eat_if('\'') {
501 let value = self.quoted_value()?;
502 self.eat('\'')?;
503 value
504 } else {
505 self.simple_value()?
506 };
507
508 Ok(value)
509 }
510
511 fn simple_value(&mut self) -> Result<String, Error> {
512 let mut value = String::new();
513
514 while let Some(&(_, c)) = self.it.peek() {
515 if c.is_whitespace() {
516 break;
517 }
518
519 self.it.next();
520 if c == '\\' {
521 if let Some((_, c2)) = self.it.next() {
522 value.push(c2);
523 }
524 } else {
525 value.push(c);
526 }
527 }
528
529 if value.is_empty() {
530 return Err(Error::todo());
531 }
532
533 Ok(value)
534 }
535
536 fn quoted_value(&mut self) -> Result<String, Error> {
537 let mut value = String::new();
538
539 while let Some(&(_, c)) = self.it.peek() {
540 if c == '\'' {
541 return Ok(value);
542 }
543
544 self.it.next();
545 if c == '\\' {
546 if let Some((_, c2)) = self.it.next() {
547 value.push(c2);
548 }
549 } else {
550 value.push(c);
551 }
552 }
553
554 Err(Error::todo())
555 }
556
557 fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
558 self.skip_ws();
559 let keyword = match self.keyword() {
560 Some(keyword) => keyword,
561 None => return Ok(None),
562 };
563 self.skip_ws();
564 self.eat('=')?;
565 self.skip_ws();
566 let value = self.value()?;
567
568 Ok(Some((keyword, value)))
569 }
570}
571
572struct UrlParser<'a> {
574 s: &'a str,
575 config: Config,
576}
577
578impl<'a> UrlParser<'a> {
579 fn parse(s: &'a str) -> Result<Option<Config>, Error> {
580 let s = match Self::remove_url_prefix(s) {
581 Some(s) => s,
582 None => return Ok(None),
583 };
584
585 let mut parser = UrlParser {
586 s,
587 config: Config::new(),
588 };
589
590 parser.parse_credentials()?;
591 parser.parse_host()?;
592 parser.parse_path()?;
593 parser.parse_params()?;
594
595 Ok(Some(parser.config))
596 }
597
598 fn remove_url_prefix(s: &str) -> Option<&str> {
599 for prefix in &["postgres://", "postgresql://"] {
600 if let Some(stripped) = s.strip_prefix(prefix) {
601 return Some(stripped);
602 }
603 }
604
605 None
606 }
607
608 fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
609 match self.s.find(end) {
610 Some(pos) => {
611 let (head, tail) = self.s.split_at(pos);
612 self.s = tail;
613 Some(head)
614 }
615 None => None,
616 }
617 }
618
619 fn take_all(&mut self) -> &'a str {
620 mem::take(&mut self.s)
621 }
622
623 fn eat_byte(&mut self) {
624 self.s = &self.s[1..];
625 }
626
627 fn parse_credentials(&mut self) -> Result<(), Error> {
628 let creds = match self.take_until(&['@']) {
629 Some(creds) => creds,
630 None => return Ok(()),
631 };
632 self.eat_byte();
633
634 let mut it = creds.splitn(2, ':');
635 let user = self.decode(it.next().unwrap())?;
636 self.config.user(&user);
637
638 if let Some(password) = it.next() {
639 let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
640 self.config.password(password);
641 }
642
643 Ok(())
644 }
645
646 fn parse_host(&mut self) -> Result<(), Error> {
647 let host = match self.take_until(&['/', '?']) {
648 Some(host) => host,
649 None => self.take_all(),
650 };
651
652 if host.is_empty() {
653 return Ok(());
654 }
655
656 for chunk in host.split(',') {
657 let (host, port) = if chunk.starts_with('[') {
658 let idx = match chunk.find(']') {
659 Some(idx) => idx,
660 None => return Err(Error::todo()),
661 };
662
663 let host = &chunk[1..idx];
664 let remaining = &chunk[idx + 1..];
665 let port = if let Some(port) = remaining.strip_prefix(':') {
666 Some(port)
667 } else if remaining.is_empty() {
668 None
669 } else {
670 return Err(Error::todo());
671 };
672
673 (host, port)
674 } else {
675 let mut it = chunk.splitn(2, ':');
676 (it.next().unwrap(), it.next())
677 };
678
679 self.host_param(host)?;
680 let port = self.decode(port.unwrap_or("5432"))?;
681 self.config.param("port", &port)?;
682 }
683
684 Ok(())
685 }
686
687 fn parse_path(&mut self) -> Result<(), Error> {
688 if !self.s.starts_with('/') {
689 return Ok(());
690 }
691 self.eat_byte();
692
693 let dbname = match self.take_until(&['?']) {
694 Some(dbname) => dbname,
695 None => self.take_all(),
696 };
697
698 if !dbname.is_empty() {
699 self.config.dbname(&self.decode(dbname)?);
700 }
701
702 Ok(())
703 }
704
705 fn parse_params(&mut self) -> Result<(), Error> {
706 if !self.s.starts_with('?') {
707 return Ok(());
708 }
709 self.eat_byte();
710
711 while !self.s.is_empty() {
712 let key = match self.take_until(&['=']) {
713 Some(key) => self.decode(key)?,
714 None => return Err(Error::todo()),
715 };
716 self.eat_byte();
717
718 let value = match self.take_until(&['&']) {
719 Some(value) => {
720 self.eat_byte();
721 value
722 }
723 None => self.take_all(),
724 };
725
726 if key == "host" {
727 self.host_param(value)?;
728 } else {
729 let value = self.decode(value)?;
730 self.config.param(&key, &value)?;
731 }
732 }
733
734 Ok(())
735 }
736
737 fn host_param(&mut self, s: &str) -> Result<(), Error> {
738 let s = self.decode(s)?;
739 self.config.param("host", &s)
740 }
741
742 fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
743 percent_encoding::percent_decode(s.as_bytes())
744 .decode_utf8()
745 .map_err(|_| Error::todo())
746 }
747}