edgedb_client/
builder.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::env;
4use std::error::Error as _;
5use std::ffi::{OsString, OsStr};
6use std::fmt;
7use std::io;
8use std::net::IpAddr;
9use std::path::{Path, PathBuf};
10use std::str::{self, FromStr};
11use std::sync::Arc;
12use std::time::{Instant, Duration};
13
14use async_std::fs;
15use async_std::future::Future;
16use async_std::net::TcpStream;
17use async_std::path::{Path as AsyncPath};
18use async_std::task::sleep;
19use bytes::{Bytes, BytesMut};
20use futures_util::AsyncReadExt;
21use rand::{thread_rng, Rng};
22use rustls::client::ServerCertVerifier;
23use scram::ScramClient;
24use serde_json::from_slice;
25use sha1::Digest;
26use tls_api::{TlsConnectorBox, TlsConnector as _, TlsConnectorBuilder as _};
27use tls_api::{TlsStream, TlsStreamDyn as _};
28use tls_api_not_tls::TlsConnector as PlainConnector;
29use typemap::{TypeMap, DebugAny};
30use webpki::DnsNameRef;
31
32use edgedb_protocol::client_message::{ClientMessage, ClientHandshake};
33use edgedb_protocol::features::ProtocolVersion;
34use edgedb_protocol::server_message::{ServerMessage, Authentication};
35use edgedb_protocol::server_message::{TransactionState, ServerHandshake};
36use edgedb_protocol::server_message::ParameterStatus;
37use edgedb_protocol::value::Value;
38
39use crate::client::{Connection, Sequence, State, PingInterval};
40use crate::credentials::{Credentials, TlsSecurity};
41use crate::errors::{ClientConnectionError, ProtocolError, ProtocolTlsError};
42use crate::errors::{ClientConnectionFailedError, AuthenticationError};
43use crate::errors::{ClientError, ClientConnectionFailedTemporarilyError};
44use crate::errors::{ClientNoCredentialsError, ProtocolEncodingError};
45use crate::errors::{Error, ErrorKind, PasswordRequired, ResultExt};
46use crate::server_params::{PostgresAddress, SystemConfig};
47use crate::tls;
48
49pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
50pub const DEFAULT_WAIT: Duration = Duration::from_secs(30);
51pub const DEFAULT_POOL_SIZE: usize = 10;
52pub const DEFAULT_HOST: &str = "localhost";
53pub const DEFAULT_PORT: u16 = 5656;
54
55type Verifier = Arc<dyn ServerCertVerifier>;
56
57/// A builder used to create connections.
58#[derive(Debug, Clone)]
59pub struct Builder {
60    address: Address,
61    admin: bool,
62    user: String,
63    password: Option<String>,
64    database: String,
65    pem: Option<String>,
66    tls_security: TlsSecurity,
67    instance_name: Option<String>,
68
69    initialized: bool,
70    wait: Duration,
71    connect_timeout: Duration,
72    insecure_dev_mode: bool,
73    creds_file_outdated: bool,
74
75    // Pool configuration
76    pub(crate) max_connections: usize,
77}
78/// Configuration of the client
79///
80/// Use [`Builder`][] to create an instance
81#[derive(Clone)]
82pub struct Config(pub(crate) Arc<ConfigInner>);
83
84pub(crate) struct ConfigInner {
85    pub address: Address,
86    #[allow(dead_code)] // TODO(tailhook) for cli only
87    pub admin: bool,
88    pub user: String,
89    pub password: Option<String>,
90    pub database: String,
91    pub verifier: Arc<dyn ServerCertVerifier>,
92    #[allow(dead_code)] // TODO(tailhook) maybe for errors
93    pub instance_name: Option<String>,
94    pub wait: Duration,
95    pub connect_timeout: Duration,
96    pub tls_security: TlsSecurity,
97    #[allow(dead_code)] // TODO(tailhook) maybe for future things
98    pub insecure_dev_mode: bool,
99
100    // Pool configuration
101    pub max_connections: usize,
102}
103
104#[derive(Debug, Clone)]
105pub(crate) enum Address {
106    Tcp((String, u16)),
107    #[allow(dead_code)] // TODO(tailhook), but for cli only
108    Unix(PathBuf),
109}
110
111struct DisplayAddr<'a>(Option<&'a Address>);
112
113pub async fn timeout<F, T>(dur: Duration, f: F) -> Result<T, Error>
114    where F: Future<Output = Result<T, Error>>,
115{
116    use async_std::future::timeout;
117
118    timeout(dur, f).await
119    .unwrap_or_else(|_| {
120        Err(ClientConnectionFailedTemporarilyError::with_source(
121            io::Error::from(io::ErrorKind::TimedOut)
122        ))
123    })
124}
125
126fn sleep_duration() -> Duration {
127    Duration::from_millis(thread_rng().gen_range(10u64..200u64))
128}
129
130fn is_temporary(e: &Error) -> bool {
131    use io::ErrorKind::{ConnectionRefused, TimedOut, NotFound};
132    use io::ErrorKind::{ConnectionAborted, ConnectionReset, UnexpectedEof};
133    use io::ErrorKind::{AddrNotAvailable};
134
135    if e.is::<ClientConnectionFailedTemporarilyError>() {
136        return true;
137    }
138    if e.is::<ClientConnectionError>() {
139        let io_err = e.source().and_then(|src| {
140            src.downcast_ref::<io::Error>()
141            .or_else(|| src.downcast_ref::<Box<io::Error>>().map(|b| &**b))
142        });
143        if let Some(e) = io_err {
144            match e.kind() {
145                | ConnectionRefused
146                | ConnectionReset
147                | ConnectionAborted
148                | NotFound  // For unix sockets
149                | TimedOut
150                | UnexpectedEof     // For Docker server which is starting up
151                | AddrNotAvailable  // Docker exposed ports not yet bound
152                => return true,
153                _ => {},
154            }
155        }
156    }
157    return false;
158}
159
160fn tls_fail(e: anyhow::Error) -> Error {
161    if let Some(e) = e.downcast_ref::<rustls::Error>() {
162        if matches!(e, rustls::Error::CorruptMessage) {
163            return ProtocolTlsError::with_message(
164                "corrupt message, possibly server \
165                 does not support TLS connection."
166            );
167        }
168    }
169    ClientConnectionError::with_source_ref(e)
170}
171
172fn get_env(name: &str) -> Result<Option<String>, Error> {
173    match env::var(name) {
174        Ok(v) if v.is_empty() => Ok(None),
175        Ok(v) => Ok(Some(v)),
176        Err(env::VarError::NotPresent) => Ok(None),
177        Err(e) => {
178            Err(
179                ClientError::with_source(e)
180                .context(
181                   format!("Cannot decode environment variable {:?}", name))
182            )
183        }
184    }
185}
186
187fn get_port_env() -> Result<Option<String>, Error> {
188    static PORT_WARN: std::sync::Once = std::sync::Once::new();
189
190    let port = get_env("EDGEDB_PORT")?;
191    if let Some(port) = &port {
192        // ignore port if it's docker-specified string
193        if port.starts_with("tcp://") {
194
195            PORT_WARN.call_once(|| {
196                log::warn!("Environment variable `EDGEDB_PORT` contains \
197                           docker-link-like definition. Ingoring...");
198            });
199
200            return Ok(None);
201        }
202    }
203    Ok(port)
204}
205
206fn get_host_port() -> Result<Option<(Option<String>, Option<u16>)>, Error> {
207    let host = get_env("EDGEDB_HOST")?;
208    let port = get_port_env()?.map(|port| {
209        port.parse().map_err(|e| {
210            ClientError::with_source(e)
211                .context("cannot parse env var EDGEDB_PORT")
212        })
213    }).transpose()?;
214    if host.is_some() || port.is_some() {
215        Ok(Some((host, port)))
216    } else {
217        Ok(None)
218    }
219}
220
221pub async fn search_dir(base: &AsyncPath) -> Result<Option<&AsyncPath>, Error>
222{
223    let mut path = base;
224    if path.join("edgedb.toml").exists().await {
225        return Ok(Some(path.into()));
226    }
227    while let Some(parent) = path.parent() {
228        if parent.join("edgedb.toml").exists().await {
229            return Ok(Some(parent.into()));
230        }
231        path = parent;
232    }
233    Ok(None)
234}
235
236#[cfg(unix)]
237fn path_bytes<'x>(path: &'x Path) -> &'x [u8] {
238    use std::os::unix::ffi::OsStrExt;
239    path.as_os_str().as_bytes()
240}
241
242#[cfg(windows)]
243fn path_bytes<'x>(path: &'x Path) -> &'x [u8] {
244    path.to_str().expect("windows paths are always valid UTF-16").as_bytes()
245}
246
247fn hash(path: &Path) -> String {
248    format!("{:x}", sha1::Sha1::new_with_prefix(path_bytes(path)).finalize())
249}
250
251fn stash_name(path: &Path) -> OsString {
252    let hash = hash(path);
253    let base = path.file_name().unwrap_or(OsStr::new(""));
254    let mut base = base.to_os_string();
255    base.push("-");
256    base.push(&hash);
257    return base;
258}
259
260fn config_dir() -> Result<PathBuf, Error> {
261    let dir = if cfg!(windows) {
262        dirs::data_local_dir()
263            .ok_or_else(|| ClientError::with_message(
264                "cannot determine local data directory"))?
265            .join("EdgeDB")
266            .join("config")
267    } else {
268        dirs::config_dir()
269            .ok_or_else(|| ClientError::with_message(
270                "cannot determine config directory"))?
271            .join("edgedb")
272    };
273    Ok(dir)
274}
275
276#[allow(dead_code)]
277#[cfg(target_os="linux")]
278fn default_runtime_base() -> Result<PathBuf, Error> {
279    extern "C" {
280        fn geteuid() -> u32;
281    }
282    Ok(Path::new("/run/user").join(unsafe { geteuid() }.to_string()))
283}
284
285#[allow(dead_code)]
286#[cfg(not(target_os="linux"))]
287fn default_runtime_base() -> Result<PathBuf, Error> {
288    Err(ClientError::with_message("no default runtime dir for the platform"))
289}
290
291fn stash_path(project_dir: &Path) -> Result<PathBuf, Error> {
292    Ok(config_dir()?.join("projects").join(stash_name(project_dir)))
293}
294
295fn is_valid_instance_name(name: &str) -> bool {
296    let mut chars = name.chars();
297    match chars.next() {
298        Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
299        _ => return false,
300    }
301    for c in chars {
302        if !c.is_ascii_alphanumeric() && c != '_' {
303            return false;
304        }
305    }
306    return true;
307}
308
309
310impl fmt::Display for DisplayAddr<'_> {
311    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
312        match &self.0 {
313            Some(Address::Tcp((host, port))) => {
314                write!(f, "{}:{}", host, port)
315            }
316            Some(Address::Unix(path)) => write!(f, "unix:{}", path.display()),
317            None => write!(f, "<no address>"),
318        }
319    }
320}
321
322impl Builder {
323
324    /// Initializes a Builder using environment variables or project config.
325    pub async fn from_env() -> Result<Builder, Error> {
326        let mut builder = Builder::uninitialized();
327
328        // optimize discovering project if defined by environment variable
329        if get_env("EDGEDB_HOST")?.is_none() &&
330           get_port_env()?.is_none() &&
331           get_env("EDGEDB_INSTANCE")?.is_none() &&
332           get_env("EDGEDB_DSN")?.is_none() &&
333           get_env("EDGEDB_CONFIGURATION_FILE")?.is_none()
334        {
335            builder.read_project(None, false).await?;
336        }
337
338        builder.read_env_vars().await?;
339        Ok(builder)
340    }
341
342    /// Reads the project config if it exists.
343    ///
344    /// Projects are initialized using the command-line tool:
345    /// ```shell
346    /// edgedb project init
347    /// ```
348    /// Linking to an already running EdgeDB is also possible:
349    /// ```shell
350    /// edgedb project init --link
351    /// ```
352    ///
353    /// Returns a boolean value indicating whether the project was found.
354    pub async fn read_project(&mut self,
355        override_dir: Option<&Path>, search_parents: bool)
356        -> Result<&mut Self, Error>
357    {
358        let dir = match override_dir {
359            Some(v) => Cow::Borrowed(v.as_ref()),
360            None => {
361                Cow::Owned(env::current_dir()
362                    .map_err(|e| ClientError::with_source(e)
363                        .context("failed to get current directory"))?
364                    .into())
365            }
366        };
367
368        let dir = if search_parents {
369            if let Some(ancestor) = search_dir(&dir).await? {
370                Cow::Borrowed(ancestor)
371            } else {
372                return Ok(self);
373            }
374        } else {
375            if !dir.join("edgedb.toml").exists().await {
376                return Ok(self);
377            }
378            dir
379        };
380        let canon = fs::canonicalize(&dir).await
381            .map_err(|e| ClientError::with_source(e).context(
382                format!("failed to canonicalize dir {:?}", dir)
383            ))?;
384        let stash_path = stash_path(canon.as_ref())?;
385        if AsRef::<AsyncPath>::as_ref(&stash_path).exists().await {
386            let instance =
387                fs::read_to_string(stash_path.join("instance-name")).await
388                .map_err(|e| ClientError::with_source(e).context(
389                    format!("error reading project settings {:?}", dir)
390                ))?;
391            self.read_instance(instance.trim()).await?;
392
393        }
394        Ok(self)
395    }
396
397    /// Indicates whether credentials are set for this builder.
398    pub fn is_initialized(&self) -> bool {
399        self.initialized
400    }
401    /// Read environment variables and set respective configuration parameters.
402    ///
403    /// This function initializes the builder if one of the following is set:
404    ///
405    /// * `EDGEDB_CREDENTIALS_FILE`
406    /// * `EDGEDB_INSTANCE`
407    /// * `EDGEDB_DSN`
408    /// * `EDGEDB_HOST` or `EDGEDB_PORT`
409    ///
410    /// If it finds one of these then it will reset all previously set
411    /// credentials.
412    ///
413    /// If one of the following are set:
414    ///
415    /// * `EDGEDB_DATABASE`
416    /// * `EDGEDB_USER`
417    /// * `EDGEDB_PASSWORD`
418    ///
419    /// Then the value of that environment variable will be used to set just
420    /// the parameter matching that environment variable.
421    ///
422    /// The `insecure_dev_mode` and connection parameters are never modified by
423    /// this function for now.
424    pub async fn read_env_vars(&mut self) -> Result<&mut Self, Error> {
425        if let Some((host, port)) = get_host_port()? {
426            self.host_port(host, port);
427        } else if let Some(path) = get_env("EDGEDB_CREDENTIALS_FILE")? {
428            self.read_credentials(path).await?;
429        } else if let Some(instance) = get_env("EDGEDB_INSTANCE")? {
430            self.read_instance(&instance).await?;
431        } else if let Some(dsn) = get_env("EDGEDB_DSN")? {
432            self.read_dsn(&dsn).await.map_err(|e|
433                e.context("cannot parse env var EDGEDB_DNS"))?;
434        }
435        if let Some(database) = get_env("EDGEDB_DATABASE")? {
436            self.database = database;
437        }
438        if let Some(user) = get_env("EDGEDB_USER")? {
439            self.user = user;
440        }
441        if let Some(password) = get_env("EDGEDB_PASSWORD")? {
442            self.password = Some(password);
443        }
444        if let Some(sec) = get_env("EDGEDB_CLIENT_TLS_SECURITY")? {
445            self.tls_security = match &sec[..] {
446                "default" => TlsSecurity::Default,
447                "insecure" => TlsSecurity::Insecure,
448                "no_host_verification" => TlsSecurity::NoHostVerification,
449                "strict" => TlsSecurity::Strict,
450                _ => {
451                    return Err(ClientError::with_message(
452                        format!("Invalid value {:?} for env var \
453                                EDGEDB_CLIENT_TLS_SECURITY. \
454                                Options: default, insecure, \
455                                no_host_verification, strict.",
456                                sec)
457                    ));
458                }
459            };
460        }
461        self.read_extra_env_vars()?;
462        Ok(self)
463    }
464    /// Read environment variables that aren't credentials
465    pub fn read_extra_env_vars(&mut self) -> Result<&mut Self, Error> {
466        if let Some(mode) = get_env("EDGEDB_CLIENT_SECURITY")? {
467            self.insecure_dev_mode = match &mode[..] {
468                "default" => false,
469                "insecure_dev_mode" => true,
470                _ => {
471                    return Err(ClientError::with_message(
472                        format!("Invalid value {:?} for env var \
473                                EDGEDB_CLIENT_SECURITY. \
474                                Options: default, insecure_dev_mode.",
475                                mode)
476                    ));
477                }
478            };
479        }
480        Ok(self)
481    }
482
483    /// Set all credentials.
484    ///
485    /// This marks the builder as initialized.
486    pub fn credentials(&mut self, credentials: &Credentials)
487        -> Result<&mut Self, Error>
488    {
489        if let Some(cert_data) = &credentials.tls_ca {
490            validate_certs(&cert_data)
491                .context("invalid certificates in `tls_ca`")?;
492        }
493        self.reset_compound();
494        self.address = Address::Tcp((
495            credentials.host.clone()
496                .unwrap_or_else(|| DEFAULT_HOST.into()),
497            credentials.port,
498        ));
499        self.admin = false;
500        self.user = credentials.user.clone();
501        self.password = credentials.password.clone();
502        self.database = credentials.database.clone()
503                .unwrap_or_else(|| "edgedb".into());
504        self.creds_file_outdated = credentials.file_outdated;
505        self.tls_security = credentials.tls_security;
506        self.pem = credentials.tls_ca.clone();
507        self.initialized = true;
508        Ok(self)
509    }
510
511    /// Returns the instance name if any when the credentials file is outdated.
512    #[cfg(feature="unstable")]
513    pub fn get_instance_name_for_creds_update(&self) -> Option<&str> {
514        if self.creds_file_outdated {
515            self.instance_name.as_deref()
516        } else {
517            None
518        }
519    }
520
521    /// Read credentials from the named instance.
522    ///
523    /// Named instances are created using the command-line tool, either
524    /// directly:
525    /// ```shell
526    /// edgedb instance create <name>
527    /// ```
528    /// or when initializing a project:
529    /// ```shell
530    /// edgedb project init
531    /// ```
532    /// In the latter case you should use [`read_project()`][Builder::read_project]
533    /// instead if possible.
534    ///
535    /// This will mark the builder as initialized (if reading is successful)
536    /// and overwrite all credentials. However, `insecure_dev_mode`, pools
537    /// sizes, and timeouts are kept intact.
538    pub async fn read_instance(&mut self, name: &str)
539        -> Result<&mut Self, Error>
540    {
541        if !is_valid_instance_name(name) {
542            return Err(ClientError::with_message(format!(
543                "instance name {:?} contains unsupported characters", name)));
544        }
545        self.read_credentials(
546            config_dir()?.join("credentials").join(format!("{}.json", name))
547        ).await?;
548        self.instance_name = Some(name.into());
549        Ok(self)
550    }
551
552    /// Read credentials from a file.
553    ///
554    /// This will mark the builder as initialized (if reading is successful)
555    /// and overwrite all credentials. However, `insecure_dev_mode`, pools
556    /// sizes, and timeouts are kept intact.
557    pub async fn read_credentials(&mut self, path: impl AsRef<Path>)
558        -> Result<&mut Self, Error>
559    {
560        let path = path.as_ref();
561        async {
562            let data = fs::read(path).await
563                .map_err(ClientError::with_source)?;
564            let creds = serde_json::from_slice(&data)
565                .map_err(ClientError::with_source)?;
566            self.credentials(&creds)?;
567            Ok(())
568        }.await.map_err(|e: Error| e.context(
569            format!("cannot read credentials file {}", path.display())
570        ))?;
571        Ok(self)
572    }
573
574    /// Initialize credentials using data source name (DSN).
575    ///
576    /// DSN's that EdgeDB like are URL with `egdgedb::/scheme`:
577    /// ```text
578    /// edgedb://user:secret@localhost:5656/
579    /// ```
580    /// All the credentials can be specified using a DSN, although parsing a
581    /// DSN may also lead to reading of environment variables (if query
582    /// arguments of the for `*_env` are specified) and local files (for query
583    /// arguments named `*_file`).
584    ///
585    /// This will mark the builder as initialized (if reading is successful)
586    /// and overwrite all the credentials. However, `insecure_dev_mode`, pools
587    /// sizes, and timeouts are kept intact.
588    pub async fn read_dsn(&mut self, dsn: &str) -> Result<&mut Self, Error> {
589        let admin = dsn.starts_with("edgedbadmin://");
590        if !dsn.starts_with("edgedb://") && !admin {
591            return Err(ClientError::with_message(format!(
592                "String {:?} is not a valid DSN", dsn)));
593        };
594        let url = url::Url::parse(dsn)
595            .map_err(|e| ClientError::with_source(e)
596                .context(format!("cannot parse DSN {:?}", dsn)))?;
597        self.reset_compound();
598        let host = if let Some(url::Host::Ipv6(host)) = url.host() {
599            // async-std uses raw IPv6 address without "[]"
600            host.to_string()
601        } else {
602            url.host_str().unwrap_or(DEFAULT_HOST).to_owned()
603        };
604        let port = url.port().unwrap_or(DEFAULT_PORT);
605        self.address = Address::Tcp((host, port));
606        self.admin = admin;
607        self.user = if url.username().is_empty() {
608            "edgedb".to_owned()
609        } else {
610            url.username().to_owned()
611        };
612        self.password = url.password().map(|s| s.to_owned());
613        self.database = url.path().strip_prefix("/")
614                .unwrap_or("edgedb").to_owned();
615        self.initialized = true;
616        Ok(self)
617    }
618    /// Creates a new builder that has to be intialized by calling some methods.
619    ///
620    /// This is only useful if you have connections to multiple unrelated
621    /// databases, or you want to have total control of the database
622    /// initialization process.
623    ///
624    /// Usually, `Builder::from_env()` should be used instead.
625    pub fn uninitialized() -> Builder {
626        Builder {
627            address: Address::Tcp((DEFAULT_HOST.into(), DEFAULT_PORT)),
628            admin: false,
629            user: "edgedb".into(),
630            password: None,
631            database: "edgedb".into(),
632            tls_security: TlsSecurity::Default,
633            pem: None,
634            instance_name: None,
635
636            wait: DEFAULT_WAIT,
637            connect_timeout: DEFAULT_CONNECT_TIMEOUT,
638            initialized: false,
639            insecure_dev_mode: false,
640            creds_file_outdated: false,
641
642            max_connections: DEFAULT_POOL_SIZE,
643        }
644    }
645    fn reset_compound(&mut self) {
646        *self = Builder {
647            // replace all of them
648            address: Address::Tcp((DEFAULT_HOST.into(), DEFAULT_PORT)),
649            admin: false,
650            user: "edgedb".into(),
651            password: None,
652            database: "edgedb".into(),
653            tls_security: TlsSecurity::Default,
654            pem: None,
655            instance_name: None,
656
657            initialized: false,
658            // keep old values
659            wait: self.wait,
660            connect_timeout: self.connect_timeout,
661            insecure_dev_mode: self.insecure_dev_mode,
662            creds_file_outdated: false,
663
664            max_connections: self.max_connections,
665        };
666    }
667    /// Extract credentials from the [Builder] so they can be saved as JSON.
668    pub fn as_credentials(&self) -> Result<Credentials, Error> {
669        let (host, port) = match &self.address {
670            Address::Tcp(pair) => pair,
671            Address::Unix(_) => {
672                return Err(ClientError::with_message(
673                    "Unix socket address cannot \
674                    be saved as credentials file"));
675            }
676        };
677        Ok(Credentials {
678            host: Some(host.clone()),
679            port: *port,
680            user: self.user.clone(),
681            password: self.password.clone(),
682            database: Some( self.database.clone()),
683            tls_ca: self.pem.clone(),
684            tls_security: self.tls_security,
685            file_outdated: false,
686            cloud_instance_id: None,
687            cloud_original_dsn: None,
688        })
689    }
690    /// Get the `host` this builder is configured to connect to.
691    ///
692    /// For unix-socket-configured builder (only if `admin_socket` feature is
693    /// enabled) returns "localhost"
694    pub fn get_host(&self) -> &str {
695        match &self.address {
696            Address::Tcp((host, _)) => host,
697            Address::Unix(_) => "localhost",
698        }
699    }
700    /// Get the `port` this builder is configured to connect to.
701    pub fn get_port(&self) -> u16 {
702        match &self.address {
703            Address::Tcp((_, port)) => *port,
704            Address::Unix(_) => 5656
705        }
706    }
707    /// Initialize credentials using host/port data.
708    ///
709    /// If either of host or port is `None`, they are replaced with the
710    /// default of `localhost` and `5656` respectively.
711    ///
712    /// This will mark the builder as initialized and overwrite all the
713    /// credentials. However, `insecure_dev_mode`, pools sizes, and timeouts
714    /// are kept intact.
715    pub fn host_port(&mut self,
716        host: Option<impl Into<String>>, port: Option<u16>)
717        -> &mut Self
718    {
719        self.reset_compound();
720        self.address = Address::Tcp((
721            host.map_or_else(|| DEFAULT_HOST.into(), |h| h.into()),
722            port.unwrap_or(DEFAULT_PORT),
723        ));
724        self.initialized = true;
725        self
726    }
727
728    #[cfg(feature="admin_socket")]
729    /// Use admin socket instead of normal socket
730    pub fn admin(&mut self) -> Result<&mut Self, Error> {
731        let prefix = if let Some(name) = &self.instance_name {
732            if cfg!(windows) {
733                return Err(ClientError::with_message(
734                    "unix sockets are not supported on Windows"));
735            } else if let Some(dir) = dirs::runtime_dir() {
736                dir.join(format!("edgedb-{}", name))
737            } else {
738                dirs::cache_dir()
739                    .ok_or_else(|| ClientError::with_message(
740                        "cannot determine cache directory"))?
741                    .join("edgedb")
742                    .join("run")
743                    .join(name)
744            }
745        } else {
746            if cfg!(target_os="macos") {
747                "/var/run/edgedb".into()
748            } else {
749                "/run/edgedb".into()
750            }
751        };
752        match self.address {
753            Address::Tcp((_, port)) => {
754                self.address = Address::Unix(
755                    prefix.join(format!(".s.EDGEDB.admin.{}", port))
756                );
757            }
758            Address::Unix(_) => {},
759        }
760        Ok(self)
761    }
762
763    #[cfg(feature="admin_socket")]
764    /// Initialize credentials using unix socket
765    pub fn unix_path(&mut self, path: impl Into<PathBuf>,
766                     port: Option<u16>, admin: bool)
767        -> &mut Self
768    {
769        self.reset_compound();
770        self.admin = admin;
771        let path = path.into();
772        let has_socket_name = path.file_name()
773            .and_then(|x| x.to_str())
774            .map(|x| x.contains(".s.EDGEDB"))
775            .unwrap_or(false);
776        let path = if has_socket_name {
777            // it's the full path
778            path
779        } else {
780            let port = port.unwrap_or(5656);
781            let socket_name = if admin {
782                format!(".s.EDGEDB.admin.{}", port)
783            } else {
784                format!(".s.EDGEDB.{}", port)
785            };
786            path.join(socket_name)
787        };
788        // TODO(tailhook) figure out whether it's a prefix or full socket?
789        self.address = Address::Unix(path.into());
790        self.initialized = true;
791        self
792    }
793    /// Get the user name for SCRAM authentication.
794    pub fn get_user(&self) -> &str {
795        &self.user
796    }
797    /// Set the user name for SCRAM authentication.
798    pub fn user(&mut self, user: impl Into<String>) -> &mut Self {
799        self.user = user.into();
800        self
801    }
802    /// Set the password for SCRAM authentication.
803    pub fn password(&mut self, password: impl Into<String>) -> &mut Self {
804        self.password = Some(password.into());
805        self
806    }
807    /// Set the database name.
808    pub fn database(&mut self, database: impl Into<String>) -> &mut Self {
809        self.database = database.into();
810        self
811    }
812    /// Get the database name.
813    pub fn get_database(&self) -> &str {
814        &self.database
815    }
816    /// Set the time to wait for the database server to become available.
817    ///
818    /// This works by ignoring certain errors known to happen while the
819    /// database is starting up or restarting (e.g. "connection refused" or
820    /// early "connection reset").
821    ///
822    /// Note: the amount of time establishing a connection can take is the sum
823    /// of `wait_until_available` plus `connect_timeout`
824    pub fn wait_until_available(&mut self, time: Duration) -> &mut Self {
825        self.wait = time;
826        self
827    }
828    /// A timeout for a single connect attempt.
829    ///
830    /// The default is 10 seconds. A subsecond timeout should be fine for most
831    /// networks. However, in some cases this can be much slower. That's
832    /// because this timeout includes authentication, during which:
833    /// * The password is checked (slow by design).
834    /// * A compiler process is launched (slow now, may be optimized later).
835    ///
836    /// So in a concurrent case on slower VMs (such as CI with parallel
837    /// tests), 10 seconds is more reasonable default.
838    ///
839    /// The `wait_until_available` setting should be larger than this value to
840    /// allow multiple attempts.
841    ///
842    /// Note: the amount of time establishing a connection can take is the sum
843    /// of `wait_until_available` plus `connect_timeout`
844    pub fn connect_timeout(&mut self, timeout: Duration) -> &mut Self {
845        self.connect_timeout = timeout;
846        self
847    }
848
849    /// Set the allowed certificate as a PEM file.
850    pub fn pem_certificates(&mut self, cert_data: &String)
851        -> Result<&mut Self, Error>
852    {
853        validate_certs(cert_data).context("invalid PEM certificate")?;
854        self.pem = Some(cert_data.clone());
855        Ok(self)
856    }
857
858    /// Updates the client TLS security mode.
859    ///
860    /// By default, the certificate chain is always verified; but hostname
861    /// verification is disabled if configured to use only a
862    /// specific certificate, and enabled if root certificates are used.
863    pub fn tls_security(&mut self, value: TlsSecurity) -> &mut Self {
864        self.tls_security = value;
865        self
866    }
867
868    /// Enables insecure dev mode.
869    ///
870    /// This disables certificate validation entirely.
871    pub fn insecure_dev_mode(&mut self, value: bool) -> &mut Self {
872        self.insecure_dev_mode = value;
873        self
874    }
875
876    /// A displayable form for an address this builder will connect to
877    pub fn display_addr<'x>(&'x self) -> impl fmt::Display + 'x {
878        if self.initialized {
879            DisplayAddr(Some(&self.address))
880        } else {
881            DisplayAddr(None)
882        }
883    }
884
885    fn insecure(&self) -> bool {
886        use TlsSecurity::Insecure;
887        self.insecure_dev_mode || self.tls_security == Insecure
888    }
889
890    fn trust_anchors(&self) -> Result<Vec<tls::OwnedTrustAnchor>, Error> {
891        tls::OwnedTrustAnchor::read_all(
892            self.pem.as_deref().unwrap_or("")
893        ).map_err(ClientError::with_source_ref)
894    }
895
896    #[cfg(feature="unstable")]
897    /// Returns certificate store
898    pub fn root_cert_store(&self) -> Result<rustls::RootCertStore, Error> {
899        self._root_cert_store()
900    }
901
902    fn _root_cert_store(&self) -> Result<rustls::RootCertStore, Error> {
903        let mut roots = rustls::RootCertStore::empty();
904        if self.pem.is_some() {
905            roots.add_server_trust_anchors(
906                self.trust_anchors()?.into_iter().map(Into::into)
907            );
908        } else {
909            roots.add_server_trust_anchors(
910                webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
911                    rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
912                        ta.subject,
913                        ta.spki,
914                        ta.name_constraints,
915                    )
916                })
917            );
918        }
919        Ok(roots)
920    }
921
922    /// Build connection and pool configuration object
923    pub fn build(&self) -> Result<Config, Error> {
924        use TlsSecurity::*;
925
926        if !self.initialized {
927            return Err(ClientNoCredentialsError::with_message(
928                "EdgeDB connection options are not initialized. \
929                Run `edgedb project init` or use environment variables \
930                to configure connection."));
931        }
932        let verifier = match self.tls_security {
933            _ if self.insecure() => Arc::new(tls::NullVerifier) as Verifier,
934            Insecure => Arc::new(tls::NullVerifier) as Verifier,
935            NoHostVerification => {
936                Arc::new(tls::NoHostnameVerifier::new(
937                        self.trust_anchors()?
938                )) as Verifier
939            }
940            Strict => {
941                Arc::new(rustls::client::WebPkiVerifier::new(
942                    self._root_cert_store()?,
943                    None,
944                )) as Verifier
945            }
946            Default => match self.pem {
947                Some(_) => {
948                    Arc::new(tls::NoHostnameVerifier::new(
949                            self.trust_anchors()?
950                    )) as Verifier
951                }
952                None => {
953                    Arc::new(rustls::client::WebPkiVerifier::new(
954                        self._root_cert_store()?,
955                        None,
956                    )) as Verifier
957                }
958            },
959        };
960
961        Ok(Config(Arc::new(ConfigInner {
962            address: self.address.clone(),
963            admin: self.admin,
964            user: self.user.clone(),
965            password: self.password.clone(),
966            database: self.database.clone(),
967            verifier,
968            instance_name: self.instance_name.clone(),
969            wait: self.wait,
970            connect_timeout: self.connect_timeout,
971            tls_security: self.tls_security,
972            insecure_dev_mode: self.insecure_dev_mode,
973
974            // Pool configuration
975            max_connections: self.max_connections,
976        })))
977    }
978
979    /// Set the maximum number of underlying database connections.
980    pub fn max_connections(&mut self, value: usize) -> &mut Self {
981        self.max_connections = value;
982        self
983    }
984
985    /// Get the path of the Unix socket if that is configured to be used.
986    ///
987    /// This is a deprecated API and should only be used by the command-line
988    /// tool.
989    #[cfg(feature="admin_socket")]
990    pub fn get_unix_path(&self) -> Option<PathBuf> {
991        self._get_unix_path().unwrap_or(None)
992    }
993    fn _get_unix_path(&self) -> Result<Option<PathBuf>, Error> {
994        match &self.address {
995            Address::Unix(path) => Ok(Some(path.clone())),
996            Address::Tcp(_) => Ok(None),
997        }
998    }
999}
1000
1001fn validate_certs(data: &str) -> Result<(), Error> {
1002    let anchors = tls::OwnedTrustAnchor::read_all(data)
1003        .map_err(|e| ClientError::with_source_ref(e))?;
1004    if anchors.is_empty() {
1005        return Err(ClientError::with_message(
1006                "PEM data contains no certificate"));
1007    }
1008    Ok(())
1009}
1010
1011impl Config {
1012
1013    /// A displayable form for an address this builder will connect to
1014    pub fn display_addr<'x>(&'x self) -> impl fmt::Display + 'x {
1015        DisplayAddr(Some(&self.0.address))
1016    }
1017
1018    /// Connect with a custom certificate verifier.
1019    ///
1020    /// Unstable API
1021    #[cfg(feature="unstable")]
1022    pub async fn connect_with_cert_verifier(
1023        &self, cert_verifier: Arc<dyn ServerCertVerifier>
1024    ) -> Result<Connection, Error> {
1025        self._connect_with_cert_verifier(cert_verifier).await
1026    }
1027
1028    async fn _connect_with_cert_verifier(
1029        &self, cert_verifier: Arc<dyn ServerCertVerifier>
1030    ) -> Result<Connection, Error> {
1031        self.connect_inner(cert_verifier).await.map_err(|e| {
1032            if e.is::<ClientConnectionError>() {
1033                e.refine_kind::<ClientConnectionFailedError>()
1034            } else {
1035                e
1036            }
1037        })
1038    }
1039
1040    /// Get the path of the Unix socket if that is configured to be used.
1041    ///
1042    /// This is a deprecated API and should only be used by the command-line
1043    /// tool.
1044    #[cfg(feature="admin_socket")]
1045    pub fn get_unix_path(&self) -> Option<PathBuf> {
1046        self._get_unix_path().unwrap_or(None)
1047    }
1048    fn _get_unix_path(&self) -> Result<Option<PathBuf>, Error> {
1049        match &self.0.address {
1050            Address::Unix(path) => Ok(Some(path.clone())),
1051            Address::Tcp(_) => Ok(None),
1052        }
1053    }
1054
1055    async fn connect_inner(
1056        &self, cert_verifier: Arc<dyn ServerCertVerifier>
1057    ) -> Result<Connection, Error> {
1058        let tls = tls::connector(cert_verifier).map_err(tls_fail)?;
1059        if log::log_enabled!(log::Level::Info) {
1060            match &self.0.address {
1061                Address::Unix(path) => {
1062                    log::info!("Connecting via Unix `{}`", path.display());
1063                }
1064                Address::Tcp((host, port)) => {
1065                    log::info!("Connecting via TCP {host}:{port}");
1066                }
1067            }
1068        }
1069
1070        let start = Instant::now();
1071        let ref mut warned = false;
1072        let conn = loop {
1073            match
1074                timeout(self.0.connect_timeout,
1075                        self._connect(&tls, warned)).await
1076            {
1077                Err(e) if is_temporary(&e) => {
1078                    log::debug!("Temporary connection error: {:#}", e);
1079                    if self.0.wait > start.elapsed() {
1080                        sleep(sleep_duration()).await;
1081                        continue;
1082                    } else if self.0.wait > Duration::new(0, 0) {
1083                        return Err(e.context(
1084                            format!("cannot establish connection for {:?}",
1085                                    self.0.wait)));
1086                    } else {
1087                        return Err(e);
1088                    }
1089                }
1090                Err(e) => {
1091                    log::debug!("Connection error: {:#}", e);
1092                    return Err(e)?;
1093                }
1094                Ok(conn) => break conn,
1095            }
1096        };
1097        Ok(conn)
1098    }
1099
1100
1101    fn do_verify_hostname(&self) -> Option<bool> {
1102        use TlsSecurity::*;
1103        if self.0.insecure_dev_mode {
1104            return Some(false);
1105        }
1106        match self.0.tls_security {
1107            Insecure => Some(false),
1108            NoHostVerification => Some(false),
1109            Strict => Some(true),
1110            Default => None,
1111        }
1112    }
1113    /// Return a single connection.
1114    #[cfg(feature="unstable")]
1115    pub async fn connect(&self) -> Result<Connection, Error> {
1116        self.private_connect().await
1117    }
1118
1119    pub(crate) async fn private_connect(&self) -> Result<Connection, Error> {
1120        let verify_host = self.do_verify_hostname();
1121        match (&self.0.address, verify_host) {
1122            (Address::Tcp((host, _)), Some(true))
1123                if IpAddr::from_str(host).is_ok() => {
1124                    return Err(ClientError::with_message(
1125                        "Cannot use `verify_hostname` or system \
1126                        root certificates with an IP address"));
1127                }
1128            _ => {}
1129        }
1130        self._connect_with_cert_verifier(self.0.verifier.clone()).await
1131    }
1132    async fn _connect(&self, tls: &TlsConnectorBox, warned: &mut bool)
1133        -> Result<Connection, Error>
1134    {
1135        let stream = match self._connect_stream(tls).await {
1136            Err(e) if e.is::<ProtocolTlsError>() => {
1137                if !*warned {
1138                    log::warn!("TLS connection failed. \
1139                        Trying plaintext...");
1140                    *warned = true;
1141                }
1142                self._connect_stream(
1143                    &PlainConnector::builder()
1144                        .map_err(ClientError::with_source_ref)?
1145                        .build().map_err(ClientError::with_source_ref)?
1146                        .into_dyn()
1147                ).await?
1148            }
1149            Err(e) => return Err(e),
1150            Ok(r) => match r.get_alpn_protocol() {
1151                Ok(Some(protocol)) if protocol == b"edgedb-binary" => r,
1152                _ => match self._get_unix_path()? {
1153                    None => Err(ClientConnectionFailedError::with_message(
1154                        "Server does not support the EdgeDB binary protocol."
1155                    ))?,
1156                    Some(_) => r,  // don't check ALPN on UNIX stream
1157                }
1158            }
1159        };
1160        self._connect_with(stream).await
1161    }
1162
1163    async fn _connect_stream(&self, tls: &TlsConnectorBox)
1164        -> Result<TlsStream, Error>
1165    {
1166        match &self.0.address {
1167            Address::Tcp((host, port)) => {
1168                let conn = TcpStream::connect(&(&host[..], *port)).await
1169                    .map_err(ClientConnectionError::with_source)?;
1170                let is_valid_dns_name = DnsNameRef::try_from_ascii_str(host)
1171                    .is_ok();
1172                let host = if !is_valid_dns_name {
1173                    // FIXME: https://github.com/rustls/rustls/issues/184
1174                    // If self.host is neither an IP address nor a valid DNS
1175                    // name, the hacks below won't make it valid anyways.
1176                    let host = format!("{}.host-for-ip.edgedb.net", host);
1177                    // for ipv6addr
1178                    let host = host.replace(":", "-").replace("%", "-");
1179                    if host.starts_with("-") {
1180                        Cow::from(format!("i{}", host))
1181                    } else {
1182                        Cow::from(host)
1183                    }
1184                } else {
1185                    Cow::from(&host[..])
1186                };
1187                Ok(tls.connect(&host[..], conn).await.map_err(tls_fail)?)
1188            }
1189            Address::Unix(path) => {
1190                #[cfg(windows)] {
1191                    return Err(ClientError::with_message(
1192                        "Unix socket are not supported on windows",
1193                    ));
1194                }
1195                #[cfg(unix)] {
1196                    use async_std::os::unix::net::UnixStream;
1197                    let conn = UnixStream::connect(&path).await
1198                        .map_err(ClientConnectionError::with_source)?;
1199                    Ok(
1200                        PlainConnector::builder()
1201                            .map_err(ClientError::with_source_ref)?
1202                            .build().map_err(ClientError::with_source_ref)?
1203                            .into_dyn()
1204                        .connect("localhost", conn).await.map_err(tls_fail)?
1205                    )
1206                }
1207            }
1208        }
1209    }
1210
1211    async fn _connect_with(&self, stream: TlsStream)
1212        -> Result<Connection, Error>
1213    {
1214        let mut version = ProtocolVersion::current();
1215        let (input, output) = stream.split();
1216        let mut conn = Connection {
1217            ping_interval: PingInterval::Unknown,
1218            input,
1219            output,
1220            input_buf: BytesMut::with_capacity(8192),
1221            output_buf: BytesMut::with_capacity(8192),
1222            params: TypeMap::custom(),
1223            transaction_state: TransactionState::NotInTransaction,
1224            state: State::Normal {
1225                idle_since: Instant::now(),
1226            },
1227            version: version.clone(),
1228        };
1229        let mut seq = conn.start_sequence().await?;
1230        let mut params = HashMap::new();
1231        params.insert(String::from("user"), self.0.user.clone());
1232        params.insert(String::from("database"), self.0.database.clone());
1233
1234        let (major_ver, minor_ver) = version.version_tuple();
1235        seq.send_messages(&[
1236            ClientMessage::ClientHandshake(ClientHandshake {
1237                major_ver,
1238                minor_ver,
1239                params,
1240                extensions: HashMap::new(),
1241            }),
1242        ]).await?;
1243
1244        let mut msg = seq.message().await?;
1245        if let ServerMessage::ServerHandshake(ServerHandshake {
1246            major_ver, minor_ver, extensions: _
1247        }) = msg {
1248            version = ProtocolVersion::new(major_ver, minor_ver);
1249            // TODO(tailhook) record extensions
1250            msg = seq.message().await?;
1251        }
1252        match msg {
1253            ServerMessage::Authentication(Authentication::Ok) => {}
1254            ServerMessage::Authentication(Authentication::Sasl { methods })
1255            => {
1256                if methods.iter().any(|x| x == "SCRAM-SHA-256") {
1257                    if let Some(password) = &self.0.password {
1258                        scram(&mut seq, &self.0.user, password).await
1259                            .map_err(ClientError::with_source)?;
1260                    } else {
1261                        return Err(PasswordRequired::with_message(
1262                            "Password required for the specified user/host"));
1263                    }
1264                } else {
1265                    return Err(AuthenticationError::with_message(format!(
1266                        "No supported authentication \
1267                        methods: {:?}", methods)));
1268                }
1269            }
1270            ServerMessage::ErrorResponse(err) => {
1271                return Err(err.into());
1272            }
1273            msg => {
1274                return Err(ProtocolError::with_message(format!(
1275                    "Error authenticating, unexpected message {:?}", msg)));
1276            }
1277        }
1278
1279        let mut server_params = TypeMap::custom();
1280        loop {
1281            let msg = seq.message().await?;
1282            match msg {
1283                ServerMessage::ReadyForCommand(ready) => {
1284                    seq.reader.consume_ready(ready);
1285                    seq.end_clean();
1286                    break;
1287                }
1288                ServerMessage::ServerKeyData(_) => {
1289                    // TODO(tailhook) store it somehow?
1290                }
1291                ServerMessage::ParameterStatus(par) => {
1292                    match &par.name[..] {
1293                        b"pgaddr" => {
1294                            let pgaddr: PostgresAddress;
1295                            pgaddr = match from_slice(&par.value[..]) {
1296                                Ok(a) => a,
1297                                Err(e) => {
1298                                    log::warn!("Can't decode param {:?}: {}",
1299                                        par.name, e);
1300                                    continue;
1301                                }
1302                            };
1303                            server_params.insert::<PostgresAddress>(pgaddr);
1304                        }
1305                        b"system_config" => {
1306                            self.handle_system_config(par, &mut server_params)?;
1307                        }
1308                        _ => {}
1309                    }
1310                }
1311                _ => {
1312                    log::warn!("unsolicited message {:?}", msg);
1313                }
1314            }
1315        }
1316        conn.version = version;
1317        conn.params = server_params;
1318        conn.state = State::Normal {
1319            idle_since: Instant::now()
1320        };
1321        Ok(conn)
1322    }
1323
1324    fn handle_system_config(
1325        &self,
1326        param_status: ParameterStatus,
1327        server_params: &mut TypeMap<dyn DebugAny + Send + Sync>
1328    ) -> Result<(), Error> {
1329        let (typedesc, data) = param_status.parse_system_config()
1330            .map_err(ProtocolEncodingError::with_source)?;
1331        let codec = typedesc.build_codec()
1332            .map_err(ProtocolEncodingError::with_source)?;
1333        let system_config = codec.decode(data.as_ref())
1334            .map_err(ProtocolEncodingError::with_source)?;
1335        let mut config = SystemConfig {
1336            session_idle_timeout: None,
1337        };
1338        if let Value::Object { shape, fields } = system_config {
1339            for (el, field) in shape.elements.iter().zip(fields) {
1340                match el.name.as_str() {
1341                    "id" => {},
1342                    "session_idle_timeout" => {
1343                        config.session_idle_timeout = match field {
1344                            Some(Value::Duration(timeout)) =>
1345                                Some(timeout.abs_duration()),
1346                            _ => {
1347                                log::warn!(
1348                                    "Wrong protocol: {}={:?}", el.name, field
1349                                );
1350                                None
1351                            },
1352                        };
1353                    }
1354                    name => {
1355                        log::debug!(
1356                            "Unhandled system config: {}={:?}", name, field
1357                        );
1358                    }
1359                }
1360            }
1361            server_params.insert::<SystemConfig>(config);
1362        } else {
1363            log::warn!("Received empty system config message.");
1364        }
1365        Ok(())
1366    }
1367}
1368
1369async fn scram(seq: &mut Sequence<'_>, user: &str, password: &str)
1370    -> Result<(), Error>
1371{
1372    use edgedb_protocol::client_message::SaslInitialResponse;
1373    use edgedb_protocol::client_message::SaslResponse;
1374
1375    let scram = ScramClient::new(&user, &password, None);
1376
1377    let (scram, first) = scram.client_first();
1378    seq.send_messages(&[
1379        ClientMessage::AuthenticationSaslInitialResponse(
1380            SaslInitialResponse {
1381            method: "SCRAM-SHA-256".into(),
1382            data: Bytes::copy_from_slice(first.as_bytes()),
1383        }),
1384    ]).await?;
1385    let msg = seq.message().await?;
1386    let data = match msg {
1387        ServerMessage::Authentication(
1388            Authentication::SaslContinue { data }
1389        ) => data,
1390        ServerMessage::ErrorResponse(err) => {
1391            return Err(err.into());
1392        }
1393        msg => {
1394            return Err(ProtocolError::with_message(format!(
1395                "Bad auth response: {:?}", msg)));
1396        }
1397    };
1398    let data = str::from_utf8(&data[..])
1399        .map_err(|e| ProtocolError::with_source(e).context(
1400            "invalid utf-8 in SCRAM-SHA-256 auth"))?;
1401    let scram = scram.handle_server_first(&data)
1402        .map_err(AuthenticationError::with_source)?;
1403    let (scram, data) = scram.client_final();
1404    seq.send_messages(&[
1405        ClientMessage::AuthenticationSaslResponse(
1406            SaslResponse {
1407                data: Bytes::copy_from_slice(data.as_bytes()),
1408            }),
1409    ]).await?;
1410    let msg = seq.message().await?;
1411    let data = match msg {
1412        ServerMessage::Authentication(Authentication::SaslFinal { data })
1413        => data,
1414        ServerMessage::ErrorResponse(err) => {
1415            return Err(err.into());
1416        }
1417        msg => {
1418            return Err(ProtocolError::with_message(format!(
1419                "auth response: {:?}", msg)));
1420        }
1421    };
1422    let data = str::from_utf8(&data[..])
1423        .map_err(|_| ProtocolError::with_message(
1424            "invalid utf-8 in SCRAM-SHA-256 auth"))?;
1425    scram.handle_server_final(&data)
1426        .map_err(|e| AuthenticationError::with_message(format!(
1427            "Authentication error: {}", e)))?;
1428    loop {
1429        let msg = seq.message().await?;
1430        match msg {
1431            ServerMessage::Authentication(Authentication::Ok) => break,
1432            msg => {
1433                log::warn!("unsolicited message {:?}", msg);
1434            }
1435        };
1436    }
1437    Ok(())
1438}
1439
1440impl fmt::Debug for Config {
1441    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1442        f.debug_struct("Config")
1443            .field("address", &self.0.address)
1444            .field("max_connections", &self.0.max_connections)
1445            // TODO(tailhook) more fields
1446            .finish()
1447    }
1448}
1449
1450#[test]
1451fn read_credentials() {
1452    let mut bld = Builder::uninitialized();
1453    async_std::task::block_on(
1454        bld.read_credentials("tests/credentials1.json")).unwrap();
1455    assert!(matches!(&bld.address, Address::Tcp((_, 10702))));
1456    assert_eq!(&bld.user, "test3n");
1457    assert_eq!(&bld.database, "test3n");
1458    assert_eq!(bld.password, Some("lZTBy1RVCfOpBAOwSCwIyBIR".into()));
1459}
1460
1461#[test]
1462fn display() {
1463    let mut bld = Builder::uninitialized();
1464    async_std::task::block_on(
1465        bld.read_dsn("edgedb://localhost:1756")).unwrap();
1466    assert!(matches!(
1467        &bld.address,
1468        Address::Tcp((host, 1756)) if host == "localhost"
1469    ));
1470    /* TODO(tailhook)
1471    bld.unix_path("/test/my.sock");
1472    assert_eq!(bld.build().unwrap()._get_unix_path().unwrap(),
1473               Some("/test/my.sock/.s.EDGEDB.5656".into()));
1474    */
1475    #[cfg(feature="admin_socket")] {
1476        bld.unix_path("/test/.s.EDGEDB.8888", None, false);
1477        assert_eq!(bld.build().unwrap()._get_unix_path().unwrap(),
1478                   Some("/test/.s.EDGEDB.8888".into()));
1479        bld.unix_path("/test", Some(8888), false);
1480        assert_eq!(bld.build().unwrap()._get_unix_path().unwrap(),
1481                   Some("/test/.s.EDGEDB.8888".into()));
1482    }
1483}
1484
1485#[test]
1486fn from_dsn() {
1487    let mut bld = Builder::uninitialized();
1488    async_std::task::block_on(bld.read_dsn(
1489        "edgedb://user1:EiPhohl7@edb-0134.elb.us-east-2.amazonaws.com/db2"
1490    )).unwrap();
1491    assert!(matches!(
1492        &bld.address,
1493        Address::Tcp((host, 5656))
1494        if host == "edb-0134.elb.us-east-2.amazonaws.com",
1495    ));
1496    assert_eq!(&bld.user, "user1");
1497    assert_eq!(&bld.database, "db2");
1498    assert_eq!(bld.password, Some("EiPhohl7".into()));
1499
1500    let mut bld = Builder::uninitialized();
1501    async_std::task::block_on(bld.read_dsn(
1502        "edgedb://user2@edb-0134.elb.us-east-2.amazonaws.com:1756/db2"
1503    )).unwrap();
1504    assert!(matches!(
1505        &bld.address,
1506        Address::Tcp((host, 1756))
1507        if host == "edb-0134.elb.us-east-2.amazonaws.com",
1508    ));
1509    assert_eq!(&bld.user, "user2");
1510    assert_eq!(&bld.database, "db2");
1511    assert_eq!(bld.password, None);
1512
1513    // Tests overriding
1514    async_std::task::block_on(bld.read_dsn(
1515        "edgedb://edb-0134.elb.us-east-2.amazonaws.com:1756"
1516    )).unwrap();
1517    assert!(matches!(
1518        &bld.address,
1519        Address::Tcp((host, 1756))
1520        if host == "edb-0134.elb.us-east-2.amazonaws.com",
1521    ));
1522    assert_eq!(&bld.user, "edgedb");
1523    assert_eq!(&bld.database, "edgedb");
1524    assert_eq!(bld.password, None);
1525
1526    async_std::task::block_on(bld.read_dsn(
1527        "edgedb://user3:123123@[::1]:5555/abcdef"
1528    )).unwrap();
1529    assert!(matches!(
1530        &bld.address,
1531        Address::Tcp((host, 5555)) if host == "::1",
1532    ));
1533    assert_eq!(&bld.user, "user3");
1534    assert_eq!(&bld.database, "abcdef");
1535    assert_eq!(bld.password, Some("123123".into()));
1536}
1537
1538#[test]
1539#[should_panic]  // servo/rust-url#424
1540fn from_dsn_ipv6_scoped_address() {
1541    let mut bld = Builder::uninitialized();
1542    async_std::task::block_on(bld.read_dsn(
1543        "edgedb://user3@[fe80::1ff:fe23:4567:890a%25eth0]:3000/ab"
1544    )).unwrap();
1545    assert!(matches!(
1546        &bld.address,
1547        Address::Tcp((host, 3000)) if host == "fe80::1ff:fe23:4567:890a%eth0",
1548    ));
1549    assert_eq!(&bld.user, "user3");
1550    assert_eq!(&bld.database, "ab");
1551    assert_eq!(bld.password, None);
1552}