#[cfg(not(target_family = "wasm"))]
pub(crate) mod native;
#[cfg(target_family = "wasm")]
pub(crate) mod wasm;
use std::marker::PhantomData;
#[cfg(not(target_family = "wasm"))]
use std::path::PathBuf;
use std::sync::Arc;
use futures::TryStreamExt;
use reqwest::RequestBuilder;
use reqwest::header::{ACCEPT, CONTENT_TYPE, HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
use surrealdb_core::dbs::{QueryResult, QueryResultBuilder};
use surrealdb_core::iam::Token as CoreToken;
use surrealdb_core::rpc::{self, DbResponse, DbResult};
use surrealdb_types::{AuthError, NotAllowedError};
#[cfg(not(target_family = "wasm"))]
use tokio::fs::OpenOptions;
#[cfg(not(target_family = "wasm"))]
use tokio::io;
use tokio::sync::RwLock;
#[cfg(not(target_family = "wasm"))]
use tokio_util::compat::FuturesAsyncReadCompatExt;
use url::Url;
use uuid::Uuid;
#[cfg(target_family = "wasm")]
use wasm_bindgen_futures::spawn_local;
use crate::conn::{Command, RequestData};
use crate::engine::SessionError;
use crate::engine::remote::RouterRequest;
use crate::headers::{AUTH_DB, AUTH_NS, DB, NS};
use crate::opt::IntoEndpoint;
use crate::opt::auth::{AccessToken, Token};
use crate::types::{HashMap, SurrealValue, Value};
use crate::{Connect, Error, Result, Surreal};
const RPC_PATH: &str = "rpc";
#[derive(Debug)]
struct SessionState {
headers: RwLock<HeaderMap>,
auth: RwLock<Option<Auth>>,
replay: boxcar::Vec<Command>,
}
impl Default for SessionState {
fn default() -> Self {
Self {
headers: RwLock::new(HeaderMap::new()),
auth: RwLock::new(None),
replay: boxcar::Vec::new(),
}
}
}
impl SessionState {
async fn clone_state(&self) -> Self {
Self {
headers: RwLock::new(self.headers.read().await.clone()),
auth: RwLock::new(self.auth.read().await.clone()),
replay: self.replay.clone(),
}
}
}
type SessionResult = std::result::Result<Arc<SessionState>, SessionError>;
struct RouterState {
sessions: HashMap<Uuid, SessionResult>,
client: reqwest::Client,
base_url: Url,
}
impl RouterState {
fn new(client: reqwest::Client, base_url: Url) -> Self {
Self {
sessions: HashMap::new(),
client,
base_url,
}
}
async fn replay_session(&self, session_id: Uuid, session_state: &SessionState) -> Result<()> {
let headers = session_state.headers.read().await.clone();
let auth = session_state.auth.read().await.clone();
for (_, command) in &session_state.replay {
let request = command
.clone()
.into_router_request(None, Some(session_id))
.expect("replay command should convert to router request");
send_request(request, &self.base_url, &self.client, &headers, &auth).await?;
}
Ok(())
}
async fn handle_session_initial(&self, session_id: Uuid) {
let session_state = SessionState::default();
session_state.replay.push(Command::Attach {
session_id,
});
let session_state = Arc::new(session_state);
self.sessions.insert(session_id, Ok(session_state.clone()));
if let Err(error) = self.replay_session(session_id, &session_state).await {
self.sessions.insert(session_id, Err(SessionError::Remote(error.to_string())));
}
}
async fn handle_session_clone(&self, old: Uuid, new: Uuid) {
match self.sessions.get(&old) {
Some(Ok(session_state)) => {
let mut new_state = session_state.clone_state().await;
if let Some(cmd) = new_state.replay.get_mut(0) {
*cmd = Command::Attach {
session_id: new,
};
}
let new_state = Arc::new(new_state);
self.sessions.insert(new, Ok(new_state.clone()));
if let Err(error) = self.replay_session(new, &new_state).await {
self.sessions.insert(new, Err(SessionError::Remote(error.to_string())));
}
}
Some(Err(error)) => {
self.sessions.insert(new, Err(error));
}
None => {
self.sessions.insert(new, Err(SessionError::NotFound(old)));
}
}
}
async fn handle_session_drop(&self, session_id: Uuid) {
if self.sessions.get(&session_id).is_some() {
let session_state = SessionState::default();
session_state.replay.push(Command::Detach {
session_id,
});
self.replay_session(session_id, &session_state).await.ok();
}
self.sessions.remove(&session_id);
}
}
#[derive(Debug)]
pub struct Http;
#[derive(Debug)]
pub struct Https;
#[derive(Debug, Clone)]
pub struct Client(());
impl Surreal<Client> {
pub fn connect<P>(
&self,
address: impl IntoEndpoint<P, Client = Client>,
) -> Connect<Client, ()> {
Connect {
surreal: self.inner.clone().into(),
address: address.into_endpoint(),
capacity: 0,
response_type: PhantomData,
}
}
}
pub(crate) fn default_headers() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_static(surrealdb_core::api::format::FLATBUFFERS));
headers
.insert(CONTENT_TYPE, HeaderValue::from_static(surrealdb_core::api::format::FLATBUFFERS));
headers
}
#[derive(Debug, Clone)]
enum Auth {
Basic {
user: String,
pass: String,
ns: Option<String>,
db: Option<String>,
},
Bearer {
token: AccessToken,
},
}
trait Authenticate {
fn auth(self, auth: &Option<Auth>) -> Self;
}
impl Authenticate for RequestBuilder {
fn auth(self, auth: &Option<Auth>) -> Self {
match auth {
Some(Auth::Basic {
user,
pass,
ns,
db,
}) => {
let mut req = self.basic_auth(user, Some(pass));
if let Some(ns) = ns {
req = req.header(&AUTH_NS, ns);
}
if let Some(db) = db {
req = req.header(&AUTH_DB, db);
}
req
}
Some(Auth::Bearer {
token,
}) => self.bearer_auth(token.as_insecure_token()),
None => self,
}
}
}
#[derive(Debug, Serialize, Deserialize, SurrealValue)]
#[surreal(crate = "crate::types")]
struct Credentials {
user: String,
pass: String,
ac: Option<String>,
ns: Option<String>,
db: Option<String>,
}
#[derive(Debug, Deserialize)]
#[expect(dead_code)]
struct AuthResponse {
code: u16,
details: String,
token: Option<Token>,
}
type BackupSender = async_channel::Sender<Result<Vec<u8>>>;
#[cfg(not(target_family = "wasm"))]
async fn export_file(request: RequestBuilder, path: PathBuf) -> Result<()> {
let mut response = request
.send()
.await
.map_err(crate::std_error_to_types_error)?
.error_for_status()
.map_err(crate::std_error_to_types_error)?
.bytes_stream()
.map_err(futures::io::Error::other)
.into_async_read()
.compat();
let mut file =
match OpenOptions::new().write(true).create(true).truncate(true).open(&path).await {
Ok(path) => path,
Err(error) => {
return Err(Error::internal(format!(
"Failed to open `{}`: {error}",
path.display()
)));
}
};
if let Err(error) = io::copy(&mut response, &mut file).await {
return Err(Error::internal(format!("Failed to read `{}`: {error}", path.display())));
}
Ok(())
}
async fn export_bytes(request: RequestBuilder, bytes: BackupSender) -> Result<()> {
let response = request
.send()
.await
.map_err(crate::std_error_to_types_error)?
.error_for_status()
.map_err(crate::std_error_to_types_error)?;
let future = async move {
let mut response = response.bytes_stream();
while let Ok(Some(b)) = response.try_next().await {
if bytes.send(Ok(b.to_vec())).await.is_err() {
break;
}
}
};
#[cfg(not(target_family = "wasm"))]
tokio::spawn(future);
#[cfg(target_family = "wasm")]
spawn_local(future);
Ok(())
}
#[cfg(not(target_family = "wasm"))]
async fn import(request: RequestBuilder, path: PathBuf) -> Result<()> {
let file = match OpenOptions::new().read(true).open(&path).await {
Ok(path) => path,
Err(error) => {
return Err(Error::internal(format!("Failed to open `{}`: {}", path.display(), error)));
}
};
let res = request
.header(ACCEPT, surrealdb_core::api::format::FLATBUFFERS)
.body(file)
.send()
.await
.map_err(crate::std_error_to_types_error)?;
if res.error_for_status_ref().is_err() {
let res = res.text().await.map_err(crate::std_error_to_types_error)?;
match res.parse::<serde_json::Value>() {
Ok(body) => {
let error_msg = format!(
"\n{}",
serde_json::to_string_pretty(&body).unwrap_or_else(|_| "{}".into())
);
return Err(Error::internal(format!("HTTP error: {error_msg}")));
}
Err(_) => {
return Err(Error::internal(format!("HTTP error: {res}")));
}
}
}
let bytes = res.bytes().await.map_err(crate::std_error_to_types_error)?;
let value: Value = surrealdb_core::rpc::format::flatbuffers::decode(&bytes)
.map_err(|x| format!("Failed to deserialize flatbuffers payload: {x:?}"))
.map_err(|e| {
crate::Error::internal(format!("The server returned an unexpected response: {e}"))
})?;
let Value::Array(arr) = value else {
return Err(Error::internal("Expected array response from import".to_string()));
};
for val in arr.into_vec() {
let result = QueryResult::from_value(val)
.map_err(|e| Error::internal(format!("Failed to parse query result: {e}")))?;
result.result?;
}
Ok(())
}
pub(crate) async fn health(request: RequestBuilder) -> Result<()> {
request
.send()
.await
.map_err(crate::std_error_to_types_error)?
.error_for_status()
.map_err(crate::std_error_to_types_error)?;
Ok(())
}
async fn send_request(
req: RouterRequest,
base_url: &Url,
client: &reqwest::Client,
headers: &HeaderMap,
auth: &Option<Auth>,
) -> Result<Vec<QueryResult>> {
let url = base_url.join(RPC_PATH).expect("valid RPC path");
let req_value = req.into_value();
let body = surrealdb_core::rpc::format::flatbuffers::encode(&req_value)
.map_err(|x| format!("Failed to serialize to flatbuffers: {x}"))
.map_err(|e| {
crate::Error::internal(format!(
"Tried to send a value which could not be serialized: {e}"
))
})?;
let http_req = client.post(url).headers(headers.clone()).auth(auth).body(body);
let response = http_req
.send()
.await
.map_err(crate::std_error_to_types_error)?
.error_for_status()
.map_err(crate::std_error_to_types_error)?;
let bytes = response.bytes().await.map_err(crate::std_error_to_types_error)?;
let response: DbResponse = surrealdb_core::rpc::format::flatbuffers::decode(&bytes)
.map_err(|x| format!("Failed to deserialize flatbuffers payload: {x}"))
.map_err(|e| {
crate::Error::internal(format!("The server returned an unexpected response: {e}"))
})?;
match response.result? {
DbResult::Query(results) => Ok(results),
DbResult::Other(value) => {
Ok(vec![QueryResultBuilder::started_now().finish_with_result(Ok(value))])
}
DbResult::Live(notification) => Ok(vec![
QueryResultBuilder::started_now().finish_with_result(Ok(notification.into_value())),
]),
}
}
async fn refresh_token(
token: CoreToken,
base_url: &Url,
client: &reqwest::Client,
headers: &HeaderMap,
auth: &Option<Auth>,
session_id: Option<uuid::Uuid>,
) -> Result<(Value, Vec<QueryResult>)> {
let req = Command::Refresh {
token,
}
.into_router_request(None, session_id)
.expect("refresh should be a valid router request");
let results = send_request(req, base_url, client, headers, auth).await?;
let value = match results.first() {
Some(result) => result.clone().result?,
None => {
error!("received invalid result from server");
return Err(Error::internal("Received invalid result from server".to_string()));
}
};
Ok((value, results))
}
async fn router(
req: RequestData,
base_url: &Url,
client: &reqwest::Client,
session_state: &SessionState,
) -> Result<Vec<QueryResult>> {
let session_id = req.session_id;
match req.command {
Command::Use {
namespace,
database,
} => {
let req = Command::Use {
namespace: namespace.clone(),
database: database.clone(),
}
.into_router_request(None, Some(session_id))
.expect("USE command should convert to router request");
let out = send_request(
req,
base_url,
client,
&*session_state.headers.read().await,
&*session_state.auth.read().await,
)
.await?;
let mut headers = session_state.headers.write().await;
if let Some(result) = out.first()
&& let Ok(Value::Object(ref obj)) = result.clone().result
{
match obj.get("namespace") {
Some(Value::String(ns)) => {
let header_value = HeaderValue::try_from(ns.as_str()).map_err(|_| {
Error::internal(format!("Invalid namespace name: {ns:?}"))
})?;
headers.insert(&NS, header_value);
}
_ => {
headers.remove(&NS);
}
}
match obj.get("database") {
Some(Value::String(db)) => {
let header_value = HeaderValue::try_from(db.as_str()).map_err(|_| {
Error::internal(format!("Invalid database name: {db:?}"))
})?;
headers.insert(&DB, header_value);
}
_ => {
headers.remove(&DB);
}
}
}
Ok(out)
}
Command::Signin {
credentials,
} => {
let req = Command::Signin {
credentials: credentials.clone(),
}
.into_router_request(None, Some(session_id))
.expect("signin should be a valid router request");
let results = send_request(
req,
base_url,
client,
&*session_state.headers.read().await,
&*session_state.auth.read().await,
)
.await?;
let value = match results.first() {
Some(result) => result.clone().result?,
None => {
error!("received invalid result from server");
return Err(Error::internal("Received invalid result from server".to_string()));
}
};
let mut auth = session_state.auth.write().await;
match Credentials::from_value(value.clone()) {
Ok(credentials) => {
*auth = Some(Auth::Basic {
user: credentials.user,
pass: credentials.pass,
ns: credentials.ns,
db: credentials.db,
});
}
Err(err) => {
debug!("Error converting Value to Credentials: {err}");
let token =
Token::from_value(value).map_err(|e| Error::internal(e.to_string()))?;
*auth = Some(Auth::Bearer {
token: token.access,
});
}
}
Ok(results)
}
Command::Authenticate {
token,
} => {
let req = Command::Authenticate {
token: token.clone(),
}
.into_router_request(None, Some(session_id))
.expect("authenticate should be a valid router request");
let mut results = send_request(
req,
base_url,
client,
&*session_state.headers.read().await,
&*session_state.auth.read().await,
)
.await?;
if let Some(result) = results.first_mut() {
match &mut result.result {
Ok(result) => {
let value = token.into_value();
*session_state.auth.write().await = Some(Auth::Bearer {
token: Token::from_value(value.clone())
.map_err(|e| Error::internal(e.to_string()))?
.access,
});
*result = value;
}
Err(error) => {
if let CoreToken::WithRefresh {
..
} = &token
{
if error.not_allowed_details().is_some_and(|a| {
matches!(a, NotAllowedError::Auth(AuthError::TokenExpired))
}) {
let (value, refresh_results) = refresh_token(
token,
base_url,
client,
&*session_state.headers.read().await,
&*session_state.auth.read().await,
Some(session_id),
)
.await?;
*session_state.auth.write().await = Some(Auth::Bearer {
token: Token::from_value(value)
.map_err(|e| Error::internal(e.to_string()))?
.access,
});
results = refresh_results;
}
}
}
}
}
Ok(results)
}
Command::Refresh {
token,
} => {
let (value, results) = refresh_token(
token,
base_url,
client,
&*session_state.headers.read().await,
&*session_state.auth.read().await,
Some(session_id),
)
.await?;
*session_state.auth.write().await = Some(Auth::Bearer {
token: Token::from_value(value).map_err(|e| Error::internal(e.to_string()))?.access,
});
Ok(results)
}
Command::Invalidate => {
let req = Command::Invalidate
.into_router_request(None, Some(session_id))
.expect("invalidate should be a valid router request");
let results = send_request(
req,
base_url,
client,
&*session_state.headers.read().await,
&*session_state.auth.read().await,
)
.await?;
*session_state.auth.write().await = None;
Ok(results)
}
Command::Set {
key,
value,
} => {
surrealdb_core::rpc::check_protected_param(&key)?;
let req = Command::Set {
key,
value,
}
.into_router_request(None, Some(session_id))
.expect("set should be a valid router request");
send_request(
req,
base_url,
client,
&*session_state.headers.read().await,
&*session_state.auth.read().await,
)
.await
}
Command::Unset {
key,
} => {
let req = Command::Unset {
key,
}
.into_router_request(None, Some(session_id))
.expect("unset should be a valid router request");
send_request(
req,
base_url,
client,
&*session_state.headers.read().await,
&*session_state.auth.read().await,
)
.await
}
#[cfg(target_family = "wasm")]
Command::ExportFile {
..
}
| Command::ExportMl {
..
}
| Command::ImportFile {
..
}
| Command::ImportMl {
..
} => {
Err(Error::internal(
"The protocol or storage engine does not support backups on this architecture"
.to_string(),
))
}
#[cfg(not(target_family = "wasm"))]
Command::ExportFile {
path,
config,
} => {
let req_path = base_url.join("export").map_err(crate::std_error_to_types_error)?;
let config = config.unwrap_or_default();
let config_value: Value = config.into_value();
let headers = session_state.headers.read().await;
let auth = session_state.auth.read().await;
let request =
client
.post(req_path)
.body(rpc::format::json::encode_str(config_value).map_err(|e| {
Error::internal(format!("failed to serialize Value: {}", e))
})?)
.headers(headers.clone())
.auth(&auth)
.header(CONTENT_TYPE, "application/json")
.header(ACCEPT, "application/octet-stream");
export_file(request, path).await?;
Ok(vec![QueryResultBuilder::instant_none()])
}
Command::ExportBytes {
bytes,
config,
} => {
let req_path = base_url.join("export").map_err(crate::std_error_to_types_error)?;
let config = config.unwrap_or_default();
let config_value = config.into_value();
let headers = session_state.headers.read().await;
let auth = session_state.auth.read().await;
let request =
client
.post(req_path)
.body(rpc::format::json::encode_str(config_value).map_err(|e| {
Error::internal(format!("failed to serialize Value: {}", e))
})?)
.headers(headers.clone())
.auth(&auth)
.header(CONTENT_TYPE, "application/json")
.header(ACCEPT, "application/octet-stream");
export_bytes(request, bytes).await?;
Ok(vec![QueryResultBuilder::instant_none()])
}
#[cfg(not(target_family = "wasm"))]
Command::ExportMl {
path,
config,
} => {
let req_path = base_url
.join("ml")
.map_err(crate::std_error_to_types_error)?
.join("export")
.map_err(crate::std_error_to_types_error)?
.join(&config.name)
.map_err(crate::std_error_to_types_error)?
.join(&config.version)
.map_err(crate::std_error_to_types_error)?;
let headers = session_state.headers.read().await;
let auth = session_state.auth.read().await;
let request = client
.get(req_path)
.headers(headers.clone())
.auth(&auth)
.header(ACCEPT, "application/octet-stream");
export_file(request, path).await?;
Ok(vec![QueryResultBuilder::instant_none()])
}
Command::ExportBytesMl {
bytes,
config,
} => {
let req_path = base_url
.join("ml")
.map_err(crate::std_error_to_types_error)?
.join("export")
.map_err(crate::std_error_to_types_error)?
.join(&config.name)
.map_err(crate::std_error_to_types_error)?
.join(&config.version)
.map_err(crate::std_error_to_types_error)?;
let headers = session_state.headers.read().await;
let auth = session_state.auth.read().await;
let request = client
.get(req_path)
.headers(headers.clone())
.auth(&auth)
.header(ACCEPT, "application/octet-stream");
export_bytes(request, bytes).await?;
Ok(vec![QueryResultBuilder::instant_none()])
}
#[cfg(not(target_family = "wasm"))]
Command::ImportFile {
path,
} => {
let req_path = base_url.join("import").map_err(crate::std_error_to_types_error)?;
let headers = session_state.headers.read().await;
let auth = session_state.auth.read().await;
let request = client
.post(req_path)
.headers(headers.clone())
.auth(&auth)
.header(CONTENT_TYPE, "application/octet-stream");
import(request, path).await?;
Ok(vec![QueryResultBuilder::instant_none()])
}
#[cfg(not(target_family = "wasm"))]
Command::ImportMl {
path,
} => {
let req_path = base_url
.join("ml")
.map_err(crate::std_error_to_types_error)?
.join("import")
.map_err(crate::std_error_to_types_error)?;
let headers = session_state.headers.read().await;
let auth = session_state.auth.read().await;
let request = client
.post(req_path)
.headers(headers.clone())
.auth(&auth)
.header(CONTENT_TYPE, "application/octet-stream");
import(request, path).await?;
Ok(vec![QueryResultBuilder::instant_none()])
}
Command::SubscribeLive {
..
} => Err(Error::internal(
"The protocol or storage engine does not support live queries on this architecture"
.to_string(),
)),
Command::Query {
txn,
query,
variables,
} => {
let req = Command::Query {
txn,
query,
variables,
}
.into_router_request(None, Some(session_id))
.expect("command should convert to router request");
send_request(
req,
base_url,
client,
&*session_state.headers.read().await,
&*session_state.auth.read().await,
)
.await
}
cmd => {
let req = cmd
.into_router_request(None, Some(session_id))
.expect("command should convert to router request");
let res = send_request(
req,
base_url,
client,
&*session_state.headers.read().await,
&*session_state.auth.read().await,
)
.await?;
Ok(res)
}
}
}