use super::actor::{ActorImpl, ActorMessage, TxnMessage};
use super::WrappedStorage;
use crate::errors::Result;
use crate::operation::Operation;
use crate::storage::{Storage, StorageTxn, TaskMap, VersionId};
use async_trait::async_trait;
use std::future::Future;
use tokio::sync::{mpsc, oneshot};
use uuid::Uuid;
pub(in crate::storage) struct Wrapper {
sender: Option<mpsc::UnboundedSender<ActorMessage>>,
#[cfg(not(target_arch = "wasm32"))]
thread: Option<std::thread::JoinHandle<()>>,
}
impl Wrapper {
pub(in crate::storage) async fn new<S, FN, FUT>(constructor: FN) -> Result<Self>
where
S: WrappedStorage,
FUT: Future<Output = Result<S>>,
FN: FnOnce() -> FUT + Send + 'static,
{
let (sender, receiver) = mpsc::unbounded_channel();
let (init_sender, init_receiver): (oneshot::Sender<Result<_>>, _) = oneshot::channel();
let in_thread = async move |init_sender: oneshot::Sender<Result<_>>| {
match constructor().await {
Ok(storage) => {
let _ = init_sender.send(Ok(()));
let mut actor = ActorImpl::new(storage, receiver);
actor.run().await;
}
Err(e) => {
let _ = init_sender.send(Err(e));
}
}
};
#[cfg(target_arch = "wasm32")]
{
wasm_bindgen_futures::spawn_local(in_thread(init_sender));
}
#[cfg(not(target_arch = "wasm32"))]
let thread = {
use std::thread;
use tokio::runtime;
thread::spawn(move || {
let rt = match runtime::Builder::new_current_thread().build() {
Ok(rt) => rt,
Err(e) => {
let _ = init_sender.send(Err(e.into()));
return;
}
};
rt.block_on(in_thread(init_sender));
})
};
init_receiver.await??;
Ok(Self {
sender: Some(sender),
#[cfg(not(target_arch = "wasm32"))]
thread: Some(thread),
})
}
}
#[async_trait]
impl Storage for Wrapper {
async fn txn<'a>(&'a mut self) -> Result<Box<dyn StorageTxn + Send + 'a>> {
let (reply_tx, reply_rx) = oneshot::channel();
self.sender
.as_mut()
.expect("txn called after drop")
.send(ActorMessage::BeginTxn(reply_tx))?;
let txn_sender = reply_rx.await??;
Ok(Box::new(WrapperTxn::new(txn_sender)))
}
}
#[cfg(not(target_arch = "wasm32"))]
impl Drop for Wrapper {
fn drop(&mut self) {
self.sender = None;
let _ = self.thread.take().expect("thread joined twice").join();
}
}
struct WrapperTxn {
sender: mpsc::UnboundedSender<TxnMessage>,
committed: bool,
}
impl WrapperTxn {
fn new(sender: mpsc::UnboundedSender<TxnMessage>) -> Self {
Self {
sender,
committed: false,
}
}
async fn call<R, F>(&self, f: F) -> Result<R>
where
F: FnOnce(oneshot::Sender<Result<R>>) -> TxnMessage,
R: Send + 'static,
{
let (tx, rx) = oneshot::channel();
self.sender.send(f(tx))?;
rx.await?
}
}
impl Drop for WrapperTxn {
fn drop(&mut self) {
if !self.committed {
let _ = self.sender.send(TxnMessage::Rollback);
}
}
}
#[async_trait]
impl StorageTxn for WrapperTxn {
async fn commit(&mut self) -> Result<()> {
let res = self.call(TxnMessage::Commit).await;
if res.is_ok() {
self.committed = true;
}
res
}
async fn get_task(&mut self, uuid: Uuid) -> Result<Option<TaskMap>> {
self.call(|tx| TxnMessage::GetTask(uuid, tx)).await
}
async fn create_task(&mut self, uuid: Uuid) -> Result<bool> {
self.call(|tx| TxnMessage::CreateTask(uuid, tx)).await
}
async fn set_task(&mut self, uuid: Uuid, task: TaskMap) -> Result<()> {
self.call(|tx| TxnMessage::SetTask(uuid, task, tx)).await
}
async fn delete_task(&mut self, uuid: Uuid) -> Result<bool> {
self.call(|tx| TxnMessage::DeleteTask(uuid, tx)).await
}
async fn get_pending_tasks(&mut self) -> Result<Vec<(Uuid, TaskMap)>> {
self.call(TxnMessage::GetPendingTasks).await
}
async fn all_tasks(&mut self) -> Result<Vec<(Uuid, TaskMap)>> {
self.call(TxnMessage::AllTasks).await
}
async fn all_task_uuids(&mut self) -> Result<Vec<Uuid>> {
self.call(TxnMessage::AllTaskUuids).await
}
async fn base_version(&mut self) -> Result<VersionId> {
self.call(TxnMessage::BaseVersion).await
}
async fn set_base_version(&mut self, version: VersionId) -> Result<()> {
self.call(|tx| TxnMessage::SetBaseVersion(version, tx))
.await
}
async fn get_task_operations(&mut self, uuid: Uuid) -> Result<Vec<Operation>> {
self.call(|tx| TxnMessage::GetTaskOperations(uuid, tx))
.await
}
async fn unsynced_operations(&mut self) -> Result<Vec<Operation>> {
self.call(TxnMessage::UnsyncedOperations).await
}
async fn num_unsynced_operations(&mut self) -> Result<usize> {
self.call(TxnMessage::NumUnsyncedOperations).await
}
async fn add_operation(&mut self, op: Operation) -> Result<()> {
self.call(|tx| TxnMessage::AddOperation(op, tx)).await
}
async fn remove_operation(&mut self, op: Operation) -> Result<()> {
self.call(|tx| TxnMessage::RemoveOperation(op, tx)).await
}
async fn sync_complete(&mut self) -> Result<()> {
self.call(TxnMessage::SyncComplete).await
}
async fn get_working_set(&mut self) -> Result<Vec<Option<Uuid>>> {
self.call(TxnMessage::GetWorkingSet).await
}
async fn add_to_working_set(&mut self, uuid: Uuid) -> Result<usize> {
self.call(|tx| TxnMessage::AddToWorkingSet(uuid, tx)).await
}
async fn set_working_set_item(&mut self, index: usize, uuid: Option<Uuid>) -> Result<()> {
self.call(|tx| TxnMessage::SetWorkingSetItem(index, uuid, tx))
.await
}
async fn clear_working_set(&mut self) -> Result<()> {
self.call(TxnMessage::ClearWorkingSet).await
}
async fn is_empty(&mut self) -> Result<bool> {
self.call(TxnMessage::IsEmpty).await
}
}