gel_dsn/gel/
config.rs

1use super::{
2    error::*, format_duration, BuildContextImpl, CredentialsFile, FromParamStr, InstanceName,
3    Param, Params,
4};
5use crate::{
6    gel::{parse_duration, BuildPhase},
7    host::{Host, HostType, LOCALHOST_HOSTNAME},
8};
9use rustls_pki_types::CertificateDer;
10use serde::{Deserialize, Serialize};
11use std::{
12    borrow::Cow,
13    collections::HashMap,
14    fmt,
15    num::NonZero,
16    path::{Path, PathBuf},
17    str::FromStr,
18    time::Duration,
19};
20use url::Url;
21
22pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
23pub const DEFAULT_WAIT: Duration = Duration::from_secs(30);
24pub const DEFAULT_TCP_KEEPALIVE: Duration = Duration::from_secs(60);
25pub const DEFAULT_POOL_SIZE: usize = 10;
26pub const DEFAULT_HOST: &HostType = crate::host::LOCALHOST;
27pub const DEFAULT_PORT: u16 = 5656;
28pub const DEFAULT_USER: &str = crate::gel::branding::BRANDING_DEFAULT_USERNAME_LEGACY;
29pub const DEFAULT_BRANCH: DatabaseBranch = DatabaseBranch::Default;
30
31pub const DEFAULT_DATABASE_NAME: &str = "edgedb";
32
33/// The branch name used when connecting to an existing instance to request
34/// the default branch.
35pub const DEFAULT_BRANCH_NAME_CONNECT: &str = "__default__";
36/// The default branch name used when creating a new instance.
37pub const DEFAULT_BRANCH_NAME_CREATE: &str = "main";
38
39/// The result of building a [`Config`].
40pub struct ConfigResult {
41    pub(crate) result: Result<Config, gel_errors::Error>,
42    pub(crate) warnings: Warnings,
43}
44
45impl std::ops::Deref for ConfigResult {
46    type Target = Result<Config, gel_errors::Error>;
47
48    fn deref(&self) -> &Self::Target {
49        &self.result
50    }
51}
52
53impl From<ConfigResult> for Result<Config, gel_errors::Error> {
54    fn from(val: ConfigResult) -> Self {
55        val.result
56    }
57}
58
59impl ConfigResult {
60    pub fn unwrap(self) -> Config {
61        self.result.unwrap()
62    }
63
64    pub fn expect(self, message: &str) -> Config {
65        self.result.expect(message)
66    }
67
68    pub fn result(&self) -> &Result<Config, gel_errors::Error> {
69        &self.result
70    }
71
72    pub fn into_result(self) -> Result<Config, gel_errors::Error> {
73        self.result
74    }
75
76    pub fn parse_error(&self) -> Option<&ParseError> {
77        use std::error::Error;
78        self.result
79            .as_ref()
80            .err()
81            .and_then(|e| e.source().and_then(|e| e.downcast_ref::<ParseError>()))
82    }
83
84    pub fn warnings(&self) -> &Warnings {
85        &self.warnings
86    }
87}
88
89/// The configuration for a connection.
90#[derive(Debug, Clone, PartialEq, Eq)]
91pub struct Config {
92    pub host: Host,
93    pub db: DatabaseBranch,
94    pub user: String,
95
96    /// If the configuration was loaded from an instance name, this will be present.
97    pub instance_name: Option<InstanceName>,
98
99    pub authentication: Authentication,
100
101    pub client_security: ClientSecurity,
102    pub tls_security: TlsSecurity,
103
104    pub tls_ca: Option<Vec<CertificateDer<'static>>>,
105    pub tls_server_name: Option<String>,
106    pub wait_until_available: Duration,
107
108    pub connect_timeout: Duration,
109    pub max_concurrency: Option<usize>,
110    pub tcp_keepalive: TcpKeepalive,
111
112    pub cloud_certs: Option<CloudCerts>,
113
114    pub server_settings: HashMap<String, String>,
115}
116
117impl Default for Config {
118    fn default() -> Self {
119        Self {
120            host: Host::new(DEFAULT_HOST.clone(), DEFAULT_PORT),
121            db: DatabaseBranch::Default,
122            user: DEFAULT_USER.to_string(),
123            instance_name: None,
124            authentication: Authentication::None,
125            client_security: ClientSecurity::Default,
126            tls_security: TlsSecurity::Strict,
127            tls_ca: None,
128            tls_server_name: None,
129            wait_until_available: DEFAULT_WAIT,
130            connect_timeout: DEFAULT_CONNECT_TIMEOUT,
131            max_concurrency: None,
132            tcp_keepalive: TcpKeepalive::Default,
133            cloud_certs: None,
134            server_settings: HashMap::new(),
135        }
136    }
137}
138
139#[derive(Debug, Clone, PartialEq, Eq, derive_more::Error, derive_more::Display)]
140pub enum CredentialsError {
141    #[display("no TCP address")]
142    NoTcpAddress,
143}
144
145fn to_pem(certs: &[CertificateDer<'static>]) -> String {
146    use base64::Engine;
147    let prefix = "-----BEGIN CERTIFICATE-----\n";
148    let suffix = "-----END CERTIFICATE-----\n";
149    let mut pem = String::new();
150    for cert in certs {
151        pem.push_str(prefix);
152        let mut b64 = vec![0; cert.len() * 4 / 3 + 4];
153        let len = base64::prelude::BASE64_STANDARD
154            .encode_slice(cert.as_ref(), &mut b64)
155            .unwrap();
156        b64.truncate(len);
157        let lines = b64.chunks(64);
158        for line in lines {
159            pem.push_str(std::str::from_utf8(line).unwrap());
160            pem.push('\n');
161        }
162        pem.push_str(suffix);
163    }
164    pem
165}
166
167impl Config {
168    pub fn instance_name(&self) -> Option<&InstanceName> {
169        self.instance_name.as_ref()
170    }
171
172    pub fn local_instance_name(&self) -> Option<&str> {
173        self.instance_name.as_ref().and_then(InstanceName::local)
174    }
175
176    pub fn admin(&self) -> bool {
177        self.host.is_unix()
178    }
179
180    pub fn user(&self) -> &str {
181        &self.user
182    }
183
184    pub fn port(&self) -> u16 {
185        self.host.1
186    }
187
188    pub fn display_addr(&self) -> impl fmt::Display + '_ {
189        self.host.to_string()
190    }
191
192    pub fn secret_key(&self) -> Option<&str> {
193        self.authentication.secret_key()
194    }
195
196    pub fn tls_ca_pem(&self) -> Option<String> {
197        self.tls_ca.as_ref().map(|v| to_pem(v))
198    }
199
200    /// Return HTTP(s) url to server if not connected via unix socket.
201    pub fn http_url(&self, tls: bool) -> Option<String> {
202        if let Some((host, port)) = self.host.target_name().ok()?.tcp() {
203            let s = if tls { "s" } else { "" };
204            Some(format!("http{s}://{host}:{port}"))
205        } else {
206            None
207        }
208    }
209
210    /// Return DSN url to server if not connected via unix socket.
211    ///
212    /// Note that this method is not guaranteed to return a fully-connectable URL.
213    pub fn dsn_url(&self) -> Option<String> {
214        let mut url = Url::parse("gel://").unwrap();
215
216        if let Some((host, port)) = self.host.target_name().ok()?.tcp() {
217            if host != LOCALHOST_HOSTNAME {
218                if port != DEFAULT_PORT {
219                    _ = url.set_host(Some(&host));
220                    _ = url.set_port(Some(port));
221                } else {
222                    _ = url.set_host(Some(&host));
223                }
224            } else if port != DEFAULT_PORT {
225                url.query_pairs_mut().append_pair("port", &port.to_string());
226            }
227        } else {
228            return None;
229        }
230
231        if self.db != DatabaseBranch::Default {
232            if let Some(database) = self.db.database() {
233                url.set_path(database);
234            }
235
236            if let Some(branch) = self.db.branch_for_connect() {
237                url.set_path(branch);
238            }
239        }
240
241        if self.user() != DEFAULT_USER {
242            if url.host().is_none() {
243                url.query_pairs_mut().append_pair("user", self.user());
244            } else {
245                _ = url.set_username(self.user());
246            }
247        }
248
249        if let Some(password) = self.authentication.password() {
250            if url.host().is_none() {
251                url.query_pairs_mut().append_pair("password", password);
252            } else {
253                _ = url.set_password(Some(password));
254            }
255        }
256
257        // NOTE: The user will need to provide a CA file
258        if self.tls_ca.is_some() {
259            url.query_pairs_mut().append_pair("tls_ca_file", "<...>");
260        }
261
262        if let Some(secret_key) = self.authentication.secret_key() {
263            url.query_pairs_mut().append_pair("secret_key", secret_key);
264        }
265
266        if self.tls_security != TlsSecurity::Strict {
267            url.query_pairs_mut()
268                .append_pair("tls_security", &self.tls_security.to_string());
269        }
270
271        if let Some(tls_server_name) = &self.tls_server_name {
272            url.query_pairs_mut()
273                .append_pair("tls_server_name", tls_server_name);
274        }
275
276        if self.wait_until_available != DEFAULT_WAIT {
277            url.query_pairs_mut().append_pair(
278                "wait_until_available",
279                &format_duration(&self.wait_until_available),
280            );
281        }
282
283        for (key, value) in &self.server_settings {
284            url.query_pairs_mut().append_pair(key, value);
285        }
286
287        Some(url.to_string())
288    }
289
290    pub fn with_host(&self, host: &str, port: u16) -> Result<Self, ParseError> {
291        Ok(Self {
292            host: Host::new(HostType::from_str(host)?, port),
293            ..self.clone()
294        })
295    }
296
297    pub fn with_branch(&self, branch: &str) -> Self {
298        Self {
299            db: DatabaseBranch::Branch(branch.to_string()),
300            ..self.clone()
301        }
302    }
303
304    pub fn with_db(&self, db: DatabaseBranch) -> Self {
305        Self { db, ..self.clone() }
306    }
307
308    pub fn with_user(&self, user: &str) -> Self {
309        Self {
310            user: user.to_string(),
311            ..self.clone()
312        }
313    }
314
315    pub fn with_password(&self, password: &str) -> Self {
316        Self {
317            authentication: Authentication::Password(password.to_string()),
318            ..self.clone()
319        }
320    }
321
322    pub fn with_wait_until_available(&self, dur: Duration) -> Self {
323        Self {
324            wait_until_available: dur,
325            ..self.clone()
326        }
327    }
328
329    pub fn with_tls_ca(&self, certs: &[CertificateDer<'static>]) -> Self {
330        Self {
331            tls_ca: Some(certs.to_vec()),
332            ..self.clone()
333        }
334    }
335
336    #[deprecated = "use with_tls_ca instead"]
337    pub fn with_pem_certificates(&self, certs: &str) -> Result<Self, ParseError> {
338        let certs = <Vec<CertificateDer<'static>> as FromParamStr>::from_param_str(
339            certs,
340            &BuildContextImpl::default(),
341        )?;
342        Ok(Self {
343            tls_ca: Some(certs),
344            ..self.clone()
345        })
346    }
347
348    #[cfg(feature = "serde")]
349    pub fn to_json(&self) -> impl serde::Serialize + std::fmt::Display {
350        use serde::Serialize;
351        use std::collections::BTreeMap;
352
353        #[derive(Serialize)]
354        #[allow(non_snake_case)]
355        struct ConfigJson {
356            address: (String, usize),
357            branch: Option<String>,
358            database: Option<String>,
359            password: Option<String>,
360            secretKey: Option<String>,
361            serverSettings: BTreeMap<String, String>,
362            tlsCAData: Option<String>,
363            tlsSecurity: String,
364            tlsServerName: Option<String>,
365            user: String,
366            waitUntilAvailable: String,
367        }
368
369        impl std::fmt::Display for ConfigJson {
370            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371                write!(f, "{}", serde_json::to_string(self).unwrap())
372            }
373        }
374
375        ConfigJson {
376            address: (self.host.0.to_string(), self.host.1 as usize),
377            branch: self.db.branch_for_connect().map(|s| s.to_string()),
378            database: self.db.database().map(|s| s.to_string()),
379            password: self.authentication.password().map(|s| s.to_string()),
380            secretKey: self.authentication.secret_key().map(|s| s.to_string()),
381            serverSettings: BTreeMap::from_iter(self.server_settings.clone()),
382            tlsCAData: self.tls_ca.as_ref().map(|cert| to_pem(cert)),
383            tlsSecurity: self.tls_security.to_string(),
384            tlsServerName: self.tls_server_name.clone(),
385            user: self.user.clone(),
386            waitUntilAvailable: super::duration::Duration::from_micros(
387                self.wait_until_available.as_micros() as i64,
388            )
389            .to_string(),
390        }
391    }
392
393    /// Convert the config lossily to an opaque [`CredentialsFile`].
394    pub fn as_credentials(&self) -> Result<CredentialsFile, CredentialsError> {
395        let target = self
396            .host
397            .target_name()
398            .map_err(|_| CredentialsError::NoTcpAddress)?;
399        let tcp = target.tcp().ok_or(CredentialsError::NoTcpAddress)?;
400        Ok(CredentialsFile {
401            user: Some(self.user.clone()),
402            host: Some(tcp.0.to_string()),
403            port: Some(NonZero::new(tcp.1).expect("invalid zero port")),
404            password: self.authentication.password().map(|s| s.to_string()),
405            secret_key: self.authentication.secret_key().map(|s| s.to_string()),
406            database: self.db.database().map(|s| s.to_string()),
407            branch: self.db.branch_for_connect().map(|s| s.to_string()),
408            tls_ca: self.tls_ca_pem(),
409            tls_security: self.tls_security,
410            tls_server_name: self.tls_server_name.clone(),
411            warnings: vec![],
412        })
413    }
414
415    #[allow(clippy::field_reassign_with_default)]
416    pub fn to_tls(&self) -> gel_stream::TlsParameters {
417        use gel_stream::{TlsAlpn, TlsCert, TlsParameters, TlsServerCertVerify};
418        use std::borrow::Cow;
419        use std::net::IpAddr;
420
421        let mut tls = TlsParameters::default();
422        tls.root_cert = TlsCert::Webpki;
423        match &self.tls_ca {
424            Some(certs) => {
425                tls.root_cert = TlsCert::Custom(certs.to_vec());
426            }
427            None => {
428                if let Some(cloud_certs) = self.cloud_certs {
429                    tls.root_cert = TlsCert::WebpkiPlus(cloud_certs.certificates().to_vec());
430                }
431            }
432        }
433        tls.server_cert_verify = match self.tls_security {
434            TlsSecurity::Insecure => TlsServerCertVerify::Insecure,
435            TlsSecurity::NoHostVerification => TlsServerCertVerify::IgnoreHostname,
436            TlsSecurity::Strict | TlsSecurity::Default => TlsServerCertVerify::VerifyFull,
437        };
438        tls.alpn = TlsAlpn::new_str(&["edgedb-binary", "gel-binary"]);
439        tls.sni_override = match &self.tls_server_name {
440            Some(server_name) => Some(Cow::from(server_name.clone())),
441            None => {
442                if let Ok(host) = self.host.target_name() {
443                    if let Some(host) = host.host() {
444                        if let Ok(ip) = IpAddr::from_str(&host) {
445                            // FIXME: https://github.com/rustls/rustls/issues/184
446                            let host = format!("{ip}.host-for-ip.edgedb.net");
447                            // for ipv6addr
448                            let host = host.replace([':', '%'], "-");
449                            if host.starts_with('-') {
450                                Some(Cow::from(format!("i{host}")))
451                            } else {
452                                Some(Cow::from(host.to_string()))
453                            }
454                        } else {
455                            Some(Cow::from(host.to_string()))
456                        }
457                    } else {
458                        None
459                    }
460                } else {
461                    None
462                }
463            }
464        };
465        tls
466    }
467}
468
469/// The authentication method to use for the connection.
470#[derive(Debug, Clone, Default, PartialEq, Eq)]
471pub enum Authentication {
472    #[default]
473    None,
474    Password(String),
475    SecretKey(String),
476}
477
478impl Authentication {
479    pub fn password(&self) -> Option<&str> {
480        match self {
481            Self::Password(password) => Some(password),
482            _ => None,
483        }
484    }
485
486    pub fn secret_key(&self) -> Option<&str> {
487        match self {
488            Self::SecretKey(secret_key) => Some(secret_key),
489            _ => None,
490        }
491    }
492}
493
494/// The database or branch to use for the connection.
495#[derive(Debug, Clone, Default, PartialEq, Eq)]
496pub enum DatabaseBranch {
497    #[default]
498    Default,
499    Database(String),
500    Branch(String),
501    Ambiguous(String),
502}
503
504impl std::fmt::Display for DatabaseBranch {
505    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
506        // Alternate display is the short form, just the name or (default)
507        if f.alternate() {
508            match self {
509                Self::Database(database) => write!(f, "{database}"),
510                Self::Branch(branch) => write!(f, "{branch}"),
511                Self::Ambiguous(ambiguous) => write!(f, "{ambiguous}"),
512                Self::Default => write!(f, "(default)"),
513            }
514        } else {
515            match self {
516                Self::Database(database) => write!(f, "database '{database}'"),
517                Self::Branch(branch) => write!(f, "branch '{branch}'"),
518                Self::Ambiguous(ambiguous) => write!(f, "'{ambiguous}'"),
519                Self::Default => write!(f, "default database/branch"),
520            }
521        }
522    }
523}
524
525impl DatabaseBranch {
526    pub fn database(&self) -> Option<&str> {
527        match self {
528            Self::Database(database) => Some(database),
529            // Special case: we return branch here
530            Self::Branch(branch) => Some(branch),
531            Self::Ambiguous(ambiguous) => Some(ambiguous),
532            Self::Default => Some(DEFAULT_DATABASE_NAME),
533        }
534    }
535
536    pub fn branch_for_connect(&self) -> Option<&str> {
537        match self {
538            Self::Branch(branch) => Some(branch),
539            // Special case: we return database here
540            Self::Database(database) => Some(database),
541            Self::Ambiguous(ambiguous) => Some(ambiguous),
542            Self::Default => Some(DEFAULT_BRANCH_NAME_CONNECT),
543        }
544    }
545
546    pub fn branch_for_create(&self) -> Option<&str> {
547        match self {
548            Self::Branch(branch) => Some(branch),
549            // Special case: we return database here
550            Self::Database(database) => Some(database),
551            Self::Ambiguous(ambiguous) => Some(ambiguous),
552            Self::Default => Some(DEFAULT_BRANCH_NAME_CREATE),
553        }
554    }
555
556    pub fn name(&self) -> Option<&str> {
557        match self {
558            Self::Database(database) => Some(database),
559            Self::Branch(branch) => Some(branch),
560            Self::Ambiguous(ambiguous) => Some(ambiguous),
561            Self::Default => None,
562        }
563    }
564}
565
566/// Client security mode.
567#[derive(Default, Debug, Clone, Copy, Eq, PartialEq)]
568pub enum ClientSecurity {
569    /// Disable security checks
570    InsecureDevMode,
571    /// Always verify domain an certificate
572    Strict,
573    /// Verify domain only if no specific certificate is configured
574    #[default]
575    Default,
576}
577
578impl FromStr for ClientSecurity {
579    type Err = ParseError;
580    fn from_str(s: &str) -> Result<ClientSecurity, Self::Err> {
581        use ClientSecurity::*;
582
583        match s {
584            "default" => Ok(Default),
585            "strict" => Ok(Strict),
586            "insecure_dev_mode" => Ok(InsecureDevMode),
587            // TODO: this should have its own error?
588            _ => Err(ParseError::InvalidTlsSecurity(
589                TlsSecurityError::InvalidValue,
590            )),
591        }
592    }
593}
594
595/// The type of cloud certificate to use.
596#[derive(Debug, Clone, Copy, PartialEq, Eq)]
597pub enum CloudCerts {
598    Staging,
599    Local,
600}
601
602impl FromStr for CloudCerts {
603    type Err = ParseError;
604    fn from_str(s: &str) -> Result<CloudCerts, Self::Err> {
605        use CloudCerts::*;
606
607        match s {
608            "staging" => Ok(Staging),
609            "local" => Ok(Local),
610            // TODO: incorrect error
611            _ => Err(ParseError::FileNotFound),
612        }
613    }
614}
615
616impl CloudCerts {
617    pub fn certificates(&self) -> &[CertificateDer<'static>] {
618        match self {
619            Self::Staging => {
620                static CERT: std::sync::OnceLock<Vec<CertificateDer<'static>>> =
621                    std::sync::OnceLock::new();
622                CERT.get_or_init(|| {
623                    Self::read_static_certs(Self::Staging.certificates_pem().as_bytes())
624                })
625            }
626            Self::Local => {
627                static CERT: std::sync::OnceLock<Vec<CertificateDer<'static>>> =
628                    std::sync::OnceLock::new();
629                CERT.get_or_init(|| {
630                    Self::read_static_certs(Self::Local.certificates_pem().as_bytes())
631                })
632            }
633        }
634    }
635
636    pub fn certificates_pem(&self) -> &'static str {
637        match self {
638            Self::Staging => include_str!("certs/letsencrypt_staging.pem"),
639            Self::Local => include_str!("certs/nebula_development.pem"),
640        }
641    }
642
643    fn read_static_certs(bytes: &'static [u8]) -> Vec<CertificateDer<'static>> {
644        let mut cursor = std::io::Cursor::new(bytes);
645        let mut certs = Vec::new();
646        for item in rustls_pemfile::read_all(&mut cursor) {
647            match item {
648                Ok(rustls_pemfile::Item::X509Certificate(data)) => certs.push(data),
649                _ => panic!(),
650            }
651        }
652        certs
653    }
654}
655
656/// TLS Client Security Mode
657#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
658#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
659#[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
660pub enum TlsSecurity {
661    /// Allow any certificate for TLS connection
662    Insecure,
663    /// Verify certificate against trusted chain but allow any host name
664    ///
665    /// This is useful for localhost (you can't make trusted chain certificate
666    /// for localhost). And when certificate of specific server is stored in
667    /// credentials file so it's secure regardless of which host name was used
668    /// to expose the server to the network.
669    NoHostVerification,
670    /// Normal TLS certificate check (checks trusted chain and hostname)
671    Strict,
672    /// If there is a specific certificate in credentials, do not check
673    /// the host name, otherwise use `Strict` mode
674    #[default]
675    Default,
676}
677
678impl FromStr for TlsSecurity {
679    type Err = ParseError;
680    fn from_str(val: &str) -> Result<Self, Self::Err> {
681        match val {
682            "default" => Ok(TlsSecurity::Default),
683            "insecure" => Ok(TlsSecurity::Insecure),
684            "no_host_verification" => Ok(TlsSecurity::NoHostVerification),
685            "strict" => Ok(TlsSecurity::Strict),
686            _ => Err(ParseError::InvalidTlsSecurity(
687                TlsSecurityError::InvalidValue,
688            )),
689        }
690    }
691}
692
693impl fmt::Display for TlsSecurity {
694    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
695        match self {
696            Self::Insecure => write!(f, "insecure"),
697            Self::NoHostVerification => write!(f, "no_host_verification"),
698            Self::Strict => write!(f, "strict"),
699            Self::Default => write!(f, "default"),
700        }
701    }
702}
703
704/// TCP keepalive configuration.
705#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
706pub enum TcpKeepalive {
707    /// Disable TCP keepalive probes.
708    Disabled,
709    /// Explicit duration between TCP keepalive probes.
710    Explicit(Duration),
711    /// Default: 60 seconds.
712    #[default]
713    Default,
714}
715
716impl TcpKeepalive {
717    pub fn as_keepalive(&self) -> Option<Duration> {
718        match self {
719            TcpKeepalive::Disabled => None,
720            TcpKeepalive::Default => Some(DEFAULT_TCP_KEEPALIVE),
721            TcpKeepalive::Explicit(duration) => Some(*duration),
722        }
723    }
724}
725
726impl FromStr for TcpKeepalive {
727    type Err = ParseError;
728    fn from_str(s: &str) -> Result<Self, Self::Err> {
729        use TcpKeepalive::*;
730
731        match s {
732            "disabled" => Ok(Disabled),
733            "default" => Ok(Default),
734            _ => Ok(Explicit(
735                parse_duration(s).map_err(|_| ParseError::InvalidDuration)?,
736            )),
737        }
738    }
739}
740
741#[derive(derive_more::Debug, Clone, PartialEq, Eq)]
742enum UnixPathInner {
743    /// The selected port will be appended to the path.
744    #[debug("{:?}{{port}}", _0)]
745    PortSuffixed(PathBuf),
746    /// The path will be used as-is.
747    #[debug("{:?}", _0)]
748    Exact(PathBuf),
749}
750
751/// A path to a Unix socket.
752#[derive(Clone, PartialEq, Eq, derive_more::Debug)]
753pub struct UnixPath {
754    #[debug("{:?}", inner)]
755    inner: UnixPathInner,
756}
757
758impl UnixPath {
759    /// The selected port will be appended to the path.
760    pub fn with_port_suffix(path: PathBuf) -> Self {
761        UnixPath {
762            inner: UnixPathInner::PortSuffixed(path),
763        }
764    }
765
766    /// The path will be used as-is.
767    pub fn exact(path: PathBuf) -> Self {
768        UnixPath {
769            inner: UnixPathInner::Exact(path),
770        }
771    }
772
773    /// Returns a path with the port suffix appended.
774    pub fn path_with_port(&self, port: u16) -> Cow<Path> {
775        match &self.inner {
776            UnixPathInner::PortSuffixed(path) => {
777                let Some(filename) = path.file_name() else {
778                    return Cow::Owned(path.join(port.to_string()));
779                };
780                let mut path = path.clone();
781                let mut filename = filename.to_owned();
782                filename.push(port.to_string());
783                path.set_file_name(filename);
784                Cow::Owned(path)
785            }
786            UnixPathInner::Exact(path) => Cow::Borrowed(path),
787        }
788    }
789}
790
791impl FromStr for UnixPath {
792    type Err = ParseError;
793    fn from_str(s: &str) -> Result<Self, Self::Err> {
794        Ok(UnixPath::exact(PathBuf::from(s)))
795    }
796}
797
798impl<T: Into<PathBuf>> From<T> for UnixPath {
799    fn from(path: T) -> Self {
800        UnixPath::exact(path.into())
801    }
802}
803
804/// Classic-style connection options.
805#[derive(Clone, Default, Debug, Serialize, Deserialize)]
806#[serde(default)]
807#[serde(deny_unknown_fields)]
808pub struct ConnectionOptions {
809    pub dsn: Option<String>,
810    pub user: Option<String>,
811    pub password: Option<String>,
812    pub instance: Option<String>,
813    pub database: Option<String>,
814    pub host: Option<String>,
815    #[serde(deserialize_with = "deserialize_string_or_number")]
816    pub port: Option<String>,
817    pub branch: Option<String>,
818    #[serde(rename = "tlsSecurity")]
819    pub tls_security: Option<String>,
820    #[serde(rename = "tlsCA")]
821    pub tls_ca: Option<String>,
822    #[serde(rename = "tlsCAFile")]
823    pub tls_ca_file: Option<String>,
824    #[serde(rename = "tlsServerName")]
825    pub tls_server_name: Option<String>,
826    #[serde(rename = "waitUntilAvailable")]
827    pub wait_until_available: Option<String>,
828    #[serde(rename = "serverSettings")]
829    pub server_settings: Option<HashMap<String, String>>,
830    #[serde(rename = "credentialsFile")]
831    pub credentials_file: Option<String>,
832    pub credentials: Option<String>,
833    #[serde(rename = "secretKey")]
834    pub secret_key: Option<String>,
835}
836
837#[cfg(feature = "serde")]
838fn deserialize_string_or_number<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
839where
840    D: serde::Deserializer<'de>,
841{
842    let s = serde_json::Value::deserialize(deserializer)?;
843    if let Some(s) = s.as_str() {
844        Ok(Some(s.to_string()))
845    } else {
846        Ok(Some(s.to_string()))
847    }
848}
849
850impl TryInto<Params> for ConnectionOptions {
851    type Error = ParseError;
852
853    fn try_into(self) -> Result<Params, Self::Error> {
854        if self.credentials.is_some() && self.credentials_file.is_some() {
855            return Err(ParseError::MultipleCompound(
856                BuildPhase::Options,
857                vec![
858                    CompoundSource::CredentialsFile,
859                    CompoundSource::CredentialsFile,
860                ],
861            ));
862        }
863
864        if self.tls_ca.is_some() && self.tls_ca_file.is_some() {
865            return Err(ParseError::ExclusiveOptions(
866                "tls_ca".to_string(),
867                "tls_ca_file".to_string(),
868            ));
869        }
870
871        if self.branch.is_some() && self.database.is_some() {
872            return Err(ParseError::ExclusiveOptions(
873                "branch".to_string(),
874                "database".to_string(),
875            ));
876        }
877
878        let mut credentials = Param::from_file(self.credentials_file.clone());
879        if credentials.is_none() {
880            credentials = Param::from_unparsed(self.credentials.clone());
881        }
882
883        let mut tls_ca = Param::from_unparsed(self.tls_ca.clone());
884        if tls_ca.is_none() {
885            tls_ca = Param::from_file(self.tls_ca_file.clone());
886        }
887
888        let explicit = Params {
889            dsn: Param::from_unparsed(self.dsn.clone()),
890            credentials,
891            user: Param::from_unparsed(self.user.clone()),
892            password: Param::from_unparsed(self.password.clone()),
893            instance: Param::from_unparsed(self.instance.clone()),
894            database: Param::from_unparsed(self.database.clone()),
895            host: Param::from_unparsed(self.host.clone()),
896            port: Param::from_unparsed(self.port.as_ref().map(|n| n.to_string())),
897            branch: Param::from_unparsed(self.branch.clone()),
898            secret_key: Param::from_unparsed(self.secret_key.clone()),
899            tls_security: Param::from_unparsed(self.tls_security.clone()),
900            tls_ca,
901            tls_server_name: Param::from_unparsed(self.tls_server_name.clone()),
902            server_settings: self.server_settings.unwrap_or_default(),
903            wait_until_available: Param::from_unparsed(self.wait_until_available.clone()),
904            ..Default::default()
905        };
906
907        Ok(explicit)
908    }
909}
910
911#[cfg(test)]
912mod tests {
913    use super::*;
914
915    #[test]
916    fn test_as_credentials() {
917        let config = Config::default();
918        let credentials = config.as_credentials().unwrap();
919        assert_eq!(credentials.host, Some("localhost".to_string()));
920    }
921
922    #[test]
923    fn test_dsn_url() {
924        let config = Config::default();
925        let url = config.dsn_url().unwrap();
926        assert_eq!(url, "gel://");
927
928        let config = Config::default().with_host("example.com", 1234).unwrap();
929        let url = config.dsn_url().unwrap();
930        assert_eq!(url, "gel://example.com:1234");
931
932        let config = Config::default()
933            .with_host("localhost", 5656)
934            .unwrap()
935            .with_db(DatabaseBranch::Database("edgedb".to_string()));
936        let url = config.dsn_url().unwrap();
937        assert_eq!(url, "gel:///edgedb");
938
939        let config = Config::default()
940            .with_host("example.com", 5656)
941            .unwrap()
942            .with_db(DatabaseBranch::Branch("main".to_string()));
943        let url = config.dsn_url().unwrap();
944        assert_eq!(url, "gel://example.com/main");
945
946        let config = Config::default()
947            .with_host("localhost", 5656)
948            .unwrap()
949            .with_db(DatabaseBranch::Branch("main".to_string()))
950            .with_user("user");
951        let url = config.dsn_url().unwrap();
952        assert_eq!(url, "gel:///main?user=user");
953
954        let config = Config::default()
955            .with_host("localhost", 5656)
956            .unwrap()
957            .with_db(DatabaseBranch::Branch("main".to_string()))
958            .with_user("user")
959            .with_password("%[]{}");
960        let url = config.dsn_url().unwrap();
961        assert_eq!(url, "gel:///main?user=user&password=%25%5B%5D%7B%7D");
962    }
963}