1use crate::{Error, Result, Runtime};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::fmt::{Display, Formatter};
5use tokio_context::context::Context;
6
7#[async_trait]
8pub trait KV: Clone + Display + Send + Sync {
9 async fn get<T>(&self, ctx: Context, key: String) -> Result<T>
12 where
13 T: Deserialize<'static> + Send;
14
15 async fn put<T>(&self, ctx: Context, key: String, val: T) -> Result<()>
17 where
18 T: Serialize + Send;
19
20 async fn cas<T>(&self, ctx: Context, key: String, from: T, to: T, put: bool) -> Result<()>
26 where
27 T: Serialize + Deserialize<'static> + Send;
28}
29
30#[derive(Clone)]
31pub struct Storage {
32 typ: &'static str,
33 runtime: Runtime,
34}
35
36#[must_use]
38pub fn lin_kv(runtime: Runtime) -> Storage {
39 Storage {
40 typ: "lin-kv",
41 runtime,
42 }
43}
44
45#[must_use]
47pub fn seq_kv(runtime: Runtime) -> Storage {
48 Storage {
49 typ: "seq-kv",
50 runtime,
51 }
52}
53
54#[must_use]
56pub fn lww_kv(runtime: Runtime) -> Storage {
57 Storage {
58 typ: "lww-kv",
59 runtime,
60 }
61}
62
63#[must_use]
65pub fn tso_kv(runtime: Runtime) -> Storage {
66 Storage {
67 typ: "lin-tso",
68 runtime,
69 }
70}
71
72impl Display for Storage {
73 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
74 write!(f, "Storage({})", self.typ)
75 }
76}
77
78#[async_trait]
79impl KV for Storage {
80 async fn get<T>(&self, ctx: Context, key: String) -> Result<T>
81 where
82 T: Deserialize<'static> + Send,
83 {
84 let req = Message::Read::<String> { key };
85 let msg = self.runtime.call(ctx, self.typ, req).await?;
86 let data = msg.body.as_obj::<Message<T>>()?;
87 match data {
88 Message::ReadOk { value } => Ok(value),
89 _ => Err(Box::new(Error::Custom(
90 -1,
91 "kv: protocol violated".to_string(),
92 ))),
93 }
94 }
95
96 async fn put<T>(&self, ctx: Context, key: String, value: T) -> Result<()>
97 where
98 T: Serialize + Send,
99 {
100 let req = Message::Write::<T> { key, value };
101 let _msg = self.runtime.call(ctx, self.typ, req).await?;
102 Ok(())
103 }
104
105 async fn cas<T>(&self, ctx: Context, key: String, from: T, to: T, put: bool) -> Result<()>
106 where
107 T: Serialize + Deserialize<'static> + Send,
108 {
109 let req = Message::Cas::<T> { key, from, to, put };
110 let _msg = self.runtime.call(ctx, self.typ, req).await?;
111 Ok(())
112 }
113}
114
115#[derive(Serialize, Deserialize)]
116#[serde(rename_all = "snake_case", tag = "type")]
117enum Message<T> {
118 Read {
120 key: String,
121 },
122 ReadOk {
124 value: T,
125 },
126 Write {
128 key: String,
129 value: T,
130 },
131 Cas {
133 key: String,
134 from: T,
135 to: T,
136 #[serde(
137 default,
138 rename = "create_if_not_exists",
139 skip_serializing_if = "is_ref_false"
140 )]
141 put: bool,
142 },
143 CasOk {},
144}
145
146#[allow(clippy::trivially_copy_pass_by_ref)]
147fn is_ref_false(b: &bool) -> bool {
148 !*b
149}