use crate::{
client_db::SavedObject,
connection::{
Command, Connection, ConnectionEvent, RequestWithSidecar, ResponsePartWithSidecar,
ResponseSender,
},
};
use anyhow::anyhow;
use crdb_cache::CacheDb;
use crdb_core::{
BinPtr, ClientSideDb, CrdbSyncFn, Db, Event, EventId, Importance, MaybeObject, Object,
ObjectData, ObjectId, Query, QueryId, Request, ResponsePart, ResultExt, SavedQuery, Session,
SessionRef, SessionToken, Updatedness, Updates, Upload, UploadId,
};
use futures::{channel::mpsc, future::Either, pin_mut, stream, FutureExt, StreamExt};
use std::{
collections::{HashMap, HashSet, VecDeque},
future::Future,
iter,
sync::{Arc, Mutex, RwLock},
};
use tokio::sync::{oneshot, watch};
#[non_exhaustive]
pub enum OnError {
Rollback,
KeepLocal,
ReplaceWith(Upload),
}
pub struct ApiDb<LocalDb: ClientSideDb> {
connection: mpsc::UnboundedSender<Command>,
upload_queue_watcher_sender: Arc<Mutex<watch::Sender<Vec<UploadId>>>>,
upload_queue_watcher_receiver: watch::Receiver<Vec<UploadId>>,
db: Arc<CacheDb<LocalDb>>,
upload_resender: mpsc::UnboundedSender<(
Option<UploadId>,
Arc<Request>,
mpsc::UnboundedSender<ResponsePartWithSidecar>,
)>,
connection_event_cb: Arc<RwLock<Box<dyn CrdbSyncFn<ConnectionEvent>>>>,
}
impl<LocalDb: ClientSideDb> ApiDb<LocalDb> {
pub(crate) async fn new<C, GSO, GSQ, EH, EHF, RRL>(
db: Arc<CacheDb<LocalDb>>,
get_saved_objects: GSO,
get_saved_queries: GSQ,
error_handler: EH,
require_relogin: RRL,
) -> crate::Result<(ApiDb<LocalDb>, mpsc::UnboundedReceiver<Updates>)>
where
C: crdb_core::Config,
GSO: 'static + waaaa::Send + FnMut() -> HashMap<ObjectId, SavedObject>,
GSQ: 'static + Send + FnMut() -> HashMap<QueryId, SavedQuery>,
EH: 'static + waaaa::Send + Fn(Upload, crate::Error) -> EHF,
EHF: 'static + waaaa::Future<Output = OnError>,
RRL: 'static + waaaa::Send + Fn(),
{
let (update_sender, update_receiver) = mpsc::unbounded();
let connection_event_cb: Arc<RwLock<Box<dyn CrdbSyncFn<ConnectionEvent>>>> =
Arc::new(RwLock::new(Box::new(|_| ()) as _));
let event_cb = {
let connection_event_cb = connection_event_cb.clone();
Box::new(move |evt| {
let need_relogin = match evt {
ConnectionEvent::LoggingIn => false,
ConnectionEvent::FailedConnecting(_) => true,
ConnectionEvent::FailedSendingToken(_) => true,
ConnectionEvent::LostConnection(_) => false,
ConnectionEvent::InvalidToken(_) => true,
ConnectionEvent::Connected => false,
ConnectionEvent::TimeOffset(_) => false,
ConnectionEvent::LoggedOut => true,
};
if need_relogin {
(require_relogin)();
}
connection_event_cb.read().unwrap()(evt);
})
};
let (connection, commands) = mpsc::unbounded();
let (requests, requests_receiver) = mpsc::unbounded();
waaaa::spawn(
Connection::new(
commands,
requests_receiver,
event_cb,
update_sender,
get_saved_objects,
get_saved_queries,
)
.run(),
);
let all_uploads = db
.list_uploads()
.await
.wrap_context("listing upload queue")?;
let (upload_queue_watcher_sender, upload_queue_watcher_receiver) =
watch::channel(all_uploads.clone());
let upload_queue_watcher_sender = Arc::new(Mutex::new(upload_queue_watcher_sender));
let (upload_resender_sender, upload_resender_receiver) = mpsc::unbounded();
waaaa::spawn(upload_resender::<C, _, _, _>(
db.clone(),
upload_resender_receiver,
requests,
upload_queue_watcher_sender.clone(),
error_handler,
));
for upload_id in all_uploads {
let upload = db
.get_upload(upload_id)
.await
.wrap_context("retrieving upload")?
.ok_or_else(|| {
crate::Error::Other(anyhow!(
"Upload vanished from queue while doing the initial read"
))
})?;
let request = Arc::new(Request::Upload(upload));
let (sender, _) = mpsc::unbounded(); upload_resender_sender
.unbounded_send((Some(upload_id), request, sender))
.expect("connection cannot go away before apidb does");
}
Ok((
ApiDb {
db,
upload_queue_watcher_sender,
upload_queue_watcher_receiver,
connection,
upload_resender: upload_resender_sender,
connection_event_cb,
},
update_receiver,
))
}
pub fn watch_upload_queue(&self) -> watch::Receiver<Vec<UploadId>> {
self.upload_queue_watcher_receiver.clone()
}
pub fn on_connection_event(&self, cb: impl 'static + CrdbSyncFn<ConnectionEvent>) {
*self.connection_event_cb.write().unwrap() = Box::new(cb);
}
pub fn login(&self, url: Arc<String>, token: SessionToken) {
self.connection
.unbounded_send(Command::Login { url, token })
.expect("connection cannot go away before sender does")
}
pub fn logout(&self) {
self.connection
.unbounded_send(Command::Logout)
.expect("connection cannot go away before sender does")
}
fn request(&self, request: Arc<Request>) -> mpsc::UnboundedReceiver<ResponsePartWithSidecar> {
let (sender, response) = mpsc::unbounded();
self.upload_resender
.unbounded_send((None, request, sender))
.expect("connection cannot go away before sender does");
response
}
pub fn rename_session(&self, name: String) -> oneshot::Receiver<crate::Result<()>> {
let response_receiver = self.request(Arc::new(Request::RenameSession(name)));
expect_simple_response(response_receiver)
}
pub async fn current_session(&self) -> crate::Result<Session> {
let response = self
.request(Arc::new(Request::CurrentSession))
.next()
.await
.ok_or_else(|| crate::Error::Other(anyhow!("Connection thread went down too early")))?;
match response.response {
ResponsePart::Sessions(mut sessions) if sessions.len() == 1 => {
Ok(sessions.pop().unwrap())
}
ResponsePart::Error(err) => Err(err.into()),
_ => Err(crate::Error::Other(anyhow!(
"Unexpected server response to CurrentSession: {:?}",
response.response
))),
}
}
pub async fn list_sessions(&self) -> crate::Result<Vec<Session>> {
let response = self
.request(Arc::new(Request::ListSessions))
.next()
.await
.ok_or_else(|| crate::Error::Other(anyhow!("Connection thread went down too early")))?;
match response.response {
ResponsePart::Sessions(sessions) => Ok(sessions),
ResponsePart::Error(err) => Err(err.into()),
_ => Err(crate::Error::Other(anyhow!(
"Unexpected server response to ListSessions: {:?}",
response.response
))),
}
}
pub fn disconnect_session(
&self,
session_ref: SessionRef,
) -> oneshot::Receiver<crate::Result<()>> {
let response_receiver = self.request(Arc::new(Request::DisconnectSession(session_ref)));
expect_simple_response(response_receiver)
}
pub fn unsubscribe(&self, object_ids: HashSet<ObjectId>) {
self.request(Arc::new(Request::Unsubscribe(object_ids)));
}
pub fn unsubscribe_query(&self, query_id: QueryId) {
self.request(Arc::new(Request::UnsubscribeQuery(query_id)));
}
async fn handle_upload_response(
mut receiver: mpsc::UnboundedReceiver<ResponsePartWithSidecar>,
) -> crate::Result<()> {
match receiver.next().await {
None => Err(crate::Error::Other(anyhow!(
"Connection did not return any answer to query"
))),
Some(ResponsePartWithSidecar {
sidecar: Some(_), ..
}) => Err(crate::Error::Other(anyhow!(
"Connection returned sidecar while we expected a simple result"
))),
Some(ResponsePartWithSidecar { response, .. }) => match response {
ResponsePart::Success => Ok(()),
ResponsePart::Error(err) => Err(err.into()),
ResponsePart::Sessions(_)
| ResponsePart::CurrentTime(_)
| ResponsePart::Objects { .. }
| ResponsePart::Binaries(_) => Err(crate::Error::Other(anyhow!(
"Connection returned unexpected answer while expecting a simple result"
))),
},
}
}
pub async fn create<T: Object>(
&self,
object_id: ObjectId,
created_at: EventId,
object: Arc<T>,
subscribe: bool,
) -> crate::Result<impl Future<Output = crate::Result<()>>> {
let required_binaries = object.required_binaries();
let upload = Upload::Object {
object_id,
type_id: *T::type_ulid(),
created_at,
snapshot_version: T::snapshot_version(),
object: Arc::new(
serde_json::to_value(object)
.wrap_context("serializing object for sending to api")?,
),
subscribe,
};
let request = Arc::new(Request::Upload(upload.clone()));
let (result_sender, result_receiver) = mpsc::unbounded();
let upload_id = self
.db
.enqueue_upload(upload, required_binaries)
.await
.wrap_context("enqueuing upload")?;
let upload_list = self
.db
.list_uploads()
.await
.wrap_context("listing uploads")?;
self.upload_queue_watcher_sender
.lock()
.unwrap()
.send_replace(upload_list);
self.upload_resender
.unbounded_send((Some(upload_id), request, result_sender))
.map_err(|_| crate::Error::Other(anyhow!("Upload resender went out too early")))?;
Ok(Self::handle_upload_response(result_receiver))
}
pub async fn submit<T: Object>(
&self,
object_id: ObjectId,
event_id: EventId,
event: Arc<T::Event>,
subscribe: bool,
) -> crate::Result<impl Future<Output = crate::Result<()>>> {
let required_binaries = event.required_binaries();
let upload = Upload::Event {
object_id,
type_id: *T::type_ulid(),
event_id,
event: Arc::new(
serde_json::to_value(event).wrap_context("serializing event for sending to api")?,
),
subscribe,
};
let request = Arc::new(Request::Upload(upload.clone()));
let (result_sender, result_receiver) = mpsc::unbounded();
let upload_id = self
.db
.enqueue_upload(upload, required_binaries)
.await
.wrap_context("enqueuing upload")?;
let upload_list = self
.db
.list_uploads()
.await
.wrap_context("listing uploads")?;
self.upload_queue_watcher_sender
.lock()
.unwrap()
.send_replace(upload_list);
self.upload_resender
.unbounded_send((Some(upload_id), request, result_sender))
.map_err(|_| crate::Error::Other(anyhow!("Upload resender went out too early")))?;
Ok(Self::handle_upload_response(result_receiver))
}
pub async fn get(&self, object_id: ObjectId, subscribe: bool) -> crate::Result<ObjectData> {
let mut object_ids = HashMap::new();
object_ids.insert(object_id, None); let request = Arc::new(Request::Get {
object_ids,
subscribe,
});
let mut response = self.request(request);
match response.next().await {
None => Err(crate::Error::Other(anyhow!(
"Connection-handling thread went out before ApiDb"
))),
Some(response) => match response.response {
ResponsePart::Error(err) => Err(err.into()),
ResponsePart::Objects { mut data, .. } if data.len() == 1 => {
match data.pop().unwrap() {
MaybeObject::AlreadySubscribed(_) => Err(crate::Error::Other(anyhow!(
"Server unexpectedly told us we already know unknown {object_id:?}"
))),
MaybeObject::NotYetSubscribed(res) => Ok(res),
}
}
_ => Err(crate::Error::Other(anyhow!(
"Unexpected response to GetSubscribe request: {:?}",
response.response
))),
},
}
}
pub fn query<T: Object>(
&self,
query_id: QueryId,
only_updated_since: Option<Updatedness>,
subscribe: bool,
query: Arc<Query>,
) -> impl waaaa::Stream<Item = crate::Result<(MaybeObject, Option<Updatedness>)>> {
let request = Arc::new(Request::Query {
query_id,
type_id: *T::type_ulid(),
query,
only_updated_since,
subscribe,
});
self.request(request).flat_map(move |response| {
match response.response {
ResponsePart::Error(err) => Either::Left(stream::iter(iter::once(Err(err.into())))),
ResponsePart::Objects {
data,
now_have_all_until,
} => {
let data_len = data.len();
Either::Right(stream::iter(data.into_iter().enumerate().map(
move |(i, d)| {
let now_have_all_until = if i + 1 == data_len {
now_have_all_until
} else {
None
};
Ok((d, now_have_all_until))
},
)))
}
resp => Either::Left(stream::iter(iter::once(Err(crate::Error::Other(anyhow!(
"Server gave unexpected answer to QuerySubscribe request: {resp:?}"
)))))),
}
})
}
pub async fn get_binary(&self, binary_id: BinPtr) -> crate::Result<Option<Arc<[u8]>>> {
let mut binary_ids = HashSet::new();
binary_ids.insert(binary_id);
let request = Arc::new(Request::GetBinaries(binary_ids));
let mut response = self.request(request);
match response.next().await {
None => Err(crate::Error::Other(anyhow!(
"Connection-handling thread went out before ApiDb"
))),
Some(response) => match response.response {
ResponsePart::Error(err) => Err(err.into()),
ResponsePart::Binaries(1) => {
let bin = response.sidecar.ok_or_else(|| {
crate::Error::Other(anyhow!(
"Connection thread claimed to send us one binary but actually did not"
))
})?;
Ok(Some(bin))
}
_ => Err(crate::Error::Other(anyhow!(
"Unexpected response to get-binary request: {:?}",
response.response
))),
},
}
}
}
async fn upload_resender<C, LocalDb, EH, EHF>(
db: Arc<CacheDb<LocalDb>>,
requests: mpsc::UnboundedReceiver<(
Option<UploadId>,
Arc<Request>,
mpsc::UnboundedSender<ResponsePartWithSidecar>,
)>,
connection: mpsc::UnboundedSender<(ResponseSender, Arc<RequestWithSidecar>)>,
upload_queue_watcher_sender: Arc<Mutex<watch::Sender<Vec<UploadId>>>>,
error_handler: EH,
) where
C: crdb_core::Config,
LocalDb: ClientSideDb,
EH: 'static + waaaa::Send + Fn(Upload, crate::Error) -> EHF,
EHF: 'static + waaaa::Future<Output = OnError>,
{
let requests = requests.peekable();
pin_mut!(requests);
macro_rules! poll_next_if {
($cond:expr) => {
requests
.as_mut()
.peek()
.now_or_never()
.and_then(|req| req)
.map($cond)
.unwrap_or(false)
};
}
while requests.as_mut().peek().await.is_some() {
while poll_next_if!(|(id, _, _)| id.is_none()) {
let (upload_id, request, sender) = requests.next().await.unwrap();
tracing::trace!(?request, "resender received non-upload request");
assert!(upload_id.is_none(), "non-upload should not have an id");
let _ = connection.unbounded_send((
sender,
Arc::new(RequestWithSidecar {
request,
sidecar: Vec::new(),
}),
));
}
let mut upload_reqs = VecDeque::new();
while poll_next_if!(|(id, _, _)| id.is_some()) {
let (upload_id, request, final_sender) = requests.next().await.unwrap();
let upload_id = upload_id.unwrap();
tracing::trace!(?upload_id, ?request, "resender received upload request");
let (sender, receiver) = mpsc::unbounded();
upload_reqs.push_back((
upload_id,
Arc::new(RequestWithSidecar {
request,
sidecar: Vec::new(),
}),
Some(final_sender),
sender,
receiver,
));
}
let mut upload_missing_binaries = None;
while !upload_reqs.is_empty() {
if let Some(upload_missing_binaries) = upload_missing_binaries.take() {
let (sender, _) = mpsc::unbounded();
let _ = connection.unbounded_send((sender, upload_missing_binaries));
}
for (_, request, _, sender, _) in upload_reqs.iter() {
let _ = connection.unbounded_send((sender.clone(), request.clone()));
}
let mut missing_binaries = HashSet::new();
for (upload_id, request, final_sender, _, receiver) in upload_reqs.iter_mut() {
match receiver.next().await {
None => return, Some(ResponsePartWithSidecar {
sidecar: Some(_), ..
}) => {
tracing::error!("got response to upload that had a sidecar");
continue;
}
Some(ResponsePartWithSidecar { response, .. }) => match response {
ResponsePart::Success => {
if let Err(err) = db.upload_finished(*upload_id).await {
tracing::error!(?err, "failed dequeuing upload");
} else {
match db.list_uploads().await.wrap_context("listing uploads") {
Err(err) => {
tracing::error!(?err, "failed listing upload queue");
}
Ok(upload_list) => {
upload_queue_watcher_sender
.lock()
.unwrap()
.send_replace(upload_list);
}
}
let _ = final_sender.take().unwrap().unbounded_send(
ResponsePartWithSidecar {
response,
sidecar: None,
},
);
}
}
ResponsePart::Error(crdb_core::SerializableError::MissingBinaries(
bins,
)) => {
missing_binaries.extend(bins);
}
ResponsePart::Error(crdb_core::SerializableError::ObjectDoesNotExist(
_,
)) if !missing_binaries.is_empty() => {
}
ResponsePart::Error(ref err) => {
let Request::Upload(upload) = &*request.request else {
panic!("is_upload == true but does not match Upload");
};
match error_handler((*upload).clone(), (*err).clone().into()).await {
OnError::Rollback => {
if let Err(err) = undo_upload::<C, _>(&db, upload).await {
tracing::error!(?err, ?upload, "failed undoing upload");
} else if let Err(err) = db.upload_finished(*upload_id).await {
tracing::error!(?err, "failed dequeuing upload");
} else {
match db
.list_uploads()
.await
.wrap_context("listing uploads")
{
Err(err) => {
tracing::error!(
?err,
"failed listing upload queue"
);
}
Ok(upload_list) => {
upload_queue_watcher_sender
.lock()
.unwrap()
.send_replace(upload_list);
}
}
let _ = final_sender.take().unwrap().unbounded_send(
ResponsePartWithSidecar {
response,
sidecar: None,
},
);
}
}
OnError::KeepLocal => {
let _ = final_sender.take().unwrap().unbounded_send(
ResponsePartWithSidecar {
response,
sidecar: None,
},
);
}
OnError::ReplaceWith(new_upload) => {
if let Err(err) = undo_upload::<C, _>(&db, upload).await {
tracing::error!(?err, ?upload, "failed undoing upload");
} else if let Err(err) =
do_upload::<C, _>(&db, &new_upload).await
{
tracing::error!(
?err,
?new_upload,
"failed doing replacement upload"
);
} else if let Err(err) = db.upload_finished(*upload_id).await {
tracing::error!(?err, "failed dequeuing upload");
} else {
match db
.list_uploads()
.await
.wrap_context("listing uploads")
{
Err(err) => {
tracing::error!(
?err,
"failed listing upload queue"
);
}
Ok(upload_list) => {
upload_queue_watcher_sender
.lock()
.unwrap()
.send_replace(upload_list);
}
}
let _ = final_sender.take().unwrap().unbounded_send(
ResponsePartWithSidecar {
response,
sidecar: None,
},
);
}
}
}
}
_ => {
tracing::error!(?response, "Unexpected response to upload submission");
continue;
}
},
}
}
upload_reqs.retain(|(_, _, final_sender, _, _)| final_sender.is_some());
if !missing_binaries.is_empty() {
let db = db.clone();
let binaries = stream::iter(missing_binaries.into_iter())
.map(move |b| {
let db = db.clone();
async move { db.get_binary(b).await }
})
.buffer_unordered(16) .filter_map(|res| async move { res.ok().and_then(|o| o) })
.collect::<Vec<Arc<[u8]>>>()
.await;
upload_missing_binaries = Some(Arc::new(RequestWithSidecar {
request: Arc::new(Request::UploadBinaries(binaries.len())),
sidecar: binaries,
}));
}
}
}
}
async fn undo_upload<C: crdb_core::Config, LocalDb: ClientSideDb>(
local_db: &CacheDb<LocalDb>,
upload: &Upload,
) -> crate::Result<()> {
match upload {
Upload::Object { object_id, .. } => local_db.remove(*object_id).await,
Upload::Event {
object_id,
type_id,
event_id,
..
} => match C::remove_event(local_db, *type_id, *object_id, *event_id).await {
Err(crate::Error::EventTooEarly { .. }) => {
Ok(())
}
res => res,
},
}
}
async fn do_upload<C: crdb_core::Config, LocalDb: ClientSideDb>(
local_db: &CacheDb<LocalDb>,
upload: &Upload,
) -> crate::Result<()> {
match upload {
Upload::Object {
object_id,
type_id,
created_at,
snapshot_version,
object,
..
} => {
C::create(
local_db,
*type_id,
*object_id,
*created_at,
*snapshot_version,
object,
)
.await
}
Upload::Event {
object_id,
type_id,
event_id,
event,
..
} => C::submit(
local_db,
*type_id,
*object_id,
*event_id,
event,
None,
Importance::NONE,
)
.await
.map(|_| ()),
}
}
fn expect_simple_response(
mut response_receiver: mpsc::UnboundedReceiver<ResponsePartWithSidecar>,
) -> oneshot::Receiver<crate::Result<()>> {
let (sender, receiver) = oneshot::channel();
waaaa::spawn(async move {
let Some(response) = response_receiver.next().await else {
let _ = sender.send(Err(crate::Error::Other(anyhow!(
"Connection thread went down too ealy"
))));
return;
};
let _ = match response.response {
ResponsePart::Success => sender.send(Ok(())),
ResponsePart::Error(err) => sender.send(Err(err.into())),
_ => sender.send(Err(crate::Error::Other(anyhow!(
"Unexpected server response to DisconnectSession: {:?}",
response.response
)))),
};
});
receiver
}