1use anyhow::anyhow;
2use axum::extract::ws::{self, WebSocket};
3use bytes::Bytes;
4use crdb_cache::CacheDb;
5use crdb_core::{
6 BinPtr, ClientMessage, Db, EventId, MaybeObject, ObjectId, Query, QueryId, ReadPermsChanges,
7 Request, RequestId, ResponsePart, ResultExt, ServerMessage, ServerSideDb, Session, SessionRef,
8 SessionToken, SystemTimeExt, Update, UpdateData, Updatedness, Updates, UpdatesWithSnap, Upload,
9 User,
10};
11use futures::{future::OptionFuture, pin_mut, stream, FutureExt, StreamExt};
12use std::{
13 collections::{HashMap, HashSet},
14 sync::{Arc, Mutex, RwLock},
15 time::{Duration, SystemTime},
16};
17use tokio::{
18 sync::{mpsc, oneshot},
19 task::JoinHandle,
20};
21use tokio_util::sync::CancellationToken;
22use ulid::Ulid;
23
24pub use chrono;
25pub use cron;
26pub use sqlx;
27
28pub use crdb_core::{Error, Result};
30pub use crdb_postgres::PostgresDb;
31
32#[cfg(test)]
33mod tests;
34
35pub type UserUpdatesMap = HashMap<ObjectId, Arc<UpdatesWithSnap>>;
36
37pub type UpdatesMap = HashMap<User, Arc<UserUpdatesMap>>;
38
39type EditableUpdatesMap = HashMap<User, HashMap<ObjectId, Arc<UpdatesWithSnap>>>;
40
41type SessionsSenderMap = HashMap<
42 User,
43 HashMap<SessionRef, Vec<mpsc::UnboundedSender<(Updatedness, Arc<UserUpdatesMap>)>>>,
44>;
45
46pub struct Server<C: crdb_core::Config> {
47 db: Arc<CacheDb<PostgresDb<C>>>,
48 last_completed_updatedness: Arc<Mutex<Updatedness>>,
49 updatedness_requester:
50 mpsc::UnboundedSender<oneshot::Sender<(Updatedness, oneshot::Sender<UpdatesMap>)>>,
51 _cleanup_token: tokio_util::sync::DropGuard,
52 sessions: Arc<Mutex<SessionsSenderMap>>,
53}
54
55impl<C: crdb_core::Config> Server<C> {
56 pub async fn new<Tz>(
61 config: C,
62 db: sqlx::PgPool,
63 cache_watermark: usize,
64 vacuum_schedule: ServerVacuumSchedule<Tz>,
65 ) -> anyhow::Result<(Self, JoinHandle<usize>)>
66 where
67 Tz: 'static + Send + chrono::TimeZone,
68 Tz::Offset: Send,
69 {
70 let _ = config; C::check_ulids();
74
75 let db = PostgresDb::connect(db, cache_watermark).await?;
77
78 db.update_pending_rdeps()
81 .await
82 .wrap_context("updating all pending reverse-dependencies")?;
83
84 let upgrade_handle = tokio::task::spawn({
86 let db = db.clone();
87 async move { C::reencode_old_versions(&*db).await }
88 });
89
90 let (updatedness_requester, mut updatedness_request_receiver) = mpsc::unbounded_channel::<
92 oneshot::Sender<(Updatedness, oneshot::Sender<UpdatesMap>)>,
93 >();
94 let (update_sender, mut update_receiver) = mpsc::unbounded_channel();
95 let last_completed_updatedness = Arc::new(Mutex::new(Updatedness::from_u128(0)));
96 tokio::task::spawn(async move {
97 let mut generator = ulid::Generator::new();
99 while let Some(requester) = updatedness_request_receiver.recv().await {
101 let updatedness = Updatedness(generator.generate().expect(
103 "you're either very unlucky, or generated 2**80 updates within one millisecond",
104 ));
105 let (sender, receiver) = oneshot::channel();
106 if update_sender.send((updatedness, receiver)).is_err() {
107 tracing::error!(
108 "Update reorderer task went away before updatedness request handler task"
109 );
110 }
111 let _ = requester.send((updatedness, sender)); }
113 });
114 let sessions = Arc::new(Mutex::new(SessionsSenderMap::new()));
115 tokio::task::spawn({
116 let sessions = sessions.clone();
117 let last_completed_updatedness = last_completed_updatedness.clone();
118 async move {
119 while let Some((updatedness, update_receiver)) = update_receiver.recv().await {
122 if let Ok(updates) = update_receiver.await {
124 let mut sessions = sessions.lock().unwrap();
125 for (user, updates) in updates {
126 if let Some(sessions) = sessions.get_mut(&user) {
127 for senders in sessions.values_mut() {
128 senders.retain(|sender| {
130 sender.send((updatedness, updates.clone())).is_ok()
131 });
132 }
134 }
135 }
136 }
137 *last_completed_updatedness.lock().unwrap() = updatedness;
138 }
139 }
140 });
141
142 let cancellation_token = CancellationToken::new();
144 tokio::task::spawn({
145 let db = db.clone();
146 let cancellation_token = cancellation_token.clone();
147 let updatedness_requester = updatedness_requester.clone();
148 async move {
149 for next_time in vacuum_schedule.schedule.upcoming(vacuum_schedule.timezone) {
150 let sleep_for = next_time.signed_duration_since(chrono::Utc::now());
152 let sleep_for = sleep_for
153 .to_std()
154 .unwrap_or_else(|_| Duration::from_secs(0));
155 tokio::select! {
156 _ = tokio::time::sleep(sleep_for) => (),
157 _ = cancellation_token.cancelled() => break,
158 }
159
160 let no_new_changes_before = vacuum_schedule.recreate_older_than.map(|d| {
162 EventId(Ulid::from_parts(
163 (SystemTime::now() - d).ms_since_posix().unwrap() as u64,
164 u128::MAX,
165 ))
166 });
167 let kill_sessions_older_than = vacuum_schedule
168 .kill_sessions_older_than
169 .map(|d| SystemTime::now() - d);
170
171 let (sender, receiver) = oneshot::channel();
173 if updatedness_requester.send(sender).is_err() {
174 tracing::error!(
175 "Updatedness request handler thread went away before autovacuum thread"
176 );
177 }
178 let Ok((updatedness, slot)) = receiver.await else {
179 tracing::error!(
180 "Updatedness request handler thread never answered autovacuum thread"
181 );
182 continue;
183 };
184
185 if let Err(err) = Self::run_vacuum(
187 &db,
188 no_new_changes_before,
189 updatedness,
190 kill_sessions_older_than,
191 slot,
192 )
193 .await
194 {
195 tracing::error!(?err, "scheduled vacuum failed");
196 }
197 }
198 }
199 });
200
201 let this = Server {
203 db,
204 last_completed_updatedness,
205 updatedness_requester,
206 _cleanup_token: cancellation_token.drop_guard(),
207 sessions,
208 };
209 Ok((this, upgrade_handle))
210 }
211
212 pub async fn login_session(
213 &self,
214 user_id: User,
215 session_name: String,
216 expiration_time: Option<SystemTime>,
217 ) -> crate::Result<(SessionToken, SessionRef)> {
218 let now = SystemTime::now();
219 self.db
220 .login_session(Session {
221 user_id,
222 session_ref: SessionRef::now(),
223 session_name,
224 login_time: now,
225 last_active: now,
226 expiration_time,
227 })
228 .await
229 }
230
231 pub async fn answer(&self, socket: WebSocket) {
232 let mut conn = ConnectionState {
233 socket,
234 session: None,
235 };
236 loop {
237 tokio::select! {
238 msg = conn.socket.next() => match msg {
239 None => break, Some(Err(err)) => {
241 tracing::warn!(?err, "received an error while waiting for message on websocket");
242 break;
243 }
244 Some(Ok(ws::Message::Ping(_) | ws::Message::Pong(_))) => continue, Some(Ok(ws::Message::Close(_))) => break, Some(Ok(ws::Message::Text(msg))) => {
247 if let Err(err) = self.handle_client_message(&mut conn, &msg).await {
248 tracing::warn!(?err, ?msg, "client message violated protocol");
249 break;
250 }
251 }
252 Some(Ok(ws::Message::Binary(bin))) => {
253 let bin = Vec::<u8>::from(bin).into_boxed_slice().into();
254 if let Err(err) = self.handle_client_binary(&mut conn, bin).await {
255 tracing::warn!(?err, "client binary violated protocol");
256 break;
257 }
258 }
259 },
260
261 Some(update) = OptionFuture::from(conn.session.as_mut().map(|s| s.updates_receiver.recv())) => {
262 if conn.has_expired() {
263 break;
265 }
266 let Some((updatedness, update)) = update else {
267 tracing::error!("Update receiver broke before connection went down");
268 break;
269 };
270 let sess = conn.session.as_ref().unwrap();
271 let mut data = Vec::new();
273 for (object_id, updates) in update.iter() {
274 if sess.is_subscribed_to(*object_id, updates.new_last_snapshot.as_deref()) {
275 data.extend(updates.updates.iter().cloned());
276 }
277 }
278 let send_res = Self::send(&mut conn.socket, &ServerMessage::Updates(Updates {
279 data,
280 now_have_all_until: updatedness,
281 })).await;
282 if let Err(err) = send_res {
283 tracing::warn!(?err, "failed sending update to client");
284 break;
285 }
286 },
287 }
288 }
289 }
290
291 async fn handle_client_binary(
292 &self,
293 conn: &mut ConnectionState,
294 bin: Arc<[u8]>,
295 ) -> crate::Result<()> {
296 conn.kill_session_if_expired()?;
297 {
299 let sess = conn
300 .session
301 .as_mut()
302 .ok_or(crate::Error::ProtocolViolation)?;
303 sess.expected_binaries = sess
304 .expected_binaries
305 .checked_sub(1)
306 .ok_or(crate::Error::ProtocolViolation)?;
307 }
308
309 let binary_id = crdb_core::hash_binary(&bin);
311 self.db.create_binary(binary_id, bin).await
312 }
313
314 async fn handle_client_message(
315 &self,
316 conn: &mut ConnectionState,
317 msg: &str,
318 ) -> crate::Result<()> {
319 conn.kill_session_if_expired()?;
320 if conn
321 .session
322 .as_ref()
323 .map(|sess| sess.expected_binaries > 0)
324 .unwrap_or(false)
325 {
326 return Err(crate::Error::ProtocolViolation);
327 }
328 tracing::trace!(?msg, session=?conn.session.as_ref().map(|s| &s.session), "received client message");
329 if let Some(sess) = &conn.session {
330 self.db
334 .mark_session_active(sess.token, SystemTime::now())
335 .await
336 .wrap_context("marking session as active")?;
337 }
338 let msg = serde_json::from_str::<ClientMessage>(msg)
339 .wrap_context("deserializing client message")?;
340 match &*msg.request {
344 Request::SetToken(token) => {
345 let res = self.db.resume_session(*token).await.map(|session| {
346 let (updates_sender, updates_receiver) = mpsc::unbounded_channel();
347 self.sessions
348 .lock()
349 .unwrap()
350 .entry(session.user_id)
351 .or_default()
352 .entry(session.session_ref)
353 .or_default()
354 .push(updates_sender);
355 conn.session = Some(SessionInfo {
356 token: *token,
357 session,
358 expected_binaries: 0,
359 subscribed_objects: Arc::new(RwLock::new(HashSet::new())),
360 subscribed_queries: Arc::new(RwLock::new(HashMap::new())),
361 updates_receiver,
362 });
363 ResponsePart::Success
364 });
365 Self::send_res(&mut conn.socket, msg.request_id, res).await
366 }
367 Request::RenameSession(name) => {
368 let res = match &conn.session {
369 None => Err(crate::Error::ProtocolViolation),
370 Some(sess) => self
371 .db
372 .rename_session(sess.token, name)
373 .await
374 .map(|()| ResponsePart::Success),
375 };
376 Self::send_res(&mut conn.socket, msg.request_id, res).await
377 }
378 Request::CurrentSession => {
379 let res = match &conn.session {
380 None => Err(crate::Error::ProtocolViolation),
381 Some(sess) => Ok(ResponsePart::Sessions(vec![sess.session.clone()])),
382 };
383 Self::send_res(&mut conn.socket, msg.request_id, res).await
384 }
385 Request::ListSessions => {
386 let res = match &conn.session {
387 None => Err(crate::Error::ProtocolViolation),
388 Some(sess) => self
389 .db
390 .list_sessions(sess.session.user_id)
391 .await
392 .map(ResponsePart::Sessions),
393 };
394 Self::send_res(&mut conn.socket, msg.request_id, res).await
395 }
396 Request::Logout => {
397 let res = match &conn.session {
398 None => Err(crate::Error::ProtocolViolation),
399 Some(sess) => self
400 .db
401 .disconnect_session(sess.session.user_id, sess.session.session_ref)
402 .await
403 .map(|()| ResponsePart::Success),
404 };
405 conn.session = None;
406 Self::send_res(&mut conn.socket, msg.request_id, res).await
407 }
408 Request::DisconnectSession(session_ref) => {
409 let res = match &conn.session {
410 None => Err(crate::Error::ProtocolViolation),
411 Some(sess) => self
412 .db
413 .disconnect_session(sess.session.user_id, *session_ref)
414 .await
415 .map(|()| ResponsePart::Success),
416 };
417 Self::send_res(&mut conn.socket, msg.request_id, res).await
418 }
419 Request::GetTime => {
420 let res = match &conn.session {
421 None => Err(crate::Error::ProtocolViolation),
422 Some(_) => Ok(ResponsePart::CurrentTime(SystemTime::now())),
423 };
424 Self::send_res(&mut conn.socket, msg.request_id, res).await
425 }
426 Request::AlreadyHave { object_ids: _ } => {
427 Ok(())
429 }
430 Request::Get {
431 object_ids,
432 subscribe,
433 } => {
434 assert!(subscribe); let object_ids = object_ids.iter().map(|(o, u)| (*o, *u)).collect::<Vec<_>>();
437 self.subscribe_and_send_objects(conn, msg.request_id, None, object_ids.into_iter())
438 .await
439 }
440 Request::Query {
441 query_id,
442 type_id,
443 query,
444 only_updated_since,
445 subscribe,
446 } => {
447 assert!(subscribe); let sess = conn
449 .session
450 .as_ref()
451 .ok_or(crate::Error::ProtocolViolation)?;
452 sess.subscribed_queries
455 .write()
456 .unwrap()
457 .insert(*query_id, query.clone());
458 let updatedness = *self.last_completed_updatedness.lock().unwrap();
459 let object_ids = self
460 .db
461 .server_query(
462 sess.session.user_id,
463 *type_id,
464 *only_updated_since,
465 query.clone(),
466 )
467 .await
468 .wrap_context("listing objects matching query")?;
469 self.subscribe_and_send_objects(
472 conn,
473 msg.request_id,
474 Some(updatedness),
475 object_ids.into_iter().map(|o| (o, None)),
476 )
477 .await
478 }
479 Request::GetBinaries(binary_ids) => {
480 let _ = conn
482 .session
483 .as_ref()
484 .ok_or(crate::Error::ProtocolViolation)?;
485 self.send_binaries(conn, msg.request_id, binary_ids.iter().copied())
486 .await
487 }
488 Request::Unsubscribe(object_ids) => {
489 {
490 let mut subscribed_objects = conn
491 .session
492 .as_ref()
493 .ok_or(crate::Error::ProtocolViolation)?
494 .subscribed_objects
495 .write()
496 .unwrap();
497 for id in object_ids {
498 subscribed_objects.remove(id);
499 }
500 }
501 Self::send_res(&mut conn.socket, msg.request_id, Ok(ResponsePart::Success)).await
502 }
503 Request::UnsubscribeQuery(query_id) => {
504 conn.session
505 .as_ref()
506 .ok_or(crate::Error::ProtocolViolation)?
507 .subscribed_queries
508 .write()
509 .unwrap()
510 .remove(query_id);
511 Self::send_res(&mut conn.socket, msg.request_id, Ok(ResponsePart::Success)).await
512 }
513 Request::Upload(upload) => {
514 let sess = conn
515 .session
516 .as_ref()
517 .ok_or(crate::Error::ProtocolViolation)?;
518 match upload {
519 Upload::Object {
520 object_id,
521 type_id,
522 created_at,
523 snapshot_version,
524 object,
525 subscribe,
526 } => {
527 let (updatedness, update_sender) = self.updatedness_slot().await?;
528 let res = C::upload_object(
529 &*self.db,
530 sess.session.user_id,
531 updatedness,
532 *type_id,
533 *object_id,
534 *created_at,
535 *snapshot_version,
536 object.clone(),
537 )
538 .await;
539 let res = match res {
540 Ok(res) => res,
541 Err(err) => {
542 return Self::send_res(&mut conn.socket, msg.request_id, Err(err))
543 .await;
544 }
545 };
546 if let Some((new_update, users_who_can_read, rdeps)) = res {
547 let mut new_data = HashMap::new();
548 self.add_rdeps_updates(&mut new_data, rdeps)
549 .await
550 .wrap_context("listing updates for rdeps")?;
551 for user in users_who_can_read {
552 let existing = new_data
553 .entry(user)
554 .or_insert_with(HashMap::new)
555 .insert(*object_id, new_update.clone());
556 if let Some(existing) = existing {
557 tracing::error!(
558 ?user,
559 ?object_id,
560 ?existing,
561 "replacing mistakenly-already-existing update"
562 );
563 }
564 }
565 let new_data = new_data
566 .into_iter()
567 .map(|(k, v)| (k, Arc::new(v)))
568 .collect();
569
570 update_sender.send(new_data).map_err(|_| {
571 crate::Error::Other(anyhow!(
572 "Update reorderer thread went away before updating thread",
573 ))
574 })?;
575 }
576 if *subscribe {
577 sess.subscribed_objects.write().unwrap().insert(*object_id);
578 }
579 Self::send_res(&mut conn.socket, msg.request_id, Ok(ResponsePart::Success))
580 .await
581 }
582 Upload::Event {
583 object_id,
584 type_id,
585 event_id,
586 event,
587 subscribe,
588 } => {
589 let (updatedness, update_sender) = self.updatedness_slot().await?;
590 let res = C::upload_event(
591 &*self.db,
592 sess.session.user_id,
593 updatedness,
594 *type_id,
595 *object_id,
596 *event_id,
597 event.clone(),
598 )
599 .await;
600 let res = match res {
601 Ok(res) => res,
602 Err(err) => {
603 return Self::send_res(&mut conn.socket, msg.request_id, Err(err))
604 .await;
605 }
606 };
607 if let Some((new_update, users_who_can_read, rdeps)) = res {
608 let mut new_data = HashMap::new();
609 self.add_rdeps_updates(&mut new_data, rdeps)
610 .await
611 .wrap_context("listing updates for rdeps")?;
612 for user in users_who_can_read {
613 let existing = new_data
614 .entry(user)
615 .or_insert_with(HashMap::new)
616 .insert(*object_id, new_update.clone());
617 if let Some(existing) = existing {
618 tracing::error!(
619 ?user,
620 ?object_id,
621 ?existing,
622 "replacing mistakenly-already-existing update"
623 );
624 }
625 }
626 let new_data = new_data
627 .into_iter()
628 .map(|(k, v)| (k, Arc::new(v)))
629 .collect();
630
631 update_sender.send(new_data).map_err(|_| {
632 crate::Error::Other(anyhow!(
633 "Update reorderer thread went away before updating thread",
634 ))
635 })?;
636 }
637 if *subscribe {
638 sess.subscribed_objects.write().unwrap().insert(*object_id);
639 }
640 Self::send_res(&mut conn.socket, msg.request_id, Ok(ResponsePart::Success))
641 .await
642 }
643 }
644 }
645 Request::UploadBinaries(num_binaries) => {
646 conn.session
647 .as_mut()
648 .ok_or(crate::Error::ProtocolViolation)?
649 .expected_binaries = *num_binaries;
650 Self::send_res(&mut conn.socket, msg.request_id, Ok(ResponsePart::Success)).await
651 }
652 }
653 }
654
655 async fn add_rdeps_updates(
658 &self,
659 updates: &mut EditableUpdatesMap,
660 rdeps: Vec<ReadPermsChanges>,
661 ) -> crate::Result<()> {
662 for c in rdeps {
663 for u in c.lost_read {
664 updates.entry(u).or_default().insert(
665 c.object_id,
666 Arc::new(UpdatesWithSnap {
667 updates: vec![Arc::new(Update {
668 object_id: c.object_id,
669 data: UpdateData::LostReadRights,
670 })],
671 new_last_snapshot: None,
672 }),
673 );
674 }
675 if let Some(one_user) = c.gained_read.iter().next() {
676 let mut t = self.db.get_transaction().await?;
677 let object = self
678 .db
679 .get_all(&mut t, *one_user, c.object_id, None)
680 .await?;
681 let last_snapshot = self
682 .db
683 .get_latest_snapshot(&mut t, *one_user, c.object_id)
684 .await?;
685 let new_updates = object.into_updates();
686 for u in c.gained_read {
687 updates.entry(u).or_default().insert(
688 c.object_id,
689 Arc::new(UpdatesWithSnap {
690 updates: new_updates.clone(),
691 new_last_snapshot: Some(last_snapshot.clone()),
692 }),
693 );
694 }
695 }
696 }
697 Ok(())
698 }
699
700 async fn subscribe_and_send_objects(
701 &self,
702 conn: &mut ConnectionState,
703 request_id: RequestId,
704 query_updatedness: Option<Updatedness>,
705 objects: impl Iterator<Item = (ObjectId, Option<Updatedness>)>,
706 ) -> crate::Result<()> {
707 let sess = conn
708 .session
709 .as_ref()
710 .ok_or(crate::Error::ProtocolViolation)?;
711 let user = sess.session.user_id;
712 let subscribed_objects = sess.subscribed_objects.clone();
713 let objects = objects.map(|(object_id, updatedness)| {
714 let subscribed_objects = subscribed_objects.clone();
715 async move {
716 if subscribed_objects.read().unwrap().contains(&object_id) {
717 Ok(MaybeObject::AlreadySubscribed(object_id))
718 } else {
719 subscribed_objects.write().unwrap().insert(object_id);
722 let mut t = self.db.get_transaction().await?;
723 let object = self
724 .db
725 .get_all(&mut t, user, object_id, updatedness)
726 .await?;
727 Ok(MaybeObject::NotYetSubscribed(object))
728 }
729 }
730 });
731 let objects = stream::iter(objects).buffer_unordered(16); pin_mut!(objects);
733 let mut size_of_message = 0;
734 let mut current_data = Vec::new();
735 while let Some(object) = objects.next().await {
738 if size_of_message >= 1024 * 1024 {
739 let data = std::mem::take(&mut current_data);
741 size_of_message = 0;
742 Self::send(
743 &mut conn.socket,
744 &ServerMessage::Response {
745 request_id,
746 response: ResponsePart::Objects {
747 data,
748 now_have_all_until: None,
749 },
750 last_response: false,
751 },
752 )
753 .await?;
754 }
755 match object {
756 Ok(object) => {
757 size_of_message += size_as_json(&object)?;
758 current_data.push(object);
759 }
760 Err(err @ crate::Error::ObjectDoesNotExist(_)) => {
761 if query_updatedness.is_some() {
762 } else {
765 Self::send(
768 &mut conn.socket,
769 &ServerMessage::Response {
770 request_id,
771 response: ResponsePart::Error(err.into()),
772 last_response: false,
773 },
774 )
775 .await?;
776 }
777 }
778 Err(err) => return Self::send_res(&mut conn.socket, request_id, Err(err)).await,
779 }
780 }
781 Self::send_res(
782 &mut conn.socket,
783 request_id,
784 Ok(ResponsePart::Objects {
785 data: current_data,
786 now_have_all_until: query_updatedness,
787 }),
788 )
789 .await
790 }
791
792 async fn send_binaries(
793 &self,
794 conn: &mut ConnectionState,
795 request_id: RequestId,
796 binaries: impl Iterator<Item = BinPtr>,
797 ) -> crate::Result<()> {
798 let binaries =
799 binaries.map(|binary_id| self.db.get_binary(binary_id).map(move |r| (binary_id, r)));
800 let binaries = stream::iter(binaries).buffer_unordered(16); pin_mut!(binaries);
802 let mut size_of_message = 0;
803 let mut current_data = Vec::new();
804 while let Some((binary_id, binary)) = binaries.next().await {
807 if size_of_message >= 1024 * 1024 {
808 size_of_message = 0;
810 Self::send_binaries_msg(
811 &mut conn.socket,
812 request_id,
813 false,
814 current_data.drain(..),
815 )
816 .await?;
817 }
818 let binary = match binary {
819 Ok(Some(binary)) => Ok(binary),
820 Ok(None) => Err(crate::Error::MissingBinaries(vec![binary_id])),
821 Err(err) => Err(err),
822 };
823 let binary = match binary {
824 Ok(binary) => binary,
825 Err(err) => {
826 if !current_data.is_empty() {
827 Self::send_binaries_msg(
828 &mut conn.socket,
829 request_id,
830 false,
831 current_data.drain(..),
832 )
833 .await?;
834 }
835 return Self::send_res(&mut conn.socket, request_id, Err(err)).await;
836 }
837 };
838 size_of_message += binary.len();
839 current_data.push(binary);
840 }
841 Self::send_binaries_msg(&mut conn.socket, request_id, true, current_data.drain(..)).await
842 }
843
844 async fn send_binaries_msg(
845 socket: &mut WebSocket,
846 request_id: RequestId,
847 last_response: bool,
848 bins: impl ExactSizeIterator<Item = Arc<[u8]>>,
849 ) -> crate::Result<()> {
850 Self::send(
851 socket,
852 &ServerMessage::Response {
853 request_id,
854 last_response,
855 response: ResponsePart::Binaries(bins.len()),
856 },
857 )
858 .await?;
859 for bin in bins {
860 socket
861 .send(ws::Message::Binary(Bytes::from_owner(bin)))
862 .await
863 .wrap_context("sending binary to client")?
864 }
865 Ok(())
866 }
867
868 async fn send_res(
869 socket: &mut WebSocket,
870 request_id: RequestId,
871 res: crate::Result<ResponsePart>,
872 ) -> crate::Result<()> {
873 let response = match res {
874 Ok(res) => res,
875 Err(err) => ResponsePart::Error(err.into()),
876 };
877 Self::send(
878 socket,
879 &ServerMessage::Response {
880 request_id,
881 response,
882 last_response: true,
883 },
884 )
885 .await
886 }
887
888 async fn send(socket: &mut WebSocket, msg: &ServerMessage) -> crate::Result<()> {
889 let msg = serde_json::to_string(msg).wrap_context("serializing server message")?;
890 socket
891 .send(ws::Message::Text(msg.into()))
892 .await
893 .wrap_context("sending response to client")
894 }
895
896 pub async fn vacuum(
901 &self,
902 no_new_changes_before: Option<EventId>,
903 kill_sessions_older_than: Option<SystemTime>,
904 ) -> crate::Result<()> {
905 let (updatedness, slot) = self.updatedness_slot().await?;
906 Self::run_vacuum(
907 &self.db,
908 no_new_changes_before,
909 updatedness,
910 kill_sessions_older_than,
911 slot,
912 )
913 .await
914 }
915
916 async fn run_vacuum(
917 db: &CacheDb<PostgresDb<C>>,
918 no_new_changes_before: Option<EventId>,
919 updatedness: Updatedness,
920 kill_sessions_older_than: Option<SystemTime>,
921 slot: oneshot::Sender<UpdatesMap>,
922 ) -> crate::Result<()> {
923 let mut updates = HashMap::new();
925 let res = db
926 .server_vacuum(
927 no_new_changes_before,
928 updatedness,
929 kill_sessions_older_than,
930 |update, users_who_can_read| {
931 let object_id = update.object_id;
933 let update = Arc::new(UpdatesWithSnap {
934 updates: vec![Arc::new(update)],
935 new_last_snapshot: None,
936 });
937 for u in users_who_can_read {
938 updates
939 .entry(u)
940 .or_insert_with(HashMap::new)
941 .insert(object_id, update.clone());
942 }
943 },
944 )
945 .await;
946
947 let updates = updates.into_iter().map(|(k, v)| (k, Arc::new(v))).collect();
950
951 if slot.send(updates).is_err() {
953 tracing::error!("Update reorderer went away before server");
954 }
955
956 res
958 }
959
960 async fn updatedness_slot(&self) -> crate::Result<(Updatedness, oneshot::Sender<UpdatesMap>)> {
961 let (sender, receiver) = oneshot::channel();
962 self.updatedness_requester.send(sender).map_err(|_| {
963 crate::Error::Other(anyhow!(
964 "Updatedness request handler thread went away too early"
965 ))
966 })?;
967 let slot = receiver.await.map_err(|_| {
968 crate::Error::Other(anyhow!("Updatedness request handler thread never answered"))
969 })?;
970 Ok(slot)
971 }
972}
973
974pub struct ServerVacuumSchedule<Tz: chrono::TimeZone> {
975 schedule: cron::Schedule,
976 timezone: Tz,
977 recreate_older_than: Option<Duration>,
979 kill_sessions_older_than: Option<Duration>,
980}
981
982impl<Tz: chrono::TimeZone> ServerVacuumSchedule<Tz> {
983 pub fn new(schedule: cron::Schedule, timezone: Tz) -> ServerVacuumSchedule<Tz> {
984 ServerVacuumSchedule {
985 schedule,
986 timezone,
987 recreate_older_than: None,
988 kill_sessions_older_than: None,
989 }
990 }
991
992 pub fn recreate_older_than(mut self, age: Duration) -> Self {
993 self.recreate_older_than = Some(age);
994 self
995 }
996
997 pub fn kill_sessions_older_than(mut self, age: Duration) -> Self {
998 self.kill_sessions_older_than = Some(age);
999 self
1000 }
1001}
1002
1003struct ConnectionState {
1004 socket: WebSocket,
1005 session: Option<SessionInfo>,
1006}
1007
1008impl ConnectionState {
1009 fn has_expired(&self) -> bool {
1010 self.session
1011 .as_ref()
1012 .and_then(|sess| sess.session.expiration_time)
1013 .map(|expiration_time| expiration_time < SystemTime::now())
1014 .unwrap_or(false)
1015 }
1016
1017 fn kill_session_if_expired(&mut self) -> crate::Result<()> {
1021 if self.has_expired() {
1022 let token = self.session.as_ref().unwrap().token;
1023 self.session = None;
1024 return Err(crate::Error::InvalidToken(token));
1025 }
1026 Ok(())
1027 }
1028}
1029
1030struct SessionInfo {
1031 token: SessionToken,
1032 session: Session,
1033 expected_binaries: usize,
1034 subscribed_objects: Arc<RwLock<HashSet<ObjectId>>>,
1035 subscribed_queries: Arc<RwLock<HashMap<QueryId, Arc<Query>>>>,
1036 updates_receiver: mpsc::UnboundedReceiver<(Updatedness, Arc<UserUpdatesMap>)>,
1037}
1038
1039impl SessionInfo {
1040 fn is_subscribed_to(
1041 &self,
1042 object_id: ObjectId,
1043 new_last_snapshot: Option<&serde_json::Value>,
1044 ) -> bool {
1045 if self.subscribed_objects.read().unwrap().contains(&object_id) {
1046 return true;
1047 }
1048 if let Some(new_last_snapshot) = new_last_snapshot {
1049 for query in self.subscribed_queries.read().unwrap().values() {
1050 if query.matches_json(new_last_snapshot) {
1051 self.subscribed_objects.write().unwrap().insert(object_id);
1052 return true;
1053 }
1054 }
1055 }
1056 false
1057 }
1058}
1059
1060fn size_as_json<T: serde::Serialize>(value: &T) -> crate::Result<usize> {
1061 struct Size(usize);
1062 impl std::io::Write for Size {
1063 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
1064 self.0 += buf.len();
1065 Ok(buf.len())
1066 }
1067
1068 fn flush(&mut self) -> std::io::Result<()> {
1069 Ok(())
1070 }
1071 }
1072 let mut size = Size(0);
1073 serde_json::to_writer(&mut size, value)
1074 .wrap_context("figuring out the serialized size of value")?;
1075 Ok(size.0)
1076}