1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::env;
4use std::error::Error as _;
5use std::ffi::{OsString, OsStr};
6use std::fmt;
7use std::io;
8use std::net::IpAddr;
9use std::path::{Path, PathBuf};
10use std::str::{self, FromStr};
11use std::sync::Arc;
12use std::time::{Instant, Duration};
13
14use async_std::fs;
15use async_std::future::Future;
16use async_std::net::TcpStream;
17use async_std::path::{Path as AsyncPath};
18use async_std::task::sleep;
19use bytes::{Bytes, BytesMut};
20use futures_util::AsyncReadExt;
21use rand::{thread_rng, Rng};
22use rustls::client::ServerCertVerifier;
23use scram::ScramClient;
24use serde_json::from_slice;
25use sha1::Digest;
26use tls_api::{TlsConnectorBox, TlsConnector as _, TlsConnectorBuilder as _};
27use tls_api::{TlsStream, TlsStreamDyn as _};
28use tls_api_not_tls::TlsConnector as PlainConnector;
29use typemap::{TypeMap, DebugAny};
30use webpki::DnsNameRef;
31
32use edgedb_protocol::client_message::{ClientMessage, ClientHandshake};
33use edgedb_protocol::features::ProtocolVersion;
34use edgedb_protocol::server_message::{ServerMessage, Authentication};
35use edgedb_protocol::server_message::{TransactionState, ServerHandshake};
36use edgedb_protocol::server_message::ParameterStatus;
37use edgedb_protocol::value::Value;
38
39use crate::client::{Connection, Sequence, State, PingInterval};
40use crate::credentials::{Credentials, TlsSecurity};
41use crate::errors::{ClientConnectionError, ProtocolError, ProtocolTlsError};
42use crate::errors::{ClientConnectionFailedError, AuthenticationError};
43use crate::errors::{ClientError, ClientConnectionFailedTemporarilyError};
44use crate::errors::{ClientNoCredentialsError, ProtocolEncodingError};
45use crate::errors::{Error, ErrorKind, PasswordRequired, ResultExt};
46use crate::server_params::{PostgresAddress, SystemConfig};
47use crate::tls;
48
49pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
50pub const DEFAULT_WAIT: Duration = Duration::from_secs(30);
51pub const DEFAULT_POOL_SIZE: usize = 10;
52pub const DEFAULT_HOST: &str = "localhost";
53pub const DEFAULT_PORT: u16 = 5656;
54
55type Verifier = Arc<dyn ServerCertVerifier>;
56
57#[derive(Debug, Clone)]
59pub struct Builder {
60 address: Address,
61 admin: bool,
62 user: String,
63 password: Option<String>,
64 database: String,
65 pem: Option<String>,
66 tls_security: TlsSecurity,
67 instance_name: Option<String>,
68
69 initialized: bool,
70 wait: Duration,
71 connect_timeout: Duration,
72 insecure_dev_mode: bool,
73 creds_file_outdated: bool,
74
75 pub(crate) max_connections: usize,
77}
78#[derive(Clone)]
82pub struct Config(pub(crate) Arc<ConfigInner>);
83
84pub(crate) struct ConfigInner {
85 pub address: Address,
86 #[allow(dead_code)] pub admin: bool,
88 pub user: String,
89 pub password: Option<String>,
90 pub database: String,
91 pub verifier: Arc<dyn ServerCertVerifier>,
92 #[allow(dead_code)] pub instance_name: Option<String>,
94 pub wait: Duration,
95 pub connect_timeout: Duration,
96 pub tls_security: TlsSecurity,
97 #[allow(dead_code)] pub insecure_dev_mode: bool,
99
100 pub max_connections: usize,
102}
103
104#[derive(Debug, Clone)]
105pub(crate) enum Address {
106 Tcp((String, u16)),
107 #[allow(dead_code)] Unix(PathBuf),
109}
110
111struct DisplayAddr<'a>(Option<&'a Address>);
112
113pub async fn timeout<F, T>(dur: Duration, f: F) -> Result<T, Error>
114 where F: Future<Output = Result<T, Error>>,
115{
116 use async_std::future::timeout;
117
118 timeout(dur, f).await
119 .unwrap_or_else(|_| {
120 Err(ClientConnectionFailedTemporarilyError::with_source(
121 io::Error::from(io::ErrorKind::TimedOut)
122 ))
123 })
124}
125
126fn sleep_duration() -> Duration {
127 Duration::from_millis(thread_rng().gen_range(10u64..200u64))
128}
129
130fn is_temporary(e: &Error) -> bool {
131 use io::ErrorKind::{ConnectionRefused, TimedOut, NotFound};
132 use io::ErrorKind::{ConnectionAborted, ConnectionReset, UnexpectedEof};
133 use io::ErrorKind::{AddrNotAvailable};
134
135 if e.is::<ClientConnectionFailedTemporarilyError>() {
136 return true;
137 }
138 if e.is::<ClientConnectionError>() {
139 let io_err = e.source().and_then(|src| {
140 src.downcast_ref::<io::Error>()
141 .or_else(|| src.downcast_ref::<Box<io::Error>>().map(|b| &**b))
142 });
143 if let Some(e) = io_err {
144 match e.kind() {
145 | ConnectionRefused
146 | ConnectionReset
147 | ConnectionAborted
148 | NotFound | TimedOut
150 | UnexpectedEof | AddrNotAvailable => return true,
153 _ => {},
154 }
155 }
156 }
157 return false;
158}
159
160fn tls_fail(e: anyhow::Error) -> Error {
161 if let Some(e) = e.downcast_ref::<rustls::Error>() {
162 if matches!(e, rustls::Error::CorruptMessage) {
163 return ProtocolTlsError::with_message(
164 "corrupt message, possibly server \
165 does not support TLS connection."
166 );
167 }
168 }
169 ClientConnectionError::with_source_ref(e)
170}
171
172fn get_env(name: &str) -> Result<Option<String>, Error> {
173 match env::var(name) {
174 Ok(v) if v.is_empty() => Ok(None),
175 Ok(v) => Ok(Some(v)),
176 Err(env::VarError::NotPresent) => Ok(None),
177 Err(e) => {
178 Err(
179 ClientError::with_source(e)
180 .context(
181 format!("Cannot decode environment variable {:?}", name))
182 )
183 }
184 }
185}
186
187fn get_port_env() -> Result<Option<String>, Error> {
188 static PORT_WARN: std::sync::Once = std::sync::Once::new();
189
190 let port = get_env("EDGEDB_PORT")?;
191 if let Some(port) = &port {
192 if port.starts_with("tcp://") {
194
195 PORT_WARN.call_once(|| {
196 log::warn!("Environment variable `EDGEDB_PORT` contains \
197 docker-link-like definition. Ingoring...");
198 });
199
200 return Ok(None);
201 }
202 }
203 Ok(port)
204}
205
206fn get_host_port() -> Result<Option<(Option<String>, Option<u16>)>, Error> {
207 let host = get_env("EDGEDB_HOST")?;
208 let port = get_port_env()?.map(|port| {
209 port.parse().map_err(|e| {
210 ClientError::with_source(e)
211 .context("cannot parse env var EDGEDB_PORT")
212 })
213 }).transpose()?;
214 if host.is_some() || port.is_some() {
215 Ok(Some((host, port)))
216 } else {
217 Ok(None)
218 }
219}
220
221pub async fn search_dir(base: &AsyncPath) -> Result<Option<&AsyncPath>, Error>
222{
223 let mut path = base;
224 if path.join("edgedb.toml").exists().await {
225 return Ok(Some(path.into()));
226 }
227 while let Some(parent) = path.parent() {
228 if parent.join("edgedb.toml").exists().await {
229 return Ok(Some(parent.into()));
230 }
231 path = parent;
232 }
233 Ok(None)
234}
235
236#[cfg(unix)]
237fn path_bytes<'x>(path: &'x Path) -> &'x [u8] {
238 use std::os::unix::ffi::OsStrExt;
239 path.as_os_str().as_bytes()
240}
241
242#[cfg(windows)]
243fn path_bytes<'x>(path: &'x Path) -> &'x [u8] {
244 path.to_str().expect("windows paths are always valid UTF-16").as_bytes()
245}
246
247fn hash(path: &Path) -> String {
248 format!("{:x}", sha1::Sha1::new_with_prefix(path_bytes(path)).finalize())
249}
250
251fn stash_name(path: &Path) -> OsString {
252 let hash = hash(path);
253 let base = path.file_name().unwrap_or(OsStr::new(""));
254 let mut base = base.to_os_string();
255 base.push("-");
256 base.push(&hash);
257 return base;
258}
259
260fn config_dir() -> Result<PathBuf, Error> {
261 let dir = if cfg!(windows) {
262 dirs::data_local_dir()
263 .ok_or_else(|| ClientError::with_message(
264 "cannot determine local data directory"))?
265 .join("EdgeDB")
266 .join("config")
267 } else {
268 dirs::config_dir()
269 .ok_or_else(|| ClientError::with_message(
270 "cannot determine config directory"))?
271 .join("edgedb")
272 };
273 Ok(dir)
274}
275
276#[allow(dead_code)]
277#[cfg(target_os="linux")]
278fn default_runtime_base() -> Result<PathBuf, Error> {
279 extern "C" {
280 fn geteuid() -> u32;
281 }
282 Ok(Path::new("/run/user").join(unsafe { geteuid() }.to_string()))
283}
284
285#[allow(dead_code)]
286#[cfg(not(target_os="linux"))]
287fn default_runtime_base() -> Result<PathBuf, Error> {
288 Err(ClientError::with_message("no default runtime dir for the platform"))
289}
290
291fn stash_path(project_dir: &Path) -> Result<PathBuf, Error> {
292 Ok(config_dir()?.join("projects").join(stash_name(project_dir)))
293}
294
295fn is_valid_instance_name(name: &str) -> bool {
296 let mut chars = name.chars();
297 match chars.next() {
298 Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
299 _ => return false,
300 }
301 for c in chars {
302 if !c.is_ascii_alphanumeric() && c != '_' {
303 return false;
304 }
305 }
306 return true;
307}
308
309
310impl fmt::Display for DisplayAddr<'_> {
311 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
312 match &self.0 {
313 Some(Address::Tcp((host, port))) => {
314 write!(f, "{}:{}", host, port)
315 }
316 Some(Address::Unix(path)) => write!(f, "unix:{}", path.display()),
317 None => write!(f, "<no address>"),
318 }
319 }
320}
321
322impl Builder {
323
324 pub async fn from_env() -> Result<Builder, Error> {
326 let mut builder = Builder::uninitialized();
327
328 if get_env("EDGEDB_HOST")?.is_none() &&
330 get_port_env()?.is_none() &&
331 get_env("EDGEDB_INSTANCE")?.is_none() &&
332 get_env("EDGEDB_DSN")?.is_none() &&
333 get_env("EDGEDB_CONFIGURATION_FILE")?.is_none()
334 {
335 builder.read_project(None, false).await?;
336 }
337
338 builder.read_env_vars().await?;
339 Ok(builder)
340 }
341
342 pub async fn read_project(&mut self,
355 override_dir: Option<&Path>, search_parents: bool)
356 -> Result<&mut Self, Error>
357 {
358 let dir = match override_dir {
359 Some(v) => Cow::Borrowed(v.as_ref()),
360 None => {
361 Cow::Owned(env::current_dir()
362 .map_err(|e| ClientError::with_source(e)
363 .context("failed to get current directory"))?
364 .into())
365 }
366 };
367
368 let dir = if search_parents {
369 if let Some(ancestor) = search_dir(&dir).await? {
370 Cow::Borrowed(ancestor)
371 } else {
372 return Ok(self);
373 }
374 } else {
375 if !dir.join("edgedb.toml").exists().await {
376 return Ok(self);
377 }
378 dir
379 };
380 let canon = fs::canonicalize(&dir).await
381 .map_err(|e| ClientError::with_source(e).context(
382 format!("failed to canonicalize dir {:?}", dir)
383 ))?;
384 let stash_path = stash_path(canon.as_ref())?;
385 if AsRef::<AsyncPath>::as_ref(&stash_path).exists().await {
386 let instance =
387 fs::read_to_string(stash_path.join("instance-name")).await
388 .map_err(|e| ClientError::with_source(e).context(
389 format!("error reading project settings {:?}", dir)
390 ))?;
391 self.read_instance(instance.trim()).await?;
392
393 }
394 Ok(self)
395 }
396
397 pub fn is_initialized(&self) -> bool {
399 self.initialized
400 }
401 pub async fn read_env_vars(&mut self) -> Result<&mut Self, Error> {
425 if let Some((host, port)) = get_host_port()? {
426 self.host_port(host, port);
427 } else if let Some(path) = get_env("EDGEDB_CREDENTIALS_FILE")? {
428 self.read_credentials(path).await?;
429 } else if let Some(instance) = get_env("EDGEDB_INSTANCE")? {
430 self.read_instance(&instance).await?;
431 } else if let Some(dsn) = get_env("EDGEDB_DSN")? {
432 self.read_dsn(&dsn).await.map_err(|e|
433 e.context("cannot parse env var EDGEDB_DNS"))?;
434 }
435 if let Some(database) = get_env("EDGEDB_DATABASE")? {
436 self.database = database;
437 }
438 if let Some(user) = get_env("EDGEDB_USER")? {
439 self.user = user;
440 }
441 if let Some(password) = get_env("EDGEDB_PASSWORD")? {
442 self.password = Some(password);
443 }
444 if let Some(sec) = get_env("EDGEDB_CLIENT_TLS_SECURITY")? {
445 self.tls_security = match &sec[..] {
446 "default" => TlsSecurity::Default,
447 "insecure" => TlsSecurity::Insecure,
448 "no_host_verification" => TlsSecurity::NoHostVerification,
449 "strict" => TlsSecurity::Strict,
450 _ => {
451 return Err(ClientError::with_message(
452 format!("Invalid value {:?} for env var \
453 EDGEDB_CLIENT_TLS_SECURITY. \
454 Options: default, insecure, \
455 no_host_verification, strict.",
456 sec)
457 ));
458 }
459 };
460 }
461 self.read_extra_env_vars()?;
462 Ok(self)
463 }
464 pub fn read_extra_env_vars(&mut self) -> Result<&mut Self, Error> {
466 if let Some(mode) = get_env("EDGEDB_CLIENT_SECURITY")? {
467 self.insecure_dev_mode = match &mode[..] {
468 "default" => false,
469 "insecure_dev_mode" => true,
470 _ => {
471 return Err(ClientError::with_message(
472 format!("Invalid value {:?} for env var \
473 EDGEDB_CLIENT_SECURITY. \
474 Options: default, insecure_dev_mode.",
475 mode)
476 ));
477 }
478 };
479 }
480 Ok(self)
481 }
482
483 pub fn credentials(&mut self, credentials: &Credentials)
487 -> Result<&mut Self, Error>
488 {
489 if let Some(cert_data) = &credentials.tls_ca {
490 validate_certs(&cert_data)
491 .context("invalid certificates in `tls_ca`")?;
492 }
493 self.reset_compound();
494 self.address = Address::Tcp((
495 credentials.host.clone()
496 .unwrap_or_else(|| DEFAULT_HOST.into()),
497 credentials.port,
498 ));
499 self.admin = false;
500 self.user = credentials.user.clone();
501 self.password = credentials.password.clone();
502 self.database = credentials.database.clone()
503 .unwrap_or_else(|| "edgedb".into());
504 self.creds_file_outdated = credentials.file_outdated;
505 self.tls_security = credentials.tls_security;
506 self.pem = credentials.tls_ca.clone();
507 self.initialized = true;
508 Ok(self)
509 }
510
511 #[cfg(feature="unstable")]
513 pub fn get_instance_name_for_creds_update(&self) -> Option<&str> {
514 if self.creds_file_outdated {
515 self.instance_name.as_deref()
516 } else {
517 None
518 }
519 }
520
521 pub async fn read_instance(&mut self, name: &str)
539 -> Result<&mut Self, Error>
540 {
541 if !is_valid_instance_name(name) {
542 return Err(ClientError::with_message(format!(
543 "instance name {:?} contains unsupported characters", name)));
544 }
545 self.read_credentials(
546 config_dir()?.join("credentials").join(format!("{}.json", name))
547 ).await?;
548 self.instance_name = Some(name.into());
549 Ok(self)
550 }
551
552 pub async fn read_credentials(&mut self, path: impl AsRef<Path>)
558 -> Result<&mut Self, Error>
559 {
560 let path = path.as_ref();
561 async {
562 let data = fs::read(path).await
563 .map_err(ClientError::with_source)?;
564 let creds = serde_json::from_slice(&data)
565 .map_err(ClientError::with_source)?;
566 self.credentials(&creds)?;
567 Ok(())
568 }.await.map_err(|e: Error| e.context(
569 format!("cannot read credentials file {}", path.display())
570 ))?;
571 Ok(self)
572 }
573
574 pub async fn read_dsn(&mut self, dsn: &str) -> Result<&mut Self, Error> {
589 let admin = dsn.starts_with("edgedbadmin://");
590 if !dsn.starts_with("edgedb://") && !admin {
591 return Err(ClientError::with_message(format!(
592 "String {:?} is not a valid DSN", dsn)));
593 };
594 let url = url::Url::parse(dsn)
595 .map_err(|e| ClientError::with_source(e)
596 .context(format!("cannot parse DSN {:?}", dsn)))?;
597 self.reset_compound();
598 let host = if let Some(url::Host::Ipv6(host)) = url.host() {
599 host.to_string()
601 } else {
602 url.host_str().unwrap_or(DEFAULT_HOST).to_owned()
603 };
604 let port = url.port().unwrap_or(DEFAULT_PORT);
605 self.address = Address::Tcp((host, port));
606 self.admin = admin;
607 self.user = if url.username().is_empty() {
608 "edgedb".to_owned()
609 } else {
610 url.username().to_owned()
611 };
612 self.password = url.password().map(|s| s.to_owned());
613 self.database = url.path().strip_prefix("/")
614 .unwrap_or("edgedb").to_owned();
615 self.initialized = true;
616 Ok(self)
617 }
618 pub fn uninitialized() -> Builder {
626 Builder {
627 address: Address::Tcp((DEFAULT_HOST.into(), DEFAULT_PORT)),
628 admin: false,
629 user: "edgedb".into(),
630 password: None,
631 database: "edgedb".into(),
632 tls_security: TlsSecurity::Default,
633 pem: None,
634 instance_name: None,
635
636 wait: DEFAULT_WAIT,
637 connect_timeout: DEFAULT_CONNECT_TIMEOUT,
638 initialized: false,
639 insecure_dev_mode: false,
640 creds_file_outdated: false,
641
642 max_connections: DEFAULT_POOL_SIZE,
643 }
644 }
645 fn reset_compound(&mut self) {
646 *self = Builder {
647 address: Address::Tcp((DEFAULT_HOST.into(), DEFAULT_PORT)),
649 admin: false,
650 user: "edgedb".into(),
651 password: None,
652 database: "edgedb".into(),
653 tls_security: TlsSecurity::Default,
654 pem: None,
655 instance_name: None,
656
657 initialized: false,
658 wait: self.wait,
660 connect_timeout: self.connect_timeout,
661 insecure_dev_mode: self.insecure_dev_mode,
662 creds_file_outdated: false,
663
664 max_connections: self.max_connections,
665 };
666 }
667 pub fn as_credentials(&self) -> Result<Credentials, Error> {
669 let (host, port) = match &self.address {
670 Address::Tcp(pair) => pair,
671 Address::Unix(_) => {
672 return Err(ClientError::with_message(
673 "Unix socket address cannot \
674 be saved as credentials file"));
675 }
676 };
677 Ok(Credentials {
678 host: Some(host.clone()),
679 port: *port,
680 user: self.user.clone(),
681 password: self.password.clone(),
682 database: Some( self.database.clone()),
683 tls_ca: self.pem.clone(),
684 tls_security: self.tls_security,
685 file_outdated: false,
686 cloud_instance_id: None,
687 cloud_original_dsn: None,
688 })
689 }
690 pub fn get_host(&self) -> &str {
695 match &self.address {
696 Address::Tcp((host, _)) => host,
697 Address::Unix(_) => "localhost",
698 }
699 }
700 pub fn get_port(&self) -> u16 {
702 match &self.address {
703 Address::Tcp((_, port)) => *port,
704 Address::Unix(_) => 5656
705 }
706 }
707 pub fn host_port(&mut self,
716 host: Option<impl Into<String>>, port: Option<u16>)
717 -> &mut Self
718 {
719 self.reset_compound();
720 self.address = Address::Tcp((
721 host.map_or_else(|| DEFAULT_HOST.into(), |h| h.into()),
722 port.unwrap_or(DEFAULT_PORT),
723 ));
724 self.initialized = true;
725 self
726 }
727
728 #[cfg(feature="admin_socket")]
729 pub fn admin(&mut self) -> Result<&mut Self, Error> {
731 let prefix = if let Some(name) = &self.instance_name {
732 if cfg!(windows) {
733 return Err(ClientError::with_message(
734 "unix sockets are not supported on Windows"));
735 } else if let Some(dir) = dirs::runtime_dir() {
736 dir.join(format!("edgedb-{}", name))
737 } else {
738 dirs::cache_dir()
739 .ok_or_else(|| ClientError::with_message(
740 "cannot determine cache directory"))?
741 .join("edgedb")
742 .join("run")
743 .join(name)
744 }
745 } else {
746 if cfg!(target_os="macos") {
747 "/var/run/edgedb".into()
748 } else {
749 "/run/edgedb".into()
750 }
751 };
752 match self.address {
753 Address::Tcp((_, port)) => {
754 self.address = Address::Unix(
755 prefix.join(format!(".s.EDGEDB.admin.{}", port))
756 );
757 }
758 Address::Unix(_) => {},
759 }
760 Ok(self)
761 }
762
763 #[cfg(feature="admin_socket")]
764 pub fn unix_path(&mut self, path: impl Into<PathBuf>,
766 port: Option<u16>, admin: bool)
767 -> &mut Self
768 {
769 self.reset_compound();
770 self.admin = admin;
771 let path = path.into();
772 let has_socket_name = path.file_name()
773 .and_then(|x| x.to_str())
774 .map(|x| x.contains(".s.EDGEDB"))
775 .unwrap_or(false);
776 let path = if has_socket_name {
777 path
779 } else {
780 let port = port.unwrap_or(5656);
781 let socket_name = if admin {
782 format!(".s.EDGEDB.admin.{}", port)
783 } else {
784 format!(".s.EDGEDB.{}", port)
785 };
786 path.join(socket_name)
787 };
788 self.address = Address::Unix(path.into());
790 self.initialized = true;
791 self
792 }
793 pub fn get_user(&self) -> &str {
795 &self.user
796 }
797 pub fn user(&mut self, user: impl Into<String>) -> &mut Self {
799 self.user = user.into();
800 self
801 }
802 pub fn password(&mut self, password: impl Into<String>) -> &mut Self {
804 self.password = Some(password.into());
805 self
806 }
807 pub fn database(&mut self, database: impl Into<String>) -> &mut Self {
809 self.database = database.into();
810 self
811 }
812 pub fn get_database(&self) -> &str {
814 &self.database
815 }
816 pub fn wait_until_available(&mut self, time: Duration) -> &mut Self {
825 self.wait = time;
826 self
827 }
828 pub fn connect_timeout(&mut self, timeout: Duration) -> &mut Self {
845 self.connect_timeout = timeout;
846 self
847 }
848
849 pub fn pem_certificates(&mut self, cert_data: &String)
851 -> Result<&mut Self, Error>
852 {
853 validate_certs(cert_data).context("invalid PEM certificate")?;
854 self.pem = Some(cert_data.clone());
855 Ok(self)
856 }
857
858 pub fn tls_security(&mut self, value: TlsSecurity) -> &mut Self {
864 self.tls_security = value;
865 self
866 }
867
868 pub fn insecure_dev_mode(&mut self, value: bool) -> &mut Self {
872 self.insecure_dev_mode = value;
873 self
874 }
875
876 pub fn display_addr<'x>(&'x self) -> impl fmt::Display + 'x {
878 if self.initialized {
879 DisplayAddr(Some(&self.address))
880 } else {
881 DisplayAddr(None)
882 }
883 }
884
885 fn insecure(&self) -> bool {
886 use TlsSecurity::Insecure;
887 self.insecure_dev_mode || self.tls_security == Insecure
888 }
889
890 fn trust_anchors(&self) -> Result<Vec<tls::OwnedTrustAnchor>, Error> {
891 tls::OwnedTrustAnchor::read_all(
892 self.pem.as_deref().unwrap_or("")
893 ).map_err(ClientError::with_source_ref)
894 }
895
896 #[cfg(feature="unstable")]
897 pub fn root_cert_store(&self) -> Result<rustls::RootCertStore, Error> {
899 self._root_cert_store()
900 }
901
902 fn _root_cert_store(&self) -> Result<rustls::RootCertStore, Error> {
903 let mut roots = rustls::RootCertStore::empty();
904 if self.pem.is_some() {
905 roots.add_server_trust_anchors(
906 self.trust_anchors()?.into_iter().map(Into::into)
907 );
908 } else {
909 roots.add_server_trust_anchors(
910 webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
911 rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
912 ta.subject,
913 ta.spki,
914 ta.name_constraints,
915 )
916 })
917 );
918 }
919 Ok(roots)
920 }
921
922 pub fn build(&self) -> Result<Config, Error> {
924 use TlsSecurity::*;
925
926 if !self.initialized {
927 return Err(ClientNoCredentialsError::with_message(
928 "EdgeDB connection options are not initialized. \
929 Run `edgedb project init` or use environment variables \
930 to configure connection."));
931 }
932 let verifier = match self.tls_security {
933 _ if self.insecure() => Arc::new(tls::NullVerifier) as Verifier,
934 Insecure => Arc::new(tls::NullVerifier) as Verifier,
935 NoHostVerification => {
936 Arc::new(tls::NoHostnameVerifier::new(
937 self.trust_anchors()?
938 )) as Verifier
939 }
940 Strict => {
941 Arc::new(rustls::client::WebPkiVerifier::new(
942 self._root_cert_store()?,
943 None,
944 )) as Verifier
945 }
946 Default => match self.pem {
947 Some(_) => {
948 Arc::new(tls::NoHostnameVerifier::new(
949 self.trust_anchors()?
950 )) as Verifier
951 }
952 None => {
953 Arc::new(rustls::client::WebPkiVerifier::new(
954 self._root_cert_store()?,
955 None,
956 )) as Verifier
957 }
958 },
959 };
960
961 Ok(Config(Arc::new(ConfigInner {
962 address: self.address.clone(),
963 admin: self.admin,
964 user: self.user.clone(),
965 password: self.password.clone(),
966 database: self.database.clone(),
967 verifier,
968 instance_name: self.instance_name.clone(),
969 wait: self.wait,
970 connect_timeout: self.connect_timeout,
971 tls_security: self.tls_security,
972 insecure_dev_mode: self.insecure_dev_mode,
973
974 max_connections: self.max_connections,
976 })))
977 }
978
979 pub fn max_connections(&mut self, value: usize) -> &mut Self {
981 self.max_connections = value;
982 self
983 }
984
985 #[cfg(feature="admin_socket")]
990 pub fn get_unix_path(&self) -> Option<PathBuf> {
991 self._get_unix_path().unwrap_or(None)
992 }
993 fn _get_unix_path(&self) -> Result<Option<PathBuf>, Error> {
994 match &self.address {
995 Address::Unix(path) => Ok(Some(path.clone())),
996 Address::Tcp(_) => Ok(None),
997 }
998 }
999}
1000
1001fn validate_certs(data: &str) -> Result<(), Error> {
1002 let anchors = tls::OwnedTrustAnchor::read_all(data)
1003 .map_err(|e| ClientError::with_source_ref(e))?;
1004 if anchors.is_empty() {
1005 return Err(ClientError::with_message(
1006 "PEM data contains no certificate"));
1007 }
1008 Ok(())
1009}
1010
1011impl Config {
1012
1013 pub fn display_addr<'x>(&'x self) -> impl fmt::Display + 'x {
1015 DisplayAddr(Some(&self.0.address))
1016 }
1017
1018 #[cfg(feature="unstable")]
1022 pub async fn connect_with_cert_verifier(
1023 &self, cert_verifier: Arc<dyn ServerCertVerifier>
1024 ) -> Result<Connection, Error> {
1025 self._connect_with_cert_verifier(cert_verifier).await
1026 }
1027
1028 async fn _connect_with_cert_verifier(
1029 &self, cert_verifier: Arc<dyn ServerCertVerifier>
1030 ) -> Result<Connection, Error> {
1031 self.connect_inner(cert_verifier).await.map_err(|e| {
1032 if e.is::<ClientConnectionError>() {
1033 e.refine_kind::<ClientConnectionFailedError>()
1034 } else {
1035 e
1036 }
1037 })
1038 }
1039
1040 #[cfg(feature="admin_socket")]
1045 pub fn get_unix_path(&self) -> Option<PathBuf> {
1046 self._get_unix_path().unwrap_or(None)
1047 }
1048 fn _get_unix_path(&self) -> Result<Option<PathBuf>, Error> {
1049 match &self.0.address {
1050 Address::Unix(path) => Ok(Some(path.clone())),
1051 Address::Tcp(_) => Ok(None),
1052 }
1053 }
1054
1055 async fn connect_inner(
1056 &self, cert_verifier: Arc<dyn ServerCertVerifier>
1057 ) -> Result<Connection, Error> {
1058 let tls = tls::connector(cert_verifier).map_err(tls_fail)?;
1059 if log::log_enabled!(log::Level::Info) {
1060 match &self.0.address {
1061 Address::Unix(path) => {
1062 log::info!("Connecting via Unix `{}`", path.display());
1063 }
1064 Address::Tcp((host, port)) => {
1065 log::info!("Connecting via TCP {host}:{port}");
1066 }
1067 }
1068 }
1069
1070 let start = Instant::now();
1071 let ref mut warned = false;
1072 let conn = loop {
1073 match
1074 timeout(self.0.connect_timeout,
1075 self._connect(&tls, warned)).await
1076 {
1077 Err(e) if is_temporary(&e) => {
1078 log::debug!("Temporary connection error: {:#}", e);
1079 if self.0.wait > start.elapsed() {
1080 sleep(sleep_duration()).await;
1081 continue;
1082 } else if self.0.wait > Duration::new(0, 0) {
1083 return Err(e.context(
1084 format!("cannot establish connection for {:?}",
1085 self.0.wait)));
1086 } else {
1087 return Err(e);
1088 }
1089 }
1090 Err(e) => {
1091 log::debug!("Connection error: {:#}", e);
1092 return Err(e)?;
1093 }
1094 Ok(conn) => break conn,
1095 }
1096 };
1097 Ok(conn)
1098 }
1099
1100
1101 fn do_verify_hostname(&self) -> Option<bool> {
1102 use TlsSecurity::*;
1103 if self.0.insecure_dev_mode {
1104 return Some(false);
1105 }
1106 match self.0.tls_security {
1107 Insecure => Some(false),
1108 NoHostVerification => Some(false),
1109 Strict => Some(true),
1110 Default => None,
1111 }
1112 }
1113 #[cfg(feature="unstable")]
1115 pub async fn connect(&self) -> Result<Connection, Error> {
1116 self.private_connect().await
1117 }
1118
1119 pub(crate) async fn private_connect(&self) -> Result<Connection, Error> {
1120 let verify_host = self.do_verify_hostname();
1121 match (&self.0.address, verify_host) {
1122 (Address::Tcp((host, _)), Some(true))
1123 if IpAddr::from_str(host).is_ok() => {
1124 return Err(ClientError::with_message(
1125 "Cannot use `verify_hostname` or system \
1126 root certificates with an IP address"));
1127 }
1128 _ => {}
1129 }
1130 self._connect_with_cert_verifier(self.0.verifier.clone()).await
1131 }
1132 async fn _connect(&self, tls: &TlsConnectorBox, warned: &mut bool)
1133 -> Result<Connection, Error>
1134 {
1135 let stream = match self._connect_stream(tls).await {
1136 Err(e) if e.is::<ProtocolTlsError>() => {
1137 if !*warned {
1138 log::warn!("TLS connection failed. \
1139 Trying plaintext...");
1140 *warned = true;
1141 }
1142 self._connect_stream(
1143 &PlainConnector::builder()
1144 .map_err(ClientError::with_source_ref)?
1145 .build().map_err(ClientError::with_source_ref)?
1146 .into_dyn()
1147 ).await?
1148 }
1149 Err(e) => return Err(e),
1150 Ok(r) => match r.get_alpn_protocol() {
1151 Ok(Some(protocol)) if protocol == b"edgedb-binary" => r,
1152 _ => match self._get_unix_path()? {
1153 None => Err(ClientConnectionFailedError::with_message(
1154 "Server does not support the EdgeDB binary protocol."
1155 ))?,
1156 Some(_) => r, }
1158 }
1159 };
1160 self._connect_with(stream).await
1161 }
1162
1163 async fn _connect_stream(&self, tls: &TlsConnectorBox)
1164 -> Result<TlsStream, Error>
1165 {
1166 match &self.0.address {
1167 Address::Tcp((host, port)) => {
1168 let conn = TcpStream::connect(&(&host[..], *port)).await
1169 .map_err(ClientConnectionError::with_source)?;
1170 let is_valid_dns_name = DnsNameRef::try_from_ascii_str(host)
1171 .is_ok();
1172 let host = if !is_valid_dns_name {
1173 let host = format!("{}.host-for-ip.edgedb.net", host);
1177 let host = host.replace(":", "-").replace("%", "-");
1179 if host.starts_with("-") {
1180 Cow::from(format!("i{}", host))
1181 } else {
1182 Cow::from(host)
1183 }
1184 } else {
1185 Cow::from(&host[..])
1186 };
1187 Ok(tls.connect(&host[..], conn).await.map_err(tls_fail)?)
1188 }
1189 Address::Unix(path) => {
1190 #[cfg(windows)] {
1191 return Err(ClientError::with_message(
1192 "Unix socket are not supported on windows",
1193 ));
1194 }
1195 #[cfg(unix)] {
1196 use async_std::os::unix::net::UnixStream;
1197 let conn = UnixStream::connect(&path).await
1198 .map_err(ClientConnectionError::with_source)?;
1199 Ok(
1200 PlainConnector::builder()
1201 .map_err(ClientError::with_source_ref)?
1202 .build().map_err(ClientError::with_source_ref)?
1203 .into_dyn()
1204 .connect("localhost", conn).await.map_err(tls_fail)?
1205 )
1206 }
1207 }
1208 }
1209 }
1210
1211 async fn _connect_with(&self, stream: TlsStream)
1212 -> Result<Connection, Error>
1213 {
1214 let mut version = ProtocolVersion::current();
1215 let (input, output) = stream.split();
1216 let mut conn = Connection {
1217 ping_interval: PingInterval::Unknown,
1218 input,
1219 output,
1220 input_buf: BytesMut::with_capacity(8192),
1221 output_buf: BytesMut::with_capacity(8192),
1222 params: TypeMap::custom(),
1223 transaction_state: TransactionState::NotInTransaction,
1224 state: State::Normal {
1225 idle_since: Instant::now(),
1226 },
1227 version: version.clone(),
1228 };
1229 let mut seq = conn.start_sequence().await?;
1230 let mut params = HashMap::new();
1231 params.insert(String::from("user"), self.0.user.clone());
1232 params.insert(String::from("database"), self.0.database.clone());
1233
1234 let (major_ver, minor_ver) = version.version_tuple();
1235 seq.send_messages(&[
1236 ClientMessage::ClientHandshake(ClientHandshake {
1237 major_ver,
1238 minor_ver,
1239 params,
1240 extensions: HashMap::new(),
1241 }),
1242 ]).await?;
1243
1244 let mut msg = seq.message().await?;
1245 if let ServerMessage::ServerHandshake(ServerHandshake {
1246 major_ver, minor_ver, extensions: _
1247 }) = msg {
1248 version = ProtocolVersion::new(major_ver, minor_ver);
1249 msg = seq.message().await?;
1251 }
1252 match msg {
1253 ServerMessage::Authentication(Authentication::Ok) => {}
1254 ServerMessage::Authentication(Authentication::Sasl { methods })
1255 => {
1256 if methods.iter().any(|x| x == "SCRAM-SHA-256") {
1257 if let Some(password) = &self.0.password {
1258 scram(&mut seq, &self.0.user, password).await
1259 .map_err(ClientError::with_source)?;
1260 } else {
1261 return Err(PasswordRequired::with_message(
1262 "Password required for the specified user/host"));
1263 }
1264 } else {
1265 return Err(AuthenticationError::with_message(format!(
1266 "No supported authentication \
1267 methods: {:?}", methods)));
1268 }
1269 }
1270 ServerMessage::ErrorResponse(err) => {
1271 return Err(err.into());
1272 }
1273 msg => {
1274 return Err(ProtocolError::with_message(format!(
1275 "Error authenticating, unexpected message {:?}", msg)));
1276 }
1277 }
1278
1279 let mut server_params = TypeMap::custom();
1280 loop {
1281 let msg = seq.message().await?;
1282 match msg {
1283 ServerMessage::ReadyForCommand(ready) => {
1284 seq.reader.consume_ready(ready);
1285 seq.end_clean();
1286 break;
1287 }
1288 ServerMessage::ServerKeyData(_) => {
1289 }
1291 ServerMessage::ParameterStatus(par) => {
1292 match &par.name[..] {
1293 b"pgaddr" => {
1294 let pgaddr: PostgresAddress;
1295 pgaddr = match from_slice(&par.value[..]) {
1296 Ok(a) => a,
1297 Err(e) => {
1298 log::warn!("Can't decode param {:?}: {}",
1299 par.name, e);
1300 continue;
1301 }
1302 };
1303 server_params.insert::<PostgresAddress>(pgaddr);
1304 }
1305 b"system_config" => {
1306 self.handle_system_config(par, &mut server_params)?;
1307 }
1308 _ => {}
1309 }
1310 }
1311 _ => {
1312 log::warn!("unsolicited message {:?}", msg);
1313 }
1314 }
1315 }
1316 conn.version = version;
1317 conn.params = server_params;
1318 conn.state = State::Normal {
1319 idle_since: Instant::now()
1320 };
1321 Ok(conn)
1322 }
1323
1324 fn handle_system_config(
1325 &self,
1326 param_status: ParameterStatus,
1327 server_params: &mut TypeMap<dyn DebugAny + Send + Sync>
1328 ) -> Result<(), Error> {
1329 let (typedesc, data) = param_status.parse_system_config()
1330 .map_err(ProtocolEncodingError::with_source)?;
1331 let codec = typedesc.build_codec()
1332 .map_err(ProtocolEncodingError::with_source)?;
1333 let system_config = codec.decode(data.as_ref())
1334 .map_err(ProtocolEncodingError::with_source)?;
1335 let mut config = SystemConfig {
1336 session_idle_timeout: None,
1337 };
1338 if let Value::Object { shape, fields } = system_config {
1339 for (el, field) in shape.elements.iter().zip(fields) {
1340 match el.name.as_str() {
1341 "id" => {},
1342 "session_idle_timeout" => {
1343 config.session_idle_timeout = match field {
1344 Some(Value::Duration(timeout)) =>
1345 Some(timeout.abs_duration()),
1346 _ => {
1347 log::warn!(
1348 "Wrong protocol: {}={:?}", el.name, field
1349 );
1350 None
1351 },
1352 };
1353 }
1354 name => {
1355 log::debug!(
1356 "Unhandled system config: {}={:?}", name, field
1357 );
1358 }
1359 }
1360 }
1361 server_params.insert::<SystemConfig>(config);
1362 } else {
1363 log::warn!("Received empty system config message.");
1364 }
1365 Ok(())
1366 }
1367}
1368
1369async fn scram(seq: &mut Sequence<'_>, user: &str, password: &str)
1370 -> Result<(), Error>
1371{
1372 use edgedb_protocol::client_message::SaslInitialResponse;
1373 use edgedb_protocol::client_message::SaslResponse;
1374
1375 let scram = ScramClient::new(&user, &password, None);
1376
1377 let (scram, first) = scram.client_first();
1378 seq.send_messages(&[
1379 ClientMessage::AuthenticationSaslInitialResponse(
1380 SaslInitialResponse {
1381 method: "SCRAM-SHA-256".into(),
1382 data: Bytes::copy_from_slice(first.as_bytes()),
1383 }),
1384 ]).await?;
1385 let msg = seq.message().await?;
1386 let data = match msg {
1387 ServerMessage::Authentication(
1388 Authentication::SaslContinue { data }
1389 ) => data,
1390 ServerMessage::ErrorResponse(err) => {
1391 return Err(err.into());
1392 }
1393 msg => {
1394 return Err(ProtocolError::with_message(format!(
1395 "Bad auth response: {:?}", msg)));
1396 }
1397 };
1398 let data = str::from_utf8(&data[..])
1399 .map_err(|e| ProtocolError::with_source(e).context(
1400 "invalid utf-8 in SCRAM-SHA-256 auth"))?;
1401 let scram = scram.handle_server_first(&data)
1402 .map_err(AuthenticationError::with_source)?;
1403 let (scram, data) = scram.client_final();
1404 seq.send_messages(&[
1405 ClientMessage::AuthenticationSaslResponse(
1406 SaslResponse {
1407 data: Bytes::copy_from_slice(data.as_bytes()),
1408 }),
1409 ]).await?;
1410 let msg = seq.message().await?;
1411 let data = match msg {
1412 ServerMessage::Authentication(Authentication::SaslFinal { data })
1413 => data,
1414 ServerMessage::ErrorResponse(err) => {
1415 return Err(err.into());
1416 }
1417 msg => {
1418 return Err(ProtocolError::with_message(format!(
1419 "auth response: {:?}", msg)));
1420 }
1421 };
1422 let data = str::from_utf8(&data[..])
1423 .map_err(|_| ProtocolError::with_message(
1424 "invalid utf-8 in SCRAM-SHA-256 auth"))?;
1425 scram.handle_server_final(&data)
1426 .map_err(|e| AuthenticationError::with_message(format!(
1427 "Authentication error: {}", e)))?;
1428 loop {
1429 let msg = seq.message().await?;
1430 match msg {
1431 ServerMessage::Authentication(Authentication::Ok) => break,
1432 msg => {
1433 log::warn!("unsolicited message {:?}", msg);
1434 }
1435 };
1436 }
1437 Ok(())
1438}
1439
1440impl fmt::Debug for Config {
1441 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1442 f.debug_struct("Config")
1443 .field("address", &self.0.address)
1444 .field("max_connections", &self.0.max_connections)
1445 .finish()
1447 }
1448}
1449
1450#[test]
1451fn read_credentials() {
1452 let mut bld = Builder::uninitialized();
1453 async_std::task::block_on(
1454 bld.read_credentials("tests/credentials1.json")).unwrap();
1455 assert!(matches!(&bld.address, Address::Tcp((_, 10702))));
1456 assert_eq!(&bld.user, "test3n");
1457 assert_eq!(&bld.database, "test3n");
1458 assert_eq!(bld.password, Some("lZTBy1RVCfOpBAOwSCwIyBIR".into()));
1459}
1460
1461#[test]
1462fn display() {
1463 let mut bld = Builder::uninitialized();
1464 async_std::task::block_on(
1465 bld.read_dsn("edgedb://localhost:1756")).unwrap();
1466 assert!(matches!(
1467 &bld.address,
1468 Address::Tcp((host, 1756)) if host == "localhost"
1469 ));
1470 #[cfg(feature="admin_socket")] {
1476 bld.unix_path("/test/.s.EDGEDB.8888", None, false);
1477 assert_eq!(bld.build().unwrap()._get_unix_path().unwrap(),
1478 Some("/test/.s.EDGEDB.8888".into()));
1479 bld.unix_path("/test", Some(8888), false);
1480 assert_eq!(bld.build().unwrap()._get_unix_path().unwrap(),
1481 Some("/test/.s.EDGEDB.8888".into()));
1482 }
1483}
1484
1485#[test]
1486fn from_dsn() {
1487 let mut bld = Builder::uninitialized();
1488 async_std::task::block_on(bld.read_dsn(
1489 "edgedb://user1:EiPhohl7@edb-0134.elb.us-east-2.amazonaws.com/db2"
1490 )).unwrap();
1491 assert!(matches!(
1492 &bld.address,
1493 Address::Tcp((host, 5656))
1494 if host == "edb-0134.elb.us-east-2.amazonaws.com",
1495 ));
1496 assert_eq!(&bld.user, "user1");
1497 assert_eq!(&bld.database, "db2");
1498 assert_eq!(bld.password, Some("EiPhohl7".into()));
1499
1500 let mut bld = Builder::uninitialized();
1501 async_std::task::block_on(bld.read_dsn(
1502 "edgedb://user2@edb-0134.elb.us-east-2.amazonaws.com:1756/db2"
1503 )).unwrap();
1504 assert!(matches!(
1505 &bld.address,
1506 Address::Tcp((host, 1756))
1507 if host == "edb-0134.elb.us-east-2.amazonaws.com",
1508 ));
1509 assert_eq!(&bld.user, "user2");
1510 assert_eq!(&bld.database, "db2");
1511 assert_eq!(bld.password, None);
1512
1513 async_std::task::block_on(bld.read_dsn(
1515 "edgedb://edb-0134.elb.us-east-2.amazonaws.com:1756"
1516 )).unwrap();
1517 assert!(matches!(
1518 &bld.address,
1519 Address::Tcp((host, 1756))
1520 if host == "edb-0134.elb.us-east-2.amazonaws.com",
1521 ));
1522 assert_eq!(&bld.user, "edgedb");
1523 assert_eq!(&bld.database, "edgedb");
1524 assert_eq!(bld.password, None);
1525
1526 async_std::task::block_on(bld.read_dsn(
1527 "edgedb://user3:123123@[::1]:5555/abcdef"
1528 )).unwrap();
1529 assert!(matches!(
1530 &bld.address,
1531 Address::Tcp((host, 5555)) if host == "::1",
1532 ));
1533 assert_eq!(&bld.user, "user3");
1534 assert_eq!(&bld.database, "abcdef");
1535 assert_eq!(bld.password, Some("123123".into()));
1536}
1537
1538#[test]
1539#[should_panic] fn from_dsn_ipv6_scoped_address() {
1541 let mut bld = Builder::uninitialized();
1542 async_std::task::block_on(bld.read_dsn(
1543 "edgedb://user3@[fe80::1ff:fe23:4567:890a%25eth0]:3000/ab"
1544 )).unwrap();
1545 assert!(matches!(
1546 &bld.address,
1547 Address::Tcp((host, 3000)) if host == "fe80::1ff:fe23:4567:890a%eth0",
1548 ));
1549 assert_eq!(&bld.user, "user3");
1550 assert_eq!(&bld.database, "ab");
1551 assert_eq!(bld.password, None);
1552}