1use std::collections::HashMap;
8use std::env;
9use std::fmt;
10
11use serde::{Deserialize, Serialize};
12
13use crate::error::{Result, SurqlError};
14
15pub const ENV_PREFIX: &str = "SURQL_";
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
20pub enum Protocol {
21 WebSocket,
23 WebSocketSecure,
25 Http,
27 Https,
29 Memory,
31 File,
33 SurrealKv,
35}
36
37impl Protocol {
38 pub fn supports_live_queries(self) -> bool {
40 !matches!(self, Self::Http | Self::Https)
41 }
42
43 pub fn is_embedded(self) -> bool {
45 matches!(self, Self::Memory | Self::File | Self::SurrealKv)
46 }
47}
48
49impl fmt::Display for Protocol {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 let s = match self {
52 Self::WebSocket => "ws",
53 Self::WebSocketSecure => "wss",
54 Self::Http => "http",
55 Self::Https => "https",
56 Self::Memory => "memory",
57 Self::File => "file",
58 Self::SurrealKv => "surrealkv",
59 };
60 f.write_str(s)
61 }
62}
63
64#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
86pub struct ConnectionConfig {
87 pub db_url: String,
89 pub db_ns: String,
91 pub db: String,
93 pub db_user: Option<String>,
95 pub db_pass: Option<String>,
97 pub db_timeout: f64,
99 pub db_max_connections: u32,
101 pub db_retry_max_attempts: u32,
103 pub db_retry_min_wait: f64,
105 pub db_retry_max_wait: f64,
107 pub db_retry_multiplier: f64,
109 pub enable_live_queries: bool,
111}
112
113impl Default for ConnectionConfig {
114 fn default() -> Self {
115 Self {
116 db_url: "ws://localhost:8000/rpc".into(),
117 db_ns: "development".into(),
118 db: "main".into(),
119 db_user: None,
120 db_pass: None,
121 db_timeout: 30.0,
122 db_max_connections: 10,
123 db_retry_max_attempts: 3,
124 db_retry_min_wait: 1.0,
125 db_retry_max_wait: 10.0,
126 db_retry_multiplier: 2.0,
127 enable_live_queries: true,
128 }
129 }
130}
131
132impl ConnectionConfig {
133 pub fn builder() -> ConnectionConfigBuilder {
135 ConnectionConfigBuilder::default()
136 }
137
138 pub fn validate(&self) -> Result<()> {
140 validate_url(&self.db_url)?;
141 validate_identifier(&self.db_ns, "namespace")?;
142 validate_identifier(&self.db, "database")?;
143 validate_numeric_range("timeout", self.db_timeout, 1.0, f64::INFINITY)?;
144 validate_numeric_range(
145 "max_connections",
146 f64::from(self.db_max_connections),
147 1.0,
148 100.0,
149 )?;
150 validate_numeric_range(
151 "retry_max_attempts",
152 f64::from(self.db_retry_max_attempts),
153 1.0,
154 10.0,
155 )?;
156 validate_numeric_range("retry_min_wait", self.db_retry_min_wait, 0.1, f64::INFINITY)?;
157 validate_numeric_range("retry_max_wait", self.db_retry_max_wait, 1.0, f64::INFINITY)?;
158 validate_numeric_range(
159 "retry_multiplier",
160 self.db_retry_multiplier,
161 1.0,
162 f64::INFINITY,
163 )?;
164 if self.db_retry_max_wait <= self.db_retry_min_wait {
165 return Err(SurqlError::Validation {
166 reason: "db_retry_max_wait must be greater than db_retry_min_wait".into(),
167 });
168 }
169 let proto = Self::detect_protocol(&self.db_url)?;
170 if self.enable_live_queries && !proto.supports_live_queries() {
171 return Err(SurqlError::Validation {
172 reason: "Live queries require WebSocket (ws://, wss://) or embedded \
173 (mem://, memory://, file://, surrealkv://) connection"
174 .into(),
175 });
176 }
177 Ok(())
178 }
179
180 pub fn protocol(&self) -> Result<Protocol> {
182 Self::detect_protocol(&self.db_url)
183 }
184
185 fn detect_protocol(url: &str) -> Result<Protocol> {
186 let trimmed = url.trim();
187 if let Some(rest) = trimmed.strip_prefix("ws://") {
188 if rest.is_empty() {
189 return Err(SurqlError::Validation {
190 reason: "URL host must not be empty".into(),
191 });
192 }
193 return Ok(Protocol::WebSocket);
194 }
195 if trimmed.starts_with("wss://") {
196 return Ok(Protocol::WebSocketSecure);
197 }
198 if trimmed.starts_with("http://") {
199 return Ok(Protocol::Http);
200 }
201 if trimmed.starts_with("https://") {
202 return Ok(Protocol::Https);
203 }
204 if trimmed.starts_with("mem://") || trimmed.starts_with("memory://") {
205 return Ok(Protocol::Memory);
206 }
207 if trimmed.starts_with("file://") {
208 return Ok(Protocol::File);
209 }
210 if trimmed.starts_with("surrealkv://") {
211 return Ok(Protocol::SurrealKv);
212 }
213 Err(SurqlError::Validation {
214 reason: "URL must use one of: ws://, wss://, http://, https://, \
215 mem://, memory://, file://, surrealkv://"
216 .into(),
217 })
218 }
219
220 pub fn from_env() -> Result<Self> {
232 Self::from_env_with_prefix(ENV_PREFIX)
233 }
234
235 pub fn from_env_with_prefix(prefix: &str) -> Result<Self> {
238 let lookup = |key: &str| env::var(key).ok();
239 Self::from_source_with_prefix(prefix, lookup)
240 }
241
242 pub fn from_source_with_prefix<F>(prefix: &str, mut lookup: F) -> Result<Self>
248 where
249 F: FnMut(&str) -> Option<String>,
250 {
251 let mut cfg = Self::default();
252 let p = prefix;
253
254 if let Some(v) = lookup_with_aliases(&mut lookup, p, &["URL", "DB_URL"]) {
255 cfg.db_url = v;
256 }
257 if let Some(v) = lookup_with_aliases(&mut lookup, p, &["NAMESPACE", "DB_NS"]) {
258 cfg.db_ns = v;
259 }
260 if let Some(v) = lookup_with_aliases(&mut lookup, p, &["DATABASE", "DB"]) {
261 cfg.db = v;
262 }
263 if let Some(v) = lookup_with_aliases(&mut lookup, p, &["USERNAME", "DB_USER"]) {
264 cfg.db_user = Some(v);
265 }
266 if let Some(v) = lookup_with_aliases(&mut lookup, p, &["PASSWORD", "DB_PASS"]) {
267 cfg.db_pass = Some(v);
268 }
269 if let Some(v) = lookup_with_aliases(&mut lookup, p, &["TIMEOUT", "DB_TIMEOUT"]) {
270 cfg.db_timeout = parse_env("timeout", &v)?;
271 }
272 if let Some(v) =
273 lookup_with_aliases(&mut lookup, p, &["MAX_CONNECTIONS", "DB_MAX_CONNECTIONS"])
274 {
275 cfg.db_max_connections = parse_env("max_connections", &v)?;
276 }
277 if let Some(v) = lookup_with_aliases(
278 &mut lookup,
279 p,
280 &["RETRY_MAX_ATTEMPTS", "DB_RETRY_MAX_ATTEMPTS"],
281 ) {
282 cfg.db_retry_max_attempts = parse_env("retry_max_attempts", &v)?;
283 }
284 if let Some(v) =
285 lookup_with_aliases(&mut lookup, p, &["RETRY_MIN_WAIT", "DB_RETRY_MIN_WAIT"])
286 {
287 cfg.db_retry_min_wait = parse_env("retry_min_wait", &v)?;
288 }
289 if let Some(v) =
290 lookup_with_aliases(&mut lookup, p, &["RETRY_MAX_WAIT", "DB_RETRY_MAX_WAIT"])
291 {
292 cfg.db_retry_max_wait = parse_env("retry_max_wait", &v)?;
293 }
294 if let Some(v) =
295 lookup_with_aliases(&mut lookup, p, &["RETRY_MULTIPLIER", "DB_RETRY_MULTIPLIER"])
296 {
297 cfg.db_retry_multiplier = parse_env("retry_multiplier", &v)?;
298 }
299 if let Some(v) = lookup_with_aliases(&mut lookup, p, &["ENABLE_LIVE_QUERIES"]) {
300 cfg.enable_live_queries = parse_bool(&v)?;
301 }
302
303 cfg.validate()?;
304 Ok(cfg)
305 }
306
307 pub fn from_map_with_prefix(prefix: &str, map: &HashMap<String, String>) -> Result<Self> {
309 Self::from_source_with_prefix(prefix, |k| map.get(k).cloned())
310 }
311
312 pub fn url(&self) -> &str {
314 &self.db_url
315 }
316
317 pub fn namespace(&self) -> &str {
319 &self.db_ns
320 }
321
322 pub fn database(&self) -> &str {
324 &self.db
325 }
326
327 pub fn username(&self) -> Option<&str> {
329 self.db_user.as_deref()
330 }
331
332 pub fn password(&self) -> Option<&str> {
334 self.db_pass.as_deref()
335 }
336
337 pub fn timeout(&self) -> f64 {
339 self.db_timeout
340 }
341
342 pub fn max_connections(&self) -> u32 {
344 self.db_max_connections
345 }
346
347 pub fn retry_max_attempts(&self) -> u32 {
349 self.db_retry_max_attempts
350 }
351
352 pub fn retry_min_wait(&self) -> f64 {
354 self.db_retry_min_wait
355 }
356
357 pub fn retry_max_wait(&self) -> f64 {
359 self.db_retry_max_wait
360 }
361
362 pub fn retry_multiplier(&self) -> f64 {
364 self.db_retry_multiplier
365 }
366}
367
368#[derive(Debug, Clone, Default)]
370pub struct ConnectionConfigBuilder {
371 inner: Option<ConnectionConfig>,
372}
373
374macro_rules! setter {
375 ($(#[$meta:meta])* $name:ident, $ty:ty, $field:ident) => {
376 $(#[$meta])*
377 pub fn $name(mut self, v: impl Into<$ty>) -> Self {
378 let mut inner = self.inner.unwrap_or_default();
379 inner.$field = v.into();
380 self.inner = Some(inner);
381 self
382 }
383 };
384}
385
386impl ConnectionConfigBuilder {
387 setter!(
388 url,
390 String,
391 db_url
392 );
393 setter!(
394 namespace,
396 String,
397 db_ns
398 );
399 setter!(
400 database,
402 String,
403 db
404 );
405
406 pub fn username(mut self, v: impl Into<String>) -> Self {
408 let mut inner = self.inner.unwrap_or_default();
409 inner.db_user = Some(v.into());
410 self.inner = Some(inner);
411 self
412 }
413
414 pub fn password(mut self, v: impl Into<String>) -> Self {
416 let mut inner = self.inner.unwrap_or_default();
417 inner.db_pass = Some(v.into());
418 self.inner = Some(inner);
419 self
420 }
421
422 pub fn timeout(mut self, secs: f64) -> Self {
424 let mut inner = self.inner.unwrap_or_default();
425 inner.db_timeout = secs;
426 self.inner = Some(inner);
427 self
428 }
429
430 pub fn max_connections(mut self, n: u32) -> Self {
432 let mut inner = self.inner.unwrap_or_default();
433 inner.db_max_connections = n;
434 self.inner = Some(inner);
435 self
436 }
437
438 pub fn retry_max_attempts(mut self, n: u32) -> Self {
440 let mut inner = self.inner.unwrap_or_default();
441 inner.db_retry_max_attempts = n;
442 self.inner = Some(inner);
443 self
444 }
445
446 pub fn retry_min_wait(mut self, secs: f64) -> Self {
448 let mut inner = self.inner.unwrap_or_default();
449 inner.db_retry_min_wait = secs;
450 self.inner = Some(inner);
451 self
452 }
453
454 pub fn retry_max_wait(mut self, secs: f64) -> Self {
456 let mut inner = self.inner.unwrap_or_default();
457 inner.db_retry_max_wait = secs;
458 self.inner = Some(inner);
459 self
460 }
461
462 pub fn retry_multiplier(mut self, m: f64) -> Self {
464 let mut inner = self.inner.unwrap_or_default();
465 inner.db_retry_multiplier = m;
466 self.inner = Some(inner);
467 self
468 }
469
470 pub fn enable_live_queries(mut self, on: bool) -> Self {
472 let mut inner = self.inner.unwrap_or_default();
473 inner.enable_live_queries = on;
474 self.inner = Some(inner);
475 self
476 }
477
478 pub fn build(self) -> Result<ConnectionConfig> {
480 let cfg = self.inner.unwrap_or_default();
481 cfg.validate()?;
482 Ok(cfg)
483 }
484}
485
486#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
488pub struct NamedConnectionConfig {
489 pub name: String,
491 pub config: ConnectionConfig,
493}
494
495impl NamedConnectionConfig {
496 pub fn from_env(name: &str) -> Result<Self> {
499 let prefix = format!("{ENV_PREFIX}{}_", name.to_uppercase());
500 let config = ConnectionConfig::from_env_with_prefix(&prefix)?;
501 Ok(Self {
502 name: name.to_lowercase(),
503 config,
504 })
505 }
506
507 pub fn from_source<F>(name: &str, lookup: F) -> Result<Self>
509 where
510 F: FnMut(&str) -> Option<String>,
511 {
512 let prefix = format!("{ENV_PREFIX}{}_", name.to_uppercase());
513 let config = ConnectionConfig::from_source_with_prefix(&prefix, lookup)?;
514 Ok(Self {
515 name: name.to_lowercase(),
516 config,
517 })
518 }
519}
520
521fn validate_url(url: &str) -> Result<()> {
526 if url.is_empty() {
527 return Err(SurqlError::Validation {
528 reason: "URL cannot be empty".into(),
529 });
530 }
531 let _ = ConnectionConfig::detect_protocol(url)?;
532 Ok(())
533}
534
535fn validate_identifier(value: &str, context: &str) -> Result<()> {
536 if value.is_empty() {
537 return Err(SurqlError::Validation {
538 reason: "Identifier cannot be empty".into(),
539 });
540 }
541 let ok = value
542 .chars()
543 .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-');
544 if !ok {
545 return Err(SurqlError::Validation {
546 reason: format!(
547 "Identifier ({context}) must be alphanumeric with optional underscores/hyphens"
548 ),
549 });
550 }
551 Ok(())
552}
553
554fn validate_numeric_range(name: &str, value: f64, min: f64, max: f64) -> Result<()> {
555 if value.is_nan() {
556 return Err(SurqlError::Validation {
557 reason: format!("{name} must be a finite number"),
558 });
559 }
560 if value < min {
561 return Err(SurqlError::Validation {
562 reason: format!("{name} must be >= {min}"),
563 });
564 }
565 if value > max {
566 return Err(SurqlError::Validation {
567 reason: format!("{name} must be <= {max}"),
568 });
569 }
570 Ok(())
571}
572
573fn lookup_with_aliases<F>(lookup: &mut F, prefix: &str, keys: &[&str]) -> Option<String>
574where
575 F: FnMut(&str) -> Option<String>,
576{
577 for k in keys {
578 let name = format!("{prefix}{k}");
579 if let Some(v) = lookup(&name) {
580 return Some(v);
581 }
582 let lower = name.to_lowercase();
583 if let Some(v) = lookup(&lower) {
584 return Some(v);
585 }
586 }
587 None
588}
589
590fn parse_env<T: std::str::FromStr>(name: &str, raw: &str) -> Result<T>
591where
592 T::Err: std::fmt::Display,
593{
594 raw.parse::<T>().map_err(|e| SurqlError::Validation {
595 reason: format!("invalid {name}={raw:?}: {e}"),
596 })
597}
598
599fn parse_bool(raw: &str) -> Result<bool> {
600 match raw.trim().to_lowercase().as_str() {
601 "1" | "true" | "yes" | "on" => Ok(true),
602 "0" | "false" | "no" | "off" => Ok(false),
603 other => Err(SurqlError::Validation {
604 reason: format!("invalid boolean value {other:?}"),
605 }),
606 }
607}
608
609#[cfg(test)]
610mod tests {
611 use super::*;
612
613 #[test]
614 fn defaults_are_valid() {
615 let cfg = ConnectionConfig::default();
616 cfg.validate().unwrap();
617 assert_eq!(cfg.url(), "ws://localhost:8000/rpc");
618 assert_eq!(cfg.namespace(), "development");
619 assert_eq!(cfg.database(), "main");
620 assert!(cfg.enable_live_queries);
621 }
622
623 #[test]
624 fn builder_overrides_fields() {
625 let cfg = ConnectionConfig::builder()
626 .url("wss://db.example.com/rpc")
627 .namespace("prod")
628 .database("app")
629 .username("alice")
630 .password("hunter2")
631 .timeout(60.0)
632 .build()
633 .unwrap();
634 assert_eq!(cfg.url(), "wss://db.example.com/rpc");
635 assert_eq!(cfg.username(), Some("alice"));
636 assert_eq!(cfg.password(), Some("hunter2"));
637 assert!((cfg.timeout() - 60.0).abs() < f64::EPSILON);
638 }
639
640 #[test]
641 fn rejects_empty_url() {
642 let cfg = ConnectionConfig {
643 db_url: String::new(),
644 ..Default::default()
645 };
646 assert!(cfg.validate().is_err());
647 }
648
649 #[test]
650 fn rejects_unsupported_protocol() {
651 let cfg = ConnectionConfig {
652 db_url: "ftp://localhost".into(),
653 ..Default::default()
654 };
655 assert!(cfg.validate().is_err());
656 }
657
658 #[test]
659 fn accepts_embedded_protocols() {
660 for url in &[
661 "mem://",
662 "memory://",
663 "file:///tmp/db.sdb",
664 "surrealkv:///tmp/db.skv",
665 ] {
666 let cfg = ConnectionConfig {
667 db_url: (*url).into(),
668 ..Default::default()
669 };
670 cfg.validate().unwrap();
671 }
672 }
673
674 #[test]
675 fn rejects_live_queries_over_http() {
676 let cfg = ConnectionConfig {
677 db_url: "https://db.example.com/rpc".into(),
678 enable_live_queries: true,
679 ..Default::default()
680 };
681 assert!(cfg.validate().is_err());
682
683 let cfg_ok = ConnectionConfig {
684 db_url: "https://db.example.com/rpc".into(),
685 enable_live_queries: false,
686 ..Default::default()
687 };
688 cfg_ok.validate().unwrap();
689 }
690
691 #[test]
692 fn rejects_invalid_identifiers() {
693 for bad in ["", "has space", "has/slash", "has!bang"] {
694 let cfg = ConnectionConfig {
695 db_ns: bad.into(),
696 ..Default::default()
697 };
698 assert!(cfg.validate().is_err(), "ns {bad:?} should be invalid");
699 }
700 }
701
702 #[test]
703 fn retry_max_must_exceed_min() {
704 let cfg = ConnectionConfig {
705 db_retry_min_wait: 5.0,
706 db_retry_max_wait: 3.0,
707 ..Default::default()
708 };
709 assert!(cfg.validate().is_err());
710 }
711
712 #[test]
713 fn protocol_detection() {
714 let cases = [
715 ("ws://localhost:8000", Protocol::WebSocket),
716 ("wss://host/rpc", Protocol::WebSocketSecure),
717 ("http://host", Protocol::Http),
718 ("https://host", Protocol::Https),
719 ("mem://", Protocol::Memory),
720 ("memory://", Protocol::Memory),
721 ("file:///tmp/db", Protocol::File),
722 ("surrealkv:///tmp/db", Protocol::SurrealKv),
723 ];
724 for (url, proto) in cases {
725 let cfg = ConnectionConfig {
726 db_url: url.into(),
727 enable_live_queries: proto.supports_live_queries(),
728 ..Default::default()
729 };
730 cfg.validate().unwrap();
731 assert_eq!(cfg.protocol().unwrap(), proto);
732 }
733 }
734
735 #[test]
736 fn protocol_helpers() {
737 assert!(Protocol::WebSocket.supports_live_queries());
738 assert!(Protocol::Memory.supports_live_queries());
739 assert!(!Protocol::Http.supports_live_queries());
740 assert!(!Protocol::Https.supports_live_queries());
741 assert!(Protocol::Memory.is_embedded());
742 assert!(!Protocol::WebSocket.is_embedded());
743 }
744
745 #[test]
746 fn from_source_reads_vars() {
747 let prefix = "SURQL_TEST_CFG_";
748 let env: HashMap<String, String> = [
749 ("URL", "wss://env.example/rpc"),
750 ("NAMESPACE", "envns"),
751 ("DATABASE", "envdb"),
752 ("USERNAME", "envuser"),
753 ("TIMEOUT", "45.5"),
754 ("ENABLE_LIVE_QUERIES", "false"),
755 ]
756 .iter()
757 .map(|(k, v)| (format!("{prefix}{k}"), (*v).to_string()))
758 .collect();
759
760 let cfg = ConnectionConfig::from_map_with_prefix(prefix, &env).unwrap();
761 assert_eq!(cfg.url(), "wss://env.example/rpc");
762 assert_eq!(cfg.namespace(), "envns");
763 assert_eq!(cfg.database(), "envdb");
764 assert_eq!(cfg.username(), Some("envuser"));
765 assert!((cfg.timeout() - 45.5).abs() < f64::EPSILON);
766 assert!(!cfg.enable_live_queries);
767 }
768
769 #[test]
770 fn from_source_accepts_legacy_aliases() {
771 let prefix = "SURQL_LEGACY_";
772 let env: HashMap<String, String> = [
773 ("DB_URL", "ws://legacy.example/rpc"),
774 ("DB_NS", "legns"),
775 ("DB", "legdb"),
776 ("DB_USER", "leguser"),
777 ("DB_PASS", "legpass"),
778 ]
779 .iter()
780 .map(|(k, v)| (format!("{prefix}{k}"), (*v).to_string()))
781 .collect();
782 let cfg = ConnectionConfig::from_map_with_prefix(prefix, &env).unwrap();
783 assert_eq!(cfg.url(), "ws://legacy.example/rpc");
784 assert_eq!(cfg.namespace(), "legns");
785 assert_eq!(cfg.database(), "legdb");
786 assert_eq!(cfg.username(), Some("leguser"));
787 assert_eq!(cfg.password(), Some("legpass"));
788 }
789
790 #[test]
791 fn named_from_source_uses_prefix() {
792 let prefix = "SURQL_PRIMARY_";
793 let env: HashMap<String, String> = [
794 ("URL", "ws://primary.example/rpc"),
795 ("NAMESPACE", "pns"),
796 ("DATABASE", "pdb"),
797 ]
798 .iter()
799 .map(|(k, v)| (format!("{prefix}{k}"), (*v).to_string()))
800 .collect();
801 let named = NamedConnectionConfig::from_source("primary", |k| env.get(k).cloned()).unwrap();
802 assert_eq!(named.name, "primary");
803 assert_eq!(named.config.url(), "ws://primary.example/rpc");
804 }
805
806 #[test]
807 fn serde_roundtrip() {
808 let cfg = ConnectionConfig::default();
809 let json = serde_json::to_string(&cfg).unwrap();
810 let back: ConnectionConfig = serde_json::from_str(&json).unwrap();
811 assert_eq!(cfg, back);
812 }
813}