1use std::collections::BTreeMap;
19use std::sync::Arc;
20use std::time::Duration;
21
22use serde::{de::DeserializeOwned, Serialize};
23use serde_json::Value;
24use surrealdb::engine::any::Any;
25use surrealdb::opt::auth::{
26 Database as SdkDatabase, Namespace as SdkNamespace, Record as SdkRecord, Root as SdkRoot, Token,
27};
28use surrealdb::Surreal;
29use tokio::sync::RwLock;
30use tokio::time::sleep;
31
32use crate::connection::auth::{AuthType, Credentials, ScopeCredentials, TokenAuth};
33use crate::connection::config::ConnectionConfig;
34use crate::error::{Result, SurqlError};
35
36#[derive(Debug, Clone)]
45pub struct DatabaseClient {
46 config: ConnectionConfig,
47 inner: Surreal<Any>,
48 connected: Arc<RwLock<bool>>,
49}
50
51impl DatabaseClient {
52 pub fn new(config: ConnectionConfig) -> Result<Self> {
55 config.validate()?;
56 Ok(Self {
57 config,
58 inner: Surreal::init(),
59 connected: Arc::new(RwLock::new(false)),
60 })
61 }
62
63 pub fn config(&self) -> &ConnectionConfig {
65 &self.config
66 }
67
68 pub fn inner(&self) -> &Surreal<Any> {
70 &self.inner
71 }
72
73 pub fn is_connected(&self) -> bool {
75 self.connected.try_read().is_ok_and(|g| *g)
76 }
77
78 pub async fn connect(&self) -> Result<()> {
84 if *self.connected.read().await {
86 self.disconnect().await.ok();
87 }
88
89 let attempts = self.config.retry_max_attempts().max(1);
90 let mut last_err: Option<SurqlError> = None;
91
92 for attempt in 1..=attempts {
93 match self.connect_once().await {
94 Ok(()) => {
95 *self.connected.write().await = true;
96 return Ok(());
97 }
98 Err(err) => {
99 last_err = Some(err);
100 if attempt < attempts {
101 let wait = self.backoff_for(attempt);
102 sleep(wait).await;
103 }
104 }
105 }
106 }
107
108 Err(last_err.unwrap_or_else(|| SurqlError::Connection {
109 reason: format!("connection failed after {attempts} attempts"),
110 }))
111 }
112
113 pub async fn disconnect(&self) -> Result<()> {
115 {
116 let mut guard = self.connected.write().await;
117 if !*guard {
118 return Ok(());
119 }
120 *guard = false;
121 }
122 self.inner.invalidate().await.ok();
127 Ok(())
128 }
129
130 pub async fn signin<C: Credentials + ?Sized>(&self, creds: &C) -> Result<TokenAuth> {
132 self.require_connected()?;
133 let payload = creds.to_signin_payload();
134 let token = match creds.auth_type() {
135 AuthType::Root => {
136 let username = payload_str(&payload, "username")?;
137 let password = payload_str(&payload, "password")?;
138 self.inner
139 .signin(SdkRoot { username, password })
140 .await
141 .map_err(|e| connection_err(&e))?
142 }
143 AuthType::Namespace => {
144 let namespace = payload_str(&payload, "namespace")?;
145 let username = payload_str(&payload, "username")?;
146 let password = payload_str(&payload, "password")?;
147 self.inner
148 .signin(SdkNamespace {
149 namespace,
150 username,
151 password,
152 })
153 .await
154 .map_err(|e| connection_err(&e))?
155 }
156 AuthType::Database => {
157 let namespace = payload_str(&payload, "namespace")?;
158 let database = payload_str(&payload, "database")?;
159 let username = payload_str(&payload, "username")?;
160 let password = payload_str(&payload, "password")?;
161 self.inner
162 .signin(SdkDatabase {
163 namespace,
164 database,
165 username,
166 password,
167 })
168 .await
169 .map_err(|e| connection_err(&e))?
170 }
171 AuthType::Scope => {
172 let namespace = payload_str(&payload, "namespace")?;
173 let database = payload_str(&payload, "database")?;
174 let access = payload_str(&payload, "access")?;
175 let mut params = serde_json::Map::new();
180 for (k, v) in &payload {
181 if !matches!(k.as_str(), "namespace" | "database" | "access") {
182 params.insert(k.clone(), v.clone());
183 }
184 }
185 self.inner
186 .signin(SdkRecord {
187 namespace,
188 database,
189 access,
190 params: Value::Object(params),
191 })
192 .await
193 .map_err(|e| connection_err(&e))?
194 }
195 };
196 Ok(TokenAuth::new(token.access.into_insecure_token()))
197 }
198
199 pub async fn signup(&self, creds: &ScopeCredentials) -> Result<TokenAuth> {
201 self.require_connected()?;
202 let mut params = serde_json::Map::new();
203 for (k, v) in &creds.variables {
204 params.insert(k.clone(), v.clone());
205 }
206 let token = self
207 .inner
208 .signup(SdkRecord {
209 namespace: creds.namespace.clone(),
210 database: creds.database.clone(),
211 access: creds.access.clone(),
212 params: Value::Object(params),
213 })
214 .await
215 .map_err(|e| connection_err(&e))?;
216 Ok(TokenAuth::new(token.access.into_insecure_token()))
217 }
218
219 pub async fn authenticate(&self, token: &str) -> Result<()> {
221 self.require_connected()?;
222 self.inner
223 .authenticate(Token::from(token))
224 .await
225 .map_err(|e| connection_err(&e))?;
226 Ok(())
227 }
228
229 pub async fn invalidate(&self) -> Result<()> {
231 self.require_connected()?;
232 self.inner
233 .invalidate()
234 .await
235 .map_err(|e| connection_err(&e))?;
236 Ok(())
237 }
238
239 pub async fn query(&self, surql: &str) -> Result<Value> {
242 self.query_with_vars(surql, BTreeMap::new()).await
243 }
244
245 pub async fn query_with_vars(
247 &self,
248 surql: &str,
249 vars: BTreeMap<String, Value>,
250 ) -> Result<Value> {
251 self.require_connected()?;
252 let mut builder = self.inner.query(surql.to_owned());
253 for (k, v) in vars {
254 builder = builder.bind((k, v));
260 }
261 let mut response = builder.await.map_err(|e| query_err(&e))?;
262 let count = response.num_statements();
263 let mut out = Vec::with_capacity(count);
264 for i in 0..count {
265 let raw: surrealdb::types::Value = response.take(i).map_err(|e| query_err(&e))?;
272 out.push(raw.into_json_value());
273 }
274 Ok(Value::Array(out))
275 }
276
277 pub async fn select<T: DeserializeOwned>(&self, target: &str) -> Result<Vec<T>> {
284 self.require_connected()?;
285 let surql = format!("SELECT * FROM {target};");
286 let raw = self.query(&surql).await?;
287 flatten_rows_typed(&raw)
288 }
289
290 pub async fn create<T>(&self, target: &str, data: T) -> Result<T>
292 where
293 T: Serialize + DeserializeOwned + Send + Sync + 'static,
294 {
295 self.require_connected()?;
296 let content = serde_json::to_value(&data).map_err(|e| SurqlError::Serialization {
297 reason: e.to_string(),
298 })?;
299 let mut vars: BTreeMap<String, Value> = BTreeMap::new();
300 vars.insert("data".into(), content);
301 let surql = format!("CREATE {target} CONTENT $data;");
302 let raw = self.query_with_vars(&surql, vars).await?;
303 first_row_typed(&raw)?.ok_or_else(|| SurqlError::Query {
304 reason: format!("CREATE on {target} returned no record"),
305 })
306 }
307
308 pub async fn update<T>(&self, target: &str, data: T) -> Result<T>
310 where
311 T: Serialize + DeserializeOwned + Send + Sync + 'static,
312 {
313 self.require_connected()?;
314 let content = serde_json::to_value(&data).map_err(|e| SurqlError::Serialization {
315 reason: e.to_string(),
316 })?;
317 let mut vars: BTreeMap<String, Value> = BTreeMap::new();
318 vars.insert("data".into(), content);
319 let surql = format!("UPDATE {target} CONTENT $data;");
320 let raw = self.query_with_vars(&surql, vars).await?;
321 first_row_typed(&raw)?.ok_or_else(|| SurqlError::Query {
322 reason: format!("UPDATE on {target} returned no record"),
323 })
324 }
325
326 pub async fn merge<D, T>(&self, target: &str, data: D) -> Result<T>
332 where
333 D: Serialize + Send + Sync + 'static,
334 T: DeserializeOwned + Send + Sync + 'static,
335 {
336 self.require_connected()?;
337 let patch = serde_json::to_value(&data).map_err(|e| SurqlError::Serialization {
338 reason: e.to_string(),
339 })?;
340 let mut vars: BTreeMap<String, Value> = BTreeMap::new();
341 vars.insert("patch".into(), patch);
342 let surql = format!("UPDATE {target} MERGE $patch;");
343 let raw = self.query_with_vars(&surql, vars).await?;
344 first_row_typed(&raw)?.ok_or_else(|| SurqlError::Query {
345 reason: format!("MERGE on {target} returned no record"),
346 })
347 }
348
349 pub async fn delete<T: DeserializeOwned>(&self, target: &str) -> Result<Vec<T>> {
351 self.require_connected()?;
352 let surql = format!("DELETE {target} RETURN BEFORE;");
353 let raw = self.query(&surql).await?;
354 flatten_rows_typed(&raw)
355 }
356
357 pub async fn health(&self) -> Result<bool> {
359 self.require_connected()?;
360 match self.inner.health().await {
361 Ok(()) => Ok(true),
362 Err(_) => Ok(false),
363 }
364 }
365
366 async fn connect_once(&self) -> Result<()> {
369 let timeout = Duration::from_secs_f64(self.config.timeout().max(0.1));
370
371 tokio::time::timeout(timeout, self.inner.connect(self.config.url().to_owned()))
372 .await
373 .map_err(|_| SurqlError::Connection {
374 reason: format!("connect timed out after {timeout:?}"),
375 })?
376 .map_err(|e| connection_err(&e))?;
377
378 if let (Some(user), Some(pass)) = (self.config.username(), self.config.password()) {
379 self.inner
380 .signin(SdkRoot {
381 username: user.to_owned(),
382 password: pass.to_owned(),
383 })
384 .await
385 .map_err(|e| connection_err(&e))?;
386 }
387
388 self.inner
389 .use_ns(self.config.namespace().to_owned())
390 .use_db(self.config.database().to_owned())
391 .await
392 .map_err(|e| connection_err(&e))?;
393
394 Ok(())
395 }
396
397 fn backoff_for(&self, attempt: u32) -> Duration {
398 let min = self.config.retry_min_wait();
399 let max = self.config.retry_max_wait();
400 let mult = self.config.retry_multiplier();
401 let exp = f64::from(attempt.saturating_sub(1));
402 let secs = (min * mult.powf(exp)).clamp(min, max);
403 Duration::from_secs_f64(secs)
404 }
405
406 fn require_connected(&self) -> Result<()> {
407 if self.is_connected() {
408 Ok(())
409 } else {
410 Err(SurqlError::Connection {
411 reason: "client is not connected to database".into(),
412 })
413 }
414 }
415}
416
417impl From<surrealdb::Error> for SurqlError {
418 fn from(err: surrealdb::Error) -> Self {
419 classify_surrealdb_error(&err, err.to_string())
425 }
426}
427
428fn classify_surrealdb_error(err: &surrealdb::Error, msg: String) -> SurqlError {
429 if err.is_connection() {
430 return SurqlError::Connection { reason: msg };
431 }
432 if err.is_query() || err.is_not_found() || err.is_not_allowed() || err.is_thrown() {
433 return SurqlError::Query { reason: msg };
434 }
435 if err.is_serialization() {
436 return SurqlError::Serialization { reason: msg };
437 }
438 let lowered = msg.to_lowercase();
439 if lowered.contains("transaction") {
440 return SurqlError::Transaction { reason: msg };
441 }
442 if lowered.contains("connect")
443 || lowered.contains("not connected")
444 || lowered.contains("websocket")
445 || lowered.contains("timed out")
446 || lowered.contains("subprotocol")
447 {
448 return SurqlError::Connection { reason: msg };
449 }
450 SurqlError::Database { reason: msg }
451}
452
453pub(crate) fn connection_err(err: &surrealdb::Error) -> SurqlError {
454 SurqlError::Connection {
455 reason: err.to_string(),
456 }
457}
458
459pub(crate) fn query_err(err: &surrealdb::Error) -> SurqlError {
460 classify_surrealdb_error(err, err.to_string())
461}
462
463fn flatten_rows_typed<T: DeserializeOwned>(raw: &Value) -> Result<Vec<T>> {
465 let mut out: Vec<T> = Vec::new();
466 collect_rows(raw, &mut out)?;
467 Ok(out)
468}
469
470fn collect_rows<T: DeserializeOwned>(value: &Value, out: &mut Vec<T>) -> Result<()> {
471 match value {
472 Value::Null => Ok(()),
473 Value::Array(items) => {
474 for item in items {
475 collect_rows(item, out)?;
476 }
477 Ok(())
478 }
479 Value::Object(obj) => {
480 if let Some(inner) = obj.get("result") {
481 return collect_rows(inner, out);
482 }
483 let row: T = serde_json::from_value(Value::Object(obj.clone())).map_err(|e| {
484 SurqlError::Serialization {
485 reason: e.to_string(),
486 }
487 })?;
488 out.push(row);
489 Ok(())
490 }
491 other => {
492 let row: T =
493 serde_json::from_value(other.clone()).map_err(|e| SurqlError::Serialization {
494 reason: e.to_string(),
495 })?;
496 out.push(row);
497 Ok(())
498 }
499 }
500}
501
502fn first_row_typed<T: DeserializeOwned>(raw: &Value) -> Result<Option<T>> {
503 let mut rows: Vec<T> = flatten_rows_typed(raw)?;
504 Ok(if rows.is_empty() {
505 None
506 } else {
507 Some(rows.remove(0))
508 })
509}
510
511fn payload_str(map: &serde_json::Map<String, Value>, key: &str) -> Result<String> {
512 match map.get(key) {
513 Some(Value::String(s)) => Ok(s.clone()),
514 Some(_) => Err(SurqlError::Validation {
515 reason: format!("credential field {key:?} must be a string"),
516 }),
517 None => Err(SurqlError::Validation {
518 reason: format!("credential field {key:?} is missing"),
519 }),
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526 use crate::connection::auth::RootCredentials;
527
528 #[test]
529 fn new_validates_config() {
530 let cfg = ConnectionConfig::default();
531 let client = DatabaseClient::new(cfg).expect("valid default config");
532 assert!(!client.is_connected());
533 }
534
535 #[test]
536 fn new_rejects_invalid_config() {
537 let bad = ConnectionConfig {
538 db_url: "ftp://nope".into(),
539 ..Default::default()
540 };
541 assert!(DatabaseClient::new(bad).is_err());
542 }
543
544 #[test]
545 fn flatten_rows_typed_handles_wrapped_and_flat_shapes() {
546 #[derive(serde::Deserialize, Debug, PartialEq)]
547 struct Row {
548 name: String,
549 }
550 let wrapped = serde_json::json!([
551 { "result": [{ "name": "alice" }, { "name": "bob" }] }
552 ]);
553 let rows: Vec<Row> = flatten_rows_typed(&wrapped).unwrap();
554 assert_eq!(rows.len(), 2);
555 assert_eq!(rows[0].name, "alice");
556
557 let flat = serde_json::json!([[{ "name": "carol" }]]);
558 let rows: Vec<Row> = flatten_rows_typed(&flat).unwrap();
559 assert_eq!(rows.len(), 1);
560 assert_eq!(rows[0].name, "carol");
561 }
562
563 #[test]
564 fn first_row_typed_returns_none_for_empty_array() {
565 #[derive(serde::Deserialize, Debug)]
566 struct Row {
567 #[allow(dead_code)]
568 name: String,
569 }
570 let raw = serde_json::json!([[]]);
571 let row: Option<Row> = first_row_typed(&raw).unwrap();
572 assert!(row.is_none());
573 }
574
575 #[test]
576 fn payload_str_round_trip() {
577 let creds = RootCredentials::new("root", "secret");
578 let m = creds.to_signin_payload();
579 assert_eq!(payload_str(&m, "username").unwrap(), "root");
580 assert_eq!(payload_str(&m, "password").unwrap(), "secret");
581 assert!(payload_str(&m, "missing").is_err());
582 }
583
584 #[tokio::test]
585 async fn disconnect_when_never_connected_is_ok() {
586 let client = DatabaseClient::new(ConnectionConfig::default()).unwrap();
587 client.disconnect().await.unwrap();
588 assert!(!client.is_connected());
589 }
590
591 #[tokio::test]
592 async fn operations_fail_when_not_connected() {
593 let client = DatabaseClient::new(ConnectionConfig::default()).unwrap();
594 let err = client.query("INFO FOR DB").await.unwrap_err();
595 assert!(matches!(err, SurqlError::Connection { .. }));
596 }
597
598 #[test]
599 fn backoff_respects_bounds() {
600 let cfg = ConnectionConfig {
601 db_retry_min_wait: 0.5,
602 db_retry_max_wait: 4.0,
603 db_retry_multiplier: 2.0,
604 ..Default::default()
605 };
606 let client = DatabaseClient::new(cfg).unwrap();
607 let a1 = client.backoff_for(1);
608 let a5 = client.backoff_for(5);
609 assert!(a1 >= Duration::from_secs_f64(0.5));
610 assert!(a5 <= Duration::from_secs_f64(4.0));
611 }
612
613 #[test]
614 fn surrealdb_error_maps_to_surql_error() {
615 let thrown: SurqlError = surrealdb::Error::thrown("boom".into()).into();
620 assert!(matches!(thrown, SurqlError::Query { .. }));
621
622 let connection: SurqlError = surrealdb::Error::connection("down".into(), None).into();
623 assert!(matches!(connection, SurqlError::Connection { .. }));
624
625 let internal: SurqlError = surrealdb::Error::internal("boom".into()).into();
626 assert!(matches!(internal, SurqlError::Database { .. }));
627 }
628}