chromoe_db/driver/
sqlite_driver.rs

1use rusqlite::{params, Connection, Error as RusqliteError, OptionalExtension, Result};
2use serde::de::DeserializeOwned;
3use serde::Serialize;
4use serde_json::{from_str, json, to_string, Error as SerdeJsonError, Value};
5
6use crate::structure::SQLiteDriverOptions;
7
8/// SQLite database driver for storing and managing JSON data.
9///
10/// The `SQLiteDriver` provides methods for interacting with an SQLite database,
11/// including adding, retrieving, updating, and deleting JSON data in a specified table.
12///
13/// It abstracts the database operations and allows the user to interact with
14/// the database as if it were a key-value store, with the data being stored
15/// as serialised JSON.
16///
17/// # Fields
18///
19/// - `name`: The name of the SQLite database file.
20/// - `options`: Configuration options for the SQLite driver, including the
21///   database file name and table name.
22/// - `table`: The name of the table in the SQLite database to operate on.
23/// - `database`: The connection to the SQLite database.
24#[derive(Debug)]
25pub struct SQLiteDriver {
26    /// The name of the SQLite database file.
27    pub name: String,
28    /// Configuration options for the SQLite driver, including the
29    /// database file and table name
30    pub options: SQLiteDriverOptions,
31    /// The name of the table in the SQLite database to operate on.
32    pub table: String,
33    /// The connection to the SQLite database.
34    pub database: Connection,
35}
36
37impl SQLiteDriver {
38    /// Creates a new instance of the `SQLiteDriver` with the provided options.
39    /// If no options are provided, it defaults to using `json.sqlite` as the
40    /// database file and `json` as the table name.
41    ///
42    /// # Parameters
43    /// - `options`: Optional configuration options for the SQLite database.
44    ///
45    /// # Returns
46    /// A `Result` containing either the `SQLiteDriver` instance or an error.
47    pub fn new(options: Option<SQLiteDriverOptions>) -> Result<Self> {
48        let options = options.unwrap_or_else(|| SQLiteDriverOptions {
49            file_name: "json.sqlite".to_string(),
50            table_name: "json".to_string(),
51        });
52
53        let database = Connection::open(&options.file_name)?;
54
55        let driver = SQLiteDriver {
56            name: options.file_name.clone(),
57            options: options.clone(),
58            table: options.table_name.clone(),
59            database,
60        };
61
62        driver.prepare(&options.table_name)?;
63
64        Ok(driver)
65    }
66
67    /// Prepares the SQLite database by creating the table if it doesn't already exist.
68    ///
69    /// # Parameters
70    /// - `table`: The name of the table to create.
71    ///
72    /// # Returns
73    /// A `Result` indicating success or failure.
74    pub fn prepare(&self, table: &str) -> Result<()> {
75        self.database.execute(
76            &format!(
77                "CREATE TABLE IF NOT EXISTS {} (ID TEXT PRIMARY KEY, JSON TEXT)",
78                table
79            ),
80            [],
81        )?;
82        Ok(())
83    }
84
85    /// Adds a value to an existing entry or creates a new entry if it doesn't exist.
86    /// The value is added to the current value of the entry (if it exists).
87    ///
88    /// # Parameters
89    /// - `key`: The key for the entry to update.
90    /// - `value`: The value to add to the current entry.
91    ///
92    /// # Returns
93    /// The new value after adding `value` to the existing entry, or an error if
94    /// the value is not finite (e.g., NaN or infinity).
95    pub fn add(&self, key: &str, value: f64) -> Result<f64> {
96        let current_value: f64 = self.get(key)?.unwrap_or(0.0);
97
98        if !current_value.is_finite() {
99            return Err(RusqliteError::ToSqlConversionFailure(Box::new(
100                std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-finite value"),
101            )));
102        }
103
104        let new_value = current_value + value;
105        self.set(key, new_value)?;
106        Ok(new_value)
107    }
108
109    /// Retrieves all data entries from the database as a vector of tuples.
110    ///
111    /// # Returns
112    /// A `Result` containing a vector of tuples where each tuple consists of
113    /// a key (`String`) and a corresponding value (`serde_json::Value`).
114    pub fn all(&self) -> Result<Vec<(String, Value)>> {
115        let mut stmt = self
116            .database
117            .prepare(&format!("SELECT * FROM {}", self.table))?;
118        let rows = stmt.query_map([], |row| {
119            let id: String = row.get(0)?;
120            let json_str: String = row.get(1)?;
121            let json: Value = from_str(&json_str).unwrap_or(Value::Null);
122            Ok((id, json))
123        })?;
124
125        let mut data = Vec::new();
126        for row in rows {
127            let (id, value) = row?;
128            data.push((id, value));
129        }
130
131        Ok(data)
132    }
133
134    /// Deletes a specific entry by key. If the key refers to a nested value,
135    /// it will remove the nested field within the JSON data.
136    ///
137    /// # Parameters
138    /// - `key`: The key of the entry to delete.
139    ///
140    /// # Returns
141    /// A `Result` indicating whether the deletion was successful.
142    pub fn delete(&self, key: &str) -> Result<bool> {
143        if key.contains('.') {
144            let split: Vec<&str> = key.split('.').collect();
145            let mut obj: Value = self.get(split[0])?.unwrap_or(Value::Null);
146            obj.as_object_mut().map(|obj| obj.remove(split[1]));
147            self.set(split[0], obj)?;
148            return Ok(true);
149        }
150
151        self.delete_row_key(key)?;
152        Ok(true)
153    }
154
155    /// Deletes all entries in the database.
156    ///
157    /// # Returns
158    /// A `Result` indicating whether the deletion was successful.
159    pub fn delete_all(&self) -> Result<bool> {
160        self.delete_rows()
161    }
162
163    /// Deletes a specific row from the table by key.
164    ///
165    /// # Parameters
166    /// - `key`: The key of the entry to delete.
167    ///
168    /// # Returns
169    /// A `Result` indicating whether the deletion was successful.
170    fn delete_row_key(&self, key: &str) -> Result<bool> {
171        self.database
172            .prepare(&format!("DELETE FROM {} WHERE ID = ?", self.table))?
173            .execute(params![key])?;
174        Ok(true)
175    }
176
177    /// Deletes all rows from the table.
178    ///
179    /// # Returns
180    /// A `Result` indicating whether the deletion was successful.
181    fn delete_rows(&self) -> Result<bool> {
182        self.database
183            .prepare(&format!("DELETE FROM {}", self.table))?
184            .execute([])?;
185        Ok(true)
186    }
187
188    /// Retrieves the value for a given key, potentially deserialising it into the specified type.
189    ///
190    /// # Parameters
191    /// - `key`: The key of the entry to retrieve.
192    ///
193    /// # Returns
194    /// A `Result` containing an `Option` of the deserialised value, or an error if the
195    /// deserialisation fails.
196    pub fn get<T>(&self, key: &str) -> Result<Option<T>>
197    where
198        T: DeserializeOwned + Default,
199    {
200        if key.contains('.') {
201            let split: Vec<&str> = key.split('.').collect();
202            let val: Value = self.get_row_key(split[0])?.unwrap_or_default();
203            let nested_value = val.pointer(&format!("/{}", split[1])).cloned();
204            Ok(nested_value.map(|v| from_str(&v.to_string()).unwrap_or_default()))
205        } else {
206            self.get_row_key(key)
207        }
208    }
209
210    /// Retrieves a value for a key, directly from the row.
211    ///
212    /// # Parameters
213    /// - `key`: The key of the entry to retrieve.
214    ///
215    /// # Returns
216    /// A `Result` containing the deserialised value, or `None` if the key doesn't exist.
217    fn get_row_key<T>(&self, key: &str) -> Result<Option<T>>
218    where
219        T: DeserializeOwned,
220    {
221        let mut stmt = self
222            .database
223            .prepare(&format!("SELECT JSON FROM {} WHERE ID = ?", self.table))?;
224
225        let row = stmt
226            .query_row(params![key], |row| row.get::<_, String>(0))
227            .optional()?;
228
229        if let Some(json_str) = row {
230            let json: Option<T> = from_str(&json_str).ok();
231            Ok(json)
232        } else {
233            Ok(None)
234        }
235    }
236
237    /// Checks if a given key exists in the database.
238    ///
239    /// # Parameters
240    /// - `key`: The key to check.
241    ///
242    /// # Returns
243    /// A `Result` containing a boolean indicating whether the key exists.
244    pub fn has(&self, key: &str) -> Result<bool> {
245        Ok(self.get::<Value>(key)?.is_some())
246    }
247
248    /// Removes a specific value from an array stored at the given key.
249    ///
250    /// # Parameters
251    /// - `key`: The key of the entry where the array is stored.
252    /// - `value`: The value to remove from the array.
253    ///
254    /// # Returns
255    /// A `Result` containing the updated array after removal.
256    pub fn pull<T>(&self, key: &str, value: T) -> Result<Vec<T>>
257    where
258        T: DeserializeOwned + std::cmp::PartialEq + Clone + Serialize,
259    {
260        let mut arr: Vec<T> = self.get(key)?.unwrap_or_default();
261
262        arr.retain(|x| x != &value);
263
264        self.set(key, arr.clone())?;
265
266        Ok(arr)
267    }
268
269    /// Appends a value to an array stored at the given key.
270    ///
271    /// # Parameters
272    /// - `key`: The key of the entry where the array is stored.
273    /// - `value`: The value to append to the array.
274    ///
275    /// # Returns
276    /// A `Result` containing the updated array after the value is appended.
277    pub fn push<T>(&self, key: &str, value: T) -> Result<Vec<T>>
278    where
279        T: DeserializeOwned + Clone + Serialize,
280    {
281        let mut arr: Vec<T> = self.get(key)?.unwrap_or_default();
282
283        arr.push(value);
284
285        self.set(key, arr.clone())?;
286
287        Ok(arr)
288    }
289
290    /// Sets or updates the value for a given key in the database.
291    ///
292    /// # Parameters
293    /// - `key`: The key for the entry.
294    /// - `value`: The value to store, which will be serialised into JSON.
295    ///
296    /// # Returns
297    /// A `Result` containing the value that was set.
298    pub fn set<T>(&self, key: &str, value: T) -> Result<()>
299    where
300        T: Serialize,
301    {
302        let parts: Vec<&str> = key.split('.').collect();
303        let root_key = parts[0];
304
305        let mut root_value: Value = self.get(root_key)?.unwrap_or_else(|| json!({}));
306
307        let mut current = &mut root_value;
308        for part in &parts[1..] {
309            current = current
310                .as_object_mut()
311                .unwrap()
312                .entry(part.to_string())
313                .or_insert(json!({}));
314        }
315        *current = json!(value);
316
317        let json_string = to_string(&root_value)
318            .map_err(|e: SerdeJsonError| RusqliteError::ToSqlConversionFailure(Box::new(e)))?;
319        self.database
320            .prepare(&format!(
321                "INSERT INTO {} (ID, JSON) VALUES (?, ?) ON CONFLICT(ID) DO UPDATE SET JSON = ?",
322                self.table
323            ))?
324            .execute(params![root_key, json_string, json_string])?;
325
326        Ok(())
327    }
328
329    /// Subtracts a value from an existing entry. If the entry does not exist,
330    /// it initialises it with the result.
331    ///
332    /// # Parameters
333    /// - `key`: The key of the entry to subtract from.
334    /// - `value`: The value to subtract from the current value.
335    ///
336    /// # Returns
337    /// The new value after subtraction.
338    pub fn subtract(&self, key: &str, value: f64) -> Result<f64> {
339        let current_value: f64 = self.get(key)?.unwrap_or(0.0);
340
341        if !current_value.is_finite() {
342            return Err(RusqliteError::ToSqlConversionFailure(Box::new(
343                std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-finite value"),
344            )));
345        }
346
347        let new_value = current_value - value;
348        self.set(key, new_value)?;
349        Ok(new_value)
350    }
351}