1use std::{
6 fmt::{self, Debug},
7 sync::{Arc, Mutex},
8};
9
10use rusqlite::{params, Connection, OptionalExtension};
11
12use crate::SqLiteDataStorageError;
13
14const INSERT_SQL: &str =
15 "INSERT INTO kvs (key, value) VALUES (?,?) ON CONFLICT(key) DO UPDATE SET value=excluded.value WHERE value != excluded.value";
16
17#[derive(Debug, Clone)]
18pub struct SqLiteApplicationStorage {
20 connection: Arc<Mutex<Connection>>,
21}
22
23impl SqLiteApplicationStorage {
24 pub(crate) fn new(connection: Connection) -> SqLiteApplicationStorage {
25 SqLiteApplicationStorage {
26 connection: Arc::new(Mutex::new(connection)),
27 }
28 }
29
30 pub fn insert(&self, key: &str, value: &[u8]) -> Result<usize, SqLiteDataStorageError> {
35 let connection = self.connection.lock().unwrap();
36
37 connection
39 .execute(INSERT_SQL, params![key, value])
40 .map_err(sql_engine_error)
41 }
42
43 pub fn transact_insert(&self, items: &[Item]) -> Result<usize, SqLiteDataStorageError> {
46 let mut connection = self.connection.lock().unwrap();
47
48 let tx = connection.transaction().map_err(sql_engine_error)?;
50
51 let total_modified = items.iter().try_fold(0, |acc, item| {
52 tx.execute(INSERT_SQL, params![item.key, item.value])
53 .map_err(sql_engine_error)
54 .map(|rows| acc + rows)
55 })?;
56
57 tx.commit().map_err(sql_engine_error)?;
58
59 Ok(total_modified)
60 }
61
62 pub fn get(&self, key: &str) -> Result<Option<Vec<u8>>, SqLiteDataStorageError> {
64 let connection = self.connection.lock().unwrap();
65
66 connection
67 .query_row("SELECT value FROM kvs WHERE key = ?", params![key], |row| {
68 row.get(0)
69 })
70 .optional()
71 .map_err(sql_engine_error)
72 }
73
74 pub fn delete(&self, key: &str) -> Result<usize, SqLiteDataStorageError> {
77 let connection = self.connection.lock().unwrap();
78
79 connection
80 .execute("DELETE FROM kvs WHERE key = ?", params![key])
81 .map_err(sql_engine_error)
82 }
83
84 pub fn get_by_prefix(&self, key_prefix: &str) -> Result<Vec<Item>, SqLiteDataStorageError> {
86 let connection = self.connection.lock().unwrap();
87 let mut key_prefix = sanitize(key_prefix);
88 key_prefix.push('%');
89
90 let mut stmt = connection
91 .prepare("SELECT key, value FROM kvs WHERE key LIKE ? ESCAPE '$'")
92 .map_err(sql_engine_error)?;
93
94 let rows = stmt
95 .query(params![key_prefix])
96 .map_err(sql_engine_error)?
97 .mapped(|row| Ok(Item::new(row.get(0)?, row.get(1)?)));
98
99 rows.collect::<Result<_, _>>().map_err(sql_engine_error)
100 }
101
102 pub fn delete_by_prefix(&self, key_prefix: &str) -> Result<usize, SqLiteDataStorageError> {
105 let connection = self.connection.lock().unwrap();
106 let mut key_prefix = sanitize(key_prefix);
107 key_prefix.push('%');
108
109 connection
110 .execute(
111 "DELETE FROM kvs WHERE key LIKE ? ESCAPE '$'",
112 params![key_prefix],
113 )
114 .map_err(sql_engine_error)
115 }
116}
117
118fn sanitize(string: &str) -> String {
119 string.replace('_', "$_").replace('%', "$%")
120}
121
122fn sql_engine_error(e: rusqlite::Error) -> SqLiteDataStorageError {
123 SqLiteDataStorageError::SqlEngineError(e.into())
124}
125
126#[derive(Clone, Default, Hash, PartialEq, Eq, PartialOrd, Ord)]
127pub struct Item {
128 pub key: String,
129 pub value: Vec<u8>,
130}
131
132impl Debug for Item {
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 f.debug_struct("Item")
135 .field("key", &self.key)
136 .field("value", &mls_rs_core::debug::pretty_bytes(&self.value))
137 .finish()
138 }
139}
140
141impl Item {
142 pub fn new(key: String, value: Vec<u8>) -> Self {
143 Self { key, value }
144 }
145
146 pub fn key(&self) -> &str {
147 &self.key
148 }
149
150 pub fn value(&self) -> &[u8] {
151 &self.value
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use crate::{
158 application::Item, connection_strategy::MemoryStrategy, test_utils::gen_rand_bytes,
159 SqLiteDataStorageEngine,
160 };
161
162 use super::SqLiteApplicationStorage;
163
164 fn test_kv() -> (String, Vec<u8>) {
165 let key = hex::encode(gen_rand_bytes(32));
166 let value = gen_rand_bytes(64);
167
168 (key, value)
169 }
170
171 fn test_storage() -> SqLiteApplicationStorage {
172 SqLiteDataStorageEngine::new(MemoryStrategy)
173 .unwrap()
174 .application_data_storage()
175 .unwrap()
176 }
177
178 #[test]
179 fn test_insert() {
180 let (key, value) = test_kv();
181 let storage = test_storage();
182
183 let modified_rows = storage.insert(&key, &value).unwrap();
184
185 assert_eq!(modified_rows, 1);
186
187 let from_storage = storage.get(&key).unwrap().unwrap();
188 assert_eq!(from_storage, value);
189 }
190
191 #[test]
192 fn test_insert_existing_overwrite() {
193 let (key, value) = test_kv();
194 let (_, new_value) = test_kv();
195
196 let storage = test_storage();
197
198 storage.insert(&key, &value).unwrap();
199 storage.insert(&key, &new_value).unwrap();
200
201 let from_storage = storage.get(&key).unwrap().unwrap();
202 assert_eq!(from_storage, new_value);
203 }
204
205 #[test]
206 fn test_duplicate_insert() {
207 let (key, value) = test_kv();
208
209 let storage = test_storage();
210
211 let modified_rows_first = storage.insert(&key, &value).unwrap();
212 let modified_rows_second = storage.insert(&key, &value).unwrap();
213
214 assert_eq!(modified_rows_first, 1);
215 assert_eq!(modified_rows_second, 0);
216
217 let from_storage = storage.get(&key).unwrap().unwrap();
218 assert_eq!(from_storage, value);
219 }
220
221 #[test]
222 fn test_delete() {
223 let (key, value) = test_kv();
224 let storage = test_storage();
225
226 storage.insert(&key, &value).unwrap();
227 let rows_deleted_some = storage.delete(&key).unwrap();
228 let rows_deleted_none = storage.delete(&key).unwrap();
229
230 assert_eq!(rows_deleted_some, 1);
231 assert_eq!(rows_deleted_none, 0);
232
233 assert!(storage.get(&key).unwrap().is_none());
234 }
235
236 #[test]
237 fn test_by_prefix() {
238 let keys = ["prefix one", "prefix two", "prefiy ", "prefiw "].map(ToString::to_string);
239 let value = gen_rand_bytes(5);
240
241 let storage = test_storage();
242
243 keys.iter().for_each(|k| {
244 storage.insert(k, &value).unwrap();
245 });
246
247 let mut expected = vec![
248 Item::new(keys[0].clone(), value.clone()),
249 Item::new(keys[1].clone(), value.clone()),
250 ];
251
252 expected.sort();
253
254 let mut result = storage.get_by_prefix("prefix").unwrap();
255 result.sort();
256
257 assert_eq!(result, expected);
258
259 let result = storage.get_by_prefix("a").unwrap();
260 assert!(result.is_empty());
261
262 let result = storage.get_by_prefix("").unwrap();
263 assert_eq!(result.len(), keys.len());
264
265 let deleted_items = storage.delete_by_prefix("prefix").unwrap();
266 assert_eq!(deleted_items, 2);
267
268 let result = storage.get_by_prefix("").unwrap();
269 assert_eq!(result.len(), 2);
270 assert!(result.contains(&Item::new("prefiy ".to_string(), value.clone())));
271 assert!(result.contains(&Item::new("prefiw ".to_string(), value)));
272 }
273
274 #[test]
275 fn test_special_characters() {
276 let storage = test_storage();
277
278 storage.insert("%$_ƕ❤_$%", &gen_rand_bytes(5)).unwrap();
279 storage.insert("%$_ƕ❤a$%", &gen_rand_bytes(5)).unwrap();
280 storage.insert("%$_ƕ❤Ḉ$%", &gen_rand_bytes(5)).unwrap();
281
282 let items = storage.get_by_prefix("%$_ƕ❤_").unwrap();
283 let keys = items.into_iter().map(|i| i.key).collect::<Vec<_>>();
284 assert_eq!(vec!["%$_ƕ❤_$%".to_string()], keys);
285 }
286
287 #[test]
288 fn batch_insert() {
289 let storage = test_storage();
290 let items = vec![test_item(), test_item(), test_item()];
291
292 let modified_rows = storage.transact_insert(&items).unwrap();
293 assert_eq!(modified_rows, 3); let modified_rows_duplicate = storage.transact_insert(&items).unwrap();
297 assert_eq!(modified_rows_duplicate, 0); for item in items {
301 assert_eq!(storage.get(&item.key).unwrap(), Some(item.value));
302 }
303 }
304
305 fn test_item() -> Item {
306 Item::new(hex::encode(gen_rand_bytes(5)), gen_rand_bytes(5))
307 }
308}