use std::marker::PhantomData;
use std::path::Path;
use serde::de::DeserializeOwned;
use serde::Serialize;
use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
use sqlx::Row;
use crate::StorageError;
pub struct IndexSpec<T> {
pub column: &'static str,
pub extractor: fn(&T) -> Option<String>,
pub unique: bool,
}
pub struct JsonTable<T: 'static> {
pub name: &'static str,
pub primary_key: &'static str,
pub indexes: &'static [IndexSpec<T>],
pub unique_constraints: &'static [&'static [&'static str]],
}
pub struct JsonStore<T>
where
T: Serialize + DeserializeOwned + Send + Sync + 'static,
{
pool: SqlitePool,
table: JsonTable<T>,
_marker: PhantomData<fn() -> T>,
}
impl<T> JsonStore<T>
where
T: Serialize + DeserializeOwned + Send + Sync + 'static,
{
pub async fn open<P: AsRef<Path>>(path: P, table: JsonTable<T>) -> Result<Self, StorageError> {
let path_str = path.as_ref().display().to_string();
let connection_string = format!("sqlite:{path_str}?mode=rwc");
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect(&connection_string)
.await?;
sqlx::query("PRAGMA journal_mode=WAL")
.execute(&pool)
.await?;
sqlx::query("PRAGMA busy_timeout=5000")
.execute(&pool)
.await?;
let this = Self {
pool,
table,
_marker: PhantomData,
};
this.init_schema().await?;
Ok(this)
}
pub async fn in_memory(table: JsonTable<T>) -> Result<Self, StorageError> {
let pool = SqlitePool::connect(":memory:").await?;
let this = Self {
pool,
table,
_marker: PhantomData,
};
this.init_schema().await?;
Ok(this)
}
async fn init_schema(&self) -> Result<(), StorageError> {
let table_name = self.table.name;
let mut ddl = String::with_capacity(256);
ddl.push_str("CREATE TABLE IF NOT EXISTS ");
ddl.push_str(table_name);
ddl.push_str(" (\n ");
ddl.push_str(self.table.primary_key);
ddl.push_str(" TEXT PRIMARY KEY NOT NULL");
for idx in self.table.indexes {
ddl.push_str(",\n ");
ddl.push_str(idx.column);
ddl.push_str(" TEXT");
}
ddl.push_str(",\n data_json TEXT NOT NULL");
ddl.push_str(",\n created_at TEXT NOT NULL");
ddl.push_str(",\n updated_at TEXT NOT NULL");
let unique_cols: Vec<&'static str> = self
.table
.indexes
.iter()
.filter(|i| i.unique)
.map(|i| i.column)
.collect();
if !unique_cols.is_empty() {
ddl.push_str(",\n UNIQUE(");
for (i, col) in unique_cols.iter().enumerate() {
if i > 0 {
ddl.push_str(", ");
}
ddl.push_str(col);
}
ddl.push(')');
}
let indexed_columns: std::collections::HashSet<&'static str> =
self.table.indexes.iter().map(|i| i.column).collect();
for constraint in self.table.unique_constraints {
if constraint.is_empty() {
continue;
}
for col in *constraint {
debug_assert!(
indexed_columns.contains(col),
"unique_constraints references column '{col}' on table \
'{table_name}' that is not declared in JsonTable::indexes"
);
}
ddl.push_str(",\n UNIQUE(");
for (i, col) in constraint.iter().enumerate() {
if i > 0 {
ddl.push_str(", ");
}
ddl.push_str(col);
}
ddl.push(')');
}
ddl.push_str("\n)");
sqlx::query(&ddl).execute(&self.pool).await?;
for idx in self.table.indexes {
let idx_ddl = format!(
"CREATE INDEX IF NOT EXISTS idx_{table}_{col} ON {table}({col})",
table = table_name,
col = idx.column,
);
sqlx::query(&idx_ddl).execute(&self.pool).await?;
}
Ok(())
}
pub async fn put(&self, id: &str, value: &T) -> Result<(), StorageError> {
let data_json = serde_json::to_string(value)?;
let now = chrono::Utc::now().to_rfc3339();
let table_name = self.table.name;
let mut columns: Vec<&str> = Vec::with_capacity(4 + self.table.indexes.len());
columns.push(self.table.primary_key);
for idx in self.table.indexes {
columns.push(idx.column);
}
columns.push("data_json");
columns.push("created_at");
columns.push("updated_at");
let placeholders = (0..columns.len())
.map(|_| "?")
.collect::<Vec<_>>()
.join(", ");
let update_assignments = {
let mut parts: Vec<String> = Vec::new();
for idx in self.table.indexes {
parts.push(format!("{col} = excluded.{col}", col = idx.column));
}
parts.push("data_json = excluded.data_json".to_string());
parts.push("updated_at = excluded.updated_at".to_string());
parts.join(", ")
};
let sql = format!(
"INSERT INTO {table} ({cols}) VALUES ({placeholders}) \
ON CONFLICT({pk}) DO UPDATE SET {updates}",
table = table_name,
cols = columns.join(", "),
placeholders = placeholders,
pk = self.table.primary_key,
updates = update_assignments,
);
let mut query = sqlx::query(&sql).bind(id);
for idx in self.table.indexes {
query = query.bind((idx.extractor)(value));
}
query = query.bind(&data_json).bind(&now).bind(&now);
match query.execute(&self.pool).await {
Ok(_) => Ok(()),
Err(err) => Err(map_sqlx_err(err, self.table.name)),
}
}
pub async fn get(&self, id: &str) -> Result<Option<T>, StorageError> {
let sql = format!(
"SELECT data_json FROM {table} WHERE {pk} = ?",
table = self.table.name,
pk = self.table.primary_key,
);
let row: Option<(String,)> = sqlx::query_as(&sql)
.bind(id)
.fetch_optional(&self.pool)
.await?;
match row {
Some((data_json,)) => {
let value: T = serde_json::from_str(&data_json)?;
Ok(Some(value))
}
None => Ok(None),
}
}
pub async fn list(&self) -> Result<Vec<T>, StorageError> {
let sql = format!(
"SELECT data_json FROM {table} ORDER BY {pk} ASC",
table = self.table.name,
pk = self.table.primary_key,
);
let rows: Vec<(String,)> = sqlx::query_as(&sql).fetch_all(&self.pool).await?;
let mut out = Vec::with_capacity(rows.len());
for (data_json,) in rows {
let value: T = serde_json::from_str(&data_json)?;
out.push(value);
}
Ok(out)
}
pub async fn list_lossy(&self) -> Result<Vec<T>, StorageError> {
let sql = format!(
"SELECT {pk}, data_json FROM {table} ORDER BY {pk} ASC",
table = self.table.name,
pk = self.table.primary_key,
);
let rows: Vec<(String, String)> = sqlx::query_as(&sql).fetch_all(&self.pool).await?;
let mut out = Vec::with_capacity(rows.len());
for (key, data_json) in rows {
match serde_json::from_str::<T>(&data_json) {
Ok(value) => out.push(value),
Err(e) => {
tracing::error!(
table = %self.table.name,
key = %key,
error = %e,
"skipping un-deserializable row in list_lossy()"
);
}
}
}
Ok(out)
}
pub async fn delete(&self, id: &str) -> Result<bool, StorageError> {
let sql = format!(
"DELETE FROM {table} WHERE {pk} = ?",
table = self.table.name,
pk = self.table.primary_key,
);
let result = sqlx::query(&sql).bind(id).execute(&self.pool).await?;
Ok(result.rows_affected() > 0)
}
pub async fn get_by_unique(
&self,
column: &str,
value: &str,
) -> Result<Option<T>, StorageError> {
let found = self.table.indexes.iter().any(|i| i.column == column);
if !found {
return Err(StorageError::Other(format!(
"unknown column '{column}' for table '{table}'",
table = self.table.name
)));
}
let sql = format!(
"SELECT data_json FROM {table} WHERE {col} = ? LIMIT 1",
table = self.table.name,
col = column,
);
let row: Option<(String,)> = sqlx::query_as(&sql)
.bind(value)
.fetch_optional(&self.pool)
.await?;
match row {
Some((data_json,)) => {
let value: T = serde_json::from_str(&data_json)?;
Ok(Some(value))
}
None => Ok(None),
}
}
pub async fn list_where(&self, column: &str, value: &str) -> Result<Vec<T>, StorageError> {
self.list_where_opt(column, Some(value)).await
}
pub async fn list_where_null(&self, column: &str) -> Result<Vec<T>, StorageError> {
self.list_where_opt(column, None).await
}
pub async fn list_where_opt(
&self,
column: &str,
value: Option<&str>,
) -> Result<Vec<T>, StorageError> {
let found = self.table.indexes.iter().any(|i| i.column == column);
if !found {
return Err(StorageError::Other(format!(
"unknown column '{column}' for table '{table}'",
table = self.table.name
)));
}
let rows: Vec<(String,)> = if let Some(v) = value {
let sql = format!(
"SELECT data_json FROM {table} WHERE {col} = ? ORDER BY {pk} ASC",
table = self.table.name,
col = column,
pk = self.table.primary_key,
);
sqlx::query_as(&sql).bind(v).fetch_all(&self.pool).await?
} else {
let sql = format!(
"SELECT data_json FROM {table} WHERE {col} IS NULL ORDER BY {pk} ASC",
table = self.table.name,
col = column,
pk = self.table.primary_key,
);
sqlx::query_as(&sql).fetch_all(&self.pool).await?
};
let mut out = Vec::with_capacity(rows.len());
for (data_json,) in rows {
let decoded: T = serde_json::from_str(&data_json)?;
out.push(decoded);
}
Ok(out)
}
#[must_use]
pub fn pool(&self) -> &SqlitePool {
&self.pool
}
pub async fn count(&self) -> Result<u64, StorageError> {
let sql = format!("SELECT COUNT(*) FROM {table}", table = self.table.name);
let row = sqlx::query(&sql).fetch_one(&self.pool).await?;
let raw: i64 = row.try_get(0)?;
Ok(u64::try_from(raw).unwrap_or(0))
}
}
fn map_sqlx_err(err: sqlx::Error, table: &str) -> StorageError {
if let sqlx::Error::Database(db) = &err {
if let Some(code) = db.code() {
if code == "2067" || code == "1555" {
return StorageError::AlreadyExists(format!(
"UNIQUE constraint failed on table '{table}': {db}"
));
}
}
let msg = db.message();
if msg.contains("UNIQUE constraint failed") {
return StorageError::AlreadyExists(format!(
"UNIQUE constraint failed on table '{table}': {msg}"
));
}
}
StorageError::from(err)
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
struct TestRecord {
id: String,
name: String,
value: String,
}
fn indexed_table() -> JsonTable<TestRecord> {
static INDEXES: &[IndexSpec<TestRecord>] = &[
IndexSpec {
column: "name",
extractor: |r| Some(r.name.clone()),
unique: true,
},
IndexSpec {
column: "value",
extractor: |r| Some(r.value.clone()),
unique: false,
},
];
JsonTable {
name: "test_records",
primary_key: "id",
indexes: INDEXES,
unique_constraints: &[],
}
}
fn blob_only_table() -> JsonTable<TestRecord> {
JsonTable {
name: "test_blobs",
primary_key: "id",
indexes: &[],
unique_constraints: &[],
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
struct ScopedRecord {
id: String,
name: String,
scope: Option<String>,
}
fn scoped_table() -> JsonTable<ScopedRecord> {
static INDEXES: &[IndexSpec<ScopedRecord>] = &[
IndexSpec {
column: "name",
extractor: |r| Some(r.name.clone()),
unique: false,
},
IndexSpec {
column: "scope",
extractor: |r| r.scope.clone(),
unique: false,
},
];
static UNIQUES: &[&[&str]] = &[&["name", "scope"]];
JsonTable {
name: "scoped_records",
primary_key: "id",
indexes: INDEXES,
unique_constraints: UNIQUES,
}
}
fn make_scoped(id: &str, name: &str, scope: Option<&str>) -> ScopedRecord {
ScopedRecord {
id: id.to_string(),
name: name.to_string(),
scope: scope.map(str::to_string),
}
}
fn make(id: &str, name: &str, value: &str) -> TestRecord {
TestRecord {
id: id.to_string(),
name: name.to_string(),
value: value.to_string(),
}
}
#[tokio::test]
async fn in_memory_round_trip_with_indexes() {
let store = JsonStore::in_memory(indexed_table()).await.unwrap();
let rec = make("id-1", "alpha", "v1");
store.put(&rec.id, &rec).await.unwrap();
let got = store.get("id-1").await.unwrap().expect("must exist");
assert_eq!(got, rec);
let list = store.list().await.unwrap();
assert_eq!(list.len(), 1);
assert_eq!(list[0], rec);
assert!(store.delete("id-1").await.unwrap());
assert!(store.get("id-1").await.unwrap().is_none());
assert!(!store.delete("id-1").await.unwrap());
}
#[tokio::test]
async fn in_memory_round_trip_blob_only() {
let store = JsonStore::in_memory(blob_only_table()).await.unwrap();
let rec = make("b-1", "bob", "42");
store.put(&rec.id, &rec).await.unwrap();
let got = store.get("b-1").await.unwrap().expect("must exist");
assert_eq!(got, rec);
let list = store.list().await.unwrap();
assert_eq!(list, vec![rec]);
}
#[tokio::test]
async fn unique_conflict_surfaces_as_already_exists() {
let store = JsonStore::in_memory(indexed_table()).await.unwrap();
store.put("a", &make("a", "dup", "v1")).await.unwrap();
let err = store
.put("b", &make("b", "dup", "v2"))
.await
.expect_err("unique violation must error");
match err {
StorageError::AlreadyExists(msg) => {
assert!(
msg.contains("test_records"),
"message should mention the table name, got: {msg}"
);
}
other => panic!("expected AlreadyExists, got {other:?}"),
}
}
#[tokio::test]
async fn get_by_unique_success_and_unknown_column() {
let store = JsonStore::in_memory(indexed_table()).await.unwrap();
store.put("x", &make("x", "name-x", "val-x")).await.unwrap();
store.put("y", &make("y", "name-y", "val-y")).await.unwrap();
let found = store
.get_by_unique("name", "name-x")
.await
.unwrap()
.expect("must exist");
assert_eq!(found.id, "x");
let by_value = store
.get_by_unique("value", "val-y")
.await
.unwrap()
.expect("must exist");
assert_eq!(by_value.id, "y");
let missing = store.get_by_unique("name", "nope").await.unwrap();
assert!(missing.is_none());
let err = store
.get_by_unique("not_a_column", "whatever")
.await
.expect_err("unknown column must fail");
match err {
StorageError::Other(msg) => {
assert!(msg.contains("not_a_column"));
assert!(msg.contains("test_records"));
}
other => panic!("expected Other, got {other:?}"),
}
}
#[tokio::test]
async fn upsert_preserves_created_at() {
let store = JsonStore::in_memory(indexed_table()).await.unwrap();
let rec = make("z", "z-name", "v1");
store.put(&rec.id, &rec).await.unwrap();
let (created_before, updated_before): (String, String) =
sqlx::query_as("SELECT created_at, updated_at FROM test_records WHERE id = ?")
.bind("z")
.fetch_one(&store.pool)
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
let updated = make("z", "z-name", "v2");
store.put(&updated.id, &updated).await.unwrap();
let (created_after, updated_after): (String, String) =
sqlx::query_as("SELECT created_at, updated_at FROM test_records WHERE id = ?")
.bind("z")
.fetch_one(&store.pool)
.await
.unwrap();
assert_eq!(
created_before, created_after,
"created_at must be preserved across upsert"
);
assert_ne!(
updated_before, updated_after,
"updated_at must advance on upsert"
);
let got = store.get("z").await.unwrap().unwrap();
assert_eq!(got.value, "v2");
}
#[tokio::test]
async fn count_tracks_inserts_and_deletes() {
let store = JsonStore::in_memory(indexed_table()).await.unwrap();
assert_eq!(store.count().await.unwrap(), 0);
store.put("1", &make("1", "one", "v")).await.unwrap();
store.put("2", &make("2", "two", "v")).await.unwrap();
store.put("3", &make("3", "three", "v")).await.unwrap();
assert_eq!(store.count().await.unwrap(), 3);
store.put("1", &make("1", "one", "v2")).await.unwrap();
assert_eq!(store.count().await.unwrap(), 3);
assert!(store.delete("2").await.unwrap());
assert_eq!(store.count().await.unwrap(), 2);
}
#[tokio::test]
async fn list_returns_every_row() {
let store = JsonStore::in_memory(blob_only_table()).await.unwrap();
store.put("c", &make("c", "cc", "3")).await.unwrap();
store.put("a", &make("a", "aa", "1")).await.unwrap();
store.put("b", &make("b", "bb", "2")).await.unwrap();
let list = store.list().await.unwrap();
assert_eq!(list.len(), 3);
let ids: Vec<&str> = list.iter().map(|r| r.id.as_str()).collect();
assert_eq!(ids, vec!["a", "b", "c"]);
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
struct NamedRecord {
name: String,
payload: String,
}
fn named_pk_table() -> JsonTable<NamedRecord> {
JsonTable {
name: "named_records",
primary_key: "name",
indexes: &[],
unique_constraints: &[],
}
}
#[tokio::test]
async fn custom_primary_key_round_trips() {
let store = JsonStore::in_memory(named_pk_table()).await.unwrap();
let pk_col: (String,) =
sqlx::query_as("SELECT name FROM pragma_table_info('named_records') WHERE pk = 1")
.fetch_one(&store.pool)
.await
.unwrap();
assert_eq!(pk_col.0, "name", "primary key column must be `name`");
let rec = NamedRecord {
name: "alpha".to_string(),
payload: "v1".to_string(),
};
store.put(&rec.name, &rec).await.unwrap();
let got = store.get("alpha").await.unwrap().expect("must exist");
assert_eq!(got, rec);
let updated = NamedRecord {
name: "alpha".to_string(),
payload: "v2".to_string(),
};
store.put(&updated.name, &updated).await.unwrap();
assert_eq!(store.count().await.unwrap(), 1);
assert_eq!(store.get("alpha").await.unwrap().unwrap().payload, "v2");
store
.put(
"beta",
&NamedRecord {
name: "beta".to_string(),
payload: "vb".to_string(),
},
)
.await
.unwrap();
let names: Vec<String> = store
.list()
.await
.unwrap()
.into_iter()
.map(|r| r.name)
.collect();
assert_eq!(names, vec!["alpha".to_string(), "beta".to_string()]);
assert!(store.delete("alpha").await.unwrap());
assert!(store.get("alpha").await.unwrap().is_none());
assert!(!store.delete("alpha").await.unwrap());
}
#[tokio::test]
async fn list_lossy_skips_undeserializable_rows() {
let store = JsonStore::in_memory(blob_only_table()).await.unwrap();
store.put("a", &make("a", "aa", "1")).await.unwrap();
store.put("c", &make("c", "cc", "3")).await.unwrap();
sqlx::query(
"INSERT INTO test_blobs (id, data_json, created_at, updated_at) \
VALUES (?, ?, ?, ?)",
)
.bind("b")
.bind("{ not valid json for TestRecord")
.bind("2026-01-01T00:00:00+00:00")
.bind("2026-01-01T00:00:00+00:00")
.execute(&store.pool)
.await
.unwrap();
assert!(store.list().await.is_err());
let good = store.list_lossy().await.unwrap();
let ids: Vec<&str> = good.iter().map(|r| r.id.as_str()).collect();
assert_eq!(ids, vec!["a", "c"]);
}
#[tokio::test]
async fn compound_unique_rejects_duplicate_pair() {
let store = JsonStore::in_memory(scoped_table()).await.unwrap();
store
.put("id-1", &make_scoped("id-1", "foo", Some("proj-a")))
.await
.unwrap();
let err = store
.put("id-2", &make_scoped("id-2", "foo", Some("proj-a")))
.await
.expect_err("compound unique violation must error");
match err {
StorageError::AlreadyExists(msg) => {
assert!(
msg.contains("scoped_records"),
"message should mention the table name, got: {msg}"
);
}
other => panic!("expected AlreadyExists, got {other:?}"),
}
}
#[tokio::test]
async fn compound_unique_permits_distinct_combinations() {
let store = JsonStore::in_memory(scoped_table()).await.unwrap();
store
.put("id-1", &make_scoped("id-1", "foo", Some("proj-a")))
.await
.unwrap();
store
.put("id-2", &make_scoped("id-2", "foo", Some("proj-b")))
.await
.unwrap();
store
.put("id-3", &make_scoped("id-3", "bar", Some("proj-a")))
.await
.unwrap();
assert_eq!(store.count().await.unwrap(), 3);
}
#[tokio::test]
async fn list_where_filters_by_column_value() {
let store = JsonStore::in_memory(scoped_table()).await.unwrap();
store
.put("a", &make_scoped("a", "x", Some("s1")))
.await
.unwrap();
store
.put("b", &make_scoped("b", "y", Some("s1")))
.await
.unwrap();
store
.put("c", &make_scoped("c", "z", Some("s2")))
.await
.unwrap();
let s1 = store.list_where("scope", "s1").await.unwrap();
assert_eq!(s1.len(), 2);
let ids: Vec<&str> = s1.iter().map(|r| r.id.as_str()).collect();
assert_eq!(ids, vec!["a", "b"]);
let s2 = store.list_where("scope", "s2").await.unwrap();
assert_eq!(s2.len(), 1);
assert_eq!(s2[0].id, "c");
let missing = store.list_where("scope", "s3").await.unwrap();
assert!(missing.is_empty());
}
#[tokio::test]
async fn list_where_null_filters_null_rows() {
let store = JsonStore::in_memory(scoped_table()).await.unwrap();
store.put("a", &make_scoped("a", "x", None)).await.unwrap();
store
.put("b", &make_scoped("b", "y", Some("s1")))
.await
.unwrap();
store.put("c", &make_scoped("c", "z", None)).await.unwrap();
let nulls = store.list_where_null("scope").await.unwrap();
assert_eq!(nulls.len(), 2);
let ids: Vec<&str> = nulls.iter().map(|r| r.id.as_str()).collect();
assert_eq!(ids, vec!["a", "c"]);
}
#[tokio::test]
async fn list_where_opt_unknown_column_returns_other() {
let store = JsonStore::in_memory(scoped_table()).await.unwrap();
let err = store
.list_where_opt("not_a_column", Some("x"))
.await
.expect_err("unknown column must fail");
match err {
StorageError::Other(msg) => {
assert!(msg.contains("not_a_column"));
assert!(msg.contains("scoped_records"));
}
other => panic!("expected Other, got {other:?}"),
}
let err2 = store
.list_where_opt("not_a_column", None)
.await
.expect_err("unknown column must fail");
match err2 {
StorageError::Other(_) => {}
other => panic!("expected Other, got {other:?}"),
}
}
}