use crate::{DebugEngine, Engine, Result, SurrealClient, SurrealError, WsCborEngine};
use serde_json::Value;
use url::Url;
#[derive(Default, Debug, Clone)]
pub struct SurrealConnection {
pub url: Option<String>,
namespace: Option<String>,
database: Option<String>,
auth: Option<AuthParams>,
version_check: bool,
debug: bool,
}
#[derive(Debug, Clone)]
pub enum AuthParams {
Root { username: String, password: String },
Namespace { username: String, password: String },
Database { username: String, password: String },
Scope {
namespace: String,
database: String,
scope: String,
params: Value,
},
Token(String),
}
impl SurrealConnection {
pub fn new() -> Self {
Self {
version_check: true,
debug: false,
..Default::default()
}
}
pub fn dsn(dsn: impl AsRef<str>) -> Result<Self> {
let mut conn = Self::new();
let url = Url::parse(dsn.as_ref())?;
if url.host().is_none() {
return Err(SurrealError::Connection(
"URL must have a valid host".to_string(),
));
}
let base_url = format!("{}://{}", url.scheme(), url.host_str().unwrap());
let port = url.port().map(|p| format!(":{}", p)).unwrap_or_default();
let final_url = format!("{}{}", base_url, port);
conn.url = Some(final_url);
if !url.username().is_empty() {
let username = url.username().to_string();
let password = url.password().unwrap_or("").to_string();
conn.auth = Some(AuthParams::Root { username, password });
}
let path_segments: Vec<&str> = url.path_segments().map(|c| c.collect()).unwrap_or_default();
if let Some(namespace) = path_segments.first().filter(|s| !s.is_empty()) {
conn.namespace = Some(namespace.to_string());
}
if let Some(database) = path_segments.get(1).filter(|s| !s.is_empty()) {
conn.database = Some(database.to_string());
}
for (key, value) in url.query_pairs() {
match key.as_ref() {
"namespace" => conn.namespace = Some(value.into_owned()),
"database" => conn.database = Some(value.into_owned()),
"version_check" => {
conn.version_check = value.parse().unwrap_or(true);
}
_ => {}
}
}
Ok(conn)
}
pub fn url(mut self, url: impl Into<String>) -> Self {
self.url = Some(url.into());
self
}
pub fn namespace(mut self, namespace: impl Into<String>) -> Self {
self.namespace = Some(namespace.into());
self
}
pub fn database(mut self, database: impl Into<String>) -> Self {
self.database = Some(database.into());
self
}
pub fn auth_root(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
self.auth = Some(AuthParams::Root {
username: username.into(),
password: password.into(),
});
self
}
pub fn auth_namespace(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.auth = Some(AuthParams::Namespace {
username: username.into(),
password: password.into(),
});
self
}
pub fn auth_database(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.auth = Some(AuthParams::Database {
username: username.into(),
password: password.into(),
});
self
}
pub fn auth_scope(
mut self,
namespace: impl Into<String>,
database: impl Into<String>,
scope: impl Into<String>,
params: Value,
) -> Self {
self.auth = Some(AuthParams::Scope {
namespace: namespace.into(),
database: database.into(),
scope: scope.into(),
params,
});
self
}
pub fn auth_token(mut self, token: impl Into<String>) -> Self {
self.auth = Some(AuthParams::Token(token.into()));
self
}
pub fn version_check(mut self, check: bool) -> Self {
self.version_check = check;
self
}
pub fn with_debug(mut self, enabled: bool) -> Self {
self.debug = enabled;
self
}
pub(crate) async fn init_engine(&self, engine: &mut crate::WsCborEngine) -> Result<()> {
use ciborium::Value as CborValue;
match self.auth.as_ref().ok_or(SurrealError::Connection(
"Attempted to connect without auth".to_string(),
))? {
AuthParams::Root { username, password } => {
let auth_params = CborValue::Array(vec![CborValue::Map(vec![
(
CborValue::Text("user".to_string()),
CborValue::Text(username.clone()),
),
(
CborValue::Text("pass".to_string()),
CborValue::Text(password.clone()),
),
])]);
engine.send_message_cbor("signin", auth_params).await?;
}
AuthParams::Namespace { username, password } => {
let namespace = self.namespace.clone().ok_or(SurrealError::Connection(
"Namespace is required for namespace auth".to_string(),
))?;
let auth_params = CborValue::Array(vec![CborValue::Map(vec![
(
CborValue::Text("user".to_string()),
CborValue::Text(username.clone()),
),
(
CborValue::Text("pass".to_string()),
CborValue::Text(password.clone()),
),
(
CborValue::Text("NS".to_string()),
CborValue::Text(namespace),
),
])]);
engine.send_message_cbor("signin", auth_params).await?;
}
AuthParams::Database { username, password } => {
let namespace = self.namespace.clone().ok_or(SurrealError::Connection(
"Namespace is required for database auth".to_string(),
))?;
let database = self.database.clone().ok_or(SurrealError::Connection(
"Database is required for database auth".to_string(),
))?;
let auth_params = CborValue::Array(vec![CborValue::Map(vec![
(
CborValue::Text("user".to_string()),
CborValue::Text(username.clone()),
),
(
CborValue::Text("pass".to_string()),
CborValue::Text(password.clone()),
),
(
CborValue::Text("NS".to_string()),
CborValue::Text(namespace),
),
(CborValue::Text("DB".to_string()), CborValue::Text(database)),
])]);
engine.send_message_cbor("signin", auth_params).await?;
}
_ => {
return Err(SurrealError::Connection(
"Unsupported authentication method".to_string(),
));
}
}
if let Some(namespace) = &self.namespace {
let use_params = CborValue::Array(vec![
CborValue::Text(namespace.clone()),
CborValue::Text(self.database.as_ref().unwrap_or(&String::new()).clone()),
]);
engine.send_message_cbor("use", use_params).await?;
}
Ok(())
}
pub async fn connect(self) -> Result<SurrealClient> {
let url_str = self
.url
.as_ref()
.ok_or_else(|| SurrealError::Connection("URL is required".to_string()))?;
let url = Url::parse(url_str)
.map_err(|e| SurrealError::Connection(format!("Invalid URL: {}", e)))?;
let mut engine: Box<dyn Engine> = match url.scheme() {
"ws" | "wss" | "cbor" => Box::new(WsCborEngine::from_connection(&self).await?),
_ => {
return Err(SurrealError::Protocol(
"Unsupported protocol. Use ws://, wss://, or cbor://".to_string(),
));
}
};
if self.debug {
engine = DebugEngine::wrap(engine);
}
let client = SurrealClient::new(engine, self.namespace, self.database);
Ok(client.with_debug(self.debug))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_builder() {
let conn = SurrealConnection::new()
.url("ws://localhost:8000")
.namespace("test_ns")
.database("test_db")
.auth_root("root", "root")
.version_check(false);
assert_eq!(conn.url, Some("ws://localhost:8000".to_string()));
assert_eq!(conn.namespace, Some("test_ns".to_string()));
assert_eq!(conn.database, Some("test_db".to_string()));
assert!(!conn.version_check);
assert!(matches!(conn.auth, Some(AuthParams::Root { .. })));
}
#[test]
fn test_dsn_parsing() {
let conn = SurrealConnection::dsn(
"ws://root:root@localhost:8000/test_ns/test_db?version_check=false",
)
.unwrap();
assert_eq!(conn.url, Some("ws://localhost:8000".to_string()));
assert_eq!(conn.namespace, Some("test_ns".to_string()));
assert_eq!(conn.database, Some("test_db".to_string()));
assert!(!conn.version_check);
assert!(matches!(conn.auth, Some(AuthParams::Root { .. })));
}
#[test]
fn test_dsn_with_query_params() {
let conn =
SurrealConnection::dsn("http://localhost:8000?namespace=ns&database=db").unwrap();
assert_eq!(conn.url, Some("http://localhost:8000".to_string()));
assert_eq!(conn.namespace, Some("ns".to_string()));
assert_eq!(conn.database, Some("db".to_string()));
}
#[test]
fn test_auth_methods() {
let conn1 = SurrealConnection::new().auth_root("admin", "pass");
assert!(matches!(conn1.auth, Some(AuthParams::Root { .. })));
let conn2 = SurrealConnection::new().auth_namespace("ns_user", "ns_pass");
assert!(matches!(conn2.auth, Some(AuthParams::Namespace { .. })));
let conn3 = SurrealConnection::new().auth_database("db_user", "db_pass");
assert!(matches!(conn3.auth, Some(AuthParams::Database { .. })));
let conn4 = SurrealConnection::new().auth_token("jwt_token");
assert!(matches!(conn4.auth, Some(AuthParams::Token(_))));
}
#[tokio::test]
async fn test_connection_to_client_flow() {
let connection = SurrealConnection::new()
.url("ws://localhost:8000")
.namespace("test_namespace")
.database("test_database")
.auth_root("admin", "password")
.version_check(false);
assert_eq!(connection.url, Some("ws://localhost:8000".to_string()));
assert_eq!(connection.namespace, Some("test_namespace".to_string()));
assert_eq!(connection.database, Some("test_database".to_string()));
assert!(!connection.version_check);
assert!(matches!(connection.auth, Some(AuthParams::Root { .. })));
}
}