use crate::{
adapter::{AdapterKind, DatabaseConfig},
error::{DataError, DataResult},
query::{FilterOperator, Query, SortDirection},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{cmp::Ordering, collections::BTreeMap, time::Instant};
use tracing::{info, warn};
pub type Row = BTreeMap<String, Value>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct StoredRow {
pub id: u64,
pub data: Row,
}
pub trait AdapterDriver: Send + Sync {
fn kind(&self) -> AdapterKind;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct PostgresAdapter;
impl AdapterDriver for PostgresAdapter {
fn kind(&self) -> AdapterKind {
AdapterKind::Postgres
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct MySqlAdapter;
impl AdapterDriver for MySqlAdapter {
fn kind(&self) -> AdapterKind {
AdapterKind::MySql
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct SqliteAdapter;
impl AdapterDriver for SqliteAdapter {
fn kind(&self) -> AdapterKind {
AdapterKind::Sqlite
}
}
pub fn adapter_for(config: &DatabaseConfig) -> DataResult<Box<dyn AdapterDriver>> {
match config.adapter {
AdapterKind::Postgres => Ok(Box::new(PostgresAdapter)),
AdapterKind::MySql => Ok(Box::new(MySqlAdapter)),
AdapterKind::Sqlite => Ok(Box::new(SqliteAdapter)),
AdapterKind::None => Err(DataError::Adapter(
"database adapter is `none`; select postgres/mysql/sqlite in shelly.data.toml"
.to_string(),
)),
}
}
pub trait Repo {
fn adapter_kind(&self) -> AdapterKind;
fn insert(&mut self, table: &str, data: Row) -> DataResult<StoredRow>;
fn update(&mut self, table: &str, id: u64, data: Row) -> DataResult<StoredRow>;
fn delete(&mut self, table: &str, id: u64) -> DataResult<()>;
fn find(&self, table: &str, id: u64) -> DataResult<Option<StoredRow>>;
fn list(&self, table: &str, query: &Query) -> DataResult<Vec<StoredRow>>;
}
pub struct MemoryRepo {
driver: Box<dyn AdapterDriver>,
tables: BTreeMap<String, Vec<StoredRow>>,
next_id: u64,
}
impl MemoryRepo {
pub fn new(driver: Box<dyn AdapterDriver>) -> Self {
Self {
driver,
tables: BTreeMap::new(),
next_id: 1,
}
}
}
impl Repo for MemoryRepo {
fn adapter_kind(&self) -> AdapterKind {
self.driver.kind()
}
fn insert(&mut self, table: &str, data: Row) -> DataResult<StoredRow> {
let started_at = Instant::now();
let result = {
let entry = self.tables.entry(table.to_string()).or_default();
let row = StoredRow {
id: self.next_id,
data,
};
self.next_id += 1;
entry.push(row.clone());
Ok(row)
};
match &result {
Ok(row) => info!(
target: "shelly.data.query",
source = "memory_repo",
adapter = self.driver.kind().as_str(),
operation = "insert",
table,
row_id = row.id,
duration_ms = started_at.elapsed().as_millis() as u64,
"Shelly data query executed"
),
Err(err) => warn!(
target: "shelly.data.query",
source = "memory_repo",
adapter = self.driver.kind().as_str(),
operation = "insert",
table,
duration_ms = started_at.elapsed().as_millis() as u64,
error = %err,
"Shelly data query failed"
),
}
result
}
fn update(&mut self, table: &str, id: u64, data: Row) -> DataResult<StoredRow> {
let started_at = Instant::now();
let result = {
let rows = self.tables.entry(table.to_string()).or_default();
match rows.iter_mut().find(|row| row.id == id) {
Some(existing) => {
existing.data = data;
Ok(existing.clone())
}
None => Err(DataError::Query(format!(
"row id {id} not found in table `{table}`"
))),
}
};
match &result {
Ok(row) => info!(
target: "shelly.data.query",
source = "memory_repo",
adapter = self.driver.kind().as_str(),
operation = "update",
table,
row_id = row.id,
duration_ms = started_at.elapsed().as_millis() as u64,
"Shelly data query executed"
),
Err(err) => warn!(
target: "shelly.data.query",
source = "memory_repo",
adapter = self.driver.kind().as_str(),
operation = "update",
table,
row_id = id,
duration_ms = started_at.elapsed().as_millis() as u64,
error = %err,
"Shelly data query failed"
),
}
result
}
fn delete(&mut self, table: &str, id: u64) -> DataResult<()> {
let started_at = Instant::now();
let result = {
let rows = self.tables.entry(table.to_string()).or_default();
let initial_len = rows.len();
rows.retain(|row| row.id != id);
if rows.len() == initial_len {
Err(DataError::Query(format!(
"row id {id} not found in table `{table}`"
)))
} else {
Ok(())
}
};
match &result {
Ok(()) => info!(
target: "shelly.data.query",
source = "memory_repo",
adapter = self.driver.kind().as_str(),
operation = "delete",
table,
row_id = id,
duration_ms = started_at.elapsed().as_millis() as u64,
"Shelly data query executed"
),
Err(err) => warn!(
target: "shelly.data.query",
source = "memory_repo",
adapter = self.driver.kind().as_str(),
operation = "delete",
table,
row_id = id,
duration_ms = started_at.elapsed().as_millis() as u64,
error = %err,
"Shelly data query failed"
),
}
result
}
fn find(&self, table: &str, id: u64) -> DataResult<Option<StoredRow>> {
let started_at = Instant::now();
let result = Ok(self
.tables
.get(table)
.and_then(|rows| rows.iter().find(|row| row.id == id))
.cloned());
match &result {
Ok(row) => info!(
target: "shelly.data.query",
source = "memory_repo",
adapter = self.driver.kind().as_str(),
operation = "find",
table,
row_id = id,
found = row.is_some(),
duration_ms = started_at.elapsed().as_millis() as u64,
"Shelly data query executed"
),
Err(err) => warn!(
target: "shelly.data.query",
source = "memory_repo",
adapter = self.driver.kind().as_str(),
operation = "find",
table,
row_id = id,
duration_ms = started_at.elapsed().as_millis() as u64,
error = %err,
"Shelly data query failed"
),
}
result
}
fn list(&self, table: &str, query: &Query) -> DataResult<Vec<StoredRow>> {
let started_at = Instant::now();
let result = {
let mut rows = self.tables.get(table).cloned().unwrap_or_default();
if !query.filters.is_empty() {
rows.retain(|row| {
query
.filters
.iter()
.all(|filter| matches_filter(row, filter))
});
}
for sort in query.sorts.iter().rev() {
rows.sort_by(|left, right| compare_for_sort(left, right, sort.field.as_str()));
if sort.direction == SortDirection::Desc {
rows.reverse();
}
}
if let Some(pagination) = query.pagination {
let offset = (pagination.page.saturating_sub(1)) * pagination.per_page;
rows = rows
.into_iter()
.skip(offset)
.take(pagination.per_page)
.collect();
}
Ok(rows)
};
match &result {
Ok(rows) => info!(
target: "shelly.data.query",
source = "memory_repo",
adapter = self.driver.kind().as_str(),
operation = "list",
table,
row_count = rows.len(),
filter_count = query.filters.len(),
sort_count = query.sorts.len(),
page = query.pagination.map(|value| value.page),
per_page = query.pagination.map(|value| value.per_page),
duration_ms = started_at.elapsed().as_millis() as u64,
"Shelly data query executed"
),
Err(err) => warn!(
target: "shelly.data.query",
source = "memory_repo",
adapter = self.driver.kind().as_str(),
operation = "list",
table,
filter_count = query.filters.len(),
sort_count = query.sorts.len(),
page = query.pagination.map(|value| value.page),
per_page = query.pagination.map(|value| value.per_page),
duration_ms = started_at.elapsed().as_millis() as u64,
error = %err,
"Shelly data query failed"
),
}
result
}
}
fn matches_filter(row: &StoredRow, filter: &crate::query::Filter) -> bool {
let Some(candidate) = row.data.get(&filter.field) else {
return false;
};
match filter.op {
FilterOperator::Eq => candidate == &filter.value,
FilterOperator::Neq => candidate != &filter.value,
FilterOperator::Contains => candidate
.as_str()
.zip(filter.value.as_str())
.is_some_and(|(left, right)| left.contains(right)),
FilterOperator::Gt => {
compare_numbers(candidate, &filter.value).is_some_and(|ord| ord == Ordering::Greater)
}
FilterOperator::Gte => compare_numbers(candidate, &filter.value)
.is_some_and(|ord| ord == Ordering::Greater || ord == Ordering::Equal),
FilterOperator::Lt => {
compare_numbers(candidate, &filter.value).is_some_and(|ord| ord == Ordering::Less)
}
FilterOperator::Lte => compare_numbers(candidate, &filter.value)
.is_some_and(|ord| ord == Ordering::Less || ord == Ordering::Equal),
}
}
fn compare_for_sort(left: &StoredRow, right: &StoredRow, field: &str) -> Ordering {
let left_value = left.data.get(field);
let right_value = right.data.get(field);
match (left_value, right_value) {
(Some(Value::Number(left_num)), Some(Value::Number(right_num))) => left_num
.as_f64()
.partial_cmp(&right_num.as_f64())
.unwrap_or(Ordering::Equal),
(Some(Value::String(left_text)), Some(Value::String(right_text))) => {
left_text.cmp(right_text)
}
_ => left.id.cmp(&right.id),
}
}
fn compare_numbers(left: &Value, right: &Value) -> Option<Ordering> {
left.as_f64()
.zip(right.as_f64())
.and_then(|(left, right)| left.partial_cmp(&right))
}
#[cfg(test)]
mod tests {
use super::{adapter_for, DatabaseConfig, MemoryRepo, Repo, Row};
use crate::{AdapterKind, DataError, Filter, FilterOperator, Query, SortDirection};
use serde_json::json;
#[test]
fn memory_repo_works_for_adapter_selection() {
let mut repo = MemoryRepo::new(
adapter_for(&DatabaseConfig {
adapter: AdapterKind::Sqlite,
url: None,
url_env: None,
})
.unwrap(),
);
let mut row = Row::new();
row.insert("title".to_string(), json!("Alpha"));
row.insert("score".to_string(), json!(10));
repo.insert("posts", row).unwrap();
let rows = repo
.list(
"posts",
&Query::new()
.where_filter(Filter::contains("title", "Al"))
.order_by("score", SortDirection::Desc),
)
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].data.get("title"), Some(&json!("Alpha")));
}
#[test]
fn adapter_for_rejects_none_and_selects_expected_driver() {
let none_result = adapter_for(&DatabaseConfig {
adapter: AdapterKind::None,
url: None,
url_env: None,
});
assert!(matches!(none_result, Err(DataError::Adapter(_))));
for kind in [
AdapterKind::Postgres,
AdapterKind::MySql,
AdapterKind::Sqlite,
] {
let driver = adapter_for(&DatabaseConfig {
adapter: kind,
url: None,
url_env: None,
})
.expect("driver should be created");
assert_eq!(driver.kind(), kind);
}
}
#[test]
fn update_delete_and_find_cover_missing_rows() {
let mut repo = MemoryRepo::new(Box::new(super::SqliteAdapter));
let mut row = Row::new();
row.insert("title".to_string(), json!("Draft"));
let inserted = repo.insert("posts", row).expect("insert should work");
assert_eq!(
repo.find("posts", inserted.id)
.expect("find should not fail")
.map(|it| it.id),
Some(inserted.id)
);
assert!(repo
.find("posts", 999)
.expect("find should not fail")
.is_none());
assert!(repo
.find("missing_table", inserted.id)
.expect("find should not fail")
.is_none());
let mut updated = Row::new();
updated.insert("title".to_string(), json!("Published"));
let updated_row = repo
.update("posts", inserted.id, updated)
.expect("update should work");
assert_eq!(updated_row.data.get("title"), Some(&json!("Published")));
let update_err = repo
.update("posts", 404, Row::new())
.expect_err("missing row should fail update");
assert!(matches!(update_err, DataError::Query(_)));
repo.delete("posts", inserted.id)
.expect("delete should remove row");
let delete_err = repo
.delete("posts", inserted.id)
.expect_err("deleting missing row should fail");
assert!(matches!(delete_err, DataError::Query(_)));
}
#[test]
fn list_applies_filters_sorts_and_pagination() {
let mut repo = MemoryRepo::new(Box::new(super::SqliteAdapter));
let mut alpha = Row::new();
alpha.insert("title".to_string(), json!("Alpha"));
alpha.insert("score".to_string(), json!(10));
alpha.insert("tag".to_string(), json!("core"));
repo.insert("posts", alpha).expect("insert alpha");
let mut beta = Row::new();
beta.insert("title".to_string(), json!("Beta"));
beta.insert("score".to_string(), json!(20));
beta.insert("tag".to_string(), json!("ops"));
repo.insert("posts", beta).expect("insert beta");
let mut gamma = Row::new();
gamma.insert("title".to_string(), json!("Gamma"));
gamma.insert("score".to_string(), json!(15));
gamma.insert("tag".to_string(), json!(123));
repo.insert("posts", gamma).expect("insert gamma");
let eq_rows = repo
.list(
"posts",
&Query::new().where_filter(Filter::eq("title", json!("Alpha"))),
)
.expect("eq filter");
assert_eq!(eq_rows.len(), 1);
assert_eq!(eq_rows[0].data.get("title"), Some(&json!("Alpha")));
let neq_rows = repo
.list(
"posts",
&Query::new().where_filter(crate::Filter {
field: "title".to_string(),
op: FilterOperator::Neq,
value: json!("Alpha"),
}),
)
.expect("neq filter");
assert_eq!(neq_rows.len(), 2);
let contains_rows = repo
.list(
"posts",
&Query::new().where_filter(Filter::contains("title", "mm")),
)
.expect("contains filter");
assert_eq!(contains_rows.len(), 1);
assert_eq!(contains_rows[0].data.get("title"), Some(&json!("Gamma")));
let contains_non_string_rows = repo
.list(
"posts",
&Query::new().where_filter(Filter::contains("tag", "2")),
)
.expect("contains on mixed type");
assert!(contains_non_string_rows.is_empty());
for (op, expected_titles) in [
(FilterOperator::Gt, vec!["Beta"]),
(FilterOperator::Gte, vec!["Beta", "Gamma"]),
(FilterOperator::Lt, vec!["Alpha"]),
(FilterOperator::Lte, vec!["Alpha", "Gamma"]),
] {
let rows = repo
.list(
"posts",
&Query::new().where_filter(crate::Filter {
field: "score".to_string(),
op,
value: json!(15),
}),
)
.expect("numeric filter");
let titles: Vec<&str> = rows
.iter()
.map(|row| {
row.data
.get("title")
.and_then(|value| value.as_str())
.expect("title")
})
.collect();
assert_eq!(titles, expected_titles);
}
let unknown_field_sort = repo
.list(
"posts",
&Query::new()
.order_by("missing", SortDirection::Desc)
.paginate(1, 2),
)
.expect("fallback sort");
assert_eq!(unknown_field_sort.len(), 2);
assert_eq!(unknown_field_sort[0].id, 3);
assert_eq!(unknown_field_sort[1].id, 2);
let score_sort = repo
.list(
"posts",
&Query::new()
.order_by("score", SortDirection::Desc)
.order_by("title", SortDirection::Asc),
)
.expect("score sort");
let score_titles: Vec<&str> = score_sort
.iter()
.map(|row| {
row.data
.get("title")
.and_then(|value| value.as_str())
.expect("title")
})
.collect();
assert_eq!(score_titles, vec!["Beta", "Gamma", "Alpha"]);
}
}