Skip to main content

surrealdb/api/
mod.rs

1//! Functionality for connecting to local and remote databases
2
3use std::fmt;
4use std::fmt::Debug;
5use std::future::IntoFuture;
6use std::marker::PhantomData;
7use std::sync::{Arc, OnceLock};
8
9use anyhow::ensure;
10use method::BoxFuture;
11use semver::{BuildMetadata, Version, VersionReq};
12use tokio::sync::watch;
13
14use crate::Result;
15
16macro_rules! transparent_wrapper{
17	(
18		$(#[$m:meta])*
19		$vis:vis struct $name:ident($field_vis:vis $inner:ty)
20	) => {
21		$(#[$m])*
22		#[repr(transparent)]
23		$vis struct $name($field_vis $inner);
24
25		#[allow(dead_code)]
26		impl $name{
27			#[doc(hidden)]
28			pub fn from_inner(inner: $inner) -> Self{
29				$name(inner)
30			}
31
32			#[doc(hidden)]
33			pub fn from_inner_ref(inner: &$inner) -> &Self{
34				unsafe{
35					std::mem::transmute::<&$inner,&$name>(inner)
36				}
37			}
38
39			#[doc(hidden)]
40			pub fn from_inner_mut(inner: &mut $inner) -> &mut Self{
41				unsafe{
42					std::mem::transmute::<&mut $inner,&mut $name>(inner)
43				}
44			}
45
46			#[doc(hidden)]
47			pub fn into_inner(self) -> $inner{
48				self.0
49			}
50
51			#[doc(hidden)]
52			pub fn into_inner_ref(&self) -> &$inner{
53				&self.0
54			}
55
56			#[doc(hidden)]
57			pub fn into_inner_mut(&mut self) -> &mut $inner{
58				&mut self.0
59			}
60		}
61
62		impl std::fmt::Display for $name{
63			fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result{
64				self.0.fmt(fmt)
65			}
66		}
67		impl std::fmt::Debug for $name{
68			fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result{
69				self.0.fmt(fmt)
70			}
71		}
72	};
73}
74
75macro_rules! impl_serialize_wrapper {
76	($ty:ty) => {
77		impl ::revision::Revisioned for $ty {
78			fn revision() -> u16 {
79				crate::core::val::Value::revision()
80			}
81
82			fn serialize_revisioned<W: std::io::Write>(
83				&self,
84				w: &mut W,
85			) -> std::result::Result<(), revision::Error> {
86				self.0.serialize_revisioned(w)
87			}
88
89			fn deserialize_revisioned<R: std::io::Read>(
90				r: &mut R,
91			) -> std::result::Result<Self, revision::Error>
92			where
93				Self: Sized,
94			{
95				::revision::Revisioned::deserialize_revisioned(r).map(Self::from_inner)
96			}
97		}
98
99		impl ::serde::Serialize for $ty {
100			fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
101			where
102				S: ::serde::ser::Serializer,
103			{
104				self.0.serialize(serializer)
105			}
106		}
107
108		impl<'de> ::serde::de::Deserialize<'de> for $ty {
109			fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
110			where
111				D: ::serde::de::Deserializer<'de>,
112			{
113				Ok(Self::from_inner(::serde::de::Deserialize::deserialize(deserializer)?))
114			}
115		}
116	};
117}
118
119pub mod engine;
120pub mod err;
121#[cfg(feature = "protocol-http")]
122pub mod headers;
123pub mod method;
124pub mod opt;
125pub mod value;
126
127mod conn;
128
129pub use method::query::Response;
130
131use self::conn::Router;
132use self::err::Error;
133use self::opt::{Endpoint, EndpointKind, WaitFor};
134
135// Channel for waiters
136type Waiter = (watch::Sender<Option<WaitFor>>, watch::Receiver<Option<WaitFor>>);
137
138const SUPPORTED_VERSIONS: (&str, &str) = (">=1.2.0, <4.0.0", "20230701.55918b7c");
139
140/// Connection trait implemented by supported engines
141pub trait Connection: conn::Sealed {}
142
143/// The future returned when creating a new SurrealDB instance
144#[derive(Debug)]
145#[must_use = "futures do nothing unless you `.await` or poll them"]
146pub struct Connect<C: Connection, Response> {
147	surreal: Surreal<C>,
148	address: Result<Endpoint>,
149	capacity: usize,
150	response_type: PhantomData<Response>,
151}
152
153impl<C, R> Connect<C, R>
154where
155	C: Connection,
156{
157	/// Sets the maximum capacity of the connection
158	///
159	/// This is used to set bounds of the channels used internally
160	/// as well set the capacity of the `HashMap` used for routing
161	/// responses in case of the WebSocket client.
162	///
163	/// Setting this capacity to `0` (the default) means that
164	/// unbounded channels will be used. If your queries per second
165	/// are so high that the client is running out of memory,
166	/// it might be helpful to set this to a number that works best
167	/// for you.
168	///
169	/// # Examples
170	///
171	/// ```no_run
172	/// # #[tokio::main]
173	/// # async fn main() -> surrealdb::Result<()> {
174	/// use surrealdb::engine::remote::ws::Ws;
175	/// use surrealdb::Surreal;
176	///
177	/// let db = Surreal::new::<Ws>("localhost:8000")
178	///     .with_capacity(100_000)
179	///     .await?;
180	/// # Ok(())
181	/// # }
182	/// ```
183	pub const fn with_capacity(mut self, capacity: usize) -> Self {
184		self.capacity = capacity;
185		self
186	}
187}
188
189impl<Client> IntoFuture for Connect<Client, Surreal<Client>>
190where
191	Client: Connection,
192{
193	type Output = Result<Surreal<Client>>;
194	type IntoFuture = BoxFuture<'static, Self::Output>;
195
196	fn into_future(self) -> Self::IntoFuture {
197		Box::pin(async move {
198			let endpoint = self.address?;
199			let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
200			let client = Client::connect(endpoint, self.capacity).await?;
201			if endpoint_kind.is_remote() {
202				match client.version().await {
203					Ok(mut version) => {
204						// we would like to be able to connect to pre-releases too
205						version.pre = Default::default();
206						client.check_server_version(&version).await?;
207					}
208					// TODO(raphaeldarley) don't error if Method Not allowed
209					Err(e) => return Err(e),
210				}
211			}
212			// Both ends of the channel are still alive at this point
213			client.inner.waiter.0.send(Some(WaitFor::Connection)).ok();
214			Ok(client)
215		})
216	}
217}
218
219impl<Client> IntoFuture for Connect<Client, ()>
220where
221	Client: Connection,
222{
223	type Output = Result<()>;
224	type IntoFuture = BoxFuture<'static, Self::Output>;
225
226	fn into_future(self) -> Self::IntoFuture {
227		Box::pin(async move {
228			// Avoid establishing another connection if already connected
229			ensure!(self.surreal.inner.router.get().is_none(), Error::AlreadyConnected);
230			let endpoint = self.address?;
231			let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
232			let client = Client::connect(endpoint, self.capacity).await?;
233			if endpoint_kind.is_remote() {
234				match client.version().await {
235					Ok(mut version) => {
236						// we would like to be able to connect to pre-releases too
237						version.pre = Default::default();
238						client.check_server_version(&version).await?;
239					}
240					// TODO(raphaeldarley) don't error if Method Not allowed
241					Err(e) => return Err(e),
242				}
243			}
244			let inner =
245				Arc::into_inner(client.inner).expect("new connection to have no references");
246			let router = inner.router.into_inner().expect("router to be set");
247			self.surreal.inner.router.set(router).map_err(|_| Error::AlreadyConnected)?;
248			// Both ends of the channel are still alive at this point
249			self.surreal.inner.waiter.0.send(Some(WaitFor::Connection)).ok();
250			Ok(())
251		})
252	}
253}
254
255#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)]
256pub(crate) enum ExtraFeatures {
257	Backup,
258	LiveQueries,
259}
260
261#[derive(Debug)]
262struct Inner {
263	router: OnceLock<Router>,
264	waiter: Waiter,
265}
266
267/// A database client instance for embedded or remote databases.
268///
269/// See [Running SurrealDB embedded in
270/// Rust](crate#running-surrealdb-embedded-in-rust) for tips on how to optimize
271/// performance for the client when working with embedded instances.
272pub struct Surreal<C: Connection> {
273	inner: Arc<Inner>,
274	engine: PhantomData<C>,
275}
276
277impl<C> From<(OnceLock<Router>, Waiter)> for Surreal<C>
278where
279	C: Connection,
280{
281	fn from((router, waiter): (OnceLock<Router>, Waiter)) -> Self {
282		Surreal {
283			inner: Arc::new(Inner {
284				router,
285				waiter,
286			}),
287			engine: PhantomData,
288		}
289	}
290}
291
292impl<C> From<(Router, Waiter)> for Surreal<C>
293where
294	C: Connection,
295{
296	fn from((router, waiter): (Router, Waiter)) -> Self {
297		Surreal {
298			inner: Arc::new(Inner {
299				router: OnceLock::with_value(router),
300				waiter,
301			}),
302			engine: PhantomData,
303		}
304	}
305}
306
307impl<C> From<Arc<Inner>> for Surreal<C>
308where
309	C: Connection,
310{
311	fn from(inner: Arc<Inner>) -> Self {
312		Surreal {
313			inner,
314			engine: PhantomData,
315		}
316	}
317}
318
319impl<C> Surreal<C>
320where
321	C: Connection,
322{
323	async fn check_server_version(&self, version: &Version) -> Result<()> {
324		let (versions, build_meta) = SUPPORTED_VERSIONS;
325		// invalid version requirements should be caught during development
326		let req = VersionReq::parse(versions).expect("valid supported versions");
327		let build_meta = BuildMetadata::new(build_meta).expect("valid supported build metadata");
328		let server_build = &version.build;
329		ensure!(
330			req.matches(version),
331			Error::VersionMismatch {
332				server_version: version.clone(),
333				supported_versions: versions.to_owned(),
334			}
335		);
336
337		ensure!(
338			server_build.is_empty() || server_build >= &build_meta,
339			Error::BuildMetadataMismatch {
340				server_metadata: server_build.clone(),
341				supported_metadata: build_meta,
342			}
343		);
344		Ok(())
345	}
346}
347
348impl<C> Clone for Surreal<C>
349where
350	C: Connection,
351{
352	fn clone(&self) -> Self {
353		Self {
354			inner: self.inner.clone(),
355			engine: self.engine,
356		}
357	}
358}
359
360impl<C> Debug for Surreal<C>
361where
362	C: Connection,
363{
364	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365		f.debug_struct("Surreal")
366			.field("router", &self.inner.router)
367			.field("engine", &self.engine)
368			.finish()
369	}
370}
371
372trait OnceLockExt {
373	fn with_value(value: Router) -> OnceLock<Router> {
374		let cell = OnceLock::new();
375		match cell.set(value) {
376			Ok(()) => cell,
377			Err(_) => unreachable!("don't have exclusive access to `cell`"),
378		}
379	}
380
381	fn extract(&self) -> Result<&Router>;
382}
383
384impl OnceLockExt for OnceLock<Router> {
385	fn extract(&self) -> Result<&Router> {
386		let router = self.get().ok_or(Error::ConnectionUninitialised)?;
387		Ok(router)
388	}
389}