Skip to main content

surreal_client/
connection.rs

1//! Connection builder for SurrealDB with authentication and engine creation
2
3use crate::{DebugEngine, Engine, Result, SurrealClient, SurrealError, WsCborEngine};
4
5use serde_json::Value;
6use url::Url;
7
8/// Connection builder for SurrealDB
9#[derive(Default, Debug, Clone)]
10pub struct SurrealConnection {
11    /// URL to connect to
12    pub url: Option<String>,
13
14    /// Namespace to use
15    namespace: Option<String>,
16
17    /// Database to use
18    database: Option<String>,
19
20    /// Authentication credentials
21    auth: Option<AuthParams>,
22
23    /// Whether to check SurrealDB version compatibility
24    version_check: bool,
25
26    /// Whether to enable debug mode for query logging
27    debug: bool,
28}
29
30/// Authentication parameters
31#[derive(Debug, Clone)]
32pub enum AuthParams {
33    /// Root authentication
34    Root { username: String, password: String },
35    /// Namespace authentication
36    Namespace { username: String, password: String },
37    /// Database authentication
38    Database { username: String, password: String },
39    /// Scope authentication
40    Scope {
41        namespace: String,
42        database: String,
43        scope: String,
44        params: Value,
45    },
46    /// JWT token authentication
47    Token(String),
48}
49
50impl SurrealConnection {
51    /// Create a new connection builder
52    pub fn new() -> Self {
53        Self {
54            version_check: true,
55            debug: false,
56            ..Default::default()
57        }
58    }
59
60    /// Parse connection from DSN string
61    pub fn dsn(dsn: impl AsRef<str>) -> Result<Self> {
62        let mut conn = Self::new();
63        let url = Url::parse(dsn.as_ref())?;
64
65        // Ensure URL has a proper host
66        if url.host().is_none() {
67            return Err(SurrealError::Connection(
68                "URL must have a valid host".to_string(),
69            ));
70        }
71
72        // Store the URL without user credentials and path/query
73        let base_url = format!("{}://{}", url.scheme(), url.host_str().unwrap());
74        let port = url.port().map(|p| format!(":{}", p)).unwrap_or_default();
75        let final_url = format!("{}{}", base_url, port);
76        conn.url = Some(final_url);
77
78        // Extract user credentials for root auth
79        if !url.username().is_empty() {
80            let username = url.username().to_string();
81            let password = url.password().unwrap_or("").to_string();
82            conn.auth = Some(AuthParams::Root { username, password });
83        }
84
85        // Extract namespace and database from path segments
86        let path_segments: Vec<&str> = url.path_segments().map(|c| c.collect()).unwrap_or_default();
87
88        if let Some(namespace) = path_segments.first().filter(|s| !s.is_empty()) {
89            conn.namespace = Some(namespace.to_string());
90        }
91        if let Some(database) = path_segments.get(1).filter(|s| !s.is_empty()) {
92            conn.database = Some(database.to_string());
93        }
94
95        // Parse query parameters
96        for (key, value) in url.query_pairs() {
97            match key.as_ref() {
98                "namespace" => conn.namespace = Some(value.into_owned()),
99                "database" => conn.database = Some(value.into_owned()),
100                "version_check" => {
101                    conn.version_check = value.parse().unwrap_or(true);
102                }
103                _ => {}
104            }
105        }
106
107        Ok(conn)
108    }
109
110    /// Set the URL to connect to
111    pub fn url(mut self, url: impl Into<String>) -> Self {
112        self.url = Some(url.into());
113        self
114    }
115
116    /// Set the namespace
117    pub fn namespace(mut self, namespace: impl Into<String>) -> Self {
118        self.namespace = Some(namespace.into());
119        self
120    }
121
122    /// Set the database
123    pub fn database(mut self, database: impl Into<String>) -> Self {
124        self.database = Some(database.into());
125        self
126    }
127
128    /// Set root authentication
129    pub fn auth_root(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
130        self.auth = Some(AuthParams::Root {
131            username: username.into(),
132            password: password.into(),
133        });
134        self
135    }
136
137    /// Set namespace authentication
138    pub fn auth_namespace(
139        mut self,
140        username: impl Into<String>,
141        password: impl Into<String>,
142    ) -> Self {
143        self.auth = Some(AuthParams::Namespace {
144            username: username.into(),
145            password: password.into(),
146        });
147        self
148    }
149
150    /// Set database authentication
151    pub fn auth_database(
152        mut self,
153        username: impl Into<String>,
154        password: impl Into<String>,
155    ) -> Self {
156        self.auth = Some(AuthParams::Database {
157            username: username.into(),
158            password: password.into(),
159        });
160        self
161    }
162
163    /// Set scope authentication
164    pub fn auth_scope(
165        mut self,
166        namespace: impl Into<String>,
167        database: impl Into<String>,
168        scope: impl Into<String>,
169        params: Value,
170    ) -> Self {
171        self.auth = Some(AuthParams::Scope {
172            namespace: namespace.into(),
173            database: database.into(),
174            scope: scope.into(),
175            params,
176        });
177        self
178    }
179
180    /// Set JWT token authentication
181    pub fn auth_token(mut self, token: impl Into<String>) -> Self {
182        self.auth = Some(AuthParams::Token(token.into()));
183        self
184    }
185
186    /// Set version check flag
187    pub fn version_check(mut self, check: bool) -> Self {
188        self.version_check = check;
189        self
190    }
191
192    /// Enable debug mode for query logging
193    pub fn with_debug(mut self, enabled: bool) -> Self {
194        self.debug = enabled;
195        self
196    }
197
198    // /// Configure connection pool with custom settings
199    // pub fn with_pool_config(mut self, config: PoolConfig) -> Self {
200    //     self.pool_config = Some(config);
201    //     self
202    // }
203
204    pub(crate) async fn init_engine(&self, engine: &mut crate::WsCborEngine) -> Result<()> {
205        use ciborium::Value as CborValue;
206
207        match self.auth.as_ref().ok_or(SurrealError::Connection(
208            "Attempted to connect without auth".to_string(),
209        ))? {
210            AuthParams::Root { username, password } => {
211                let auth_params = CborValue::Array(vec![CborValue::Map(vec![
212                    (
213                        CborValue::Text("user".to_string()),
214                        CborValue::Text(username.clone()),
215                    ),
216                    (
217                        CborValue::Text("pass".to_string()),
218                        CborValue::Text(password.clone()),
219                    ),
220                ])]);
221                engine.send_message_cbor("signin", auth_params).await?;
222            }
223            AuthParams::Namespace { username, password } => {
224                let namespace = self.namespace.clone().ok_or(SurrealError::Connection(
225                    "Namespace is required for namespace auth".to_string(),
226                ))?;
227                let auth_params = CborValue::Array(vec![CborValue::Map(vec![
228                    (
229                        CborValue::Text("user".to_string()),
230                        CborValue::Text(username.clone()),
231                    ),
232                    (
233                        CborValue::Text("pass".to_string()),
234                        CborValue::Text(password.clone()),
235                    ),
236                    (
237                        CborValue::Text("NS".to_string()),
238                        CborValue::Text(namespace),
239                    ),
240                ])]);
241                engine.send_message_cbor("signin", auth_params).await?;
242            }
243            AuthParams::Database { username, password } => {
244                let namespace = self.namespace.clone().ok_or(SurrealError::Connection(
245                    "Namespace is required for database auth".to_string(),
246                ))?;
247                let database = self.database.clone().ok_or(SurrealError::Connection(
248                    "Database is required for database auth".to_string(),
249                ))?;
250                let auth_params = CborValue::Array(vec![CborValue::Map(vec![
251                    (
252                        CborValue::Text("user".to_string()),
253                        CborValue::Text(username.clone()),
254                    ),
255                    (
256                        CborValue::Text("pass".to_string()),
257                        CborValue::Text(password.clone()),
258                    ),
259                    (
260                        CborValue::Text("NS".to_string()),
261                        CborValue::Text(namespace),
262                    ),
263                    (CborValue::Text("DB".to_string()), CborValue::Text(database)),
264                ])]);
265                engine.send_message_cbor("signin", auth_params).await?;
266            }
267            _ => {
268                return Err(SurrealError::Connection(
269                    "Unsupported authentication method".to_string(),
270                ));
271            }
272        }
273
274        if let Some(namespace) = &self.namespace {
275            let use_params = CborValue::Array(vec![
276                CborValue::Text(namespace.clone()),
277                CborValue::Text(self.database.as_ref().unwrap_or(&String::new()).clone()),
278            ]);
279            engine.send_message_cbor("use", use_params).await?;
280        }
281
282        Ok(())
283    }
284
285    /// Connect to SurrealDB and return an immutable client
286    pub async fn connect(self) -> Result<SurrealClient> {
287        let url_str = self
288            .url
289            .as_ref()
290            .ok_or_else(|| SurrealError::Connection("URL is required".to_string()))?;
291        let url = Url::parse(url_str)
292            .map_err(|e| SurrealError::Connection(format!("Invalid URL: {}", e)))?;
293
294        let mut engine: Box<dyn Engine> = match url.scheme() {
295            "ws" | "wss" | "cbor" => Box::new(WsCborEngine::from_connection(&self).await?),
296            _ => {
297                return Err(SurrealError::Protocol(
298                    "Unsupported protocol. Use ws://, wss://, or cbor://".to_string(),
299                ));
300            }
301        };
302
303        if self.debug {
304            engine = DebugEngine::wrap(engine);
305        }
306
307        let client = SurrealClient::new(engine, self.namespace, self.database);
308        Ok(client.with_debug(self.debug))
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn test_connection_builder() {
318        let conn = SurrealConnection::new()
319            .url("ws://localhost:8000")
320            .namespace("test_ns")
321            .database("test_db")
322            .auth_root("root", "root")
323            .version_check(false);
324
325        assert_eq!(conn.url, Some("ws://localhost:8000".to_string()));
326        assert_eq!(conn.namespace, Some("test_ns".to_string()));
327        assert_eq!(conn.database, Some("test_db".to_string()));
328        assert!(!conn.version_check);
329        assert!(matches!(conn.auth, Some(AuthParams::Root { .. })));
330    }
331
332    #[test]
333    fn test_dsn_parsing() {
334        let conn = SurrealConnection::dsn(
335            "ws://root:root@localhost:8000/test_ns/test_db?version_check=false",
336        )
337        .unwrap();
338
339        assert_eq!(conn.url, Some("ws://localhost:8000".to_string()));
340        assert_eq!(conn.namespace, Some("test_ns".to_string()));
341        assert_eq!(conn.database, Some("test_db".to_string()));
342        assert!(!conn.version_check);
343        assert!(matches!(conn.auth, Some(AuthParams::Root { .. })));
344    }
345
346    #[test]
347    fn test_dsn_with_query_params() {
348        let conn =
349            SurrealConnection::dsn("http://localhost:8000?namespace=ns&database=db").unwrap();
350
351        assert_eq!(conn.url, Some("http://localhost:8000".to_string()));
352        assert_eq!(conn.namespace, Some("ns".to_string()));
353        assert_eq!(conn.database, Some("db".to_string()));
354    }
355
356    #[test]
357    fn test_auth_methods() {
358        let conn1 = SurrealConnection::new().auth_root("admin", "pass");
359        assert!(matches!(conn1.auth, Some(AuthParams::Root { .. })));
360
361        let conn2 = SurrealConnection::new().auth_namespace("ns_user", "ns_pass");
362        assert!(matches!(conn2.auth, Some(AuthParams::Namespace { .. })));
363
364        let conn3 = SurrealConnection::new().auth_database("db_user", "db_pass");
365        assert!(matches!(conn3.auth, Some(AuthParams::Database { .. })));
366
367        let conn4 = SurrealConnection::new().auth_token("jwt_token");
368        assert!(matches!(conn4.auth, Some(AuthParams::Token(_))));
369    }
370
371    #[tokio::test]
372    async fn test_connection_to_client_flow() {
373        // Example of the new flow: Connection -> authenticate -> creates engine -> returns immutable client
374
375        // This would be the typical usage:
376        // let client = Connection::new()
377        //     .url("ws://localhost:8000")
378        //     .namespace("bakery")
379        //     .database("inventory")
380        //     .auth_root("root", "root")
381        //     .connect()
382        //     .await
383        //     .unwrap();
384
385        // For testing, we just verify the builder pattern works
386        let connection = SurrealConnection::new()
387            .url("ws://localhost:8000")
388            .namespace("test_namespace")
389            .database("test_database")
390            .auth_root("admin", "password")
391            .version_check(false);
392
393        assert_eq!(connection.url, Some("ws://localhost:8000".to_string()));
394        assert_eq!(connection.namespace, Some("test_namespace".to_string()));
395        assert_eq!(connection.database, Some("test_database".to_string()));
396        assert!(!connection.version_check);
397        assert!(matches!(connection.auth, Some(AuthParams::Root { .. })));
398
399        // The client would be immutable once created:
400        // - client.query() - no mut needed
401        // - client.select() - no mut needed
402        // - client.let_var() - changes session but client stays immutable
403        // - Multiple clients can be cloned, each with unique session
404    }
405}