maelstrom/
kv.rs

1use crate::{Error, Result, Runtime};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::fmt::{Display, Formatter};
5use tokio_context::context::Context;
6
7#[async_trait]
8pub trait KV: Clone + Display + Send + Sync {
9    /// Get returns the value for a given key in the key/value store.
10    /// Returns an RPCError error with a KeyDoesNotExist code if the key does not exist.
11    async fn get<T>(&self, ctx: Context, key: String) -> Result<T>
12    where
13        T: Deserialize<'static> + Send;
14
15    /// Put overwrites the value for a given key in the key/value store.
16    async fn put<T>(&self, ctx: Context, key: String, val: T) -> Result<()>
17    where
18        T: Serialize + Send;
19
20    /// CAS updates the value for a key if its current value matches the
21    /// previous value. Creates the key if it is not exist is requested.
22    ///
23    /// Returns an RPCError with a code of PreconditionFailed if the previous value
24    /// does not match. Return a code of KeyDoesNotExist if the key did not exist.
25    async fn cas<T>(&self, ctx: Context, key: String, from: T, to: T, put: bool) -> Result<()>
26    where
27        T: Serialize + Deserialize<'static> + Send;
28}
29
30#[derive(Clone)]
31pub struct Storage {
32    typ: &'static str,
33    runtime: Runtime,
34}
35
36/// Creates a linearizable storage.
37#[must_use]
38pub fn lin_kv(runtime: Runtime) -> Storage {
39    Storage {
40        typ: "lin-kv",
41        runtime,
42    }
43}
44
45/// Creates a sequentially consistent storage.
46#[must_use]
47pub fn seq_kv(runtime: Runtime) -> Storage {
48    Storage {
49        typ: "seq-kv",
50        runtime,
51    }
52}
53
54/// Creates last-write-wins storage type.
55#[must_use]
56pub fn lww_kv(runtime: Runtime) -> Storage {
57    Storage {
58        typ: "lww-kv",
59        runtime,
60    }
61}
62
63/// Creates total-store-order kind of storage.
64#[must_use]
65pub fn tso_kv(runtime: Runtime) -> Storage {
66    Storage {
67        typ: "lin-tso",
68        runtime,
69    }
70}
71
72impl Display for Storage {
73    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
74        write!(f, "Storage({})", self.typ)
75    }
76}
77
78#[async_trait]
79impl KV for Storage {
80    async fn get<T>(&self, ctx: Context, key: String) -> Result<T>
81    where
82        T: Deserialize<'static> + Send,
83    {
84        let req = Message::Read::<String> { key };
85        let msg = self.runtime.call(ctx, self.typ, req).await?;
86        let data = msg.body.as_obj::<Message<T>>()?;
87        match data {
88            Message::ReadOk { value } => Ok(value),
89            _ => Err(Box::new(Error::Custom(
90                -1,
91                "kv: protocol violated".to_string(),
92            ))),
93        }
94    }
95
96    async fn put<T>(&self, ctx: Context, key: String, value: T) -> Result<()>
97    where
98        T: Serialize + Send,
99    {
100        let req = Message::Write::<T> { key, value };
101        let _msg = self.runtime.call(ctx, self.typ, req).await?;
102        Ok(())
103    }
104
105    async fn cas<T>(&self, ctx: Context, key: String, from: T, to: T, put: bool) -> Result<()>
106    where
107        T: Serialize + Deserialize<'static> + Send,
108    {
109        let req = Message::Cas::<T> { key, from, to, put };
110        let _msg = self.runtime.call(ctx, self.typ, req).await?;
111        Ok(())
112    }
113}
114
115#[derive(Serialize, Deserialize)]
116#[serde(rename_all = "snake_case", tag = "type")]
117enum Message<T> {
118    /// KVReadMessageBody represents the body for the KV "read" message.
119    Read {
120        key: String,
121    },
122    /// KVReadOKMessageBody represents the response body for the KV "read_ok" message.
123    ReadOk {
124        value: T,
125    },
126    /// KVWriteMessageBody represents the body for the KV "cas" message.
127    Write {
128        key: String,
129        value: T,
130    },
131    /// KVCASMessageBody represents the body for the KV "cas" message.
132    Cas {
133        key: String,
134        from: T,
135        to: T,
136        #[serde(
137            default,
138            rename = "create_if_not_exists",
139            skip_serializing_if = "is_ref_false"
140        )]
141        put: bool,
142    },
143    CasOk {},
144}
145
146#[allow(clippy::trivially_copy_pass_by_ref)]
147fn is_ref_false(b: &bool) -> bool {
148    !*b
149}