surrealdb/api/engine/remote/ws/
native.rs

1use super::PATH;
2use crate::api::conn::Connection;
3use crate::api::conn::DbResponse;
4use crate::api::conn::Method;
5use crate::api::conn::Param;
6use crate::api::conn::Route;
7use crate::api::conn::Router;
8use crate::api::engine::remote::ws::Client;
9use crate::api::engine::remote::ws::Response;
10use crate::api::engine::remote::ws::PING_INTERVAL;
11use crate::api::engine::remote::ws::PING_METHOD;
12use crate::api::err::Error;
13use crate::api::opt::Endpoint;
14#[cfg(any(feature = "native-tls", feature = "rustls"))]
15use crate::api::opt::Tls;
16use crate::api::ExtraFeatures;
17use crate::api::OnceLockExt;
18use crate::api::Result;
19use crate::api::Surreal;
20use crate::engine::remote::ws::Data;
21use crate::engine::IntervalStream;
22use crate::sql::serde::{deserialize, serialize};
23use crate::sql::Strand;
24use crate::sql::Value;
25use flume::Receiver;
26use futures::stream::SplitSink;
27use futures::SinkExt;
28use futures::StreamExt;
29use futures_concurrency::stream::Merge as _;
30use indexmap::IndexMap;
31use serde::Deserialize;
32use std::collections::hash_map::Entry;
33use std::collections::BTreeMap;
34use std::collections::HashMap;
35use std::collections::HashSet;
36use std::future::Future;
37use std::marker::PhantomData;
38use std::pin::Pin;
39use std::sync::atomic::AtomicI64;
40use std::sync::Arc;
41use std::sync::OnceLock;
42use tokio::net::TcpStream;
43use tokio::time;
44use tokio::time::MissedTickBehavior;
45use tokio_tungstenite::tungstenite::error::Error as WsError;
46use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
47use tokio_tungstenite::tungstenite::Message;
48use tokio_tungstenite::Connector;
49use tokio_tungstenite::MaybeTlsStream;
50use tokio_tungstenite::WebSocketStream;
51use trice::Instant;
52use url::Url;
53
54type WsResult<T> = std::result::Result<T, WsError>;
55
56pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20; // 64 MiB
57pub(crate) const MAX_FRAME_SIZE: usize = 16 << 20; // 16 MiB
58pub(crate) const WRITE_BUFFER_SIZE: usize = 128000; // tungstenite default
59pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = WRITE_BUFFER_SIZE + MAX_MESSAGE_SIZE; // Recommended max according to tungstenite docs
60pub(crate) const NAGLE_ALG: bool = false;
61
62pub(crate) enum Either {
63	Request(Option<Route>),
64	Response(WsResult<Message>),
65	Ping,
66}
67
68#[cfg(any(feature = "native-tls", feature = "rustls"))]
69impl From<Tls> for Connector {
70	fn from(tls: Tls) -> Self {
71		match tls {
72			#[cfg(feature = "native-tls")]
73			Tls::Native(config) => Self::NativeTls(config),
74			#[cfg(feature = "rustls")]
75			Tls::Rust(config) => Self::Rustls(Arc::new(config)),
76		}
77	}
78}
79
80pub(crate) async fn connect(
81	url: &Url,
82	config: Option<WebSocketConfig>,
83	#[allow(unused_variables)] maybe_connector: Option<Connector>,
84) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
85	#[cfg(any(feature = "native-tls", feature = "rustls"))]
86	let (socket, _) =
87		tokio_tungstenite::connect_async_tls_with_config(url, config, NAGLE_ALG, maybe_connector)
88			.await?;
89
90	#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
91	let (socket, _) = tokio_tungstenite::connect_async_with_config(url, config, NAGLE_ALG).await?;
92
93	Ok(socket)
94}
95
96impl crate::api::Connection for Client {}
97
98impl Connection for Client {
99	fn new(method: Method) -> Self {
100		Self {
101			id: 0,
102			method,
103		}
104	}
105
106	fn connect(
107		address: Endpoint,
108		capacity: usize,
109	) -> Pin<Box<dyn Future<Output = Result<Surreal<Self>>> + Send + Sync + 'static>> {
110		Box::pin(async move {
111			let url = address.url.join(PATH)?;
112			#[cfg(any(feature = "native-tls", feature = "rustls"))]
113			let maybe_connector = address.config.tls_config.map(Connector::from);
114			#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
115			let maybe_connector = None;
116
117			let config = WebSocketConfig {
118				max_message_size: Some(MAX_MESSAGE_SIZE),
119				max_frame_size: Some(MAX_FRAME_SIZE),
120				max_write_buffer_size: MAX_WRITE_BUFFER_SIZE,
121				..Default::default()
122			};
123
124			let socket = connect(&url, Some(config), maybe_connector.clone()).await?;
125
126			let (route_tx, route_rx) = match capacity {
127				0 => flume::unbounded(),
128				capacity => flume::bounded(capacity),
129			};
130
131			router(url, maybe_connector, capacity, config, socket, route_rx);
132
133			let mut features = HashSet::new();
134			features.insert(ExtraFeatures::LiveQueries);
135
136			Ok(Surreal {
137				router: Arc::new(OnceLock::with_value(Router {
138					features,
139					conn: PhantomData,
140					sender: route_tx,
141					last_id: AtomicI64::new(0),
142				})),
143			})
144		})
145	}
146
147	fn send<'r>(
148		&'r mut self,
149		router: &'r Router<Self>,
150		param: Param,
151	) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
152		Box::pin(async move {
153			self.id = router.next_id();
154			let (sender, receiver) = flume::bounded(1);
155			let route = Route {
156				request: (self.id, self.method, param),
157				response: sender,
158			};
159			router.sender.send_async(Some(route)).await?;
160			Ok(receiver)
161		})
162	}
163}
164
165#[allow(clippy::too_many_lines)]
166pub(crate) fn router(
167	url: Url,
168	maybe_connector: Option<Connector>,
169	capacity: usize,
170	config: WebSocketConfig,
171	mut socket: WebSocketStream<MaybeTlsStream<TcpStream>>,
172	route_rx: Receiver<Option<Route>>,
173) {
174	tokio::spawn(async move {
175		let ping = {
176			let mut request = BTreeMap::new();
177			request.insert("method".to_owned(), PING_METHOD.into());
178			let value = Value::from(request);
179			let value = serialize(&value).unwrap();
180			Message::Binary(value)
181		};
182
183		let mut vars = IndexMap::new();
184		let mut replay = IndexMap::new();
185
186		'router: loop {
187			let (socket_sink, socket_stream) = socket.split();
188			let mut socket_sink = Socket(Some(socket_sink));
189
190			if let Socket(Some(socket_sink)) = &mut socket_sink {
191				let mut routes = match capacity {
192					0 => HashMap::new(),
193					capacity => HashMap::with_capacity(capacity),
194				};
195				let mut live_queries = HashMap::new();
196
197				let mut interval = time::interval(PING_INTERVAL);
198				// don't bombard the server with pings if we miss some ticks
199				interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
200				// Delay sending the first ping
201				interval.tick().await;
202
203				let pinger = IntervalStream::new(interval);
204
205				let streams = (
206					socket_stream.map(Either::Response),
207					route_rx.stream().map(Either::Request),
208					pinger.map(|_| Either::Ping),
209				);
210
211				let mut merged = streams.merge();
212				let mut last_activity = Instant::now();
213
214				while let Some(either) = merged.next().await {
215					match either {
216						Either::Request(Some(Route {
217							request,
218							response,
219						})) => {
220							let (id, method, param) = request;
221							let params = match param.query {
222								Some((query, bindings)) => {
223									vec![query.into(), bindings.into()]
224								}
225								None => param.other,
226							};
227							match method {
228								Method::Set => {
229									if let [Value::Strand(Strand(key)), value] = &params[..2] {
230										vars.insert(key.clone(), value.clone());
231									}
232								}
233								Method::Unset => {
234									if let [Value::Strand(Strand(key))] = &params[..1] {
235										vars.remove(key);
236									}
237								}
238								Method::Live => {
239									if let Some(sender) = param.notification_sender {
240										if let [Value::Uuid(id)] = &params[..1] {
241											live_queries.insert(*id, sender);
242										}
243									}
244									if response
245										.into_send_async(Ok(DbResponse::Other(Value::None)))
246										.await
247										.is_err()
248									{
249										trace!("Receiver dropped");
250									}
251									// There is nothing to send to the server here
252									continue;
253								}
254								Method::Kill => {
255									if let [Value::Uuid(id)] = &params[..1] {
256										live_queries.remove(id);
257									}
258								}
259								_ => {}
260							}
261							let method_str = match method {
262								Method::Health => PING_METHOD,
263								_ => method.as_str(),
264							};
265							let message = {
266								let mut request = BTreeMap::new();
267								request.insert("id".to_owned(), Value::from(id));
268								request.insert("method".to_owned(), method_str.into());
269								if !params.is_empty() {
270									request.insert("params".to_owned(), params.into());
271								}
272								let payload = Value::from(request);
273								trace!("Request {payload}");
274								let payload = serialize(&payload).unwrap();
275								Message::Binary(payload)
276							};
277							if let Method::Authenticate
278							| Method::Invalidate
279							| Method::Signin
280							| Method::Signup
281							| Method::Use = method
282							{
283								replay.insert(method, message.clone());
284							}
285							match socket_sink.send(message).await {
286								Ok(..) => {
287									last_activity = Instant::now();
288									match routes.entry(id) {
289										Entry::Vacant(entry) => {
290											// Register query route
291											entry.insert((method, response));
292										}
293										Entry::Occupied(..) => {
294											let error = Error::DuplicateRequestId(id);
295											if response
296												.into_send_async(Err(error.into()))
297												.await
298												.is_err()
299											{
300												trace!("Receiver dropped");
301											}
302										}
303									}
304								}
305								Err(error) => {
306									let error = Error::Ws(error.to_string());
307									if response.into_send_async(Err(error.into())).await.is_err() {
308										trace!("Receiver dropped");
309									}
310									break;
311								}
312							}
313						}
314						Either::Response(result) => {
315							last_activity = Instant::now();
316							match result {
317								Ok(message) => {
318									match Response::try_from(&message) {
319										Ok(option) => {
320											// We are only interested in responses that are not empty
321											if let Some(response) = option {
322												trace!("{response:?}");
323												match response.id {
324													// If `id` is set this is a normal response
325													Some(id) => {
326														if let Ok(id) = id.coerce_to_i64() {
327															// We can only route responses with IDs
328															if let Some((_method, sender)) =
329																routes.remove(&id)
330															{
331																// Send the response back to the caller
332																let _res = sender
333																	.into_send_async(
334																		DbResponse::from(
335																			response.result,
336																		),
337																	)
338																	.await;
339															}
340														}
341													}
342													// If `id` is not set, this may be a live query notification
343													None => match response.result {
344														Ok(Data::Live(notification)) => {
345															let live_query_id = notification.id;
346															// Check if this live query is registered
347															if let Some(sender) =
348																live_queries.get(&live_query_id)
349															{
350																// Send the notification back to the caller or kill live query if the receiver is already dropped
351																if sender
352																	.send(notification)
353																	.await
354																	.is_err()
355																{
356																	live_queries
357																		.remove(&live_query_id);
358																	let kill = {
359																		let mut request =
360																			BTreeMap::new();
361																		request.insert(
362																			"method".to_owned(),
363																			Method::Kill
364																				.as_str()
365																				.into(),
366																		);
367																		request.insert(
368																			"params".to_owned(),
369																			vec![Value::from(
370																				live_query_id,
371																			)]
372																			.into(),
373																		);
374																		let value =
375																			Value::from(request);
376																		let value =
377																			serialize(&value)
378																				.unwrap();
379																		Message::Binary(value)
380																	};
381																	if let Err(error) =
382																		socket_sink.send(kill).await
383																	{
384																		trace!("failed to send kill query to the server; {error:?}");
385																		break;
386																	}
387																}
388															}
389														}
390														Ok(..) => { /* Ignored responses like pings */
391														}
392														Err(error) => error!("{error:?}"),
393													},
394												}
395											}
396										}
397										Err(error) => {
398											#[derive(Deserialize)]
399											struct Response {
400												id: Option<Value>,
401											}
402
403											// Let's try to find out the ID of the response that failed to deserialise
404											if let Message::Binary(binary) = message {
405												if let Ok(Response {
406													id,
407												}) = deserialize(&binary)
408												{
409													// Return an error if an ID was returned
410													if let Some(Ok(id)) =
411														id.map(Value::coerce_to_i64)
412													{
413														if let Some((_method, sender)) =
414															routes.remove(&id)
415														{
416															let _res = sender
417																.into_send_async(Err(error))
418																.await;
419														}
420													}
421												} else {
422													// Unfortunately, we don't know which response failed to deserialize
423													warn!(
424														"Failed to deserialise message; {error:?}"
425													);
426												}
427											}
428										}
429									}
430								}
431								Err(error) => {
432									match error {
433										WsError::ConnectionClosed => {
434											trace!("Connection successfully closed on the server");
435										}
436										error => {
437											trace!("{error}");
438										}
439									}
440									break;
441								}
442							}
443						}
444						Either::Ping => {
445							// only ping if we haven't talked to the server recently
446							if last_activity.elapsed() >= PING_INTERVAL {
447								trace!("Pinging the server");
448								if let Err(error) = socket_sink.send(ping.clone()).await {
449									trace!("failed to ping the server; {error:?}");
450									break;
451								}
452							}
453						}
454						// Close connection request received
455						Either::Request(None) => {
456							match socket_sink.send(Message::Close(None)).await {
457								Ok(..) => trace!("Connection closed successfully"),
458								Err(error) => {
459									warn!("Failed to close database connection; {error}")
460								}
461							}
462							break 'router;
463						}
464					}
465				}
466			}
467
468			'reconnect: loop {
469				trace!("Reconnecting...");
470				match connect(&url, Some(config), maybe_connector.clone()).await {
471					Ok(s) => {
472						socket = s;
473						for (_, message) in &replay {
474							if let Err(error) = socket.send(message.clone()).await {
475								trace!("{error}");
476								time::sleep(time::Duration::from_secs(1)).await;
477								continue 'reconnect;
478							}
479						}
480						#[cfg(feature = "protocol-ws")]
481						for (key, value) in &vars {
482							let mut request = BTreeMap::new();
483							request.insert("method".to_owned(), Method::Set.as_str().into());
484							request.insert(
485								"params".to_owned(),
486								vec![key.as_str().into(), value.clone()].into(),
487							);
488							let payload = Value::from(request);
489							trace!("Request {payload}");
490							if let Err(error) = socket.send(Message::Binary(payload.into())).await {
491								trace!("{error}");
492								time::sleep(time::Duration::from_secs(1)).await;
493								continue 'reconnect;
494							}
495						}
496						trace!("Reconnected successfully");
497						break;
498					}
499					Err(error) => {
500						trace!("Failed to reconnect; {error}");
501						time::sleep(time::Duration::from_secs(1)).await;
502					}
503				}
504			}
505		}
506	});
507}
508
509impl Response {
510	fn try_from(message: &Message) -> Result<Option<Self>> {
511		match message {
512			Message::Text(text) => {
513				trace!("Received an unexpected text message; {text}");
514				Ok(None)
515			}
516			Message::Binary(binary) => deserialize(binary).map(Some).map_err(|error| {
517				Error::ResponseFromBinary {
518					binary: binary.clone(),
519					error,
520				}
521				.into()
522			}),
523			Message::Ping(..) => {
524				trace!("Received a ping from the server");
525				Ok(None)
526			}
527			Message::Pong(..) => {
528				trace!("Received a pong from the server");
529				Ok(None)
530			}
531			Message::Frame(..) => {
532				trace!("Received an unexpected raw frame");
533				Ok(None)
534			}
535			Message::Close(..) => {
536				trace!("Received an unexpected close message");
537				Ok(None)
538			}
539		}
540	}
541}
542
543pub struct Socket(Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>);