1use std::fmt;
4use std::fmt::Debug;
5use std::future::IntoFuture;
6use std::marker::PhantomData;
7use std::sync::{Arc, OnceLock};
8
9use anyhow::ensure;
10use method::BoxFuture;
11use semver::{BuildMetadata, Version, VersionReq};
12use tokio::sync::watch;
13
14use crate::Result;
15
16macro_rules! transparent_wrapper{
17 (
18 $(#[$m:meta])*
19 $vis:vis struct $name:ident($field_vis:vis $inner:ty)
20 ) => {
21 $(#[$m])*
22 #[repr(transparent)]
23 $vis struct $name($field_vis $inner);
24
25 #[allow(dead_code)]
26 impl $name{
27 #[doc(hidden)]
28 pub fn from_inner(inner: $inner) -> Self{
29 $name(inner)
30 }
31
32 #[doc(hidden)]
33 pub fn from_inner_ref(inner: &$inner) -> &Self{
34 unsafe{
35 std::mem::transmute::<&$inner,&$name>(inner)
36 }
37 }
38
39 #[doc(hidden)]
40 pub fn from_inner_mut(inner: &mut $inner) -> &mut Self{
41 unsafe{
42 std::mem::transmute::<&mut $inner,&mut $name>(inner)
43 }
44 }
45
46 #[doc(hidden)]
47 pub fn into_inner(self) -> $inner{
48 self.0
49 }
50
51 #[doc(hidden)]
52 pub fn into_inner_ref(&self) -> &$inner{
53 &self.0
54 }
55
56 #[doc(hidden)]
57 pub fn into_inner_mut(&mut self) -> &mut $inner{
58 &mut self.0
59 }
60 }
61
62 impl std::fmt::Display for $name{
63 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result{
64 self.0.fmt(fmt)
65 }
66 }
67 impl std::fmt::Debug for $name{
68 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result{
69 self.0.fmt(fmt)
70 }
71 }
72 };
73}
74
75macro_rules! impl_serialize_wrapper {
76 ($ty:ty) => {
77 impl ::revision::Revisioned for $ty {
78 fn revision() -> u16 {
79 crate::core::val::Value::revision()
80 }
81
82 fn serialize_revisioned<W: std::io::Write>(
83 &self,
84 w: &mut W,
85 ) -> std::result::Result<(), revision::Error> {
86 self.0.serialize_revisioned(w)
87 }
88
89 fn deserialize_revisioned<R: std::io::Read>(
90 r: &mut R,
91 ) -> std::result::Result<Self, revision::Error>
92 where
93 Self: Sized,
94 {
95 ::revision::Revisioned::deserialize_revisioned(r).map(Self::from_inner)
96 }
97 }
98
99 impl ::serde::Serialize for $ty {
100 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
101 where
102 S: ::serde::ser::Serializer,
103 {
104 self.0.serialize(serializer)
105 }
106 }
107
108 impl<'de> ::serde::de::Deserialize<'de> for $ty {
109 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
110 where
111 D: ::serde::de::Deserializer<'de>,
112 {
113 Ok(Self::from_inner(::serde::de::Deserialize::deserialize(deserializer)?))
114 }
115 }
116 };
117}
118
119pub mod engine;
120pub mod err;
121#[cfg(feature = "protocol-http")]
122pub mod headers;
123pub mod method;
124pub mod opt;
125pub mod value;
126
127mod conn;
128
129pub use method::query::Response;
130
131use self::conn::Router;
132use self::err::Error;
133use self::opt::{Endpoint, EndpointKind, WaitFor};
134
135type Waiter = (watch::Sender<Option<WaitFor>>, watch::Receiver<Option<WaitFor>>);
137
138const SUPPORTED_VERSIONS: (&str, &str) = (">=1.2.0, <4.0.0", "20230701.55918b7c");
139
140pub trait Connection: conn::Sealed {}
142
143#[derive(Debug)]
145#[must_use = "futures do nothing unless you `.await` or poll them"]
146pub struct Connect<C: Connection, Response> {
147 surreal: Surreal<C>,
148 address: Result<Endpoint>,
149 capacity: usize,
150 response_type: PhantomData<Response>,
151}
152
153impl<C, R> Connect<C, R>
154where
155 C: Connection,
156{
157 pub const fn with_capacity(mut self, capacity: usize) -> Self {
184 self.capacity = capacity;
185 self
186 }
187}
188
189impl<Client> IntoFuture for Connect<Client, Surreal<Client>>
190where
191 Client: Connection,
192{
193 type Output = Result<Surreal<Client>>;
194 type IntoFuture = BoxFuture<'static, Self::Output>;
195
196 fn into_future(self) -> Self::IntoFuture {
197 Box::pin(async move {
198 let endpoint = self.address?;
199 let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
200 let client = Client::connect(endpoint, self.capacity).await?;
201 if endpoint_kind.is_remote() {
202 match client.version().await {
203 Ok(mut version) => {
204 version.pre = Default::default();
206 client.check_server_version(&version).await?;
207 }
208 Err(e) => return Err(e),
210 }
211 }
212 client.inner.waiter.0.send(Some(WaitFor::Connection)).ok();
214 Ok(client)
215 })
216 }
217}
218
219impl<Client> IntoFuture for Connect<Client, ()>
220where
221 Client: Connection,
222{
223 type Output = Result<()>;
224 type IntoFuture = BoxFuture<'static, Self::Output>;
225
226 fn into_future(self) -> Self::IntoFuture {
227 Box::pin(async move {
228 ensure!(self.surreal.inner.router.get().is_none(), Error::AlreadyConnected);
230 let endpoint = self.address?;
231 let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
232 let client = Client::connect(endpoint, self.capacity).await?;
233 if endpoint_kind.is_remote() {
234 match client.version().await {
235 Ok(mut version) => {
236 version.pre = Default::default();
238 client.check_server_version(&version).await?;
239 }
240 Err(e) => return Err(e),
242 }
243 }
244 let inner =
245 Arc::into_inner(client.inner).expect("new connection to have no references");
246 let router = inner.router.into_inner().expect("router to be set");
247 self.surreal.inner.router.set(router).map_err(|_| Error::AlreadyConnected)?;
248 self.surreal.inner.waiter.0.send(Some(WaitFor::Connection)).ok();
250 Ok(())
251 })
252 }
253}
254
255#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)]
256pub(crate) enum ExtraFeatures {
257 Backup,
258 LiveQueries,
259}
260
261#[derive(Debug)]
262struct Inner {
263 router: OnceLock<Router>,
264 waiter: Waiter,
265}
266
267pub struct Surreal<C: Connection> {
273 inner: Arc<Inner>,
274 engine: PhantomData<C>,
275}
276
277impl<C> From<(OnceLock<Router>, Waiter)> for Surreal<C>
278where
279 C: Connection,
280{
281 fn from((router, waiter): (OnceLock<Router>, Waiter)) -> Self {
282 Surreal {
283 inner: Arc::new(Inner {
284 router,
285 waiter,
286 }),
287 engine: PhantomData,
288 }
289 }
290}
291
292impl<C> From<(Router, Waiter)> for Surreal<C>
293where
294 C: Connection,
295{
296 fn from((router, waiter): (Router, Waiter)) -> Self {
297 Surreal {
298 inner: Arc::new(Inner {
299 router: OnceLock::with_value(router),
300 waiter,
301 }),
302 engine: PhantomData,
303 }
304 }
305}
306
307impl<C> From<Arc<Inner>> for Surreal<C>
308where
309 C: Connection,
310{
311 fn from(inner: Arc<Inner>) -> Self {
312 Surreal {
313 inner,
314 engine: PhantomData,
315 }
316 }
317}
318
319impl<C> Surreal<C>
320where
321 C: Connection,
322{
323 async fn check_server_version(&self, version: &Version) -> Result<()> {
324 let (versions, build_meta) = SUPPORTED_VERSIONS;
325 let req = VersionReq::parse(versions).expect("valid supported versions");
327 let build_meta = BuildMetadata::new(build_meta).expect("valid supported build metadata");
328 let server_build = &version.build;
329 ensure!(
330 req.matches(version),
331 Error::VersionMismatch {
332 server_version: version.clone(),
333 supported_versions: versions.to_owned(),
334 }
335 );
336
337 ensure!(
338 server_build.is_empty() || server_build >= &build_meta,
339 Error::BuildMetadataMismatch {
340 server_metadata: server_build.clone(),
341 supported_metadata: build_meta,
342 }
343 );
344 Ok(())
345 }
346}
347
348impl<C> Clone for Surreal<C>
349where
350 C: Connection,
351{
352 fn clone(&self) -> Self {
353 Self {
354 inner: self.inner.clone(),
355 engine: self.engine,
356 }
357 }
358}
359
360impl<C> Debug for Surreal<C>
361where
362 C: Connection,
363{
364 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365 f.debug_struct("Surreal")
366 .field("router", &self.inner.router)
367 .field("engine", &self.engine)
368 .finish()
369 }
370}
371
372trait OnceLockExt {
373 fn with_value(value: Router) -> OnceLock<Router> {
374 let cell = OnceLock::new();
375 match cell.set(value) {
376 Ok(()) => cell,
377 Err(_) => unreachable!("don't have exclusive access to `cell`"),
378 }
379 }
380
381 fn extract(&self) -> Result<&Router>;
382}
383
384impl OnceLockExt for OnceLock<Router> {
385 fn extract(&self) -> Result<&Router> {
386 let router = self.get().ok_or(Error::ConnectionUninitialised)?;
387 Ok(router)
388 }
389}