use crate::{Error, Result, Runtime};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter};
use tokio_context::context::Context;
#[async_trait]
pub trait KV: Clone + Display + Send + Sync {
async fn get<T>(&self, ctx: Context, key: String) -> Result<T>
where
T: Deserialize<'static> + Send;
async fn put<T>(&self, ctx: Context, key: String, val: T) -> Result<()>
where
T: Serialize + Send;
async fn cas<T>(&self, ctx: Context, key: String, from: T, to: T, put: bool) -> Result<()>
where
T: Serialize + Deserialize<'static> + Send;
}
#[derive(Clone)]
pub struct Storage {
typ: &'static str,
runtime: Runtime,
}
#[must_use]
pub fn lin_kv(runtime: Runtime) -> Storage {
Storage {
typ: "lin-kv",
runtime,
}
}
#[must_use]
pub fn seq_kv(runtime: Runtime) -> Storage {
Storage {
typ: "seq-kv",
runtime,
}
}
#[must_use]
pub fn lww_kv(runtime: Runtime) -> Storage {
Storage {
typ: "lww-kv",
runtime,
}
}
#[must_use]
pub fn tso_kv(runtime: Runtime) -> Storage {
Storage {
typ: "lin-tso",
runtime,
}
}
impl Display for Storage {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Storage({})", self.typ)
}
}
#[async_trait]
impl KV for Storage {
async fn get<T>(&self, ctx: Context, key: String) -> Result<T>
where
T: Deserialize<'static> + Send,
{
let req = Message::Read::<String> { key };
let msg = self.runtime.call(ctx, self.typ, req).await?;
let data = msg.body.as_obj::<Message<T>>()?;
match data {
Message::ReadOk { value } => Ok(value),
_ => Err(Box::new(Error::Custom(
-1,
"kv: protocol violated".to_string(),
))),
}
}
async fn put<T>(&self, ctx: Context, key: String, value: T) -> Result<()>
where
T: Serialize + Send,
{
let req = Message::Write::<T> { key, value };
let _msg = self.runtime.call(ctx, self.typ, req).await?;
Ok(())
}
async fn cas<T>(&self, ctx: Context, key: String, from: T, to: T, put: bool) -> Result<()>
where
T: Serialize + Deserialize<'static> + Send,
{
let req = Message::Cas::<T> { key, from, to, put };
let _msg = self.runtime.call(ctx, self.typ, req).await?;
Ok(())
}
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
enum Message<T> {
Read {
key: String,
},
ReadOk {
value: T,
},
Write {
key: String,
value: T,
},
Cas {
key: String,
from: T,
to: T,
#[serde(
default,
rename = "create_if_not_exists",
skip_serializing_if = "is_ref_false"
)]
put: bool,
},
CasOk {},
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_ref_false(b: &bool) -> bool {
!*b
}