mls_rs_provider_sqlite/
application.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use std::{
6    fmt::{self, Debug},
7    sync::{Arc, Mutex},
8};
9
10use rusqlite::{params, Connection, OptionalExtension};
11
12use crate::SqLiteDataStorageError;
13
14const INSERT_SQL: &str =
15    "INSERT INTO kvs (key, value) VALUES (?,?) ON CONFLICT(key) DO UPDATE SET value=excluded.value WHERE value != excluded.value";
16
17#[derive(Debug, Clone)]
18/// SQLite key-value storage for application specific data.
19pub struct SqLiteApplicationStorage {
20    connection: Arc<Mutex<Connection>>,
21}
22
23impl SqLiteApplicationStorage {
24    pub(crate) fn new(connection: Connection) -> SqLiteApplicationStorage {
25        SqLiteApplicationStorage {
26            connection: Arc::new(Mutex::new(connection)),
27        }
28    }
29
30    /// Insert `value` into storage indexed by `key`.
31    ///
32    /// If a value already exists for `key` it will be overwritten.
33    /// Returns the number of rows modified (0 if the key-value pair already exists).
34    pub fn insert(&self, key: &str, value: &[u8]) -> Result<usize, SqLiteDataStorageError> {
35        let connection = self.connection.lock().unwrap();
36
37        // Use a query that only updates if the value is different
38        connection
39            .execute(INSERT_SQL, params![key, value])
40            .map_err(sql_engine_error)
41    }
42
43    /// Execute multiple [`SqLiteApplicationStorage::insert`] operations in a transaction.
44    /// Returns the total number of rows modified.
45    pub fn transact_insert(&self, items: &[Item]) -> Result<usize, SqLiteDataStorageError> {
46        let mut connection = self.connection.lock().unwrap();
47
48        // Upsert into the database
49        let tx = connection.transaction().map_err(sql_engine_error)?;
50
51        let total_modified = items.iter().try_fold(0, |acc, item| {
52            tx.execute(INSERT_SQL, params![item.key, item.value])
53                .map_err(sql_engine_error)
54                .map(|rows| acc + rows)
55        })?;
56
57        tx.commit().map_err(sql_engine_error)?;
58
59        Ok(total_modified)
60    }
61
62    /// Get a value from storage based on its `key`.
63    pub fn get(&self, key: &str) -> Result<Option<Vec<u8>>, SqLiteDataStorageError> {
64        let connection = self.connection.lock().unwrap();
65
66        connection
67            .query_row("SELECT value FROM kvs WHERE key = ?", params![key], |row| {
68                row.get(0)
69            })
70            .optional()
71            .map_err(sql_engine_error)
72    }
73
74    /// Delete a value from storage based on its `key`.
75    /// Returns the number of rows modified (0 if the key-value pair didnt exist).
76    pub fn delete(&self, key: &str) -> Result<usize, SqLiteDataStorageError> {
77        let connection = self.connection.lock().unwrap();
78
79        connection
80            .execute("DELETE FROM kvs WHERE key = ?", params![key])
81            .map_err(sql_engine_error)
82    }
83
84    /// Get all keys and values from storage for which key starts with `key_prefix`.
85    pub fn get_by_prefix(&self, key_prefix: &str) -> Result<Vec<Item>, SqLiteDataStorageError> {
86        let connection = self.connection.lock().unwrap();
87        let mut key_prefix = sanitize(key_prefix);
88        key_prefix.push('%');
89
90        let mut stmt = connection
91            .prepare("SELECT key, value FROM kvs WHERE key LIKE ? ESCAPE '$'")
92            .map_err(sql_engine_error)?;
93
94        let rows = stmt
95            .query(params![key_prefix])
96            .map_err(sql_engine_error)?
97            .mapped(|row| Ok(Item::new(row.get(0)?, row.get(1)?)));
98
99        rows.collect::<Result<_, _>>().map_err(sql_engine_error)
100    }
101
102    /// Delete all values from storage for which key starts with `key_prefix`.
103    /// Returns the total number of rows modified.
104    pub fn delete_by_prefix(&self, key_prefix: &str) -> Result<usize, SqLiteDataStorageError> {
105        let connection = self.connection.lock().unwrap();
106        let mut key_prefix = sanitize(key_prefix);
107        key_prefix.push('%');
108
109        connection
110            .execute(
111                "DELETE FROM kvs WHERE key LIKE ? ESCAPE '$'",
112                params![key_prefix],
113            )
114            .map_err(sql_engine_error)
115    }
116}
117
118fn sanitize(string: &str) -> String {
119    string.replace('_', "$_").replace('%', "$%")
120}
121
122fn sql_engine_error(e: rusqlite::Error) -> SqLiteDataStorageError {
123    SqLiteDataStorageError::SqlEngineError(e.into())
124}
125
126#[derive(Clone, Default, Hash, PartialEq, Eq, PartialOrd, Ord)]
127pub struct Item {
128    pub key: String,
129    pub value: Vec<u8>,
130}
131
132impl Debug for Item {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        f.debug_struct("Item")
135            .field("key", &self.key)
136            .field("value", &mls_rs_core::debug::pretty_bytes(&self.value))
137            .finish()
138    }
139}
140
141impl Item {
142    pub fn new(key: String, value: Vec<u8>) -> Self {
143        Self { key, value }
144    }
145
146    pub fn key(&self) -> &str {
147        &self.key
148    }
149
150    pub fn value(&self) -> &[u8] {
151        &self.value
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use crate::{
158        application::Item, connection_strategy::MemoryStrategy, test_utils::gen_rand_bytes,
159        SqLiteDataStorageEngine,
160    };
161
162    use super::SqLiteApplicationStorage;
163
164    fn test_kv() -> (String, Vec<u8>) {
165        let key = hex::encode(gen_rand_bytes(32));
166        let value = gen_rand_bytes(64);
167
168        (key, value)
169    }
170
171    fn test_storage() -> SqLiteApplicationStorage {
172        SqLiteDataStorageEngine::new(MemoryStrategy)
173            .unwrap()
174            .application_data_storage()
175            .unwrap()
176    }
177
178    #[test]
179    fn test_insert() {
180        let (key, value) = test_kv();
181        let storage = test_storage();
182
183        let modified_rows = storage.insert(&key, &value).unwrap();
184
185        assert_eq!(modified_rows, 1);
186
187        let from_storage = storage.get(&key).unwrap().unwrap();
188        assert_eq!(from_storage, value);
189    }
190
191    #[test]
192    fn test_insert_existing_overwrite() {
193        let (key, value) = test_kv();
194        let (_, new_value) = test_kv();
195
196        let storage = test_storage();
197
198        storage.insert(&key, &value).unwrap();
199        storage.insert(&key, &new_value).unwrap();
200
201        let from_storage = storage.get(&key).unwrap().unwrap();
202        assert_eq!(from_storage, new_value);
203    }
204
205    #[test]
206    fn test_duplicate_insert() {
207        let (key, value) = test_kv();
208
209        let storage = test_storage();
210
211        let modified_rows_first = storage.insert(&key, &value).unwrap();
212        let modified_rows_second = storage.insert(&key, &value).unwrap();
213
214        assert_eq!(modified_rows_first, 1);
215        assert_eq!(modified_rows_second, 0);
216
217        let from_storage = storage.get(&key).unwrap().unwrap();
218        assert_eq!(from_storage, value);
219    }
220
221    #[test]
222    fn test_delete() {
223        let (key, value) = test_kv();
224        let storage = test_storage();
225
226        storage.insert(&key, &value).unwrap();
227        let rows_deleted_some = storage.delete(&key).unwrap();
228        let rows_deleted_none = storage.delete(&key).unwrap();
229
230        assert_eq!(rows_deleted_some, 1);
231        assert_eq!(rows_deleted_none, 0);
232
233        assert!(storage.get(&key).unwrap().is_none());
234    }
235
236    #[test]
237    fn test_by_prefix() {
238        let keys = ["prefix one", "prefix two", "prefiy ", "prefiw "].map(ToString::to_string);
239        let value = gen_rand_bytes(5);
240
241        let storage = test_storage();
242
243        keys.iter().for_each(|k| {
244            storage.insert(k, &value).unwrap();
245        });
246
247        let mut expected = vec![
248            Item::new(keys[0].clone(), value.clone()),
249            Item::new(keys[1].clone(), value.clone()),
250        ];
251
252        expected.sort();
253
254        let mut result = storage.get_by_prefix("prefix").unwrap();
255        result.sort();
256
257        assert_eq!(result, expected);
258
259        let result = storage.get_by_prefix("a").unwrap();
260        assert!(result.is_empty());
261
262        let result = storage.get_by_prefix("").unwrap();
263        assert_eq!(result.len(), keys.len());
264
265        let deleted_items = storage.delete_by_prefix("prefix").unwrap();
266        assert_eq!(deleted_items, 2);
267
268        let result = storage.get_by_prefix("").unwrap();
269        assert_eq!(result.len(), 2);
270        assert!(result.contains(&Item::new("prefiy ".to_string(), value.clone())));
271        assert!(result.contains(&Item::new("prefiw ".to_string(), value)));
272    }
273
274    #[test]
275    fn test_special_characters() {
276        let storage = test_storage();
277
278        storage.insert("%$_ƕ❤_$%", &gen_rand_bytes(5)).unwrap();
279        storage.insert("%$_ƕ❤a$%", &gen_rand_bytes(5)).unwrap();
280        storage.insert("%$_ƕ❤Ḉ$%", &gen_rand_bytes(5)).unwrap();
281
282        let items = storage.get_by_prefix("%$_ƕ❤_").unwrap();
283        let keys = items.into_iter().map(|i| i.key).collect::<Vec<_>>();
284        assert_eq!(vec!["%$_ƕ❤_$%".to_string()], keys);
285    }
286
287    #[test]
288    fn batch_insert() {
289        let storage = test_storage();
290        let items = vec![test_item(), test_item(), test_item()];
291
292        let modified_rows = storage.transact_insert(&items).unwrap();
293        assert_eq!(modified_rows, 3); // All 3 items should be inserted
294
295        // Test duplicate inserts
296        let modified_rows_duplicate = storage.transact_insert(&items).unwrap();
297        assert_eq!(modified_rows_duplicate, 0); // No rows should be modified for duplicates
298
299        // Verify all items were stored correctly
300        for item in items {
301            assert_eq!(storage.get(&item.key).unwrap(), Some(item.value));
302        }
303    }
304
305    fn test_item() -> Item {
306        Item::new(hex::encode(gen_rand_bytes(5)), gen_rand_bytes(5))
307    }
308}