1use anyhow::anyhow;
2use crdb_core::{
3 ClientMessage, CrdbFn, MaybeObject, ObjectId, QueryId, Request, RequestId, ResponsePart,
4 SavedQuery, SerializableError, ServerMessage, SessionToken, SystemTimeExt, Update, UpdateData,
5 Updatedness, Updates,
6};
7use futures::{channel::mpsc, future::OptionFuture, SinkExt, StreamExt};
8use std::{
9 collections::{HashMap, VecDeque},
10 sync::Arc,
11 time::Duration,
12};
13use waaaa::{WebSocket, WsMessage};
14use web_time::Instant;
15use web_time::SystemTime;
16
17use crate::client_db::SavedObject;
18
19const RECONNECT_INTERVAL: Duration = Duration::from_secs(10);
20const PING_INTERVAL: Duration = Duration::from_secs(10);
21const PONG_DEADLINE: Duration = Duration::from_secs(10);
22
23#[derive(Debug)]
24pub struct RequestWithSidecar {
25 pub request: Arc<Request>,
26 pub sidecar: Vec<Arc<[u8]>>,
27}
28
29#[derive(Debug)]
30pub struct ResponsePartWithSidecar {
31 pub response: ResponsePart,
32 pub sidecar: Option<Arc<[u8]>>,
33}
34
35#[derive(Debug)]
36pub enum Command {
37 Login {
38 url: Arc<String>,
39 token: SessionToken,
40 },
41 Logout,
42}
43
44#[derive(Debug)]
45pub enum ConnectionEvent {
46 LoggingIn,
47 FailedConnecting(anyhow::Error),
48 FailedSendingToken(anyhow::Error),
49 LostConnection(anyhow::Error),
50 InvalidToken(SessionToken),
51 Connected,
52 TimeOffset(i64), LoggedOut,
54}
55
56enum IncomingMessage<T> {
57 Text(T),
58 Binary(Arc<[u8]>),
59}
60
61pub enum State {
62 NoValidInfo,
63 Disconnected {
64 url: Arc<String>,
65 token: SessionToken,
66 },
67 TokenSent {
68 url: Arc<String>,
69 token: SessionToken,
70 socket: WebSocket,
71 request_id: RequestId,
72 },
73 Connected {
74 url: Arc<String>,
75 token: SessionToken,
76 socket: WebSocket,
77 expected_binaries: Option<(RequestId, usize)>,
79 },
80}
81
82impl State {
83 fn disconnect(self) -> Self {
84 match self {
85 State::NoValidInfo => State::NoValidInfo,
86 State::Disconnected { url, token }
87 | State::TokenSent { url, token, .. }
88 | State::Connected { url, token, .. } => State::Disconnected { url, token },
89 }
90 }
91
92 fn no_longer_expecting_binaries(&mut self) {
93 match self {
94 State::NoValidInfo | State::Disconnected { .. } | State::TokenSent { .. } => (),
95 State::Connected {
96 expected_binaries, ..
97 } => *expected_binaries = None,
98 }
99 }
100
101 async fn next_msg(&mut self) -> Option<anyhow::Result<IncomingMessage<ServerMessage>>> {
102 match self {
103 State::NoValidInfo | State::Disconnected { .. } => None,
104 State::TokenSent { socket, .. } | State::Connected { socket, .. } => {
105 match socket.recv().await {
106 Err(err) => Some(Err(err)),
107 Ok(None) => Some(Err(anyhow!(
108 "Got websocket end-of-stream, expected a message"
109 ))),
110 Ok(Some(WsMessage::Binary(b))) => {
111 Some(Ok(IncomingMessage::Binary(b.into_boxed_slice().into())))
112 }
113 Ok(Some(WsMessage::Text(msg))) => match serde_json::from_str(&msg) {
114 Ok(msg) => Some(Ok(IncomingMessage::Text(msg))),
115 Err(err) => Some(Err(err.into())),
116 },
117 }
118 }
119 }
120 }
121}
122
123pub type ResponseSender = mpsc::UnboundedSender<ResponsePartWithSidecar>;
124
125pub struct Connection<GetSavedObjects, GetSavedQueries> {
126 state: State,
127 last_request_id: RequestId,
128 commands: mpsc::UnboundedReceiver<Command>,
129 requests: mpsc::UnboundedReceiver<(ResponseSender, Arc<RequestWithSidecar>)>,
130 not_sent_requests: VecDeque<(RequestId, Arc<RequestWithSidecar>, ResponseSender)>,
131 pending_requests: HashMap<RequestId, (Arc<RequestWithSidecar>, ResponseSender, bool)>,
134 event_cb: Box<dyn CrdbFn<ConnectionEvent>>,
135 update_sender: mpsc::UnboundedSender<Updates>,
136 last_ping: i64, next_ping: Option<Instant>,
138 next_pong_deadline: Option<(RequestId, Instant)>,
139 get_saved_objects: GetSavedObjects,
140 get_saved_queries: GetSavedQueries,
141}
142
143impl<GSO, GSQ> Connection<GSO, GSQ>
144where
145 GSO: 'static + FnMut() -> HashMap<ObjectId, SavedObject>,
146 GSQ: 'static + FnMut() -> HashMap<QueryId, SavedQuery>,
147{
148 pub fn new(
149 commands: mpsc::UnboundedReceiver<Command>,
150 requests: mpsc::UnboundedReceiver<(ResponseSender, Arc<RequestWithSidecar>)>,
151 event_cb: Box<dyn CrdbFn<ConnectionEvent>>,
152 update_sender: mpsc::UnboundedSender<Updates>,
153 get_saved_objects: GSO,
154 get_saved_queries: GSQ,
155 ) -> Connection<GSO, GSQ> {
156 Connection {
157 state: State::NoValidInfo,
158 last_request_id: RequestId(0),
159 commands,
160 requests,
161 not_sent_requests: VecDeque::new(),
162 pending_requests: HashMap::new(),
163 event_cb,
164 update_sender,
165 last_ping: SystemTime::now().ms_since_posix().unwrap(),
166 next_ping: None,
167 next_pong_deadline: None,
168 get_saved_objects,
169 get_saved_queries,
170 }
171 }
172
173 pub async fn run(mut self) {
174 loop {
175 tokio::select! {
177 _reconnect_attempt_interval = waaaa::sleep(RECONNECT_INTERVAL),
181 if self.is_trying_to_connect() => {
182 tracing::trace!("reconnect interval elapsed");
183 },
184
185 Some(_) = OptionFuture::from(self.next_ping.map(waaaa::sleep_until)), if self.is_connected() => {
187 tracing::trace!("sending ping request");
188 let request_id = self.next_request_id();
189 let _ = self.send_connected(&ClientMessage {
190 request_id,
191 request: Arc::new(Request::GetTime),
192 }).await;
193 self.last_ping = SystemTime::now().ms_since_posix().unwrap();
194 self.next_ping = None;
195 self.next_pong_deadline = Some((request_id, Instant::now() + PONG_DEADLINE));
196 }
197
198 Some(_) = OptionFuture::from(self.next_pong_deadline.map(|(_, t)| waaaa::sleep_until(t))), if self.is_connecting() => {
200 tracing::trace!("pong did not come in time, disconnecting");
201 self.state = self.state.disconnect();
202 self.next_pong_deadline = None;
203 }
204
205 command = self.commands.next() => {
208 tracing::trace!(?command, "received command");
209 let Some(command) = command else {
210 break; };
212 self.handle_command(command).await;
213 }
214
215 request = self.requests.next() => {
217 let Some((sender, request)) = request else {
218 break; };
220 let request_id = self.next_request_id();
221 tracing::trace!(?request, ?request_id, "submitting request");
222 match self.state {
223 State::Connected { .. } => self.handle_request(request_id, request, sender).await,
224 _ => self.not_sent_requests.push_back((request_id, request, sender)),
225 }
226 }
227
228 Some(message) = self.state.next_msg() => match message {
230
231 Err(err) => {
233 tracing::trace!(?err, "received server error");
234 self.state = self.state.disconnect();
235 (self.event_cb)(ConnectionEvent::LostConnection(err));
236 }
237
238 Ok(IncomingMessage::Text(message)) => match self.state {
240 State::NoValidInfo | State::Disconnected { .. } => unreachable!(),
241
242 State::TokenSent { url, token, socket, request_id: req } => match message {
244 ServerMessage::Response {
245 request_id,
246 response: ResponsePart::Success,
247 last_response: true
248 } if req == request_id => {
249 tracing::trace!("received server success for token sending");
250 self.state = State::Connected { url, token, socket, expected_binaries: None };
251 self.next_ping = Some(Instant::now() + PING_INTERVAL);
252 self.next_pong_deadline = None;
253 (self.event_cb)(ConnectionEvent::Connected);
254
255 let saved_objects = (self.get_saved_objects)();
259 let saved_queries = (self.get_saved_queries)();
260 if !saved_objects.is_empty() {
261 let (responses_sender, responses_receiver) = mpsc::unbounded();
262 let request_id = self.next_request_id();
263 let subscribed_objects = saved_objects
264 .iter()
265 .filter(|(_, o)| o.importance.subscribe())
266 .map(|(id, o)| (*id, o.have_all_until))
267 .collect();
268 let not_subscribed_objects = saved_objects
269 .iter()
270 .filter_map(|(id, o)| o.have_all_until.and_then(|u| (!o.importance.subscribe()).then_some((*id, u))))
271 .collect();
272 self.handle_request(
273 request_id,
274 Arc::new(RequestWithSidecar {
275 request: Arc::new(Request::Get {
276 object_ids: subscribed_objects,
277 subscribe: true,
278 }),
279 sidecar: Vec::new(),
280 }),
281 responses_sender,
282 ).await;
283 waaaa::spawn(Self::send_responses_as_updates(self.update_sender.clone(), responses_receiver));
284 let (ignore_sender, _receiver) = mpsc::unbounded();
285 self.handle_request(
286 request_id,
287 Arc::new(RequestWithSidecar {
288 request: Arc::new(Request::AlreadyHave {
289 object_ids: not_subscribed_objects,
290 }),
291 sidecar: Vec::new(),
292 }),
293 ignore_sender,
294 ).await;
295 }
296 for (query_id, q) in saved_queries {
297 if !q.importance.subscribe() {
298 continue;
302 }
303 let (responses_sender, responses_receiver) = mpsc::unbounded();
304 let request_id = self.next_request_id();
305 self.handle_request(
306 request_id,
307 Arc::new(RequestWithSidecar {
308 request: Arc::new(Request::Query {
309 query_id,
310 type_id: q.type_id,
311 query: q.query,
312 only_updated_since: q.have_all_until,
313 subscribe: true,
314 }),
315 sidecar: Vec::new(),
316 }),
317 responses_sender,
318 ).await;
319 waaaa::spawn(Self::send_responses_as_updates(self.update_sender.clone(), responses_receiver));
320 }
321 }
322 ServerMessage::Response {
323 request_id,
324 response: ResponsePart::Error(crdb_core::SerializableError::InvalidToken(tok)),
325 last_response: true
326 } if req == request_id && tok == token => {
327 tracing::trace!("server answered that token is invalid");
328 self.state = State::NoValidInfo;
329 (self.event_cb)(ConnectionEvent::InvalidToken(token));
330 }
331 resp => {
332 tracing::trace!(?resp, "server gave unexpected answer");
333 self.state = State::Disconnected { url, token };
334 (self.event_cb)(ConnectionEvent::LostConnection(
335 anyhow!("Unexpected server answer to login request: {resp:?}")
336 ));
337 }
338 }
339
340 State::Connected { expected_binaries: None, .. } => {
342 tracing::trace!(?message, "received server message");
343 if let ServerMessage::Response {
344 request_id: _,
345 response: ResponsePart::Error(SerializableError::InvalidToken(token)),
346 last_response: _,
347 } = &message {
348 self.state = self.state.disconnect();
349 (self.event_cb)(ConnectionEvent::InvalidToken(*token));
350 } else {
351 self.handle_connected_message(message).await;
352 }
353 }
354
355 State::Connected { expected_binaries: Some(_), .. } => {
357 tracing::trace!(?message, "received server message but expected a binary");
358 self.state = self.state.disconnect();
359 (self.event_cb)(ConnectionEvent::LostConnection(
360 anyhow!("Unexpected server message while waiting for binaries: {message:?}")
361 ));
362 }
363 }
364
365 Ok(IncomingMessage::Binary(message)) => {
367 tracing::trace!("received server binary message");
368 if let State::Connected { expected_binaries: Some((request_id, num_bins)), .. } = &mut self.state {
369 if let Some((_, sender, already_sent)) = self.pending_requests.get_mut(request_id) {
370 *already_sent = true;
371 let _ = sender.unbounded_send(ResponsePartWithSidecar {
372 response: ResponsePart::Binaries(1),
373 sidecar: Some(message),
374 });
375 *num_bins -= 1;
376 if *num_bins == 0 {
377 self.state.no_longer_expecting_binaries();
378 }
379 } else {
380 tracing::error!(?request_id, "Connection::State.expected_binaries is pointing to a non-existent request");
381 }
382 } else {
383 self.state = self.state.disconnect();
384 (self.event_cb)(ConnectionEvent::LostConnection(
385 anyhow!("Unexpected server binary frame while not waiting for it")
386 ));
387 }
388 }
389 }
390 }
391
392 if let State::Connected { .. } = self.state {
393 if !self.not_sent_requests.is_empty() {
394 let not_sent_requests = std::mem::take(&mut self.not_sent_requests);
395 tracing::trace!(?not_sent_requests, "sending not-sent requests");
396 for (request_id, request, sender) in not_sent_requests {
397 self.handle_request(request_id, request, sender).await;
398 }
399 }
400 }
401
402 if let State::Disconnected { url, token } = &self.state {
404 let url = url.clone();
405 let token = *token;
406 tracing::trace!(%url, "connecting to websocket");
407 let mut socket = match WebSocket::connect(&url).await {
408 Ok(socket) => socket,
409 Err(err) => {
410 (self.event_cb)(ConnectionEvent::FailedConnecting(err));
411 self.state = State::Disconnected { url, token }; continue;
413 }
414 };
415 let request_id = self.next_request_id();
416 let message = ClientMessage {
417 request_id,
418 request: Arc::new(Request::SetToken(token)),
419 };
420 tracing::trace!("sending token");
421 if let Err(err) = Self::send(&mut socket, &message).await {
422 (self.event_cb)(ConnectionEvent::FailedSendingToken(err));
423 self.state = State::Disconnected { url, token }; continue;
425 }
426 self.state = State::TokenSent {
427 url,
428 token,
429 socket,
430 request_id,
431 };
432 self.next_pong_deadline = Some((request_id, Instant::now() + PONG_DEADLINE));
433 if !self.pending_requests.is_empty() {
435 for (request_id, (request, sender, already_sent)) in
436 self.pending_requests.drain()
437 {
438 if already_sent {
439 let _ = sender.unbounded_send(ResponsePartWithSidecar {
440 response: ResponsePart::Error(
441 crdb_core::SerializableError::ConnectionLoss,
442 ),
443 sidecar: None,
444 });
445 } else {
446 self.not_sent_requests
447 .push_front((request_id, request, sender));
448 }
449 }
450 self.not_sent_requests
451 .make_contiguous()
452 .sort_unstable_by_key(|v| v.0);
453 }
454 }
455 }
456 }
457
458 fn is_trying_to_connect(&self) -> bool {
459 matches!(self.state, State::Disconnected { .. })
460 }
461
462 fn is_connected(&self) -> bool {
463 matches!(self.state, State::Connected { .. })
464 }
465
466 fn is_connecting(&self) -> bool {
467 matches!(
468 self.state,
469 State::Connected { .. } | State::TokenSent { .. }
470 )
471 }
472
473 fn next_request_id(&mut self) -> RequestId {
474 self.last_request_id = RequestId(self.last_request_id.0 + 1);
475 self.last_request_id
476 }
477
478 async fn handle_command(&mut self, command: Command) {
479 match command {
480 Command::Login { url, token } => {
481 self.state = State::Disconnected { url, token };
482 (self.event_cb)(ConnectionEvent::LoggingIn);
483 }
484 Command::Logout => {
485 if let State::Connected { .. } = self.state {
486 let request_id = self.next_request_id();
487 self.send_connected(&ClientMessage {
488 request_id,
489 request: Arc::new(Request::Logout),
490 })
491 .await;
492 }
493 self.last_request_id = RequestId(0);
494 while let Ok(Some(_)) = self.requests.try_next() {} self.pending_requests.clear();
496 self.state = State::NoValidInfo;
497 (self.event_cb)(ConnectionEvent::LoggedOut);
498 }
499 }
500 }
501
502 async fn handle_request(
503 &mut self,
504 request_id: RequestId,
505 request: Arc<RequestWithSidecar>,
506 sender: ResponseSender,
507 ) {
508 let message = ClientMessage {
509 request_id,
510 request: request.request.clone(),
511 };
512 self.send_connected(&message).await;
513 self.send_connected_sidecar(&request.sidecar).await;
514 self.pending_requests
515 .insert(request_id, (request, sender, false));
516 }
517
518 async fn handle_connected_message(&mut self, message: ServerMessage) {
519 match message {
520 ServerMessage::Updates(updates) => {
521 if let Err(err) = self.update_sender.send(updates).await {
522 tracing::error!(?err, "failed sending updates");
523 }
524 }
525 ServerMessage::Response {
526 request_id,
527 response,
528 last_response,
529 } => {
530 if let Some((_, sender, already_sent)) = self.pending_requests.get_mut(&request_id)
531 {
532 if let ResponsePart::Binaries(num_bins) = &response {
533 let State::Connected {
535 expected_binaries, ..
536 } = &mut self.state
537 else {
538 panic!("Called send_connected while not connected");
539 };
540 if *num_bins > 0 {
541 *expected_binaries = Some((request_id, *num_bins));
542 }
543 } else {
545 *already_sent = true;
548 let _ = sender.unbounded_send(ResponsePartWithSidecar {
549 response,
550 sidecar: None,
551 });
552 if last_response {
553 self.pending_requests.remove(&request_id);
554 }
555 }
556 } else if self.next_pong_deadline.map(|(r, _)| r) == Some(request_id) {
557 let ResponsePart::CurrentTime(server_time) = response else {
558 tracing::error!("Server answered GetTime with unexpected {response:?}");
559 return;
560 };
561 let Ok(server_time) = server_time.ms_since_posix() else {
562 tracing::error!("Server answered GetTime with obviously-wrong timestamp {server_time:?}");
563 return;
564 };
565 self.next_ping = Some(Instant::now() + PING_INTERVAL);
566 self.next_pong_deadline = None;
567 let now = SystemTime::now().ms_since_posix().unwrap();
569 if server_time.saturating_sub(now) > 0 {
570 (self.event_cb)(ConnectionEvent::TimeOffset(
571 server_time.saturating_sub(now),
572 ));
573 } else if server_time.saturating_sub(self.last_ping) < 0 {
574 (self.event_cb)(ConnectionEvent::TimeOffset(
575 server_time.saturating_sub(self.last_ping),
576 ));
577 } else {
578 (self.event_cb)(ConnectionEvent::TimeOffset(0));
579 }
580 } else {
581 tracing::warn!(
582 "Server gave us a response to {request_id:?} that we do not know of"
583 );
584 }
585 }
586 }
587 }
588
589 async fn send_responses_as_updates(
590 update_sender: mpsc::UnboundedSender<Updates>,
591 mut responses_receiver: mpsc::UnboundedReceiver<ResponsePartWithSidecar>,
592 ) {
593 while let Some(response) = responses_receiver.next().await {
595 match response.response {
597 ResponsePart::Error(crdb_core::SerializableError::ConnectionLoss) => (), ResponsePart::Error(crdb_core::SerializableError::ObjectDoesNotExist(
599 object_id,
600 )) => {
601 let _ = update_sender.unbounded_send(Updates {
604 data: vec![Arc::new(Update {
605 object_id,
606 data: UpdateData::LostReadRights,
607 })],
608 now_have_all_until: Updatedness::from_u128(0), });
610 }
611 ResponsePart::Error(err) => {
612 tracing::error!(?err, "got unexpected server error upon re-subscribing");
613 }
614 ResponsePart::Objects { data, .. } => {
615 for maybe_object in data.into_iter() {
616 match maybe_object {
617 MaybeObject::AlreadySubscribed(_) => continue,
618 MaybeObject::NotYetSubscribed(object) => {
619 let now_have_all_until = object.now_have_all_until;
620 let _ = update_sender.unbounded_send(Updates {
621 data: object.into_updates(),
622 now_have_all_until,
623 });
624 }
625 }
626 }
627 }
636 response => {
637 tracing::error!(
638 ?response,
639 "got unexpected server response upon re-subscribing"
640 );
641 }
642 }
643 }
644 }
645
646 async fn send_connected_sidecar(&mut self, sidecar: &[Arc<[u8]>]) {
647 let State::Connected {
648 socket, url, token, ..
649 } = &mut self.state
650 else {
651 panic!("Called send_connected while not connected");
652 };
653 if let Err(err) = send_sidecar(socket, sidecar).await {
654 (self.event_cb)(ConnectionEvent::LostConnection(err));
655 self.state = State::Disconnected {
656 url: url.clone(),
657 token: *token,
658 };
659 }
660 }
661
662 async fn send_connected(&mut self, message: &ClientMessage) {
663 let State::Connected {
664 socket, url, token, ..
665 } = &mut self.state
666 else {
667 panic!("Called send_connected while not connected");
668 };
669 if let Err(err) = Self::send(socket, message).await {
670 (self.event_cb)(ConnectionEvent::LostConnection(err));
671 self.state = State::Disconnected {
672 url: url.clone(),
673 token: *token,
674 };
675 }
676 }
677
678 async fn send(sock: &mut WebSocket, msg: &ClientMessage) -> anyhow::Result<()> {
679 let msg = serde_json::to_string(msg)?;
680 sock.send(WsMessage::Text(msg)).await
681 }
682}
683
684async fn send_sidecar(socket: &mut WebSocket, sidecar: &[Arc<[u8]>]) -> anyhow::Result<()> {
685 for bin in sidecar {
686 socket.send(WsMessage::Binary(bin.to_vec())).await?;
689 }
690
691 Ok(())
692}