crdb_server/
lib.rs

1use anyhow::anyhow;
2use axum::extract::ws::{self, WebSocket};
3use bytes::Bytes;
4use crdb_cache::CacheDb;
5use crdb_core::{
6    BinPtr, ClientMessage, Db, EventId, MaybeObject, ObjectId, Query, QueryId, ReadPermsChanges,
7    Request, RequestId, ResponsePart, ResultExt, ServerMessage, ServerSideDb, Session, SessionRef,
8    SessionToken, SystemTimeExt, Update, UpdateData, Updatedness, Updates, UpdatesWithSnap, Upload,
9    User,
10};
11use futures::{future::OptionFuture, pin_mut, stream, FutureExt, StreamExt};
12use std::{
13    collections::{HashMap, HashSet},
14    sync::{Arc, Mutex, RwLock},
15    time::{Duration, SystemTime},
16};
17use tokio::{
18    sync::{mpsc, oneshot},
19    task::JoinHandle,
20};
21use tokio_util::sync::CancellationToken;
22use ulid::Ulid;
23
24pub use chrono;
25pub use cron;
26pub use sqlx;
27
28// TODO(api-highest): kill the concept of snapshot version, replace with always using protobuf and it doing its stuff internally
29pub use crdb_core::{Error, Result};
30pub use crdb_postgres::PostgresDb;
31
32#[cfg(test)]
33mod tests;
34
35pub type UserUpdatesMap = HashMap<ObjectId, Arc<UpdatesWithSnap>>;
36
37pub type UpdatesMap = HashMap<User, Arc<UserUpdatesMap>>;
38
39type EditableUpdatesMap = HashMap<User, HashMap<ObjectId, Arc<UpdatesWithSnap>>>;
40
41type SessionsSenderMap = HashMap<
42    User,
43    HashMap<SessionRef, Vec<mpsc::UnboundedSender<(Updatedness, Arc<UserUpdatesMap>)>>>,
44>;
45
46pub struct Server<C: crdb_core::Config> {
47    db: Arc<CacheDb<PostgresDb<C>>>,
48    last_completed_updatedness: Arc<Mutex<Updatedness>>,
49    updatedness_requester:
50        mpsc::UnboundedSender<oneshot::Sender<(Updatedness, oneshot::Sender<UpdatesMap>)>>,
51    _cleanup_token: tokio_util::sync::DropGuard,
52    sessions: Arc<Mutex<SessionsSenderMap>>,
53}
54
55impl<C: crdb_core::Config> Server<C> {
56    /// Returns both the server itself, as well as a `JoinHandle` that will resolve once all the operations
57    /// needed for database upgrading are over. The handle resolves with the number of errors that occurred
58    /// during the upgrade, normal runs would return 0. There will be one error message in the tracing logs
59    /// for each such error.
60    pub async fn new<Tz>(
61        config: C,
62        db: sqlx::PgPool,
63        cache_watermark: usize,
64        vacuum_schedule: ServerVacuumSchedule<Tz>,
65    ) -> anyhow::Result<(Self, JoinHandle<usize>)>
66    where
67        Tz: 'static + Send + chrono::TimeZone,
68        Tz::Offset: Send,
69    {
70        let _ = config; // ignore argument
71
72        // Check that all type ULIDs are distinct
73        C::check_ulids();
74
75        // Connect to the database and setup the cache
76        let db = PostgresDb::connect(db, cache_watermark).await?;
77
78        // Immediately update the permissions of objects pending permissions upgrades
79        // This must happen before starting the server, so long as we do not actually push the returned ReadPermsChange's to subscribers
80        db.update_pending_rdeps()
81            .await
82            .wrap_context("updating all pending reverse-dependencies")?;
83
84        // Start the upgrading task
85        let upgrade_handle = tokio::task::spawn({
86            let db = db.clone();
87            async move { C::reencode_old_versions(&*db).await }
88        });
89
90        // Setup the update reorderer task
91        let (updatedness_requester, mut updatedness_request_receiver) = mpsc::unbounded_channel::<
92            oneshot::Sender<(Updatedness, oneshot::Sender<UpdatesMap>)>,
93        >();
94        let (update_sender, mut update_receiver) = mpsc::unbounded_channel();
95        let last_completed_updatedness = Arc::new(Mutex::new(Updatedness::from_u128(0)));
96        tokio::task::spawn(async move {
97            // Updatedness request handler
98            let mut generator = ulid::Generator::new();
99            // No cancellation token needed, closing the sender will naturally close this task
100            while let Some(requester) = updatedness_request_receiver.recv().await {
101                // TODO(blocked): use generate_overflowing once it lands https://github.com/dylanhart/ulid-rs/pull/75
102                let updatedness = Updatedness(generator.generate().expect(
103                    "you're either very unlucky, or generated 2**80 updates within one millisecond",
104                ));
105                let (sender, receiver) = oneshot::channel();
106                if update_sender.send((updatedness, receiver)).is_err() {
107                    tracing::error!(
108                        "Update reorderer task went away before updatedness request handler task"
109                    );
110                }
111                let _ = requester.send((updatedness, sender)); // Ignore any failures, they'll free the slot anyway
112            }
113        });
114        let sessions = Arc::new(Mutex::new(SessionsSenderMap::new()));
115        tokio::task::spawn({
116            let sessions = sessions.clone();
117            let last_completed_updatedness = last_completed_updatedness.clone();
118            async move {
119                // Actual update reorderer
120                // No cancellation token needed, closing the senders will naturally close this task
121                while let Some((updatedness, update_receiver)) = update_receiver.recv().await {
122                    // Ignore the case where the slot sender was dropped
123                    if let Ok(updates) = update_receiver.await {
124                        let mut sessions = sessions.lock().unwrap();
125                        for (user, updates) in updates {
126                            if let Some(sessions) = sessions.get_mut(&user) {
127                                for senders in sessions.values_mut() {
128                                    // Discard all senders that return an error
129                                    senders.retain(|sender| {
130                                        sender.send((updatedness, updates.clone())).is_ok()
131                                    });
132                                    // TODO(perf-med): remove the entry from the hashmap altogether if it becomes empty
133                                }
134                            }
135                        }
136                    }
137                    *last_completed_updatedness.lock().unwrap() = updatedness;
138                }
139            }
140        });
141
142        // Setup the auto-vacuum task
143        let cancellation_token = CancellationToken::new();
144        tokio::task::spawn({
145            let db = db.clone();
146            let cancellation_token = cancellation_token.clone();
147            let updatedness_requester = updatedness_requester.clone();
148            async move {
149                for next_time in vacuum_schedule.schedule.upcoming(vacuum_schedule.timezone) {
150                    // Sleep until the next vacuum
151                    let sleep_for = next_time.signed_duration_since(chrono::Utc::now());
152                    let sleep_for = sleep_for
153                        .to_std()
154                        .unwrap_or_else(|_| Duration::from_secs(0));
155                    tokio::select! {
156                        _ = tokio::time::sleep(sleep_for) => (),
157                        _ = cancellation_token.cancelled() => break,
158                    }
159
160                    // Define the parameters
161                    let no_new_changes_before = vacuum_schedule.recreate_older_than.map(|d| {
162                        EventId(Ulid::from_parts(
163                            (SystemTime::now() - d).ms_since_posix().unwrap() as u64,
164                            u128::MAX,
165                        ))
166                    });
167                    let kill_sessions_older_than = vacuum_schedule
168                        .kill_sessions_older_than
169                        .map(|d| SystemTime::now() - d);
170
171                    // Retrieve the updatedness slot
172                    let (sender, receiver) = oneshot::channel();
173                    if updatedness_requester.send(sender).is_err() {
174                        tracing::error!(
175                            "Updatedness request handler thread went away before autovacuum thread"
176                        );
177                    }
178                    let Ok((updatedness, slot)) = receiver.await else {
179                        tracing::error!(
180                            "Updatedness request handler thread never answered autovacuum thread"
181                        );
182                        continue;
183                    };
184
185                    // Finally, run the vacuum
186                    if let Err(err) = Self::run_vacuum(
187                        &db,
188                        no_new_changes_before,
189                        updatedness,
190                        kill_sessions_older_than,
191                        slot,
192                    )
193                    .await
194                    {
195                        tracing::error!(?err, "scheduled vacuum failed");
196                    }
197                }
198            }
199        });
200
201        // Finally, return the information
202        let this = Server {
203            db,
204            last_completed_updatedness,
205            updatedness_requester,
206            _cleanup_token: cancellation_token.drop_guard(),
207            sessions,
208        };
209        Ok((this, upgrade_handle))
210    }
211
212    pub async fn login_session(
213        &self,
214        user_id: User,
215        session_name: String,
216        expiration_time: Option<SystemTime>,
217    ) -> crate::Result<(SessionToken, SessionRef)> {
218        let now = SystemTime::now();
219        self.db
220            .login_session(Session {
221                user_id,
222                session_ref: SessionRef::now(),
223                session_name,
224                login_time: now,
225                last_active: now,
226                expiration_time,
227            })
228            .await
229    }
230
231    pub async fn answer(&self, socket: WebSocket) {
232        let mut conn = ConnectionState {
233            socket,
234            session: None,
235        };
236        loop {
237            tokio::select! {
238                msg = conn.socket.next() => match msg {
239                    None => break, // End-of-stream
240                    Some(Err(err)) => {
241                        tracing::warn!(?err, "received an error while waiting for message on websocket");
242                        break;
243                    }
244                    Some(Ok(ws::Message::Ping(_) | ws::Message::Pong(_))) => continue, // Auto-handled by axum, ignore
245                    Some(Ok(ws::Message::Close(_))) => break, // End-of-stream
246                    Some(Ok(ws::Message::Text(msg))) => {
247                        if let Err(err) = self.handle_client_message(&mut conn, &msg).await {
248                            tracing::warn!(?err, ?msg, "client message violated protocol");
249                            break;
250                        }
251                    }
252                    Some(Ok(ws::Message::Binary(bin))) => {
253                        let bin = Vec::<u8>::from(bin).into_boxed_slice().into();
254                        if let Err(err) = self.handle_client_binary(&mut conn, bin).await {
255                            tracing::warn!(?err, "client binary violated protocol");
256                            break;
257                        }
258                    }
259                },
260
261                Some(update) = OptionFuture::from(conn.session.as_mut().map(|s| s.updates_receiver.recv())) => {
262                    if conn.has_expired() {
263                        // Do not reset conn.session, so that client gets a proper InvalidToken error on next request
264                        break;
265                    }
266                    let Some((updatedness, update)) = update else {
267                        tracing::error!("Update receiver broke before connection went down");
268                        break;
269                    };
270                    let sess = conn.session.as_ref().unwrap();
271                    // TODO(perf-high): make sure the size of update messages is both as batched as possible and not too big
272                    let mut data = Vec::new();
273                    for (object_id, updates) in update.iter() {
274                        if sess.is_subscribed_to(*object_id, updates.new_last_snapshot.as_deref()) {
275                            data.extend(updates.updates.iter().cloned());
276                        }
277                    }
278                    let send_res = Self::send(&mut conn.socket, &ServerMessage::Updates(Updates {
279                        data,
280                        now_have_all_until: updatedness,
281                    })).await;
282                    if let Err(err) = send_res {
283                        tracing::warn!(?err, "failed sending update to client");
284                        break;
285                    }
286                },
287            }
288        }
289    }
290
291    async fn handle_client_binary(
292        &self,
293        conn: &mut ConnectionState,
294        bin: Arc<[u8]>,
295    ) -> crate::Result<()> {
296        conn.kill_session_if_expired()?;
297        // Check we're waiting for binaries and count one as done
298        {
299            let sess = conn
300                .session
301                .as_mut()
302                .ok_or(crate::Error::ProtocolViolation)?;
303            sess.expected_binaries = sess
304                .expected_binaries
305                .checked_sub(1)
306                .ok_or(crate::Error::ProtocolViolation)?;
307        }
308
309        // Actually send the binary
310        let binary_id = crdb_core::hash_binary(&bin);
311        self.db.create_binary(binary_id, bin).await
312    }
313
314    async fn handle_client_message(
315        &self,
316        conn: &mut ConnectionState,
317        msg: &str,
318    ) -> crate::Result<()> {
319        conn.kill_session_if_expired()?;
320        if conn
321            .session
322            .as_ref()
323            .map(|sess| sess.expected_binaries > 0)
324            .unwrap_or(false)
325        {
326            return Err(crate::Error::ProtocolViolation);
327        }
328        tracing::trace!(?msg, session=?conn.session.as_ref().map(|s| &s.session), "received client message");
329        if let Some(sess) = &conn.session {
330            // TODO(perf-low): do not mark the session as active upon each incoming message? we don't really need to
331            // mark the session as active every 10 seconds / 1 minute, yet that's the frequency at which the clients
332            // send GetTime in order to detect disconnection
333            self.db
334                .mark_session_active(sess.token, SystemTime::now())
335                .await
336                .wrap_context("marking session as active")?;
337        }
338        let msg = serde_json::from_str::<ClientMessage>(msg)
339            .wrap_context("deserializing client message")?;
340        // TODO(perf-med): We could parallelize requests here, and not just pipeline them. However, we need to be
341        // careful about not sending updates about subscribed objects before the objects themselves, so it is
342        // nontrivial. Do this only after thinking well about what could happen.
343        match &*msg.request {
344            Request::SetToken(token) => {
345                let res = self.db.resume_session(*token).await.map(|session| {
346                    let (updates_sender, updates_receiver) = mpsc::unbounded_channel();
347                    self.sessions
348                        .lock()
349                        .unwrap()
350                        .entry(session.user_id)
351                        .or_default()
352                        .entry(session.session_ref)
353                        .or_default()
354                        .push(updates_sender);
355                    conn.session = Some(SessionInfo {
356                        token: *token,
357                        session,
358                        expected_binaries: 0,
359                        subscribed_objects: Arc::new(RwLock::new(HashSet::new())),
360                        subscribed_queries: Arc::new(RwLock::new(HashMap::new())),
361                        updates_receiver,
362                    });
363                    ResponsePart::Success
364                });
365                Self::send_res(&mut conn.socket, msg.request_id, res).await
366            }
367            Request::RenameSession(name) => {
368                let res = match &conn.session {
369                    None => Err(crate::Error::ProtocolViolation),
370                    Some(sess) => self
371                        .db
372                        .rename_session(sess.token, name)
373                        .await
374                        .map(|()| ResponsePart::Success),
375                };
376                Self::send_res(&mut conn.socket, msg.request_id, res).await
377            }
378            Request::CurrentSession => {
379                let res = match &conn.session {
380                    None => Err(crate::Error::ProtocolViolation),
381                    Some(sess) => Ok(ResponsePart::Sessions(vec![sess.session.clone()])),
382                };
383                Self::send_res(&mut conn.socket, msg.request_id, res).await
384            }
385            Request::ListSessions => {
386                let res = match &conn.session {
387                    None => Err(crate::Error::ProtocolViolation),
388                    Some(sess) => self
389                        .db
390                        .list_sessions(sess.session.user_id)
391                        .await
392                        .map(ResponsePart::Sessions),
393                };
394                Self::send_res(&mut conn.socket, msg.request_id, res).await
395            }
396            Request::Logout => {
397                let res = match &conn.session {
398                    None => Err(crate::Error::ProtocolViolation),
399                    Some(sess) => self
400                        .db
401                        .disconnect_session(sess.session.user_id, sess.session.session_ref)
402                        .await
403                        .map(|()| ResponsePart::Success),
404                };
405                conn.session = None;
406                Self::send_res(&mut conn.socket, msg.request_id, res).await
407            }
408            Request::DisconnectSession(session_ref) => {
409                let res = match &conn.session {
410                    None => Err(crate::Error::ProtocolViolation),
411                    Some(sess) => self
412                        .db
413                        .disconnect_session(sess.session.user_id, *session_ref)
414                        .await
415                        .map(|()| ResponsePart::Success),
416                };
417                Self::send_res(&mut conn.socket, msg.request_id, res).await
418            }
419            Request::GetTime => {
420                let res = match &conn.session {
421                    None => Err(crate::Error::ProtocolViolation),
422                    Some(_) => Ok(ResponsePart::CurrentTime(SystemTime::now())),
423                };
424                Self::send_res(&mut conn.socket, msg.request_id, res).await
425            }
426            Request::AlreadyHave { object_ids: _ } => {
427                // TODO(api-highest): actually record the information to avoid re-sending the objects' contents
428                Ok(())
429            }
430            Request::Get {
431                object_ids,
432                subscribe,
433            } => {
434                assert!(subscribe); // TODO(api-highest): implement non-subscribing gets
435                                    // TODO(blocked): remove this copy once https://github.com/rust-lang/rust/issues/110338 is fixed
436                let object_ids = object_ids.iter().map(|(o, u)| (*o, *u)).collect::<Vec<_>>();
437                self.subscribe_and_send_objects(conn, msg.request_id, None, object_ids.into_iter())
438                    .await
439            }
440            Request::Query {
441                query_id,
442                type_id,
443                query,
444                only_updated_since,
445                subscribe,
446            } => {
447                assert!(subscribe); // TODO(api-highest): implement non-subscribing gets
448                let sess = conn
449                    .session
450                    .as_ref()
451                    .ok_or(crate::Error::ProtocolViolation)?;
452                // Subscribe BEFORE running the query. This makes sure no updates are lost.
453                // We must then not return to the update-sending loop until all the responses are sent.
454                sess.subscribed_queries
455                    .write()
456                    .unwrap()
457                    .insert(*query_id, query.clone());
458                let updatedness = *self.last_completed_updatedness.lock().unwrap();
459                let object_ids = self
460                    .db
461                    .server_query(
462                        sess.session.user_id,
463                        *type_id,
464                        *only_updated_since,
465                        query.clone(),
466                    )
467                    .await
468                    .wrap_context("listing objects matching query")?;
469                // Note: `subscribe_and_send_objects` will only fetch and send objects that the user has not yet subscribed upon.
470                // So, setting `None` here is the right thing to do.
471                self.subscribe_and_send_objects(
472                    conn,
473                    msg.request_id,
474                    Some(updatedness),
475                    object_ids.into_iter().map(|o| (o, None)),
476                )
477                .await
478            }
479            Request::GetBinaries(binary_ids) => {
480                // Just avoid unauthed binary gets
481                let _ = conn
482                    .session
483                    .as_ref()
484                    .ok_or(crate::Error::ProtocolViolation)?;
485                self.send_binaries(conn, msg.request_id, binary_ids.iter().copied())
486                    .await
487            }
488            Request::Unsubscribe(object_ids) => {
489                {
490                    let mut subscribed_objects = conn
491                        .session
492                        .as_ref()
493                        .ok_or(crate::Error::ProtocolViolation)?
494                        .subscribed_objects
495                        .write()
496                        .unwrap();
497                    for id in object_ids {
498                        subscribed_objects.remove(id);
499                    }
500                }
501                Self::send_res(&mut conn.socket, msg.request_id, Ok(ResponsePart::Success)).await
502            }
503            Request::UnsubscribeQuery(query_id) => {
504                conn.session
505                    .as_ref()
506                    .ok_or(crate::Error::ProtocolViolation)?
507                    .subscribed_queries
508                    .write()
509                    .unwrap()
510                    .remove(query_id);
511                Self::send_res(&mut conn.socket, msg.request_id, Ok(ResponsePart::Success)).await
512            }
513            Request::Upload(upload) => {
514                let sess = conn
515                    .session
516                    .as_ref()
517                    .ok_or(crate::Error::ProtocolViolation)?;
518                match upload {
519                    Upload::Object {
520                        object_id,
521                        type_id,
522                        created_at,
523                        snapshot_version,
524                        object,
525                        subscribe,
526                    } => {
527                        let (updatedness, update_sender) = self.updatedness_slot().await?;
528                        let res = C::upload_object(
529                            &*self.db,
530                            sess.session.user_id,
531                            updatedness,
532                            *type_id,
533                            *object_id,
534                            *created_at,
535                            *snapshot_version,
536                            object.clone(),
537                        )
538                        .await;
539                        let res = match res {
540                            Ok(res) => res,
541                            Err(err) => {
542                                return Self::send_res(&mut conn.socket, msg.request_id, Err(err))
543                                    .await;
544                            }
545                        };
546                        if let Some((new_update, users_who_can_read, rdeps)) = res {
547                            let mut new_data = HashMap::new();
548                            self.add_rdeps_updates(&mut new_data, rdeps)
549                                .await
550                                .wrap_context("listing updates for rdeps")?;
551                            for user in users_who_can_read {
552                                let existing = new_data
553                                    .entry(user)
554                                    .or_insert_with(HashMap::new)
555                                    .insert(*object_id, new_update.clone());
556                                if let Some(existing) = existing {
557                                    tracing::error!(
558                                        ?user,
559                                        ?object_id,
560                                        ?existing,
561                                        "replacing mistakenly-already-existing update"
562                                    );
563                                }
564                            }
565                            let new_data = new_data
566                                .into_iter()
567                                .map(|(k, v)| (k, Arc::new(v)))
568                                .collect();
569
570                            update_sender.send(new_data).map_err(|_| {
571                                crate::Error::Other(anyhow!(
572                                    "Update reorderer thread went away before updating thread",
573                                ))
574                            })?;
575                        }
576                        if *subscribe {
577                            sess.subscribed_objects.write().unwrap().insert(*object_id);
578                        }
579                        Self::send_res(&mut conn.socket, msg.request_id, Ok(ResponsePart::Success))
580                            .await
581                    }
582                    Upload::Event {
583                        object_id,
584                        type_id,
585                        event_id,
586                        event,
587                        subscribe,
588                    } => {
589                        let (updatedness, update_sender) = self.updatedness_slot().await?;
590                        let res = C::upload_event(
591                            &*self.db,
592                            sess.session.user_id,
593                            updatedness,
594                            *type_id,
595                            *object_id,
596                            *event_id,
597                            event.clone(),
598                        )
599                        .await;
600                        let res = match res {
601                            Ok(res) => res,
602                            Err(err) => {
603                                return Self::send_res(&mut conn.socket, msg.request_id, Err(err))
604                                    .await;
605                            }
606                        };
607                        if let Some((new_update, users_who_can_read, rdeps)) = res {
608                            let mut new_data = HashMap::new();
609                            self.add_rdeps_updates(&mut new_data, rdeps)
610                                .await
611                                .wrap_context("listing updates for rdeps")?;
612                            for user in users_who_can_read {
613                                let existing = new_data
614                                    .entry(user)
615                                    .or_insert_with(HashMap::new)
616                                    .insert(*object_id, new_update.clone());
617                                if let Some(existing) = existing {
618                                    tracing::error!(
619                                        ?user,
620                                        ?object_id,
621                                        ?existing,
622                                        "replacing mistakenly-already-existing update"
623                                    );
624                                }
625                            }
626                            let new_data = new_data
627                                .into_iter()
628                                .map(|(k, v)| (k, Arc::new(v)))
629                                .collect();
630
631                            update_sender.send(new_data).map_err(|_| {
632                                crate::Error::Other(anyhow!(
633                                    "Update reorderer thread went away before updating thread",
634                                ))
635                            })?;
636                        }
637                        if *subscribe {
638                            sess.subscribed_objects.write().unwrap().insert(*object_id);
639                        }
640                        Self::send_res(&mut conn.socket, msg.request_id, Ok(ResponsePart::Success))
641                            .await
642                    }
643                }
644            }
645            Request::UploadBinaries(num_binaries) => {
646                conn.session
647                    .as_mut()
648                    .ok_or(crate::Error::ProtocolViolation)?
649                    .expected_binaries = *num_binaries;
650                Self::send_res(&mut conn.socket, msg.request_id, Ok(ResponsePart::Success)).await
651            }
652        }
653    }
654
655    // TODO(api-high): This uses get_latest_snapshot, and thus assumes that the latest snapshot
656    // was just written in this run, with the latest snapshot version
657    async fn add_rdeps_updates(
658        &self,
659        updates: &mut EditableUpdatesMap,
660        rdeps: Vec<ReadPermsChanges>,
661    ) -> crate::Result<()> {
662        for c in rdeps {
663            for u in c.lost_read {
664                updates.entry(u).or_default().insert(
665                    c.object_id,
666                    Arc::new(UpdatesWithSnap {
667                        updates: vec![Arc::new(Update {
668                            object_id: c.object_id,
669                            data: UpdateData::LostReadRights,
670                        })],
671                        new_last_snapshot: None,
672                    }),
673                );
674            }
675            if let Some(one_user) = c.gained_read.iter().next() {
676                let mut t = self.db.get_transaction().await?;
677                let object = self
678                    .db
679                    .get_all(&mut t, *one_user, c.object_id, None)
680                    .await?;
681                let last_snapshot = self
682                    .db
683                    .get_latest_snapshot(&mut t, *one_user, c.object_id)
684                    .await?;
685                let new_updates = object.into_updates();
686                for u in c.gained_read {
687                    updates.entry(u).or_default().insert(
688                        c.object_id,
689                        Arc::new(UpdatesWithSnap {
690                            updates: new_updates.clone(),
691                            new_last_snapshot: Some(last_snapshot.clone()),
692                        }),
693                    );
694                }
695            }
696        }
697        Ok(())
698    }
699
700    async fn subscribe_and_send_objects(
701        &self,
702        conn: &mut ConnectionState,
703        request_id: RequestId,
704        query_updatedness: Option<Updatedness>,
705        objects: impl Iterator<Item = (ObjectId, Option<Updatedness>)>,
706    ) -> crate::Result<()> {
707        let sess = conn
708            .session
709            .as_ref()
710            .ok_or(crate::Error::ProtocolViolation)?;
711        let user = sess.session.user_id;
712        let subscribed_objects = sess.subscribed_objects.clone();
713        let objects = objects.map(|(object_id, updatedness)| {
714            let subscribed_objects = subscribed_objects.clone();
715            async move {
716                if subscribed_objects.read().unwrap().contains(&object_id) {
717                    Ok(MaybeObject::AlreadySubscribed(object_id))
718                } else {
719                    // Subscribe BEFORE getting the object. This makes sure no updates are lost.
720                    // We must then not return to the update-sending loop until all the responses are sent.
721                    subscribed_objects.write().unwrap().insert(object_id);
722                    let mut t = self.db.get_transaction().await?;
723                    let object = self
724                        .db
725                        .get_all(&mut t, user, object_id, updatedness)
726                        .await?;
727                    Ok(MaybeObject::NotYetSubscribed(object))
728                }
729            }
730        });
731        let objects = stream::iter(objects).buffer_unordered(16); // TODO(perf-low): is 16 a good number?
732        pin_mut!(objects);
733        let mut size_of_message = 0;
734        let mut current_data = Vec::new();
735        // Send all the objects to the client, batching them by messages of a reasonable size, to both allow for better
736        // resumption after a connection loss, while not sending one message per mini-object.
737        while let Some(object) = objects.next().await {
738            if size_of_message >= 1024 * 1024 {
739                // TODO(perf-low): is 1MiB a good number?
740                let data = std::mem::take(&mut current_data);
741                size_of_message = 0;
742                Self::send(
743                    &mut conn.socket,
744                    &ServerMessage::Response {
745                        request_id,
746                        response: ResponsePart::Objects {
747                            data,
748                            now_have_all_until: None,
749                        },
750                        last_response: false,
751                    },
752                )
753                .await?;
754            }
755            match object {
756                Ok(object) => {
757                    size_of_message += size_as_json(&object)?;
758                    current_data.push(object);
759                }
760                Err(err @ crate::Error::ObjectDoesNotExist(_)) => {
761                    if query_updatedness.is_some() {
762                        // User lost read access to object between query and read
763                        // Do nothing
764                    } else {
765                        // User explicitly requested a non-existing object
766                        // Return an error but keep processing the request
767                        Self::send(
768                            &mut conn.socket,
769                            &ServerMessage::Response {
770                                request_id,
771                                response: ResponsePart::Error(err.into()),
772                                last_response: false,
773                            },
774                        )
775                        .await?;
776                    }
777                }
778                Err(err) => return Self::send_res(&mut conn.socket, request_id, Err(err)).await,
779            }
780        }
781        Self::send_res(
782            &mut conn.socket,
783            request_id,
784            Ok(ResponsePart::Objects {
785                data: current_data,
786                now_have_all_until: query_updatedness,
787            }),
788        )
789        .await
790    }
791
792    async fn send_binaries(
793        &self,
794        conn: &mut ConnectionState,
795        request_id: RequestId,
796        binaries: impl Iterator<Item = BinPtr>,
797    ) -> crate::Result<()> {
798        let binaries =
799            binaries.map(|binary_id| self.db.get_binary(binary_id).map(move |r| (binary_id, r)));
800        let binaries = stream::iter(binaries).buffer_unordered(16); // TODO(perf-low): is 16 a good number?
801        pin_mut!(binaries);
802        let mut size_of_message = 0;
803        let mut current_data = Vec::new();
804        // Send all the binaries to the client, trying to avoid having too many ResponsePart::Binaries messages while still sending as
805        // many binaries as possible before any potential error (in particular missing-binary).
806        while let Some((binary_id, binary)) = binaries.next().await {
807            if size_of_message >= 1024 * 1024 {
808                // TODO(perf-low): is 1MiB a good number?
809                size_of_message = 0;
810                Self::send_binaries_msg(
811                    &mut conn.socket,
812                    request_id,
813                    false,
814                    current_data.drain(..),
815                )
816                .await?;
817            }
818            let binary = match binary {
819                Ok(Some(binary)) => Ok(binary),
820                Ok(None) => Err(crate::Error::MissingBinaries(vec![binary_id])),
821                Err(err) => Err(err),
822            };
823            let binary = match binary {
824                Ok(binary) => binary,
825                Err(err) => {
826                    if !current_data.is_empty() {
827                        Self::send_binaries_msg(
828                            &mut conn.socket,
829                            request_id,
830                            false,
831                            current_data.drain(..),
832                        )
833                        .await?;
834                    }
835                    return Self::send_res(&mut conn.socket, request_id, Err(err)).await;
836                }
837            };
838            size_of_message += binary.len();
839            current_data.push(binary);
840        }
841        Self::send_binaries_msg(&mut conn.socket, request_id, true, current_data.drain(..)).await
842    }
843
844    async fn send_binaries_msg(
845        socket: &mut WebSocket,
846        request_id: RequestId,
847        last_response: bool,
848        bins: impl ExactSizeIterator<Item = Arc<[u8]>>,
849    ) -> crate::Result<()> {
850        Self::send(
851            socket,
852            &ServerMessage::Response {
853                request_id,
854                last_response,
855                response: ResponsePart::Binaries(bins.len()),
856            },
857        )
858        .await?;
859        for bin in bins {
860            socket
861                .send(ws::Message::Binary(Bytes::from_owner(bin)))
862                .await
863                .wrap_context("sending binary to client")?
864        }
865        Ok(())
866    }
867
868    async fn send_res(
869        socket: &mut WebSocket,
870        request_id: RequestId,
871        res: crate::Result<ResponsePart>,
872    ) -> crate::Result<()> {
873        let response = match res {
874            Ok(res) => res,
875            Err(err) => ResponsePart::Error(err.into()),
876        };
877        Self::send(
878            socket,
879            &ServerMessage::Response {
880                request_id,
881                response,
882                last_response: true,
883            },
884        )
885        .await
886    }
887
888    async fn send(socket: &mut WebSocket, msg: &ServerMessage) -> crate::Result<()> {
889        let msg = serde_json::to_string(msg).wrap_context("serializing server message")?;
890        socket
891            .send(ws::Message::Text(msg.into()))
892            .await
893            .wrap_context("sending response to client")
894    }
895
896    /// Cleans up and optimizes up the database
897    ///
898    /// After running this, the database will reject any new change that would happen before
899    /// `no_new_changes_before` if it is set.
900    pub async fn vacuum(
901        &self,
902        no_new_changes_before: Option<EventId>,
903        kill_sessions_older_than: Option<SystemTime>,
904    ) -> crate::Result<()> {
905        let (updatedness, slot) = self.updatedness_slot().await?;
906        Self::run_vacuum(
907            &self.db,
908            no_new_changes_before,
909            updatedness,
910            kill_sessions_older_than,
911            slot,
912        )
913        .await
914    }
915
916    async fn run_vacuum(
917        db: &CacheDb<PostgresDb<C>>,
918        no_new_changes_before: Option<EventId>,
919        updatedness: Updatedness,
920        kill_sessions_older_than: Option<SystemTime>,
921        slot: oneshot::Sender<UpdatesMap>,
922    ) -> crate::Result<()> {
923        // Perform the vacuum, collecting all updates
924        let mut updates = HashMap::new();
925        let res = db
926            .server_vacuum(
927                no_new_changes_before,
928                updatedness,
929                kill_sessions_older_than,
930                |update, users_who_can_read| {
931                    // Vacuum cannot change any latest snapshot
932                    let object_id = update.object_id;
933                    let update = Arc::new(UpdatesWithSnap {
934                        updates: vec![Arc::new(update)],
935                        new_last_snapshot: None,
936                    });
937                    for u in users_who_can_read {
938                        updates
939                            .entry(u)
940                            .or_insert_with(HashMap::new)
941                            .insert(object_id, update.clone());
942                    }
943                },
944            )
945            .await;
946
947        // Arc where appropriate
948        // TODO(perf-low): this could probably be done without copying by having &muts to the underlying Arc at creation time
949        let updates = updates.into_iter().map(|(k, v)| (k, Arc::new(v))).collect();
950
951        // Submit the updates
952        if slot.send(updates).is_err() {
953            tracing::error!("Update reorderer went away before server");
954        }
955
956        // And return the result
957        res
958    }
959
960    async fn updatedness_slot(&self) -> crate::Result<(Updatedness, oneshot::Sender<UpdatesMap>)> {
961        let (sender, receiver) = oneshot::channel();
962        self.updatedness_requester.send(sender).map_err(|_| {
963            crate::Error::Other(anyhow!(
964                "Updatedness request handler thread went away too early"
965            ))
966        })?;
967        let slot = receiver.await.map_err(|_| {
968            crate::Error::Other(anyhow!("Updatedness request handler thread never answered"))
969        })?;
970        Ok(slot)
971    }
972}
973
974pub struct ServerVacuumSchedule<Tz: chrono::TimeZone> {
975    schedule: cron::Schedule,
976    timezone: Tz,
977    // TODO(api-high): recreate different object types at different frequencies
978    recreate_older_than: Option<Duration>,
979    kill_sessions_older_than: Option<Duration>,
980}
981
982impl<Tz: chrono::TimeZone> ServerVacuumSchedule<Tz> {
983    pub fn new(schedule: cron::Schedule, timezone: Tz) -> ServerVacuumSchedule<Tz> {
984        ServerVacuumSchedule {
985            schedule,
986            timezone,
987            recreate_older_than: None,
988            kill_sessions_older_than: None,
989        }
990    }
991
992    pub fn recreate_older_than(mut self, age: Duration) -> Self {
993        self.recreate_older_than = Some(age);
994        self
995    }
996
997    pub fn kill_sessions_older_than(mut self, age: Duration) -> Self {
998        self.kill_sessions_older_than = Some(age);
999        self
1000    }
1001}
1002
1003struct ConnectionState {
1004    socket: WebSocket,
1005    session: Option<SessionInfo>,
1006}
1007
1008impl ConnectionState {
1009    fn has_expired(&self) -> bool {
1010        self.session
1011            .as_ref()
1012            .and_then(|sess| sess.session.expiration_time)
1013            .map(|expiration_time| expiration_time < SystemTime::now())
1014            .unwrap_or(false)
1015    }
1016
1017    /// Kills the session if it expired.
1018    ///
1019    /// Returns an InvalidToken error if the session was killed.
1020    fn kill_session_if_expired(&mut self) -> crate::Result<()> {
1021        if self.has_expired() {
1022            let token = self.session.as_ref().unwrap().token;
1023            self.session = None;
1024            return Err(crate::Error::InvalidToken(token));
1025        }
1026        Ok(())
1027    }
1028}
1029
1030struct SessionInfo {
1031    token: SessionToken,
1032    session: Session,
1033    expected_binaries: usize,
1034    subscribed_objects: Arc<RwLock<HashSet<ObjectId>>>,
1035    subscribed_queries: Arc<RwLock<HashMap<QueryId, Arc<Query>>>>,
1036    updates_receiver: mpsc::UnboundedReceiver<(Updatedness, Arc<UserUpdatesMap>)>,
1037}
1038
1039impl SessionInfo {
1040    fn is_subscribed_to(
1041        &self,
1042        object_id: ObjectId,
1043        new_last_snapshot: Option<&serde_json::Value>,
1044    ) -> bool {
1045        if self.subscribed_objects.read().unwrap().contains(&object_id) {
1046            return true;
1047        }
1048        if let Some(new_last_snapshot) = new_last_snapshot {
1049            for query in self.subscribed_queries.read().unwrap().values() {
1050                if query.matches_json(new_last_snapshot) {
1051                    self.subscribed_objects.write().unwrap().insert(object_id);
1052                    return true;
1053                }
1054            }
1055        }
1056        false
1057    }
1058}
1059
1060fn size_as_json<T: serde::Serialize>(value: &T) -> crate::Result<usize> {
1061    struct Size(usize);
1062    impl std::io::Write for Size {
1063        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
1064            self.0 += buf.len();
1065            Ok(buf.len())
1066        }
1067
1068        fn flush(&mut self) -> std::io::Result<()> {
1069            Ok(())
1070        }
1071    }
1072    let mut size = Size(0);
1073    serde_json::to_writer(&mut size, value)
1074        .wrap_context("figuring out the serialized size of value")?;
1075    Ok(size.0)
1076}