Skip to main content

surql/connection/
client.rs

1//! Async SurrealDB client wrapper.
2//!
3//! Port of `surql/connection/client.py`. Wraps
4//! [`surrealdb::Surreal<surrealdb::engine::any::Any>`], which picks the
5//! underlying engine (WebSocket, HTTP, in-memory, file, `SurrealKV`) from
6//! the URL at runtime. Retry logic, connection timeout, and
7//! auth-level dispatch mirror the Python client one-for-one.
8//!
9//! Targets the `surrealdb` crate 3.x line, which removed the
10//! top-level `api::` module in favour of `engine::`, replaced the
11//! opaque `Jwt` return on signin with a structured `Token`, and made
12//! the `SurrealValue` trait the typed-call envelope. For the typed
13//! CRUD helpers exposed by [`DatabaseClient`] we intentionally round
14//! through raw SurrealQL + `serde_json::Value` so callers only need
15//! `serde::Serialize + serde::de::DeserializeOwned` bounds on their
16//! types (not `SurrealValue`).
17
18use std::collections::BTreeMap;
19use std::sync::Arc;
20use std::time::Duration;
21
22use serde::{de::DeserializeOwned, Serialize};
23use serde_json::Value;
24use surrealdb::engine::any::Any;
25use surrealdb::opt::auth::{
26    Database as SdkDatabase, Namespace as SdkNamespace, Record as SdkRecord, Root as SdkRoot, Token,
27};
28use surrealdb::Surreal;
29use tokio::sync::RwLock;
30use tokio::time::sleep;
31
32use crate::connection::auth::{AuthType, Credentials, ScopeCredentials, TokenAuth};
33use crate::connection::config::ConnectionConfig;
34use crate::error::{Result, SurqlError};
35
36/// Async SurrealDB client with connection + retry management.
37///
38/// This is a thin wrapper over [`surrealdb::Surreal`] bound to the
39/// dynamic [`Any`] engine. All methods are `async` and cancellation-safe
40/// at the tokio level.
41///
42/// The client is `Clone`-able: every clone shares the same underlying
43/// connection (the `surrealdb` SDK holds its own `Arc`).
44#[derive(Debug, Clone)]
45pub struct DatabaseClient {
46    config: ConnectionConfig,
47    inner: Surreal<Any>,
48    connected: Arc<RwLock<bool>>,
49}
50
51impl DatabaseClient {
52    /// Build a new client. Does **not** open a network connection; call
53    /// [`DatabaseClient::connect`] for that.
54    pub fn new(config: ConnectionConfig) -> Result<Self> {
55        config.validate()?;
56        Ok(Self {
57            config,
58            inner: Surreal::init(),
59            connected: Arc::new(RwLock::new(false)),
60        })
61    }
62
63    /// Borrow the underlying configuration.
64    pub fn config(&self) -> &ConnectionConfig {
65        &self.config
66    }
67
68    /// Borrow the underlying SurrealDB SDK handle (advanced usage).
69    pub fn inner(&self) -> &Surreal<Any> {
70        &self.inner
71    }
72
73    /// Return `true` if [`DatabaseClient::connect`] has completed successfully.
74    pub fn is_connected(&self) -> bool {
75        self.connected.try_read().is_ok_and(|g| *g)
76    }
77
78    /// Establish the connection and select the configured namespace / database.
79    ///
80    /// Retries with exponential backoff up to
81    /// [`ConnectionConfig::retry_max_attempts`] times; each attempt is
82    /// bounded by [`ConnectionConfig::timeout`].
83    pub async fn connect(&self) -> Result<()> {
84        // Reconnect is idempotent: disconnect any previous session first.
85        if *self.connected.read().await {
86            self.disconnect().await.ok();
87        }
88
89        let attempts = self.config.retry_max_attempts().max(1);
90        let mut last_err: Option<SurqlError> = None;
91
92        for attempt in 1..=attempts {
93            match self.connect_once().await {
94                Ok(()) => {
95                    *self.connected.write().await = true;
96                    return Ok(());
97                }
98                Err(err) => {
99                    last_err = Some(err);
100                    if attempt < attempts {
101                        let wait = self.backoff_for(attempt);
102                        sleep(wait).await;
103                    }
104                }
105            }
106        }
107
108        Err(last_err.unwrap_or_else(|| SurqlError::Connection {
109            reason: format!("connection failed after {attempts} attempts"),
110        }))
111    }
112
113    /// Close the underlying connection. Safe to call even if not connected.
114    pub async fn disconnect(&self) -> Result<()> {
115        {
116            let mut guard = self.connected.write().await;
117            if !*guard {
118                return Ok(());
119            }
120            *guard = false;
121        }
122        // The SDK exposes `invalidate` to clear auth, but there is no
123        // explicit disconnect on `Surreal<Any>` beyond dropping the
124        // handle. We invalidate the session so subsequent calls fail
125        // cleanly.
126        self.inner.invalidate().await.ok();
127        Ok(())
128    }
129
130    /// Sign in using one of the four auth levels.
131    pub async fn signin<C: Credentials + ?Sized>(&self, creds: &C) -> Result<TokenAuth> {
132        self.require_connected()?;
133        let payload = creds.to_signin_payload();
134        let token = match creds.auth_type() {
135            AuthType::Root => {
136                let username = payload_str(&payload, "username")?;
137                let password = payload_str(&payload, "password")?;
138                self.inner
139                    .signin(SdkRoot { username, password })
140                    .await
141                    .map_err(|e| connection_err(&e))?
142            }
143            AuthType::Namespace => {
144                let namespace = payload_str(&payload, "namespace")?;
145                let username = payload_str(&payload, "username")?;
146                let password = payload_str(&payload, "password")?;
147                self.inner
148                    .signin(SdkNamespace {
149                        namespace,
150                        username,
151                        password,
152                    })
153                    .await
154                    .map_err(|e| connection_err(&e))?
155            }
156            AuthType::Database => {
157                let namespace = payload_str(&payload, "namespace")?;
158                let database = payload_str(&payload, "database")?;
159                let username = payload_str(&payload, "username")?;
160                let password = payload_str(&payload, "password")?;
161                self.inner
162                    .signin(SdkDatabase {
163                        namespace,
164                        database,
165                        username,
166                        password,
167                    })
168                    .await
169                    .map_err(|e| connection_err(&e))?
170            }
171            AuthType::Scope => {
172                let namespace = payload_str(&payload, "namespace")?;
173                let database = payload_str(&payload, "database")?;
174                let access = payload_str(&payload, "access")?;
175                // Everything else is scope-defined vars. In v3 the
176                // `Record` credential is generic over `P: SurrealValue`;
177                // `serde_json::Value` implements it, so we bundle the
178                // remaining credential fields into a JSON object.
179                let mut params = serde_json::Map::new();
180                for (k, v) in &payload {
181                    if !matches!(k.as_str(), "namespace" | "database" | "access") {
182                        params.insert(k.clone(), v.clone());
183                    }
184                }
185                self.inner
186                    .signin(SdkRecord {
187                        namespace,
188                        database,
189                        access,
190                        params: Value::Object(params),
191                    })
192                    .await
193                    .map_err(|e| connection_err(&e))?
194            }
195        };
196        Ok(TokenAuth::new(token.access.into_insecure_token()))
197    }
198
199    /// Sign up a scope user (record access).
200    pub async fn signup(&self, creds: &ScopeCredentials) -> Result<TokenAuth> {
201        self.require_connected()?;
202        let mut params = serde_json::Map::new();
203        for (k, v) in &creds.variables {
204            params.insert(k.clone(), v.clone());
205        }
206        let token = self
207            .inner
208            .signup(SdkRecord {
209                namespace: creds.namespace.clone(),
210                database: creds.database.clone(),
211                access: creds.access.clone(),
212                params: Value::Object(params),
213            })
214            .await
215            .map_err(|e| connection_err(&e))?;
216        Ok(TokenAuth::new(token.access.into_insecure_token()))
217    }
218
219    /// Authenticate using a previously-issued JWT.
220    pub async fn authenticate(&self, token: &str) -> Result<()> {
221        self.require_connected()?;
222        self.inner
223            .authenticate(Token::from(token))
224            .await
225            .map_err(|e| connection_err(&e))?;
226        Ok(())
227    }
228
229    /// Invalidate the current session.
230    pub async fn invalidate(&self) -> Result<()> {
231        self.require_connected()?;
232        self.inner
233            .invalidate()
234            .await
235            .map_err(|e| connection_err(&e))?;
236        Ok(())
237    }
238
239    /// Execute a raw SurrealQL query and return every statement's result
240    /// as a JSON array (one entry per statement).
241    pub async fn query(&self, surql: &str) -> Result<Value> {
242        self.query_with_vars(surql, BTreeMap::new()).await
243    }
244
245    /// Execute a raw SurrealQL query with bound variables.
246    pub async fn query_with_vars(
247        &self,
248        surql: &str,
249        vars: BTreeMap<String, Value>,
250    ) -> Result<Value> {
251        self.require_connected()?;
252        let mut builder = self.inner.query(surql.to_owned());
253        for (k, v) in vars {
254            // In 3.x the `bind` input must implement `SurrealValue`;
255            // `(String, serde_json::Value)` qualifies because both
256            // components do (and tuples are encoded as 2-element
257            // arrays which `into_variables` unpacks as key/value
258            // chunks).
259            builder = builder.bind((k, v));
260        }
261        let mut response = builder.await.map_err(|e| query_err(&e))?;
262        let count = response.num_statements();
263        let mut out = Vec::with_capacity(count);
264        for i in 0..count {
265            // `IndexedResults::take(usize)` in 3.x only accepts
266            // `surrealdb::types::Value` / `Vec<T>` / `Option<T>` for
267            // index-based retrieval. Take the core `Value` (which
268            // preserves record IDs, durations, decimals, etc.) and
269            // downgrade to `serde_json::Value` via
270            // `into_json_value`.
271            let raw: surrealdb::types::Value = response.take(i).map_err(|e| query_err(&e))?;
272            out.push(raw.into_json_value());
273        }
274        Ok(Value::Array(out))
275    }
276
277    /// Typed `SELECT` against a table or record ID (`"user"` / `"user:alice"`).
278    ///
279    /// Internally routes through raw SurrealQL + `serde_json::Value`
280    /// so callers only need `serde::de::DeserializeOwned`; the 3.x
281    /// SDK's typed `select` would force a `SurrealValue` bound on
282    /// `T`, which would be a breaking change for existing users.
283    pub async fn select<T: DeserializeOwned>(&self, target: &str) -> Result<Vec<T>> {
284        self.require_connected()?;
285        let surql = format!("SELECT * FROM {target};");
286        let raw = self.query(&surql).await?;
287        flatten_rows_typed(&raw)
288    }
289
290    /// Typed `CREATE`. Returns the created record.
291    pub async fn create<T>(&self, target: &str, data: T) -> Result<T>
292    where
293        T: Serialize + DeserializeOwned + Send + Sync + 'static,
294    {
295        self.require_connected()?;
296        let content = serde_json::to_value(&data).map_err(|e| SurqlError::Serialization {
297            reason: e.to_string(),
298        })?;
299        let mut vars: BTreeMap<String, Value> = BTreeMap::new();
300        vars.insert("data".into(), content);
301        let surql = format!("CREATE {target} CONTENT $data;");
302        let raw = self.query_with_vars(&surql, vars).await?;
303        first_row_typed(&raw)?.ok_or_else(|| SurqlError::Query {
304            reason: format!("CREATE on {target} returned no record"),
305        })
306    }
307
308    /// Typed `UPDATE`. Returns the updated record.
309    pub async fn update<T>(&self, target: &str, data: T) -> Result<T>
310    where
311        T: Serialize + DeserializeOwned + Send + Sync + 'static,
312    {
313        self.require_connected()?;
314        let content = serde_json::to_value(&data).map_err(|e| SurqlError::Serialization {
315            reason: e.to_string(),
316        })?;
317        let mut vars: BTreeMap<String, Value> = BTreeMap::new();
318        vars.insert("data".into(), content);
319        let surql = format!("UPDATE {target} CONTENT $data;");
320        let raw = self.query_with_vars(&surql, vars).await?;
321        first_row_typed(&raw)?.ok_or_else(|| SurqlError::Query {
322            reason: format!("UPDATE on {target} returned no record"),
323        })
324    }
325
326    /// Typed `MERGE`. Returns the merged record.
327    ///
328    /// The input (`D`) is a partial patch; the output (`T`) is the full
329    /// merged record. Pass a `serde_json::Value` or a dedicated patch
330    /// struct for `D`.
331    pub async fn merge<D, T>(&self, target: &str, data: D) -> Result<T>
332    where
333        D: Serialize + Send + Sync + 'static,
334        T: DeserializeOwned + Send + Sync + 'static,
335    {
336        self.require_connected()?;
337        let patch = serde_json::to_value(&data).map_err(|e| SurqlError::Serialization {
338            reason: e.to_string(),
339        })?;
340        let mut vars: BTreeMap<String, Value> = BTreeMap::new();
341        vars.insert("patch".into(), patch);
342        let surql = format!("UPDATE {target} MERGE $patch;");
343        let raw = self.query_with_vars(&surql, vars).await?;
344        first_row_typed(&raw)?.ok_or_else(|| SurqlError::Query {
345            reason: format!("MERGE on {target} returned no record"),
346        })
347    }
348
349    /// Typed `DELETE`. Returns the deleted records.
350    pub async fn delete<T: DeserializeOwned>(&self, target: &str) -> Result<Vec<T>> {
351        self.require_connected()?;
352        let surql = format!("DELETE {target} RETURN BEFORE;");
353        let raw = self.query(&surql).await?;
354        flatten_rows_typed(&raw)
355    }
356
357    /// Server-side health check (wraps `Surreal::health`).
358    pub async fn health(&self) -> Result<bool> {
359        self.require_connected()?;
360        match self.inner.health().await {
361            Ok(()) => Ok(true),
362            Err(_) => Ok(false),
363        }
364    }
365
366    // -- internal ----------------------------------------------------------
367
368    async fn connect_once(&self) -> Result<()> {
369        let timeout = Duration::from_secs_f64(self.config.timeout().max(0.1));
370
371        tokio::time::timeout(timeout, self.inner.connect(self.config.url().to_owned()))
372            .await
373            .map_err(|_| SurqlError::Connection {
374                reason: format!("connect timed out after {timeout:?}"),
375            })?
376            .map_err(|e| connection_err(&e))?;
377
378        if let (Some(user), Some(pass)) = (self.config.username(), self.config.password()) {
379            self.inner
380                .signin(SdkRoot {
381                    username: user.to_owned(),
382                    password: pass.to_owned(),
383                })
384                .await
385                .map_err(|e| connection_err(&e))?;
386        }
387
388        self.inner
389            .use_ns(self.config.namespace().to_owned())
390            .use_db(self.config.database().to_owned())
391            .await
392            .map_err(|e| connection_err(&e))?;
393
394        Ok(())
395    }
396
397    fn backoff_for(&self, attempt: u32) -> Duration {
398        let min = self.config.retry_min_wait();
399        let max = self.config.retry_max_wait();
400        let mult = self.config.retry_multiplier();
401        let exp = f64::from(attempt.saturating_sub(1));
402        let secs = (min * mult.powf(exp)).clamp(min, max);
403        Duration::from_secs_f64(secs)
404    }
405
406    fn require_connected(&self) -> Result<()> {
407        if self.is_connected() {
408            Ok(())
409        } else {
410            Err(SurqlError::Connection {
411                reason: "client is not connected to database".into(),
412            })
413        }
414    }
415}
416
417impl From<surrealdb::Error> for SurqlError {
418    fn from(err: surrealdb::Error) -> Self {
419        // 3.x unifies `Error` into a single struct with a `kind_str()`
420        // discriminator and a human-readable message. Map the relevant
421        // kinds onto the richer `SurqlError` taxonomy; fall back to a
422        // substring match on the message for anything not yet modelled
423        // in the typed details.
424        classify_surrealdb_error(&err, err.to_string())
425    }
426}
427
428fn classify_surrealdb_error(err: &surrealdb::Error, msg: String) -> SurqlError {
429    if err.is_connection() {
430        return SurqlError::Connection { reason: msg };
431    }
432    if err.is_query() || err.is_not_found() || err.is_not_allowed() || err.is_thrown() {
433        return SurqlError::Query { reason: msg };
434    }
435    if err.is_serialization() {
436        return SurqlError::Serialization { reason: msg };
437    }
438    let lowered = msg.to_lowercase();
439    if lowered.contains("transaction") {
440        return SurqlError::Transaction { reason: msg };
441    }
442    if lowered.contains("connect")
443        || lowered.contains("not connected")
444        || lowered.contains("websocket")
445        || lowered.contains("timed out")
446        || lowered.contains("subprotocol")
447    {
448        return SurqlError::Connection { reason: msg };
449    }
450    SurqlError::Database { reason: msg }
451}
452
453pub(crate) fn connection_err(err: &surrealdb::Error) -> SurqlError {
454    SurqlError::Connection {
455        reason: err.to_string(),
456    }
457}
458
459pub(crate) fn query_err(err: &surrealdb::Error) -> SurqlError {
460    classify_surrealdb_error(err, err.to_string())
461}
462
463/// Flatten every row in the raw `query()` response into a typed vector.
464fn flatten_rows_typed<T: DeserializeOwned>(raw: &Value) -> Result<Vec<T>> {
465    let mut out: Vec<T> = Vec::new();
466    collect_rows(raw, &mut out)?;
467    Ok(out)
468}
469
470fn collect_rows<T: DeserializeOwned>(value: &Value, out: &mut Vec<T>) -> Result<()> {
471    match value {
472        Value::Null => Ok(()),
473        Value::Array(items) => {
474            for item in items {
475                collect_rows(item, out)?;
476            }
477            Ok(())
478        }
479        Value::Object(obj) => {
480            if let Some(inner) = obj.get("result") {
481                return collect_rows(inner, out);
482            }
483            let row: T = serde_json::from_value(Value::Object(obj.clone())).map_err(|e| {
484                SurqlError::Serialization {
485                    reason: e.to_string(),
486                }
487            })?;
488            out.push(row);
489            Ok(())
490        }
491        other => {
492            let row: T =
493                serde_json::from_value(other.clone()).map_err(|e| SurqlError::Serialization {
494                    reason: e.to_string(),
495                })?;
496            out.push(row);
497            Ok(())
498        }
499    }
500}
501
502fn first_row_typed<T: DeserializeOwned>(raw: &Value) -> Result<Option<T>> {
503    let mut rows: Vec<T> = flatten_rows_typed(raw)?;
504    Ok(if rows.is_empty() {
505        None
506    } else {
507        Some(rows.remove(0))
508    })
509}
510
511fn payload_str(map: &serde_json::Map<String, Value>, key: &str) -> Result<String> {
512    match map.get(key) {
513        Some(Value::String(s)) => Ok(s.clone()),
514        Some(_) => Err(SurqlError::Validation {
515            reason: format!("credential field {key:?} must be a string"),
516        }),
517        None => Err(SurqlError::Validation {
518            reason: format!("credential field {key:?} is missing"),
519        }),
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526    use crate::connection::auth::RootCredentials;
527
528    #[test]
529    fn new_validates_config() {
530        let cfg = ConnectionConfig::default();
531        let client = DatabaseClient::new(cfg).expect("valid default config");
532        assert!(!client.is_connected());
533    }
534
535    #[test]
536    fn new_rejects_invalid_config() {
537        let bad = ConnectionConfig {
538            db_url: "ftp://nope".into(),
539            ..Default::default()
540        };
541        assert!(DatabaseClient::new(bad).is_err());
542    }
543
544    #[test]
545    fn flatten_rows_typed_handles_wrapped_and_flat_shapes() {
546        #[derive(serde::Deserialize, Debug, PartialEq)]
547        struct Row {
548            name: String,
549        }
550        let wrapped = serde_json::json!([
551            { "result": [{ "name": "alice" }, { "name": "bob" }] }
552        ]);
553        let rows: Vec<Row> = flatten_rows_typed(&wrapped).unwrap();
554        assert_eq!(rows.len(), 2);
555        assert_eq!(rows[0].name, "alice");
556
557        let flat = serde_json::json!([[{ "name": "carol" }]]);
558        let rows: Vec<Row> = flatten_rows_typed(&flat).unwrap();
559        assert_eq!(rows.len(), 1);
560        assert_eq!(rows[0].name, "carol");
561    }
562
563    #[test]
564    fn first_row_typed_returns_none_for_empty_array() {
565        #[derive(serde::Deserialize, Debug)]
566        struct Row {
567            #[allow(dead_code)]
568            name: String,
569        }
570        let raw = serde_json::json!([[]]);
571        let row: Option<Row> = first_row_typed(&raw).unwrap();
572        assert!(row.is_none());
573    }
574
575    #[test]
576    fn payload_str_round_trip() {
577        let creds = RootCredentials::new("root", "secret");
578        let m = creds.to_signin_payload();
579        assert_eq!(payload_str(&m, "username").unwrap(), "root");
580        assert_eq!(payload_str(&m, "password").unwrap(), "secret");
581        assert!(payload_str(&m, "missing").is_err());
582    }
583
584    #[tokio::test]
585    async fn disconnect_when_never_connected_is_ok() {
586        let client = DatabaseClient::new(ConnectionConfig::default()).unwrap();
587        client.disconnect().await.unwrap();
588        assert!(!client.is_connected());
589    }
590
591    #[tokio::test]
592    async fn operations_fail_when_not_connected() {
593        let client = DatabaseClient::new(ConnectionConfig::default()).unwrap();
594        let err = client.query("INFO FOR DB").await.unwrap_err();
595        assert!(matches!(err, SurqlError::Connection { .. }));
596    }
597
598    #[test]
599    fn backoff_respects_bounds() {
600        let cfg = ConnectionConfig {
601            db_retry_min_wait: 0.5,
602            db_retry_max_wait: 4.0,
603            db_retry_multiplier: 2.0,
604            ..Default::default()
605        };
606        let client = DatabaseClient::new(cfg).unwrap();
607        let a1 = client.backoff_for(1);
608        let a5 = client.backoff_for(5);
609        assert!(a1 >= Duration::from_secs_f64(0.5));
610        assert!(a5 <= Duration::from_secs_f64(4.0));
611    }
612
613    #[test]
614    fn surrealdb_error_maps_to_surql_error() {
615        // In 3.x `surrealdb::Error` is a single struct with typed
616        // variants exposed via predicate methods. Use the public
617        // constructor helpers to synthesise representative cases and
618        // assert they map onto the expected `SurqlError` variants.
619        let thrown: SurqlError = surrealdb::Error::thrown("boom".into()).into();
620        assert!(matches!(thrown, SurqlError::Query { .. }));
621
622        let connection: SurqlError = surrealdb::Error::connection("down".into(), None).into();
623        assert!(matches!(connection, SurqlError::Connection { .. }));
624
625        let internal: SurqlError = surrealdb::Error::internal("boom".into()).into();
626        assert!(matches!(internal, SurqlError::Database { .. }));
627    }
628}