#[cfg(not(target_family = "wasm"))]
pub(crate) mod native;
#[cfg(target_family = "wasm")]
pub(crate) mod wasm;
use std::marker::PhantomData;
#[cfg(not(target_family = "wasm"))]
use std::pin::pin;
use std::sync::Arc;
#[cfg(not(target_family = "wasm"))]
use std::task::{Poll, ready};
#[cfg(not(target_family = "wasm"))]
use std::{future::Future, path::PathBuf};
use async_channel::Sender;
#[cfg(all(not(target_family = "wasm"), feature = "ml"))]
use futures::StreamExt;
#[cfg(not(target_family = "wasm"))]
use futures::stream::poll_fn;
use surrealdb_core::dbs::{QueryResult, QueryResultBuilder, Session};
use surrealdb_core::iam;
#[cfg(not(target_family = "wasm"))]
use surrealdb_core::kvs::export::Config as DbExportConfig;
use surrealdb_core::kvs::{Datastore, LockType, Transaction, TransactionType};
#[cfg(all(not(target_family = "wasm"), feature = "ml"))]
use surrealdb_core::{
iam::{Action, ResourceKind, check::check_ns_db},
ml::storage::surml_file::SurMlFile,
};
use surrealdb_types::Error as TypesError;
use tokio::sync::RwLock;
#[cfg(not(target_family = "wasm"))]
use tokio::{
fs::OpenOptions,
io::{self, AsyncReadExt, AsyncWriteExt},
};
#[cfg(not(target_family = "wasm"))]
use tokio_util::bytes::BytesMut;
use uuid::Uuid;
use crate::conn::Command;
#[cfg(all(not(target_family = "wasm"), feature = "ml"))]
use crate::conn::MlExportConfig;
use crate::engine::SessionError;
use crate::opt::IntoEndpoint;
use crate::opt::auth::{AccessToken, RefreshToken, SecureToken, Token};
use crate::types::{HashMap, Notification, SurrealValue, ToSql, Value, Variables};
use crate::{Connect, Surreal};
#[cfg(feature = "kv-mem")]
#[cfg_attr(docsrs, doc(cfg(feature = "kv-mem")))]
#[derive(Debug)]
pub struct Mem;
#[cfg(feature = "kv-rocksdb")]
#[cfg_attr(docsrs, doc(cfg(feature = "kv-rocksdb")))]
#[derive(Debug)]
pub struct RocksDb;
#[cfg(feature = "kv-indxdb")]
#[cfg_attr(docsrs, doc(cfg(feature = "kv-indxdb")))]
#[derive(Debug)]
pub struct IndxDb;
#[cfg(feature = "kv-tikv")]
#[cfg_attr(docsrs, doc(cfg(feature = "kv-tikv")))]
#[derive(Debug)]
pub struct TiKv;
#[cfg(feature = "kv-surrealkv")]
#[cfg_attr(docsrs, doc(cfg(feature = "kv-surrealkv")))]
#[derive(Debug)]
pub struct SurrealKv;
#[derive(Debug, Clone)]
pub struct Db(());
impl Surreal<Db> {
pub fn connect<P>(&self, address: impl IntoEndpoint<P, Client = Db>) -> Connect<Db, ()> {
Connect {
surreal: self.inner.clone().into(),
address: address.into_endpoint(),
capacity: 0,
response_type: PhantomData,
}
}
}
struct RouterState {
kvs: Arc<Datastore>,
sessions: HashMap<Uuid, SessionResult>,
}
impl RouterState {
fn handle_session_initial(&self, session_id: Uuid) {
self.sessions.insert(session_id, Ok(Arc::new(SessionState::new(session_id))));
}
async fn handle_session_clone(&self, old: Uuid, new: Uuid) {
let state = match self.sessions.get(&old) {
Some(Ok(state)) => {
let mut session = state.session.read().await.clone();
session.id = Some(new);
Ok(Arc::new(SessionState {
session: RwLock::new(session),
vars: RwLock::new(state.vars.read().await.clone()),
transactions: HashMap::new(),
live_queries: HashMap::new(),
}))
}
Some(Err(error)) => Err(error),
None => Err(SessionError::NotFound(old)),
};
self.sessions.insert(new, state);
}
fn handle_session_drop(&self, session_id: Uuid) {
self.sessions.remove(&session_id);
}
}
type SessionResult = Result<Arc<SessionState>, SessionError>;
struct SessionState {
session: RwLock<Session>,
vars: RwLock<Variables>,
transactions: HashMap<Uuid, Arc<Transaction>>,
live_queries: HashMap<Uuid, Sender<crate::Result<Notification>>>,
}
impl SessionState {
fn new(id: Uuid) -> Self {
let mut session = Session::default().with_rt(true);
session.id = Some(id);
Self {
session: RwLock::new(session),
vars: RwLock::new(Variables::default()),
transactions: HashMap::new(),
live_queries: HashMap::new(),
}
}
}
#[cfg(not(target_family = "wasm"))]
async fn export_file(
kvs: &Datastore,
sess: &Session,
chn: async_channel::Sender<Vec<u8>>,
config: Option<DbExportConfig>,
) -> Result<(), crate::Error> {
let res = match config {
Some(config) => {
kvs.export_with_config(sess, chn, config)
.await
.map_err(crate::std_error_to_types_error)?
.await
}
None => kvs.export(sess, chn).await.map_err(crate::std_error_to_types_error)?.await,
};
if let Err(error) = res {
let error_str = error.to_string();
if error_str.contains("channel") || error_str.contains("Channel") {
trace!("{error_str}");
return Ok(());
}
return Err(crate::Error::internal(error.to_string()));
}
Ok(())
}
#[cfg(all(not(target_family = "wasm"), feature = "ml"))]
async fn export_ml(
kvs: &Datastore,
sess: &Session,
chn: async_channel::Sender<Vec<u8>>,
MlExportConfig {
name,
version,
}: MlExportConfig,
) -> Result<(), crate::Error> {
let (nsv, dbv) = check_ns_db(sess).map_err(|e| crate::Error::internal(e.to_string()))?;
kvs.check(sess, Action::View, ResourceKind::Model.on_db(&nsv, &dbv))
.map_err(|e| crate::Error::internal(e.to_string()))?;
let Some(model) = kvs
.get_db_model(&nsv, &dbv, &name, &version)
.await
.map_err(|e| crate::Error::internal(e.to_string()))?
else {
return Err(crate::Error::internal("Model not found".to_string()));
};
let mut data = surrealdb_core::obs::stream(model.hash.clone())
.await
.map_err(|e| crate::Error::internal(e.to_string()))?;
while let Some(Ok(bytes)) = data.next().await {
if chn.send(bytes.to_vec()).await.is_err() {
break;
}
}
Ok(())
}
#[cfg(not(target_family = "wasm"))]
async fn copy<'a, R, W>(
path: PathBuf,
reader: &'a mut R,
writer: &'a mut W,
) -> Result<(), crate::Error>
where
R: tokio::io::AsyncRead + Unpin + ?Sized,
W: tokio::io::AsyncWrite + Unpin + ?Sized,
{
io::copy(reader, writer).await.map(|_| ()).map_err(|error| {
crate::Error::internal(format!("Failed to read `{}`: {}", path.display(), error))
})
}
async fn kill_live_query(
kvs: &Datastore,
id: Uuid,
session: &Session,
vars: Variables,
) -> Result<Vec<QueryResult>, TypesError> {
let sql = format!("KILL {id}");
let results = kvs.execute(&sql, session, Some(vars)).await?;
Ok(results)
}
async fn router(
kvs: &Arc<Datastore>,
state: &SessionState,
command: Command,
) -> Result<Vec<QueryResult>, crate::Error> {
match command {
Command::Use {
namespace,
database,
} => {
let result = {
kvs.process_use(None, &mut *state.session.write().await, namespace, database)
.await?
};
Ok(vec![result])
}
Command::Signup {
credentials,
} => {
let query_result = QueryResultBuilder::started_now();
let signup_data = {
iam::signup::signup(kvs, &mut *state.session.write().await, credentials.into())
.await
.map_err(|e| TypesError::not_allowed(e.to_string(), None))?
};
let token = match signup_data {
iam::Token::Access(token) => Token {
access: AccessToken(SecureToken(token)),
refresh: None,
},
iam::Token::WithRefresh {
access: token,
refresh,
} => Token {
access: AccessToken(SecureToken(token)),
refresh: Some(RefreshToken(SecureToken(refresh))),
},
};
let result = query_result.finish_with_result(Ok(token.into_value()));
Ok(vec![result])
}
Command::Signin {
credentials,
} => {
let query_result = QueryResultBuilder::started_now();
let signin_data = {
iam::signin::signin(kvs, &mut *state.session.write().await, credentials.into())
.await
.map_err(|e| TypesError::not_allowed(e.to_string(), None))?
};
let token = match signin_data {
iam::Token::Access(token) => Token {
access: AccessToken(SecureToken(token)),
refresh: None,
},
iam::Token::WithRefresh {
access,
refresh,
} => Token {
access: AccessToken(SecureToken(access)),
refresh: Some(RefreshToken(SecureToken(refresh))),
},
};
let result = query_result.finish_with_result(Ok(token.into_value()));
Ok(vec![result])
}
Command::Authenticate {
token,
} => {
let query_result = QueryResultBuilder::started_now();
let (access, with_refresh) = match &token {
iam::Token::Access(access) => (access, false),
iam::Token::WithRefresh {
access,
..
} => (access, true),
};
let result = {
match iam::verify::token(kvs, &mut *state.session.write().await, access).await {
Ok(_) => query_result.finish_with_result(Ok(token.into_value())),
Err(error) => {
if with_refresh && surrealdb_core::iam::is_expired_token_error(&error) {
let result =
match token.refresh(kvs, &mut *state.session.write().await).await {
Ok(token) => {
query_result.finish_with_result(Ok(token.into_value()))
}
Err(error) => query_result.finish_with_result(Err(
TypesError::internal(error.to_string()),
)),
};
return Ok(vec![result]);
}
query_result
.finish_with_result(Err(TypesError::internal(error.to_string())))
}
}
};
Ok(vec![result])
}
Command::Refresh {
token,
} => {
let query_result = QueryResultBuilder::started_now();
let result = {
match token.refresh(kvs, &mut *state.session.write().await).await {
Ok(token) => query_result.finish_with_result(Ok(token.into_value())),
Err(error) => query_result
.finish_with_result(Err(TypesError::internal(error.to_string()))),
}
};
Ok(vec![result])
}
Command::Invalidate => {
let query_result = QueryResultBuilder::started_now();
let result = {
match iam::clear::clear(&mut *state.session.write().await) {
Ok(_) => query_result.finish_with_result(Ok(Value::None)),
Err(error) => query_result
.finish_with_result(Err(TypesError::internal(error.to_string()))),
}
};
Ok(vec![result])
}
Command::Begin => {
let query_result = QueryResultBuilder::started_now();
let result = match kvs.transaction(TransactionType::Write, LockType::Optimistic).await {
Ok(txn) => {
let id = Uuid::now_v7();
state.transactions.insert(id, Arc::new(txn));
query_result.finish_with_result(Ok(Value::Uuid(id.into())))
}
Err(error) => {
query_result.finish_with_result(Err(TypesError::internal(error.to_string())))
}
};
Ok(vec![result])
}
Command::Revoke {
token,
} => {
let query_result = QueryResultBuilder::started_now();
let result = match token.revoke_refresh_token(kvs).await {
Ok(_) => query_result.finish_with_result(Ok(Value::None)),
Err(error) => {
query_result.finish_with_result(Err(TypesError::internal(error.to_string())))
}
};
Ok(vec![result])
}
Command::Rollback {
txn,
} => {
if let Some(tx) = state.transactions.get(&txn) {
state.transactions.remove(&txn);
tx.cancel().await.map_err(crate::std_error_to_types_error)?;
}
Ok(vec![QueryResultBuilder::instant_none()])
}
Command::Commit {
txn,
} => {
if let Some(tx) = state.transactions.get(&txn) {
state.transactions.remove(&txn);
tx.commit().await.map_err(crate::std_error_to_types_error)?;
}
Ok(vec![QueryResultBuilder::instant_none()])
}
Command::Query {
txn,
query,
variables,
} => {
let mut vars = state.vars.read().await.clone();
vars.extend(variables);
let response = if let Some(txn_id) = txn {
let tx_option = state.transactions.get(&txn_id);
if let Some(tx) = tx_option {
kvs.execute_with_transaction(
query.as_ref(),
&*state.session.read().await,
Some(vars),
tx,
)
.await?
} else {
return Ok(vec![QueryResultBuilder::started_now().finish_with_result(Err(
TypesError::not_found(
"Transaction not found".to_string(),
Some(surrealdb_types::NotFoundError::Transaction),
),
))]);
}
} else {
kvs.execute(query.as_ref(), &*state.session.read().await, Some(vars)).await?
};
Ok(response)
}
#[cfg(target_family = "wasm")]
Command::ExportFile {
..
}
| Command::ExportBytes {
..
}
| Command::ImportFile {
..
} => Err(crate::Error::internal(
"The protocol or storage engine does not support backups on this architecture"
.to_string(),
)),
#[cfg(any(target_family = "wasm", not(feature = "ml")))]
Command::ExportMl {
..
}
| Command::ExportBytesMl {
..
}
| Command::ImportMl {
..
} => Err(crate::Error::internal(
"The protocol or storage engine does not support backups on this architecture"
.to_string(),
)),
#[cfg(not(target_family = "wasm"))]
Command::ExportFile {
path: file,
config,
} => {
let query_result = QueryResultBuilder::started_now();
let (tx, rx) = crate::channel::bounded(1);
let (mut writer, mut reader) = io::duplex(10_240);
let session = state.session.read().await.clone();
let export = export_file(kvs, &session, tx, config);
let bridge = async move {
while let Ok(value) = rx.recv().await {
if writer.write_all(&value).await.is_err() {
break;
}
}
Ok(())
};
let mut output = match OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&file)
.await
{
Ok(path) => path,
Err(error) => {
return Err(crate::Error::internal(format!(
"Failed to open `{}`: {}",
file.display(),
error
)));
}
};
let copy = copy(file, &mut reader, &mut output);
tokio::try_join!(export, bridge, copy)?;
Ok(vec![query_result.finish()])
}
#[cfg(all(not(target_family = "wasm"), feature = "ml"))]
Command::ExportMl {
path,
config,
} => {
let query_result = QueryResultBuilder::started_now();
let (tx, rx) = crate::channel::bounded(1);
let (mut writer, mut reader) = io::duplex(10_240);
let session = state.session.read().await;
let export = export_ml(kvs, &session, tx, config);
let bridge = async move {
while let Ok(value) = rx.recv().await {
if writer.write_all(&value).await.is_err() {
break;
}
}
Ok(())
};
let mut output = match OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&path)
.await
{
Ok(path) => path,
Err(error) => {
return Err(crate::Error::internal(format!(
"Failed to open `{}`: {}",
path.display(),
error
)));
}
};
let copy = copy(path, &mut reader, &mut output);
tokio::try_join!(export, bridge, copy)?;
Ok(vec![query_result.finish()])
}
#[cfg(not(target_family = "wasm"))]
Command::ExportBytes {
bytes,
config,
} => {
let query_result = QueryResultBuilder::started_now();
let (tx, rx) = crate::channel::bounded(1);
let kvs = kvs.clone();
let session = state.session.read().await.clone();
tokio::spawn(async move {
let export = async {
if let Err(error) = export_file(&kvs, &session, tx, config).await {
bytes.send(Err(error)).await.ok();
}
};
let bridge = async {
while let Ok(b) = rx.recv().await {
if bytes.send(Ok(b)).await.is_err() {
break;
}
}
};
tokio::join!(export, bridge);
});
Ok(vec![query_result.finish()])
}
#[cfg(all(not(target_family = "wasm"), feature = "ml"))]
Command::ExportBytesMl {
bytes,
config,
} => {
let query_result = QueryResultBuilder::started_now();
let (tx, rx) = crate::channel::bounded(1);
let kvs = kvs.clone();
let session = state.session.read().await.clone();
tokio::spawn(async move {
let export = async {
if let Err(error) = export_ml(&kvs, &session, tx, config).await {
bytes.send(Err(error)).await.ok();
}
};
let bridge = async {
while let Ok(b) = rx.recv().await {
if bytes.send(Ok(b)).await.is_err() {
break;
}
}
};
tokio::join!(export, bridge);
});
Ok(vec![query_result.finish()])
}
#[cfg(not(target_family = "wasm"))]
Command::ImportFile {
path,
} => {
let query_result = QueryResultBuilder::started_now();
let file = match OpenOptions::new().read(true).open(&path).await {
Ok(path) => path,
Err(error) => {
return Err(crate::Error::internal(format!(
"Failed to open `{}`: {}",
path.display(),
error
)));
}
};
let mut file = pin!(file);
let mut buffer = BytesMut::with_capacity(4096);
let stream = poll_fn(|ctx| {
if buffer.capacity() == 0 {
buffer.reserve(4096);
}
let future = pin!(file.read_buf(&mut buffer));
match ready!(future.poll(ctx)) {
Ok(0) => Poll::Ready(None),
Ok(_) => Poll::Ready(Some(Ok(buffer.split().freeze()))),
Err(e) => Poll::Ready(Some(Err(anyhow::anyhow!("{}", e)))),
}
});
let responses = kvs
.execute_import(
&*state.session.read().await,
Some(state.vars.read().await.clone()),
stream,
)
.await
.map_err(crate::std_error_to_types_error)?;
for response in responses {
response.result?;
}
Ok(vec![query_result.finish()])
}
#[cfg(all(not(target_family = "wasm"), feature = "ml"))]
Command::ImportMl {
path,
} => {
let query_result = QueryResultBuilder::started_now();
let mut file = match OpenOptions::new().read(true).open(&path).await {
Ok(path) => path,
Err(error) => {
return Err(crate::Error::internal(format!(
"Failed to open `{}`: {}",
path.display(),
error
)));
}
};
let (nsv, dbv) = check_ns_db(&*state.session.read().await)
.map_err(crate::std_error_to_types_error)?;
kvs.check(
&*state.session.read().await,
Action::Edit,
ResourceKind::Model.on_db(&nsv, &dbv),
)
.map_err(crate::std_error_to_types_error)?;
let mut buffer = Vec::new();
if let Err(error) = file.read_to_end(&mut buffer).await {
return Err(crate::Error::internal(format!(
"Failed to read `{}`: {}",
path.display(),
error
)));
}
let file = match SurMlFile::from_bytes(buffer) {
Ok(file) => file,
Err(error) => {
return Err(crate::Error::internal(format!(
"Invalid SurrealML file: {}",
error.message
)));
}
};
let data = file.to_bytes();
kvs.put_ml_model(
&*state.session.read().await,
&file.header.name.to_string(),
&file.header.version.to_string(),
&file.header.description.to_string(),
data,
)
.await
.map_err(crate::std_error_to_types_error)?;
Ok(vec![query_result.finish()])
}
Command::Health => Ok(vec![QueryResultBuilder::instant_none()]),
Command::Version => {
let query_result = QueryResultBuilder::started_now();
Ok(vec![
query_result.finish_with_result(Ok(Value::from_t(
surrealdb_core::env::VERSION.to_string(),
))),
])
}
Command::Set {
key,
value,
} => {
let query_result = QueryResultBuilder::started_now();
surrealdb_core::rpc::check_protected_param(&key)
.map_err(|e| crate::Error::internal(e.to_string()))?;
match value {
Value::None => state.vars.write().await.remove(&key),
v => state.vars.write().await.insert(key, v),
};
Ok(vec![query_result.finish()])
}
Command::Unset {
key,
} => {
let query_result = QueryResultBuilder::started_now();
state.vars.write().await.remove(&key);
Ok(vec![query_result.finish()])
}
Command::SubscribeLive {
uuid,
notification_sender,
} => {
let query_result = QueryResultBuilder::started_now();
state.live_queries.insert(uuid, notification_sender);
Ok(vec![query_result.finish()])
}
Command::Kill {
uuid,
} => {
state.live_queries.remove(&uuid);
let results = kill_live_query(
kvs,
uuid,
&*state.session.read().await,
state.vars.read().await.clone(),
)
.await?;
Ok(results)
}
Command::Run {
name,
version,
args,
} => {
let formatted_args = args.iter().map(|v| v.to_sql()).collect::<Vec<_>>().join(", ");
let sql = match version {
Some(v) => format!("{name}<{v}>({formatted_args})"),
None => format!("{name}({formatted_args})"),
};
let results = kvs
.execute(&sql, &*state.session.read().await, Some(state.vars.read().await.clone()))
.await?;
Ok(results)
}
Command::Attach {
..
} => {
let query_result = QueryResultBuilder::started_now();
Ok(vec![query_result.finish()])
}
Command::Detach {
..
} => {
let query_result = QueryResultBuilder::started_now();
Ok(vec![query_result.finish()])
}
}
}