#![recursion_limit = "256"]
#![doc(html_favicon_url = "https://surrealdb.s3.amazonaws.com/favicon.png")]
#![doc(html_logo_url = "https://surrealdb.s3.amazonaws.com/icon.png")]
#![cfg_attr(docsrs, feature(doc_cfg))]
#[cfg(all(target_family = "wasm", feature = "ml"))]
compile_error!("The `ml` feature is not supported on Wasm.");
#[macro_use]
extern crate tracing;
pub mod engine;
#[doc(hidden)]
#[cfg(feature = "protocol-http")]
pub mod headers;
pub mod method;
pub mod opt;
mod conn;
mod notification;
#[doc(hidden)]
pub mod channel {
pub use async_channel::{Receiver, Sender, bounded, unbounded};
}
pub mod parse {
pub use surrealdb_core::syn::value;
}
#[doc(inline)]
pub use method::Stats;
#[doc(inline)]
pub use method::Stream;
#[doc(inline)]
pub use method::query::IndexedResults;
#[doc(inline)]
pub use surrealdb_types as types;
#[doc(inline)]
pub use crate::notification::Notification;
pub type Result<T> = std::result::Result<T, Error>;
use std::fmt;
use std::fmt::Debug;
use std::future::IntoFuture;
use std::marker::PhantomData;
use std::sync::{Arc, OnceLock};
use async_channel::{Receiver, Sender};
use method::BoxFuture;
use semver::{Version, VersionReq};
#[doc(inline)]
pub use surrealdb_types::Error;
use tokio::sync::watch;
use uuid::Uuid;
use self::conn::Router;
use self::opt::{Endpoint, EndpointKind, WaitFor};
type Waiter = (watch::Sender<Option<WaitFor>>, watch::Receiver<Option<WaitFor>>);
const SUPPORTED_VERSIONS: &str = ">=3.0.0-alpha.1, <4.0.0";
pub trait Connection: conn::Sealed {}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Connect<C: Connection, Response> {
surreal: Surreal<C>,
address: Result<Endpoint>,
capacity: usize,
response_type: PhantomData<Response>,
}
impl<C, R> Connect<C, R>
where
C: Connection,
{
pub const fn with_capacity(mut self, capacity: usize) -> Self {
self.capacity = capacity;
self
}
}
impl<Client> IntoFuture for Connect<Client, Surreal<Client>>
where
Client: Connection,
{
type Output = Result<Surreal<Client>>;
type IntoFuture = BoxFuture<'static, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let endpoint = self.address?;
let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
let client = Client::connect(endpoint, self.capacity, None).await?;
if endpoint_kind.is_remote() {
match client.version().await {
Ok(mut version) => {
version.pre = Default::default();
client.check_server_version(&version).await?;
}
Err(e) => return Err(e),
}
}
client.inner.waiter.0.send(Some(WaitFor::Connection)).ok();
Ok(client)
})
}
}
impl<Client> IntoFuture for Connect<Client, ()>
where
Client: Connection,
{
type Output = Result<()>;
type IntoFuture = BoxFuture<'static, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
if self.surreal.inner.router.get().is_some() {
return Err(Error::connection(
"Already connected".to_string(),
Some(crate::types::ConnectionError::AlreadyConnected),
));
}
let endpoint = self.address?;
let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
let session_clone = self.surreal.inner.session_clone.clone();
let client = Client::connect(endpoint, self.capacity, Some(session_clone)).await?;
if endpoint_kind.is_remote() {
match client.version().await {
Ok(mut version) => {
version.pre = Default::default();
client.check_server_version(&version).await?;
}
Err(e) => return Err(e),
}
}
let router = client.inner.router.wait().clone();
self.surreal.inner.router.set(router).map_err(|_| {
Error::connection(
"Already connected".to_string(),
Some(crate::types::ConnectionError::AlreadyConnected),
)
})?;
self.surreal.inner.waiter.0.send(Some(WaitFor::Connection)).ok();
Ok(())
})
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub(crate) enum ExtraFeatures {
Backup,
LiveQueries,
}
#[derive(Debug)]
#[allow(dead_code)]
enum SessionId {
Initial(Uuid),
Clone {
old: Uuid,
new: Uuid,
},
Drop(Uuid),
}
#[derive(Debug, Clone)]
struct SessionClone {
sender: Sender<SessionId>,
#[allow(dead_code)]
receiver: Receiver<SessionId>,
}
impl SessionClone {
fn new() -> Self {
let (sender, receiver) = async_channel::unbounded();
Self {
sender,
receiver,
}
}
}
#[derive(Debug)]
struct Inner {
router: OnceLock<Router>,
waiter: Waiter,
session_clone: SessionClone,
}
impl Inner {
fn clone_session(&self, old: Uuid, new: Uuid) {
self.session_clone
.sender
.try_send(SessionId::Clone {
old,
new,
})
.ok();
}
}
pub struct Surreal<C: Connection> {
inner: Arc<Inner>,
session_id: Uuid,
engine: PhantomData<C>,
}
#[doc(hidden)]
impl<C> From<Arc<Inner>> for Surreal<C>
where
C: Connection,
{
fn from(inner: Arc<Inner>) -> Self {
let session_id = Uuid::new_v4();
inner.session_clone.sender.try_send(SessionId::Initial(session_id)).ok();
Surreal {
inner,
session_id,
engine: PhantomData,
}
}
}
#[doc(hidden)]
impl<C> From<(OnceLock<Router>, Waiter, SessionClone)> for Surreal<C>
where
C: Connection,
{
fn from((router, waiter, session_clone): (OnceLock<Router>, Waiter, SessionClone)) -> Self {
Arc::new(Inner {
router,
waiter,
session_clone,
})
.into()
}
}
#[doc(hidden)]
impl<C> From<(Router, Waiter, SessionClone)> for Surreal<C>
where
C: Connection,
{
fn from((router, waiter, session_clone): (Router, Waiter, SessionClone)) -> Self {
let oncelock = OnceLock::with_value(router);
(oncelock, waiter, session_clone).into()
}
}
impl<C> Surreal<C>
where
C: Connection,
{
async fn check_server_version(&self, version: &Version) -> Result<()> {
let req = VersionReq::parse(SUPPORTED_VERSIONS).expect("valid supported versions");
if !req.matches(version) {
return Err(Error::internal(format!(
"server version `{version}` does not match the range supported by the client `{SUPPORTED_VERSIONS}`"
)));
}
Ok(())
}
}
impl<C> Clone for Surreal<C>
where
C: Connection,
{
fn clone(&self) -> Self {
let session_id = Uuid::new_v4();
self.inner.clone_session(self.session_id, session_id);
Self {
inner: self.inner.clone(),
session_id,
engine: self.engine,
}
}
}
impl<C> Drop for Surreal<C>
where
C: Connection,
{
fn drop(&mut self) {
self.inner.session_clone.sender.try_send(SessionId::Drop(self.session_id)).ok();
}
}
impl<C> Debug for Surreal<C>
where
C: Connection,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Surreal")
.field("router", &self.inner.router)
.field("session_id", &self.session_id)
.field("engine", &self.engine)
.finish()
}
}
trait OnceLockExt {
fn with_value(value: Router) -> OnceLock<Router> {
let cell = OnceLock::new();
match cell.set(value) {
Ok(()) => cell,
Err(_) => unreachable!("don't have exclusive access to `cell`"),
}
}
fn extract(&self) -> Result<&Router>;
}
impl OnceLockExt for OnceLock<Router> {
fn extract(&self) -> Result<&Router> {
let router = self.get().ok_or_else(|| {
Error::connection(
"Connection uninitialised".to_string(),
Some(crate::types::ConnectionError::Uninitialised),
)
})?;
Ok(router)
}
}
#[allow(dead_code)]
fn std_error_to_types_error(error: impl std::fmt::Display) -> Error {
Error::internal(error.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_supported_versions() {
let req = VersionReq::parse(SUPPORTED_VERSIONS).expect("valid supported versions");
assert!(req.matches(&Version::parse("3.0.0-alpha.1").unwrap()));
assert!(req.matches(&Version::parse("3.0.0-beta.3").unwrap()));
assert!(req.matches(&Version::parse("3.0.0-rc.2").unwrap()));
assert!(req.matches(&Version::parse("3.0.0").unwrap()));
assert!(req.matches(&Version::parse("3.0.1").unwrap()));
assert!(req.matches(&Version::parse("3.9.0").unwrap()));
assert!(!req.matches(&Version::parse("2.9.0").unwrap()));
assert!(!req.matches(&Version::parse("4.0.0").unwrap()));
}
}