use parking_lot::Mutex;
use rusqlite::{params, Connection, OptionalExtension};
use serde_json::{json, Value};
use std::path::Path;
use std::sync::Arc;
use crate::memory::store::safety;
use crate::memory::types::MemoryKvRecord;
const INIT_SQL: &str = "
PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
CREATE TABLE IF NOT EXISTS kv_global (
key TEXT PRIMARY KEY,
value_json TEXT NOT NULL,
updated_at REAL NOT NULL
);
CREATE TABLE IF NOT EXISTS kv_namespace (
namespace TEXT NOT NULL,
key TEXT NOT NULL,
value_json TEXT NOT NULL,
updated_at REAL NOT NULL,
PRIMARY KEY (namespace, key)
);
CREATE INDEX IF NOT EXISTS idx_kv_ns ON kv_namespace(namespace);
";
pub struct KvStore {
conn: Arc<Mutex<Connection>>,
}
impl KvStore {
pub fn open(db_path: &Path) -> anyhow::Result<Self> {
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent)?;
}
let conn = Connection::open(db_path)?;
conn.execute_batch(INIT_SQL)?;
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
})
}
pub fn open_in_memory() -> anyhow::Result<Self> {
let conn = Connection::open_in_memory()?;
conn.execute_batch(INIT_SQL)?;
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
})
}
fn now_ts() -> f64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs_f64())
.unwrap_or(0.0)
}
fn sanitize_namespace(namespace: &str) -> String {
namespace
.chars()
.map(|c| match c {
c if c.is_whitespace() => '_',
'#' => '_',
other => other,
})
.collect()
}
pub fn set_global(&self, key: &str, value: &Value) -> Result<(), String> {
if safety::has_likely_secret(key) {
return Err("kv key cannot contain secrets".to_string());
}
if safety::has_likely_pii(key) {
return Err("kv key cannot contain personal identifiers".to_string());
}
let sanitized = safety::sanitize_json(value);
let conn = self.conn.lock();
conn.execute(
"INSERT INTO kv_global (key, value_json, updated_at)
VALUES (?1, ?2, ?3)
ON CONFLICT(key) DO UPDATE SET value_json = excluded.value_json, updated_at = excluded.updated_at",
params![key, sanitized.value.to_string(), Self::now_ts()],
)
.map_err(|e| format!("set_global: {e}"))?;
Ok(())
}
pub fn get_global(&self, key: &str) -> Result<Option<Value>, String> {
let conn = self.conn.lock();
let value: Option<String> = conn
.query_row(
"SELECT value_json FROM kv_global WHERE key = ?1",
params![key],
|row| row.get(0),
)
.optional()
.map_err(|e| format!("get_global: {e}"))?;
Ok(value.and_then(|v| serde_json::from_str(&v).ok()))
}
pub fn set_namespace(&self, namespace: &str, key: &str, value: &Value) -> Result<(), String> {
if safety::has_likely_secret(namespace) || safety::has_likely_secret(key) {
return Err("kv namespace/key cannot contain secrets".to_string());
}
if safety::has_likely_pii(namespace) || safety::has_likely_pii(key) {
return Err("kv namespace/key cannot contain personal identifiers".to_string());
}
let sanitized = safety::sanitize_json(value);
let conn = self.conn.lock();
conn.execute(
"INSERT INTO kv_namespace (namespace, key, value_json, updated_at)
VALUES (?1, ?2, ?3, ?4)
ON CONFLICT(namespace, key) DO UPDATE SET value_json = excluded.value_json, updated_at = excluded.updated_at",
params![
Self::sanitize_namespace(namespace),
key,
sanitized.value.to_string(),
Self::now_ts()
],
)
.map_err(|e| format!("set_namespace: {e}"))?;
Ok(())
}
pub fn get_namespace(&self, namespace: &str, key: &str) -> Result<Option<Value>, String> {
let conn = self.conn.lock();
let value: Option<String> = conn
.query_row(
"SELECT value_json FROM kv_namespace WHERE namespace = ?1 AND key = ?2",
params![Self::sanitize_namespace(namespace), key],
|row| row.get(0),
)
.optional()
.map_err(|e| format!("get_namespace: {e}"))?;
Ok(value.and_then(|v| serde_json::from_str(&v).ok()))
}
pub fn delete_global(&self, key: &str) -> Result<bool, String> {
let conn = self.conn.lock();
let changed = conn
.execute("DELETE FROM kv_global WHERE key = ?1", params![key])
.map_err(|e| format!("delete_global: {e}"))?;
Ok(changed > 0)
}
pub fn delete_namespace(&self, namespace: &str, key: &str) -> Result<bool, String> {
let conn = self.conn.lock();
let changed = conn
.execute(
"DELETE FROM kv_namespace WHERE namespace = ?1 AND key = ?2",
params![Self::sanitize_namespace(namespace), key],
)
.map_err(|e| format!("delete_namespace: {e}"))?;
Ok(changed > 0)
}
pub fn list_namespace(&self, namespace: &str) -> Result<Vec<Value>, String> {
let conn = self.conn.lock();
let mut stmt = conn
.prepare(
"SELECT key, value_json, updated_at FROM kv_namespace
WHERE namespace = ?1 ORDER BY updated_at DESC",
)
.map_err(|e| format!("list_namespace prepare: {e}"))?;
let mut rows = stmt
.query(params![Self::sanitize_namespace(namespace)])
.map_err(|e| format!("list_namespace query: {e}"))?;
let mut out = Vec::new();
while let Some(row) = rows
.next()
.map_err(|e| format!("list_namespace row: {e}"))?
{
let value_raw: String = row.get(1).map_err(|e| e.to_string())?;
out.push(json!({
"key": row.get::<_, String>(0).map_err(|e| e.to_string())?,
"value": serde_json::from_str::<Value>(&value_raw).unwrap_or(Value::Null),
"updatedAt": row.get::<_, f64>(2).map_err(|e| e.to_string())?,
}));
}
Ok(out)
}
pub fn records_for_scope(&self, namespace: &str) -> Result<Vec<MemoryKvRecord>, String> {
let mut records = self.records_namespace(namespace)?;
records.extend(self.records_global()?);
records.sort_by(|a, b| {
b.updated_at
.partial_cmp(&a.updated_at)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(records)
}
pub fn records_namespace(&self, namespace: &str) -> Result<Vec<MemoryKvRecord>, String> {
let ns = Self::sanitize_namespace(namespace);
let conn = self.conn.lock();
let mut stmt = conn
.prepare(
"SELECT key, value_json, updated_at FROM kv_namespace
WHERE namespace = ?1 ORDER BY updated_at DESC",
)
.map_err(|e| format!("prepare records_namespace: {e}"))?;
let mut rows = stmt
.query(params![ns])
.map_err(|e| format!("query records_namespace: {e}"))?;
let mut out = Vec::new();
while let Some(row) = rows
.next()
.map_err(|e| format!("row records_namespace: {e}"))?
{
let value_raw: String = row.get(1).map_err(|e| e.to_string())?;
out.push(MemoryKvRecord {
namespace: Some(ns.clone()),
key: row.get(0).map_err(|e| e.to_string())?,
value: serde_json::from_str(&value_raw).unwrap_or(Value::Null),
updated_at: row.get(2).map_err(|e| e.to_string())?,
});
}
Ok(out)
}
pub fn records_global(&self) -> Result<Vec<MemoryKvRecord>, String> {
let conn = self.conn.lock();
let mut stmt = conn
.prepare("SELECT key, value_json, updated_at FROM kv_global ORDER BY updated_at DESC")
.map_err(|e| format!("prepare records_global: {e}"))?;
let mut rows = stmt
.query([])
.map_err(|e| format!("query records_global: {e}"))?;
let mut out = Vec::new();
while let Some(row) = rows
.next()
.map_err(|e| format!("row records_global: {e}"))?
{
let value_raw: String = row.get(1).map_err(|e| e.to_string())?;
out.push(MemoryKvRecord {
namespace: None,
key: row.get(0).map_err(|e| e.to_string())?,
value: serde_json::from_str(&value_raw).unwrap_or(Value::Null),
updated_at: row.get(2).map_err(|e| e.to_string())?,
});
}
Ok(out)
}
}
#[cfg(test)]
#[path = "kv_tests.rs"]
mod tests;