use crate::StorageBackend;
use crate::{PersistentError, Result};
use serde::{de::DeserializeOwned, Serialize};
use std::{collections::HashMap, hash::Hash, str::FromStr};
use tokio_rusqlite::{params, Connection};
#[derive(Debug)]
pub struct SqliteBackend {
conn: Connection,
}
impl SqliteBackend {
pub async fn new(db_path: &str) -> Result<Self> {
let conn = Connection::open(db_path).await?;
conn.call(|c| {
c.execute(
"CREATE TABLE IF NOT EXISTS kv (key TEXT PRIMARY KEY, value TEXT NOT NULL)",
[],
)
.map_err(tokio_rusqlite::Error::Rusqlite)
})
.await?;
conn.call(|c| {
c.execute("CREATE INDEX IF NOT EXISTS kv_key_idx ON kv (key)", [])
.map_err(tokio_rusqlite::Error::Rusqlite)
})
.await?;
Ok(Self { conn })
}
pub async fn db_path(&self) -> Result<String> {
let result = self
.conn
.call(|c| {
c.query_row("PRAGMA database_list", [], |row| {
let path: String = row.get(2)?;
Ok(path)
})
.map_err(tokio_rusqlite::Error::Rusqlite)
})
.await?;
Ok(result)
}
}
#[async_trait::async_trait]
impl<K, V> StorageBackend<K, V> for SqliteBackend
where
K: Eq
+ Hash
+ Clone
+ Serialize
+ DeserializeOwned
+ Send
+ Sync
+ 'static
+ ToString
+ FromStr,
<K as FromStr>::Err: std::error::Error + Send + Sync + 'static,
V: Clone + Serialize + DeserializeOwned + Send + Sync + 'static,
{
async fn load_all(&self) -> Result<HashMap<K, V>, PersistentError> {
let rows = self
.conn
.call(|c| {
let mut stmt = c.prepare_cached("SELECT key, value FROM kv")?;
let mut map = HashMap::with_capacity(100); let mut rows_iter = stmt.query_map([], |r| {
let key_str: String = r.get(0)?;
let val_str: String = r.get(1)?;
Ok((key_str, val_str))
})?;
while let Some(Ok((k_str, v_str))) = rows_iter.next() {
let value: V = serde_json::from_str(&v_str)
.map_err(|e| tokio_rusqlite::Error::Other(Box::new(e)))?;
let key = k_str
.parse()
.map_err(|e| tokio_rusqlite::Error::Other(Box::new(e)))?;
map.insert(key, value);
}
Ok(map)
})
.await?;
Ok(rows)
}
async fn save(&self, key: K, value: V) -> Result<(), PersistentError> {
let key_str = key.to_string();
let val_json = serde_json::to_string(&value)?;
self.conn
.call(move |c| {
c.execute(
"INSERT OR REPLACE INTO kv (key, value) VALUES (?1, ?2)",
params![key_str, val_json],
)
.map_err(tokio_rusqlite::Error::Rusqlite)
})
.await?;
Ok(())
}
#[inline]
async fn delete(&self, key: &K) -> Result<(), PersistentError> {
let key_str = key.to_string();
self.conn
.call(move |c| {
c.execute("DELETE FROM kv WHERE key = ?1", params![key_str])
.map_err(tokio_rusqlite::Error::Rusqlite)
})
.await?;
Ok(())
}
async fn flush(&self) -> Result<(), PersistentError> {
self.conn
.call(|c| {
c.execute("PRAGMA synchronous = FULL", [])
.map_err(tokio_rusqlite::Error::Rusqlite)
})
.await?;
Ok(())
}
}