use anyhow::{Context, Result};
use std::path::Path;
use std::sync::OnceLock;
use surrealdb::Surreal;
use surrealdb::engine::local::SurrealKv;
use surrealdb::engine::remote::ws::{Client as WsClient, Ws, Wss};
use surrealdb::opt::auth::{Database, Namespace, Root};
use tokio::runtime::Runtime;
use super::SurrealDatabase;
#[derive(Debug, Clone, PartialEq, Default)]
pub enum SurrealMode {
#[default]
Embedded,
Network,
}
#[derive(Debug, Clone, PartialEq, Default)]
pub enum AuthLevel {
#[default]
Root,
Namespace,
Database,
}
impl AuthLevel {
pub fn from_env_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"root" => Ok(Self::Root),
"namespace" | "ns" => Ok(Self::Namespace),
"database" | "db" => Ok(Self::Database),
other => anyhow::bail!(
"Unknown MX_SURREAL_AUTH_LEVEL '{}'. Valid values: root, namespace (ns), database (db)",
other
),
}
}
}
impl std::fmt::Display for AuthLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Root => write!(f, "root"),
Self::Namespace => write!(f, "namespace"),
Self::Database => write!(f, "database"),
}
}
}
#[derive(Debug, Clone)]
pub struct SurrealConfig {
pub mode: SurrealMode,
pub url: String,
pub user: String,
pub pass: Option<String>,
pub namespace: String,
pub database: String,
pub auth_level: AuthLevel,
}
impl Default for SurrealConfig {
fn default() -> Self {
Self {
mode: SurrealMode::Embedded,
url: "ws://localhost:8000".to_string(),
user: "root".to_string(),
pass: None,
namespace: "memory".to_string(),
database: "knowledge".to_string(),
auth_level: AuthLevel::Root,
}
}
}
impl SurrealConfig {
pub fn from_env() -> Self {
let mode = match std::env::var("MX_SURREAL_MODE")
.unwrap_or_default()
.to_lowercase()
.as_str()
{
"network" => SurrealMode::Network,
_ => SurrealMode::Embedded,
};
let url =
std::env::var("MX_SURREAL_URL").unwrap_or_else(|_| "ws://localhost:8000".to_string());
let user = std::env::var("MX_SURREAL_USER").unwrap_or_else(|_| "root".to_string());
let pass = std::env::var("MX_SURREAL_PASS")
.ok()
.or_else(|| {
std::env::var("MX_SURREAL_PASS_FILE")
.ok()
.and_then(|path| std::fs::read_to_string(path).ok())
})
.map(|s| s.trim().to_string())
.filter(|p| !p.is_empty());
let namespace = std::env::var("MX_SURREAL_NS").unwrap_or_else(|_| "memory".to_string());
let database = std::env::var("MX_SURREAL_DB").unwrap_or_else(|_| "knowledge".to_string());
let auth_level_str =
std::env::var("MX_SURREAL_AUTH_LEVEL").unwrap_or_else(|_| "root".to_string());
let auth_level = AuthLevel::from_env_str(&auth_level_str).unwrap_or_else(|e| {
eprintln!("[mx] WARNING: {e}, defaulting to root");
AuthLevel::Root
});
Self {
mode,
url,
user,
pass,
namespace,
database,
auth_level,
}
}
pub fn is_network(&self) -> bool {
self.mode == SurrealMode::Network
}
}
pub enum SurrealConnection {
Embedded(Surreal<surrealdb::engine::local::Db>),
Network(Surreal<WsClient>),
}
const SCHEMA: &str = include_str!("../../schema/surrealdb-schema.surql");
pub(super) fn normalize_datetime(s: &str) -> String {
if s.contains('T') && (s.ends_with('Z') || s.contains('+') || s.contains("-0")) {
return s.to_string();
}
if s.contains(' ') && !s.contains('T') {
return s.replace(' ', "T") + "Z";
}
if !s.ends_with('Z') && !s.contains('+') {
return format!("{}Z", s);
}
s.to_string()
}
impl SurrealDatabase {
pub(super) fn runtime() -> &'static Runtime {
static RT: OnceLock<Runtime> = OnceLock::new();
RT.get_or_init(|| Runtime::new().expect("Failed to create tokio runtime"))
}
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let config = SurrealConfig::from_env();
Self::runtime().block_on(Self::open_with_config_async(path, &config, false))
}
pub fn open_with_verbose<P: AsRef<Path>>(path: P, verbose: bool) -> Result<Self> {
let config = SurrealConfig::from_env();
Self::runtime().block_on(Self::open_with_config_async(path, &config, verbose))
}
pub fn connect<P: AsRef<Path>>(path: P, config: &SurrealConfig) -> Result<Self> {
Self::runtime().block_on(Self::open_with_config_async(path, config, false))
}
async fn open_with_config_async<P: AsRef<Path>>(
path: P,
config: &SurrealConfig,
verbose: bool,
) -> Result<Self> {
match config.mode {
SurrealMode::Embedded => Self::open_embedded_async(path, config, verbose).await,
SurrealMode::Network => Self::open_network_async(config, verbose).await,
}
}
async fn open_embedded_async<P: AsRef<Path>>(
path: P,
config: &SurrealConfig,
verbose: bool,
) -> Result<Self> {
let path = path.as_ref();
if verbose {
eprintln!(
"[mx] Connecting to SurrealDB in embedded mode: {}",
path.display()
);
}
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("Failed to create database directory: {:?}", parent))?;
}
let db = Surreal::new::<SurrealKv>(path).await.with_context(|| {
format!(
"Failed to open SurrealDB at {} (check file permissions and disk space)",
path.display()
)
})?;
if verbose {
eprintln!(
"[mx] Using namespace '{}' and database '{}'",
config.namespace, config.database
);
}
db.use_ns(&config.namespace)
.use_db(&config.database)
.await
.context("Failed to set namespace and database")?;
if verbose {
eprintln!("[mx] Applying database schema");
}
let mut response = db
.query(SCHEMA)
.await
.context("Failed to apply database schema")?;
let errors = response.take_errors();
if !errors.is_empty() {
return Err(anyhow::anyhow!("Schema application failed: {:?}", errors));
}
if verbose {
eprintln!("[mx] Embedded connection established successfully");
}
Ok(Self {
conn: SurrealConnection::Embedded(db),
})
}
fn is_localhost_url(url: &str) -> bool {
url.contains("://localhost") || url.contains("://127.0.0.1") || url.contains("://[::1]")
}
fn sanitize_ws_url(url: &str) -> String {
url.strip_prefix("ws://")
.or_else(|| url.strip_prefix("wss://"))
.unwrap_or(url)
.to_string()
}
async fn open_network_async(config: &SurrealConfig, verbose: bool) -> Result<Self> {
if verbose {
eprintln!(
"[mx] Connecting to SurrealDB in network mode: {}",
config.url
);
}
if config.pass.is_some()
&& config.url.starts_with("ws://")
&& !Self::is_localhost_url(&config.url)
{
eprintln!(
"[mx] WARNING: Sending credentials over unencrypted WebSocket to {}",
config.url
);
eprintln!("[mx] WARNING: Consider using wss:// (TLS) for secure authentication");
}
let is_tls = config.url.starts_with("wss://");
let sanitized_url = Self::sanitize_ws_url(&config.url);
let db = if is_tls {
Surreal::new::<Wss>(sanitized_url.as_str())
.await
.with_context(|| {
format!(
"Failed to connect to SurrealDB at {} (check that server is running and URL is correct)",
config.url
)
})?
} else {
Surreal::new::<Ws>(sanitized_url.as_str())
.await
.with_context(|| {
format!(
"Failed to connect to SurrealDB at {} (check that server is running and URL is correct)",
config.url
)
})?
};
if let Some(pass) = &config.pass {
if verbose {
eprintln!(
"[mx] Authenticating as user '{}' (auth level: {})",
config.user, config.auth_level
);
}
match config.auth_level {
AuthLevel::Namespace => {
db.signin(Namespace {
namespace: &config.namespace,
username: &config.user,
password: pass,
})
.await
.with_context(|| {
format!(
"Failed to authenticate to SurrealDB at {} as namespace-level user '{}' in namespace '{}' (check credentials in MX_SURREAL_USER and MX_SURREAL_PASS)",
config.url, config.user, config.namespace
)
})?;
}
AuthLevel::Database => {
db.signin(Database {
namespace: &config.namespace,
database: &config.database,
username: &config.user,
password: pass,
})
.await
.with_context(|| {
format!(
"Failed to authenticate to SurrealDB at {} as database-level user '{}' in namespace '{}' database '{}' (check credentials in MX_SURREAL_USER and MX_SURREAL_PASS)",
config.url, config.user, config.namespace, config.database
)
})?;
}
AuthLevel::Root => {
db.signin(Root {
username: &config.user,
password: pass,
})
.await
.with_context(|| {
format!(
"Failed to authenticate to SurrealDB at {} as user '{}' (check credentials in MX_SURREAL_USER and MX_SURREAL_PASS)",
config.url, config.user
)
})?;
}
}
} else if verbose {
eprintln!("[mx] No password provided, connecting without authentication");
}
if verbose {
eprintln!(
"[mx] Using namespace '{}' and database '{}'",
config.namespace, config.database
);
}
db.use_ns(&config.namespace)
.use_db(&config.database)
.await
.with_context(|| {
format!(
"Failed to set namespace '{}' and database '{}' (check that they exist on the server)",
config.namespace, config.database
)
})?;
if verbose {
eprintln!("[mx] Network connection established successfully");
}
Ok(Self {
conn: SurrealConnection::Network(db),
})
}
#[allow(dead_code)]
async fn open_async<P: AsRef<Path>>(path: P) -> Result<Self> {
let config = SurrealConfig::from_env();
Self::open_with_config_async(path, &config, false).await
}
#[cfg(test)]
pub fn open_in_memory() -> Result<Self> {
use tempfile::tempdir;
let temp_dir = tempdir()?;
let config = SurrealConfig::default(); Self::connect(temp_dir.path(), &config)
}
#[deprecated(note = "Use connection-agnostic methods instead")]
pub fn inner(&self) -> Option<&Surreal<surrealdb::engine::local::Db>> {
match &self.conn {
SurrealConnection::Embedded(db) => Some(db),
SurrealConnection::Network(_) => None,
}
}
}