1use super::{
2 error::*, format_duration, BuildContextImpl, CredentialsFile, FromParamStr, InstanceName,
3 Param, Params,
4};
5use crate::{
6 gel::{parse_duration, BuildPhase},
7 host::{Host, HostType, LOCALHOST_HOSTNAME},
8};
9use rustls_pki_types::CertificateDer;
10use serde::{Deserialize, Serialize};
11use std::{
12 borrow::Cow,
13 collections::HashMap,
14 fmt,
15 num::NonZero,
16 path::{Path, PathBuf},
17 str::FromStr,
18 time::Duration,
19};
20use url::Url;
21
22pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
23pub const DEFAULT_WAIT: Duration = Duration::from_secs(30);
24pub const DEFAULT_TCP_KEEPALIVE: Duration = Duration::from_secs(60);
25pub const DEFAULT_POOL_SIZE: usize = 10;
26pub const DEFAULT_HOST: &HostType = crate::host::LOCALHOST;
27pub const DEFAULT_PORT: u16 = 5656;
28pub const DEFAULT_USER: &str = crate::gel::branding::BRANDING_DEFAULT_USERNAME_LEGACY;
29pub const DEFAULT_BRANCH: DatabaseBranch = DatabaseBranch::Default;
30
31pub const DEFAULT_DATABASE_NAME: &str = "edgedb";
32
33pub const DEFAULT_BRANCH_NAME_CONNECT: &str = "__default__";
36pub const DEFAULT_BRANCH_NAME_CREATE: &str = "main";
38
39pub struct ConfigResult {
41 pub(crate) result: Result<Config, gel_errors::Error>,
42 pub(crate) warnings: Warnings,
43}
44
45impl std::ops::Deref for ConfigResult {
46 type Target = Result<Config, gel_errors::Error>;
47
48 fn deref(&self) -> &Self::Target {
49 &self.result
50 }
51}
52
53impl From<ConfigResult> for Result<Config, gel_errors::Error> {
54 fn from(val: ConfigResult) -> Self {
55 val.result
56 }
57}
58
59impl ConfigResult {
60 pub fn unwrap(self) -> Config {
61 self.result.unwrap()
62 }
63
64 pub fn expect(self, message: &str) -> Config {
65 self.result.expect(message)
66 }
67
68 pub fn result(&self) -> &Result<Config, gel_errors::Error> {
69 &self.result
70 }
71
72 pub fn into_result(self) -> Result<Config, gel_errors::Error> {
73 self.result
74 }
75
76 pub fn parse_error(&self) -> Option<&ParseError> {
77 use std::error::Error;
78 self.result
79 .as_ref()
80 .err()
81 .and_then(|e| e.source().and_then(|e| e.downcast_ref::<ParseError>()))
82 }
83
84 pub fn warnings(&self) -> &Warnings {
85 &self.warnings
86 }
87}
88
89#[derive(Debug, Clone, PartialEq, Eq)]
91pub struct Config {
92 pub host: Host,
93 pub db: DatabaseBranch,
94 pub user: String,
95
96 pub instance_name: Option<InstanceName>,
98
99 pub authentication: Authentication,
100
101 pub client_security: ClientSecurity,
102 pub tls_security: TlsSecurity,
103
104 pub tls_ca: Option<Vec<CertificateDer<'static>>>,
105 pub tls_server_name: Option<String>,
106 pub wait_until_available: Duration,
107
108 pub connect_timeout: Duration,
109 pub max_concurrency: Option<usize>,
110 pub tcp_keepalive: TcpKeepalive,
111
112 pub cloud_certs: Option<CloudCerts>,
113
114 pub server_settings: HashMap<String, String>,
115}
116
117impl Default for Config {
118 fn default() -> Self {
119 Self {
120 host: Host::new(DEFAULT_HOST.clone(), DEFAULT_PORT),
121 db: DatabaseBranch::Default,
122 user: DEFAULT_USER.to_string(),
123 instance_name: None,
124 authentication: Authentication::None,
125 client_security: ClientSecurity::Default,
126 tls_security: TlsSecurity::Strict,
127 tls_ca: None,
128 tls_server_name: None,
129 wait_until_available: DEFAULT_WAIT,
130 connect_timeout: DEFAULT_CONNECT_TIMEOUT,
131 max_concurrency: None,
132 tcp_keepalive: TcpKeepalive::Default,
133 cloud_certs: None,
134 server_settings: HashMap::new(),
135 }
136 }
137}
138
139#[derive(Debug, Clone, PartialEq, Eq, derive_more::Error, derive_more::Display)]
140pub enum CredentialsError {
141 #[display("no TCP address")]
142 NoTcpAddress,
143}
144
145fn to_pem(certs: &[CertificateDer<'static>]) -> String {
146 use base64::Engine;
147 let prefix = "-----BEGIN CERTIFICATE-----\n";
148 let suffix = "-----END CERTIFICATE-----\n";
149 let mut pem = String::new();
150 for cert in certs {
151 pem.push_str(prefix);
152 let mut b64 = vec![0; cert.len() * 4 / 3 + 4];
153 let len = base64::prelude::BASE64_STANDARD
154 .encode_slice(cert.as_ref(), &mut b64)
155 .unwrap();
156 b64.truncate(len);
157 let lines = b64.chunks(64);
158 for line in lines {
159 pem.push_str(std::str::from_utf8(line).unwrap());
160 pem.push('\n');
161 }
162 pem.push_str(suffix);
163 }
164 pem
165}
166
167impl Config {
168 pub fn instance_name(&self) -> Option<&InstanceName> {
169 self.instance_name.as_ref()
170 }
171
172 pub fn local_instance_name(&self) -> Option<&str> {
173 self.instance_name.as_ref().and_then(InstanceName::local)
174 }
175
176 pub fn admin(&self) -> bool {
177 self.host.is_unix()
178 }
179
180 pub fn user(&self) -> &str {
181 &self.user
182 }
183
184 pub fn port(&self) -> u16 {
185 self.host.1
186 }
187
188 pub fn display_addr(&self) -> impl fmt::Display + '_ {
189 self.host.to_string()
190 }
191
192 pub fn secret_key(&self) -> Option<&str> {
193 self.authentication.secret_key()
194 }
195
196 pub fn tls_ca_pem(&self) -> Option<String> {
197 self.tls_ca.as_ref().map(|v| to_pem(v))
198 }
199
200 pub fn http_url(&self, tls: bool) -> Option<String> {
202 if let Some((host, port)) = self.host.target_name().ok()?.tcp() {
203 let s = if tls { "s" } else { "" };
204 Some(format!("http{s}://{host}:{port}"))
205 } else {
206 None
207 }
208 }
209
210 pub fn dsn_url(&self) -> Option<String> {
214 let mut url = Url::parse("gel://").unwrap();
215
216 if let Some((host, port)) = self.host.target_name().ok()?.tcp() {
217 if host != LOCALHOST_HOSTNAME {
218 if port != DEFAULT_PORT {
219 _ = url.set_host(Some(&host));
220 _ = url.set_port(Some(port));
221 } else {
222 _ = url.set_host(Some(&host));
223 }
224 } else if port != DEFAULT_PORT {
225 url.query_pairs_mut().append_pair("port", &port.to_string());
226 }
227 } else {
228 return None;
229 }
230
231 if self.db != DatabaseBranch::Default {
232 if let Some(database) = self.db.database() {
233 url.set_path(database);
234 }
235
236 if let Some(branch) = self.db.branch_for_connect() {
237 url.set_path(branch);
238 }
239 }
240
241 if self.user() != DEFAULT_USER {
242 if url.host().is_none() {
243 url.query_pairs_mut().append_pair("user", self.user());
244 } else {
245 _ = url.set_username(self.user());
246 }
247 }
248
249 if let Some(password) = self.authentication.password() {
250 if url.host().is_none() {
251 url.query_pairs_mut().append_pair("password", password);
252 } else {
253 _ = url.set_password(Some(password));
254 }
255 }
256
257 if self.tls_ca.is_some() {
259 url.query_pairs_mut().append_pair("tls_ca_file", "<...>");
260 }
261
262 if let Some(secret_key) = self.authentication.secret_key() {
263 url.query_pairs_mut().append_pair("secret_key", secret_key);
264 }
265
266 if self.tls_security != TlsSecurity::Strict {
267 url.query_pairs_mut()
268 .append_pair("tls_security", &self.tls_security.to_string());
269 }
270
271 if let Some(tls_server_name) = &self.tls_server_name {
272 url.query_pairs_mut()
273 .append_pair("tls_server_name", tls_server_name);
274 }
275
276 if self.wait_until_available != DEFAULT_WAIT {
277 url.query_pairs_mut().append_pair(
278 "wait_until_available",
279 &format_duration(&self.wait_until_available),
280 );
281 }
282
283 for (key, value) in &self.server_settings {
284 url.query_pairs_mut().append_pair(key, value);
285 }
286
287 Some(url.to_string())
288 }
289
290 pub fn with_host(&self, host: &str, port: u16) -> Result<Self, ParseError> {
291 Ok(Self {
292 host: Host::new(HostType::from_str(host)?, port),
293 ..self.clone()
294 })
295 }
296
297 pub fn with_branch(&self, branch: &str) -> Self {
298 Self {
299 db: DatabaseBranch::Branch(branch.to_string()),
300 ..self.clone()
301 }
302 }
303
304 pub fn with_db(&self, db: DatabaseBranch) -> Self {
305 Self { db, ..self.clone() }
306 }
307
308 pub fn with_user(&self, user: &str) -> Self {
309 Self {
310 user: user.to_string(),
311 ..self.clone()
312 }
313 }
314
315 pub fn with_password(&self, password: &str) -> Self {
316 Self {
317 authentication: Authentication::Password(password.to_string()),
318 ..self.clone()
319 }
320 }
321
322 pub fn with_wait_until_available(&self, dur: Duration) -> Self {
323 Self {
324 wait_until_available: dur,
325 ..self.clone()
326 }
327 }
328
329 pub fn with_tls_ca(&self, certs: &[CertificateDer<'static>]) -> Self {
330 Self {
331 tls_ca: Some(certs.to_vec()),
332 ..self.clone()
333 }
334 }
335
336 #[deprecated = "use with_tls_ca instead"]
337 pub fn with_pem_certificates(&self, certs: &str) -> Result<Self, ParseError> {
338 let certs = <Vec<CertificateDer<'static>> as FromParamStr>::from_param_str(
339 certs,
340 &BuildContextImpl::default(),
341 )?;
342 Ok(Self {
343 tls_ca: Some(certs),
344 ..self.clone()
345 })
346 }
347
348 #[cfg(feature = "serde")]
349 pub fn to_json(&self) -> impl serde::Serialize + std::fmt::Display {
350 use serde::Serialize;
351 use std::collections::BTreeMap;
352
353 #[derive(Serialize)]
354 #[allow(non_snake_case)]
355 struct ConfigJson {
356 address: (String, usize),
357 branch: Option<String>,
358 database: Option<String>,
359 password: Option<String>,
360 secretKey: Option<String>,
361 serverSettings: BTreeMap<String, String>,
362 tlsCAData: Option<String>,
363 tlsSecurity: String,
364 tlsServerName: Option<String>,
365 user: String,
366 waitUntilAvailable: String,
367 }
368
369 impl std::fmt::Display for ConfigJson {
370 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371 write!(f, "{}", serde_json::to_string(self).unwrap())
372 }
373 }
374
375 ConfigJson {
376 address: (self.host.0.to_string(), self.host.1 as usize),
377 branch: self.db.branch_for_connect().map(|s| s.to_string()),
378 database: self.db.database().map(|s| s.to_string()),
379 password: self.authentication.password().map(|s| s.to_string()),
380 secretKey: self.authentication.secret_key().map(|s| s.to_string()),
381 serverSettings: BTreeMap::from_iter(self.server_settings.clone()),
382 tlsCAData: self.tls_ca.as_ref().map(|cert| to_pem(cert)),
383 tlsSecurity: self.tls_security.to_string(),
384 tlsServerName: self.tls_server_name.clone(),
385 user: self.user.clone(),
386 waitUntilAvailable: super::duration::Duration::from_micros(
387 self.wait_until_available.as_micros() as i64,
388 )
389 .to_string(),
390 }
391 }
392
393 pub fn as_credentials(&self) -> Result<CredentialsFile, CredentialsError> {
395 let target = self
396 .host
397 .target_name()
398 .map_err(|_| CredentialsError::NoTcpAddress)?;
399 let tcp = target.tcp().ok_or(CredentialsError::NoTcpAddress)?;
400 Ok(CredentialsFile {
401 user: Some(self.user.clone()),
402 host: Some(tcp.0.to_string()),
403 port: Some(NonZero::new(tcp.1).expect("invalid zero port")),
404 password: self.authentication.password().map(|s| s.to_string()),
405 secret_key: self.authentication.secret_key().map(|s| s.to_string()),
406 database: self.db.database().map(|s| s.to_string()),
407 branch: self.db.branch_for_connect().map(|s| s.to_string()),
408 tls_ca: self.tls_ca_pem(),
409 tls_security: self.tls_security,
410 tls_server_name: self.tls_server_name.clone(),
411 warnings: vec![],
412 })
413 }
414
415 #[allow(clippy::field_reassign_with_default)]
416 pub fn to_tls(&self) -> gel_stream::TlsParameters {
417 use gel_stream::{TlsAlpn, TlsCert, TlsParameters, TlsServerCertVerify};
418 use std::borrow::Cow;
419 use std::net::IpAddr;
420
421 let mut tls = TlsParameters::default();
422 tls.root_cert = TlsCert::Webpki;
423 match &self.tls_ca {
424 Some(certs) => {
425 tls.root_cert = TlsCert::Custom(certs.to_vec());
426 }
427 None => {
428 if let Some(cloud_certs) = self.cloud_certs {
429 tls.root_cert = TlsCert::WebpkiPlus(cloud_certs.certificates().to_vec());
430 }
431 }
432 }
433 tls.server_cert_verify = match self.tls_security {
434 TlsSecurity::Insecure => TlsServerCertVerify::Insecure,
435 TlsSecurity::NoHostVerification => TlsServerCertVerify::IgnoreHostname,
436 TlsSecurity::Strict | TlsSecurity::Default => TlsServerCertVerify::VerifyFull,
437 };
438 tls.alpn = TlsAlpn::new_str(&["edgedb-binary", "gel-binary"]);
439 tls.sni_override = match &self.tls_server_name {
440 Some(server_name) => Some(Cow::from(server_name.clone())),
441 None => {
442 if let Ok(host) = self.host.target_name() {
443 if let Some(host) = host.host() {
444 if let Ok(ip) = IpAddr::from_str(&host) {
445 let host = format!("{ip}.host-for-ip.edgedb.net");
447 let host = host.replace([':', '%'], "-");
449 if host.starts_with('-') {
450 Some(Cow::from(format!("i{host}")))
451 } else {
452 Some(Cow::from(host.to_string()))
453 }
454 } else {
455 Some(Cow::from(host.to_string()))
456 }
457 } else {
458 None
459 }
460 } else {
461 None
462 }
463 }
464 };
465 tls
466 }
467}
468
469#[derive(Debug, Clone, Default, PartialEq, Eq)]
471pub enum Authentication {
472 #[default]
473 None,
474 Password(String),
475 SecretKey(String),
476}
477
478impl Authentication {
479 pub fn password(&self) -> Option<&str> {
480 match self {
481 Self::Password(password) => Some(password),
482 _ => None,
483 }
484 }
485
486 pub fn secret_key(&self) -> Option<&str> {
487 match self {
488 Self::SecretKey(secret_key) => Some(secret_key),
489 _ => None,
490 }
491 }
492}
493
494#[derive(Debug, Clone, Default, PartialEq, Eq)]
496pub enum DatabaseBranch {
497 #[default]
498 Default,
499 Database(String),
500 Branch(String),
501 Ambiguous(String),
502}
503
504impl std::fmt::Display for DatabaseBranch {
505 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
506 if f.alternate() {
508 match self {
509 Self::Database(database) => write!(f, "{database}"),
510 Self::Branch(branch) => write!(f, "{branch}"),
511 Self::Ambiguous(ambiguous) => write!(f, "{ambiguous}"),
512 Self::Default => write!(f, "(default)"),
513 }
514 } else {
515 match self {
516 Self::Database(database) => write!(f, "database '{database}'"),
517 Self::Branch(branch) => write!(f, "branch '{branch}'"),
518 Self::Ambiguous(ambiguous) => write!(f, "'{ambiguous}'"),
519 Self::Default => write!(f, "default database/branch"),
520 }
521 }
522 }
523}
524
525impl DatabaseBranch {
526 pub fn database(&self) -> Option<&str> {
527 match self {
528 Self::Database(database) => Some(database),
529 Self::Branch(branch) => Some(branch),
531 Self::Ambiguous(ambiguous) => Some(ambiguous),
532 Self::Default => Some(DEFAULT_DATABASE_NAME),
533 }
534 }
535
536 pub fn branch_for_connect(&self) -> Option<&str> {
537 match self {
538 Self::Branch(branch) => Some(branch),
539 Self::Database(database) => Some(database),
541 Self::Ambiguous(ambiguous) => Some(ambiguous),
542 Self::Default => Some(DEFAULT_BRANCH_NAME_CONNECT),
543 }
544 }
545
546 pub fn branch_for_create(&self) -> Option<&str> {
547 match self {
548 Self::Branch(branch) => Some(branch),
549 Self::Database(database) => Some(database),
551 Self::Ambiguous(ambiguous) => Some(ambiguous),
552 Self::Default => Some(DEFAULT_BRANCH_NAME_CREATE),
553 }
554 }
555
556 pub fn name(&self) -> Option<&str> {
557 match self {
558 Self::Database(database) => Some(database),
559 Self::Branch(branch) => Some(branch),
560 Self::Ambiguous(ambiguous) => Some(ambiguous),
561 Self::Default => None,
562 }
563 }
564}
565
566#[derive(Default, Debug, Clone, Copy, Eq, PartialEq)]
568pub enum ClientSecurity {
569 InsecureDevMode,
571 Strict,
573 #[default]
575 Default,
576}
577
578impl FromStr for ClientSecurity {
579 type Err = ParseError;
580 fn from_str(s: &str) -> Result<ClientSecurity, Self::Err> {
581 use ClientSecurity::*;
582
583 match s {
584 "default" => Ok(Default),
585 "strict" => Ok(Strict),
586 "insecure_dev_mode" => Ok(InsecureDevMode),
587 _ => Err(ParseError::InvalidTlsSecurity(
589 TlsSecurityError::InvalidValue,
590 )),
591 }
592 }
593}
594
595#[derive(Debug, Clone, Copy, PartialEq, Eq)]
597pub enum CloudCerts {
598 Staging,
599 Local,
600}
601
602impl FromStr for CloudCerts {
603 type Err = ParseError;
604 fn from_str(s: &str) -> Result<CloudCerts, Self::Err> {
605 use CloudCerts::*;
606
607 match s {
608 "staging" => Ok(Staging),
609 "local" => Ok(Local),
610 _ => Err(ParseError::FileNotFound),
612 }
613 }
614}
615
616impl CloudCerts {
617 pub fn certificates(&self) -> &[CertificateDer<'static>] {
618 match self {
619 Self::Staging => {
620 static CERT: std::sync::OnceLock<Vec<CertificateDer<'static>>> =
621 std::sync::OnceLock::new();
622 CERT.get_or_init(|| {
623 Self::read_static_certs(Self::Staging.certificates_pem().as_bytes())
624 })
625 }
626 Self::Local => {
627 static CERT: std::sync::OnceLock<Vec<CertificateDer<'static>>> =
628 std::sync::OnceLock::new();
629 CERT.get_or_init(|| {
630 Self::read_static_certs(Self::Local.certificates_pem().as_bytes())
631 })
632 }
633 }
634 }
635
636 pub fn certificates_pem(&self) -> &'static str {
637 match self {
638 Self::Staging => include_str!("certs/letsencrypt_staging.pem"),
639 Self::Local => include_str!("certs/nebula_development.pem"),
640 }
641 }
642
643 fn read_static_certs(bytes: &'static [u8]) -> Vec<CertificateDer<'static>> {
644 let mut cursor = std::io::Cursor::new(bytes);
645 let mut certs = Vec::new();
646 for item in rustls_pemfile::read_all(&mut cursor) {
647 match item {
648 Ok(rustls_pemfile::Item::X509Certificate(data)) => certs.push(data),
649 _ => panic!(),
650 }
651 }
652 certs
653 }
654}
655
656#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
658#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
659#[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
660pub enum TlsSecurity {
661 Insecure,
663 NoHostVerification,
670 Strict,
672 #[default]
675 Default,
676}
677
678impl FromStr for TlsSecurity {
679 type Err = ParseError;
680 fn from_str(val: &str) -> Result<Self, Self::Err> {
681 match val {
682 "default" => Ok(TlsSecurity::Default),
683 "insecure" => Ok(TlsSecurity::Insecure),
684 "no_host_verification" => Ok(TlsSecurity::NoHostVerification),
685 "strict" => Ok(TlsSecurity::Strict),
686 _ => Err(ParseError::InvalidTlsSecurity(
687 TlsSecurityError::InvalidValue,
688 )),
689 }
690 }
691}
692
693impl fmt::Display for TlsSecurity {
694 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
695 match self {
696 Self::Insecure => write!(f, "insecure"),
697 Self::NoHostVerification => write!(f, "no_host_verification"),
698 Self::Strict => write!(f, "strict"),
699 Self::Default => write!(f, "default"),
700 }
701 }
702}
703
704#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
706pub enum TcpKeepalive {
707 Disabled,
709 Explicit(Duration),
711 #[default]
713 Default,
714}
715
716impl TcpKeepalive {
717 pub fn as_keepalive(&self) -> Option<Duration> {
718 match self {
719 TcpKeepalive::Disabled => None,
720 TcpKeepalive::Default => Some(DEFAULT_TCP_KEEPALIVE),
721 TcpKeepalive::Explicit(duration) => Some(*duration),
722 }
723 }
724}
725
726impl FromStr for TcpKeepalive {
727 type Err = ParseError;
728 fn from_str(s: &str) -> Result<Self, Self::Err> {
729 use TcpKeepalive::*;
730
731 match s {
732 "disabled" => Ok(Disabled),
733 "default" => Ok(Default),
734 _ => Ok(Explicit(
735 parse_duration(s).map_err(|_| ParseError::InvalidDuration)?,
736 )),
737 }
738 }
739}
740
741#[derive(derive_more::Debug, Clone, PartialEq, Eq)]
742enum UnixPathInner {
743 #[debug("{:?}{{port}}", _0)]
745 PortSuffixed(PathBuf),
746 #[debug("{:?}", _0)]
748 Exact(PathBuf),
749}
750
751#[derive(Clone, PartialEq, Eq, derive_more::Debug)]
753pub struct UnixPath {
754 #[debug("{:?}", inner)]
755 inner: UnixPathInner,
756}
757
758impl UnixPath {
759 pub fn with_port_suffix(path: PathBuf) -> Self {
761 UnixPath {
762 inner: UnixPathInner::PortSuffixed(path),
763 }
764 }
765
766 pub fn exact(path: PathBuf) -> Self {
768 UnixPath {
769 inner: UnixPathInner::Exact(path),
770 }
771 }
772
773 pub fn path_with_port(&self, port: u16) -> Cow<Path> {
775 match &self.inner {
776 UnixPathInner::PortSuffixed(path) => {
777 let Some(filename) = path.file_name() else {
778 return Cow::Owned(path.join(port.to_string()));
779 };
780 let mut path = path.clone();
781 let mut filename = filename.to_owned();
782 filename.push(port.to_string());
783 path.set_file_name(filename);
784 Cow::Owned(path)
785 }
786 UnixPathInner::Exact(path) => Cow::Borrowed(path),
787 }
788 }
789}
790
791impl FromStr for UnixPath {
792 type Err = ParseError;
793 fn from_str(s: &str) -> Result<Self, Self::Err> {
794 Ok(UnixPath::exact(PathBuf::from(s)))
795 }
796}
797
798impl<T: Into<PathBuf>> From<T> for UnixPath {
799 fn from(path: T) -> Self {
800 UnixPath::exact(path.into())
801 }
802}
803
804#[derive(Clone, Default, Debug, Serialize, Deserialize)]
806#[serde(default)]
807#[serde(deny_unknown_fields)]
808pub struct ConnectionOptions {
809 pub dsn: Option<String>,
810 pub user: Option<String>,
811 pub password: Option<String>,
812 pub instance: Option<String>,
813 pub database: Option<String>,
814 pub host: Option<String>,
815 #[serde(deserialize_with = "deserialize_string_or_number")]
816 pub port: Option<String>,
817 pub branch: Option<String>,
818 #[serde(rename = "tlsSecurity")]
819 pub tls_security: Option<String>,
820 #[serde(rename = "tlsCA")]
821 pub tls_ca: Option<String>,
822 #[serde(rename = "tlsCAFile")]
823 pub tls_ca_file: Option<String>,
824 #[serde(rename = "tlsServerName")]
825 pub tls_server_name: Option<String>,
826 #[serde(rename = "waitUntilAvailable")]
827 pub wait_until_available: Option<String>,
828 #[serde(rename = "serverSettings")]
829 pub server_settings: Option<HashMap<String, String>>,
830 #[serde(rename = "credentialsFile")]
831 pub credentials_file: Option<String>,
832 pub credentials: Option<String>,
833 #[serde(rename = "secretKey")]
834 pub secret_key: Option<String>,
835}
836
837#[cfg(feature = "serde")]
838fn deserialize_string_or_number<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
839where
840 D: serde::Deserializer<'de>,
841{
842 let s = serde_json::Value::deserialize(deserializer)?;
843 if let Some(s) = s.as_str() {
844 Ok(Some(s.to_string()))
845 } else {
846 Ok(Some(s.to_string()))
847 }
848}
849
850impl TryInto<Params> for ConnectionOptions {
851 type Error = ParseError;
852
853 fn try_into(self) -> Result<Params, Self::Error> {
854 if self.credentials.is_some() && self.credentials_file.is_some() {
855 return Err(ParseError::MultipleCompound(
856 BuildPhase::Options,
857 vec![
858 CompoundSource::CredentialsFile,
859 CompoundSource::CredentialsFile,
860 ],
861 ));
862 }
863
864 if self.tls_ca.is_some() && self.tls_ca_file.is_some() {
865 return Err(ParseError::ExclusiveOptions(
866 "tls_ca".to_string(),
867 "tls_ca_file".to_string(),
868 ));
869 }
870
871 if self.branch.is_some() && self.database.is_some() {
872 return Err(ParseError::ExclusiveOptions(
873 "branch".to_string(),
874 "database".to_string(),
875 ));
876 }
877
878 let mut credentials = Param::from_file(self.credentials_file.clone());
879 if credentials.is_none() {
880 credentials = Param::from_unparsed(self.credentials.clone());
881 }
882
883 let mut tls_ca = Param::from_unparsed(self.tls_ca.clone());
884 if tls_ca.is_none() {
885 tls_ca = Param::from_file(self.tls_ca_file.clone());
886 }
887
888 let explicit = Params {
889 dsn: Param::from_unparsed(self.dsn.clone()),
890 credentials,
891 user: Param::from_unparsed(self.user.clone()),
892 password: Param::from_unparsed(self.password.clone()),
893 instance: Param::from_unparsed(self.instance.clone()),
894 database: Param::from_unparsed(self.database.clone()),
895 host: Param::from_unparsed(self.host.clone()),
896 port: Param::from_unparsed(self.port.as_ref().map(|n| n.to_string())),
897 branch: Param::from_unparsed(self.branch.clone()),
898 secret_key: Param::from_unparsed(self.secret_key.clone()),
899 tls_security: Param::from_unparsed(self.tls_security.clone()),
900 tls_ca,
901 tls_server_name: Param::from_unparsed(self.tls_server_name.clone()),
902 server_settings: self.server_settings.unwrap_or_default(),
903 wait_until_available: Param::from_unparsed(self.wait_until_available.clone()),
904 ..Default::default()
905 };
906
907 Ok(explicit)
908 }
909}
910
911#[cfg(test)]
912mod tests {
913 use super::*;
914
915 #[test]
916 fn test_as_credentials() {
917 let config = Config::default();
918 let credentials = config.as_credentials().unwrap();
919 assert_eq!(credentials.host, Some("localhost".to_string()));
920 }
921
922 #[test]
923 fn test_dsn_url() {
924 let config = Config::default();
925 let url = config.dsn_url().unwrap();
926 assert_eq!(url, "gel://");
927
928 let config = Config::default().with_host("example.com", 1234).unwrap();
929 let url = config.dsn_url().unwrap();
930 assert_eq!(url, "gel://example.com:1234");
931
932 let config = Config::default()
933 .with_host("localhost", 5656)
934 .unwrap()
935 .with_db(DatabaseBranch::Database("edgedb".to_string()));
936 let url = config.dsn_url().unwrap();
937 assert_eq!(url, "gel:///edgedb");
938
939 let config = Config::default()
940 .with_host("example.com", 5656)
941 .unwrap()
942 .with_db(DatabaseBranch::Branch("main".to_string()));
943 let url = config.dsn_url().unwrap();
944 assert_eq!(url, "gel://example.com/main");
945
946 let config = Config::default()
947 .with_host("localhost", 5656)
948 .unwrap()
949 .with_db(DatabaseBranch::Branch("main".to_string()))
950 .with_user("user");
951 let url = config.dsn_url().unwrap();
952 assert_eq!(url, "gel:///main?user=user");
953
954 let config = Config::default()
955 .with_host("localhost", 5656)
956 .unwrap()
957 .with_db(DatabaseBranch::Branch("main".to_string()))
958 .with_user("user")
959 .with_password("%[]{}");
960 let url = config.dsn_url().unwrap();
961 assert_eq!(url, "gel:///main?user=user&password=%25%5B%5D%7B%7D");
962 }
963}