surrealdb/api/engine/remote/ws/
native.rs1use super::PATH;
2use crate::api::conn::Connection;
3use crate::api::conn::DbResponse;
4use crate::api::conn::Method;
5use crate::api::conn::Param;
6use crate::api::conn::Route;
7use crate::api::conn::Router;
8use crate::api::engine::remote::ws::Client;
9use crate::api::engine::remote::ws::Response;
10use crate::api::engine::remote::ws::PING_INTERVAL;
11use crate::api::engine::remote::ws::PING_METHOD;
12use crate::api::err::Error;
13use crate::api::opt::Endpoint;
14#[cfg(any(feature = "native-tls", feature = "rustls"))]
15use crate::api::opt::Tls;
16use crate::api::ExtraFeatures;
17use crate::api::OnceLockExt;
18use crate::api::Result;
19use crate::api::Surreal;
20use crate::engine::remote::ws::Data;
21use crate::engine::IntervalStream;
22use crate::sql::serde::{deserialize, serialize};
23use crate::sql::Strand;
24use crate::sql::Value;
25use flume::Receiver;
26use futures::stream::SplitSink;
27use futures::SinkExt;
28use futures::StreamExt;
29use futures_concurrency::stream::Merge as _;
30use indexmap::IndexMap;
31use serde::Deserialize;
32use std::collections::hash_map::Entry;
33use std::collections::BTreeMap;
34use std::collections::HashMap;
35use std::collections::HashSet;
36use std::future::Future;
37use std::marker::PhantomData;
38use std::pin::Pin;
39use std::sync::atomic::AtomicI64;
40use std::sync::Arc;
41use std::sync::OnceLock;
42use tokio::net::TcpStream;
43use tokio::time;
44use tokio::time::MissedTickBehavior;
45use tokio_tungstenite::tungstenite::error::Error as WsError;
46use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
47use tokio_tungstenite::tungstenite::Message;
48use tokio_tungstenite::Connector;
49use tokio_tungstenite::MaybeTlsStream;
50use tokio_tungstenite::WebSocketStream;
51use trice::Instant;
52use url::Url;
53
54type WsResult<T> = std::result::Result<T, WsError>;
55
56pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20; pub(crate) const MAX_FRAME_SIZE: usize = 16 << 20; pub(crate) const WRITE_BUFFER_SIZE: usize = 128000; pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = WRITE_BUFFER_SIZE + MAX_MESSAGE_SIZE; pub(crate) const NAGLE_ALG: bool = false;
61
62pub(crate) enum Either {
63 Request(Option<Route>),
64 Response(WsResult<Message>),
65 Ping,
66}
67
68#[cfg(any(feature = "native-tls", feature = "rustls"))]
69impl From<Tls> for Connector {
70 fn from(tls: Tls) -> Self {
71 match tls {
72 #[cfg(feature = "native-tls")]
73 Tls::Native(config) => Self::NativeTls(config),
74 #[cfg(feature = "rustls")]
75 Tls::Rust(config) => Self::Rustls(Arc::new(config)),
76 }
77 }
78}
79
80pub(crate) async fn connect(
81 url: &Url,
82 config: Option<WebSocketConfig>,
83 #[allow(unused_variables)] maybe_connector: Option<Connector>,
84) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
85 #[cfg(any(feature = "native-tls", feature = "rustls"))]
86 let (socket, _) =
87 tokio_tungstenite::connect_async_tls_with_config(url, config, NAGLE_ALG, maybe_connector)
88 .await?;
89
90 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
91 let (socket, _) = tokio_tungstenite::connect_async_with_config(url, config, NAGLE_ALG).await?;
92
93 Ok(socket)
94}
95
96impl crate::api::Connection for Client {}
97
98impl Connection for Client {
99 fn new(method: Method) -> Self {
100 Self {
101 id: 0,
102 method,
103 }
104 }
105
106 fn connect(
107 address: Endpoint,
108 capacity: usize,
109 ) -> Pin<Box<dyn Future<Output = Result<Surreal<Self>>> + Send + Sync + 'static>> {
110 Box::pin(async move {
111 let url = address.url.join(PATH)?;
112 #[cfg(any(feature = "native-tls", feature = "rustls"))]
113 let maybe_connector = address.config.tls_config.map(Connector::from);
114 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
115 let maybe_connector = None;
116
117 let config = WebSocketConfig {
118 max_message_size: Some(MAX_MESSAGE_SIZE),
119 max_frame_size: Some(MAX_FRAME_SIZE),
120 max_write_buffer_size: MAX_WRITE_BUFFER_SIZE,
121 ..Default::default()
122 };
123
124 let socket = connect(&url, Some(config), maybe_connector.clone()).await?;
125
126 let (route_tx, route_rx) = match capacity {
127 0 => flume::unbounded(),
128 capacity => flume::bounded(capacity),
129 };
130
131 router(url, maybe_connector, capacity, config, socket, route_rx);
132
133 let mut features = HashSet::new();
134 features.insert(ExtraFeatures::LiveQueries);
135
136 Ok(Surreal {
137 router: Arc::new(OnceLock::with_value(Router {
138 features,
139 conn: PhantomData,
140 sender: route_tx,
141 last_id: AtomicI64::new(0),
142 })),
143 })
144 })
145 }
146
147 fn send<'r>(
148 &'r mut self,
149 router: &'r Router<Self>,
150 param: Param,
151 ) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
152 Box::pin(async move {
153 self.id = router.next_id();
154 let (sender, receiver) = flume::bounded(1);
155 let route = Route {
156 request: (self.id, self.method, param),
157 response: sender,
158 };
159 router.sender.send_async(Some(route)).await?;
160 Ok(receiver)
161 })
162 }
163}
164
165#[allow(clippy::too_many_lines)]
166pub(crate) fn router(
167 url: Url,
168 maybe_connector: Option<Connector>,
169 capacity: usize,
170 config: WebSocketConfig,
171 mut socket: WebSocketStream<MaybeTlsStream<TcpStream>>,
172 route_rx: Receiver<Option<Route>>,
173) {
174 tokio::spawn(async move {
175 let ping = {
176 let mut request = BTreeMap::new();
177 request.insert("method".to_owned(), PING_METHOD.into());
178 let value = Value::from(request);
179 let value = serialize(&value).unwrap();
180 Message::Binary(value)
181 };
182
183 let mut vars = IndexMap::new();
184 let mut replay = IndexMap::new();
185
186 'router: loop {
187 let (socket_sink, socket_stream) = socket.split();
188 let mut socket_sink = Socket(Some(socket_sink));
189
190 if let Socket(Some(socket_sink)) = &mut socket_sink {
191 let mut routes = match capacity {
192 0 => HashMap::new(),
193 capacity => HashMap::with_capacity(capacity),
194 };
195 let mut live_queries = HashMap::new();
196
197 let mut interval = time::interval(PING_INTERVAL);
198 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
200 interval.tick().await;
202
203 let pinger = IntervalStream::new(interval);
204
205 let streams = (
206 socket_stream.map(Either::Response),
207 route_rx.stream().map(Either::Request),
208 pinger.map(|_| Either::Ping),
209 );
210
211 let mut merged = streams.merge();
212 let mut last_activity = Instant::now();
213
214 while let Some(either) = merged.next().await {
215 match either {
216 Either::Request(Some(Route {
217 request,
218 response,
219 })) => {
220 let (id, method, param) = request;
221 let params = match param.query {
222 Some((query, bindings)) => {
223 vec![query.into(), bindings.into()]
224 }
225 None => param.other,
226 };
227 match method {
228 Method::Set => {
229 if let [Value::Strand(Strand(key)), value] = ¶ms[..2] {
230 vars.insert(key.clone(), value.clone());
231 }
232 }
233 Method::Unset => {
234 if let [Value::Strand(Strand(key))] = ¶ms[..1] {
235 vars.remove(key);
236 }
237 }
238 Method::Live => {
239 if let Some(sender) = param.notification_sender {
240 if let [Value::Uuid(id)] = ¶ms[..1] {
241 live_queries.insert(*id, sender);
242 }
243 }
244 if response
245 .into_send_async(Ok(DbResponse::Other(Value::None)))
246 .await
247 .is_err()
248 {
249 trace!("Receiver dropped");
250 }
251 continue;
253 }
254 Method::Kill => {
255 if let [Value::Uuid(id)] = ¶ms[..1] {
256 live_queries.remove(id);
257 }
258 }
259 _ => {}
260 }
261 let method_str = match method {
262 Method::Health => PING_METHOD,
263 _ => method.as_str(),
264 };
265 let message = {
266 let mut request = BTreeMap::new();
267 request.insert("id".to_owned(), Value::from(id));
268 request.insert("method".to_owned(), method_str.into());
269 if !params.is_empty() {
270 request.insert("params".to_owned(), params.into());
271 }
272 let payload = Value::from(request);
273 trace!("Request {payload}");
274 let payload = serialize(&payload).unwrap();
275 Message::Binary(payload)
276 };
277 if let Method::Authenticate
278 | Method::Invalidate
279 | Method::Signin
280 | Method::Signup
281 | Method::Use = method
282 {
283 replay.insert(method, message.clone());
284 }
285 match socket_sink.send(message).await {
286 Ok(..) => {
287 last_activity = Instant::now();
288 match routes.entry(id) {
289 Entry::Vacant(entry) => {
290 entry.insert((method, response));
292 }
293 Entry::Occupied(..) => {
294 let error = Error::DuplicateRequestId(id);
295 if response
296 .into_send_async(Err(error.into()))
297 .await
298 .is_err()
299 {
300 trace!("Receiver dropped");
301 }
302 }
303 }
304 }
305 Err(error) => {
306 let error = Error::Ws(error.to_string());
307 if response.into_send_async(Err(error.into())).await.is_err() {
308 trace!("Receiver dropped");
309 }
310 break;
311 }
312 }
313 }
314 Either::Response(result) => {
315 last_activity = Instant::now();
316 match result {
317 Ok(message) => {
318 match Response::try_from(&message) {
319 Ok(option) => {
320 if let Some(response) = option {
322 trace!("{response:?}");
323 match response.id {
324 Some(id) => {
326 if let Ok(id) = id.coerce_to_i64() {
327 if let Some((_method, sender)) =
329 routes.remove(&id)
330 {
331 let _res = sender
333 .into_send_async(
334 DbResponse::from(
335 response.result,
336 ),
337 )
338 .await;
339 }
340 }
341 }
342 None => match response.result {
344 Ok(Data::Live(notification)) => {
345 let live_query_id = notification.id;
346 if let Some(sender) =
348 live_queries.get(&live_query_id)
349 {
350 if sender
352 .send(notification)
353 .await
354 .is_err()
355 {
356 live_queries
357 .remove(&live_query_id);
358 let kill = {
359 let mut request =
360 BTreeMap::new();
361 request.insert(
362 "method".to_owned(),
363 Method::Kill
364 .as_str()
365 .into(),
366 );
367 request.insert(
368 "params".to_owned(),
369 vec![Value::from(
370 live_query_id,
371 )]
372 .into(),
373 );
374 let value =
375 Value::from(request);
376 let value =
377 serialize(&value)
378 .unwrap();
379 Message::Binary(value)
380 };
381 if let Err(error) =
382 socket_sink.send(kill).await
383 {
384 trace!("failed to send kill query to the server; {error:?}");
385 break;
386 }
387 }
388 }
389 }
390 Ok(..) => { }
392 Err(error) => error!("{error:?}"),
393 },
394 }
395 }
396 }
397 Err(error) => {
398 #[derive(Deserialize)]
399 struct Response {
400 id: Option<Value>,
401 }
402
403 if let Message::Binary(binary) = message {
405 if let Ok(Response {
406 id,
407 }) = deserialize(&binary)
408 {
409 if let Some(Ok(id)) =
411 id.map(Value::coerce_to_i64)
412 {
413 if let Some((_method, sender)) =
414 routes.remove(&id)
415 {
416 let _res = sender
417 .into_send_async(Err(error))
418 .await;
419 }
420 }
421 } else {
422 warn!(
424 "Failed to deserialise message; {error:?}"
425 );
426 }
427 }
428 }
429 }
430 }
431 Err(error) => {
432 match error {
433 WsError::ConnectionClosed => {
434 trace!("Connection successfully closed on the server");
435 }
436 error => {
437 trace!("{error}");
438 }
439 }
440 break;
441 }
442 }
443 }
444 Either::Ping => {
445 if last_activity.elapsed() >= PING_INTERVAL {
447 trace!("Pinging the server");
448 if let Err(error) = socket_sink.send(ping.clone()).await {
449 trace!("failed to ping the server; {error:?}");
450 break;
451 }
452 }
453 }
454 Either::Request(None) => {
456 match socket_sink.send(Message::Close(None)).await {
457 Ok(..) => trace!("Connection closed successfully"),
458 Err(error) => {
459 warn!("Failed to close database connection; {error}")
460 }
461 }
462 break 'router;
463 }
464 }
465 }
466 }
467
468 'reconnect: loop {
469 trace!("Reconnecting...");
470 match connect(&url, Some(config), maybe_connector.clone()).await {
471 Ok(s) => {
472 socket = s;
473 for (_, message) in &replay {
474 if let Err(error) = socket.send(message.clone()).await {
475 trace!("{error}");
476 time::sleep(time::Duration::from_secs(1)).await;
477 continue 'reconnect;
478 }
479 }
480 #[cfg(feature = "protocol-ws")]
481 for (key, value) in &vars {
482 let mut request = BTreeMap::new();
483 request.insert("method".to_owned(), Method::Set.as_str().into());
484 request.insert(
485 "params".to_owned(),
486 vec![key.as_str().into(), value.clone()].into(),
487 );
488 let payload = Value::from(request);
489 trace!("Request {payload}");
490 if let Err(error) = socket.send(Message::Binary(payload.into())).await {
491 trace!("{error}");
492 time::sleep(time::Duration::from_secs(1)).await;
493 continue 'reconnect;
494 }
495 }
496 trace!("Reconnected successfully");
497 break;
498 }
499 Err(error) => {
500 trace!("Failed to reconnect; {error}");
501 time::sleep(time::Duration::from_secs(1)).await;
502 }
503 }
504 }
505 }
506 });
507}
508
509impl Response {
510 fn try_from(message: &Message) -> Result<Option<Self>> {
511 match message {
512 Message::Text(text) => {
513 trace!("Received an unexpected text message; {text}");
514 Ok(None)
515 }
516 Message::Binary(binary) => deserialize(binary).map(Some).map_err(|error| {
517 Error::ResponseFromBinary {
518 binary: binary.clone(),
519 error,
520 }
521 .into()
522 }),
523 Message::Ping(..) => {
524 trace!("Received a ping from the server");
525 Ok(None)
526 }
527 Message::Pong(..) => {
528 trace!("Received a pong from the server");
529 Ok(None)
530 }
531 Message::Frame(..) => {
532 trace!("Received an unexpected raw frame");
533 Ok(None)
534 }
535 Message::Close(..) => {
536 trace!("Received an unexpected close message");
537 Ok(None)
538 }
539 }
540 }
541}
542
543pub struct Socket(Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>);