1use crate::{
2 adapter::{AdapterKind, DatabaseConfig},
3 error::{DataError, DataResult},
4 query::{FilterOperator, Query, SortDirection},
5};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::{cmp::Ordering, collections::BTreeMap, time::Instant};
9use tracing::{info, warn};
10
11pub type Row = BTreeMap<String, Value>;
12
13#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
14pub struct StoredRow {
15 pub id: u64,
16 pub data: Row,
17}
18
19pub trait AdapterDriver: Send + Sync {
20 fn kind(&self) -> AdapterKind;
21}
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct PostgresAdapter;
25
26impl AdapterDriver for PostgresAdapter {
27 fn kind(&self) -> AdapterKind {
28 AdapterKind::Postgres
29 }
30}
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct MySqlAdapter;
34
35impl AdapterDriver for MySqlAdapter {
36 fn kind(&self) -> AdapterKind {
37 AdapterKind::MySql
38 }
39}
40
41#[derive(Debug, Default, Clone, Copy)]
42pub struct SqliteAdapter;
43
44impl AdapterDriver for SqliteAdapter {
45 fn kind(&self) -> AdapterKind {
46 AdapterKind::Sqlite
47 }
48}
49
50pub fn adapter_for(config: &DatabaseConfig) -> DataResult<Box<dyn AdapterDriver>> {
51 match config.adapter {
52 AdapterKind::Postgres => Ok(Box::new(PostgresAdapter)),
53 AdapterKind::MySql => Ok(Box::new(MySqlAdapter)),
54 AdapterKind::Sqlite => Ok(Box::new(SqliteAdapter)),
55 AdapterKind::None => Err(DataError::Adapter(
56 "database adapter is `none`; select postgres/mysql/sqlite in shelly.data.toml"
57 .to_string(),
58 )),
59 }
60}
61
62pub trait Repo {
63 fn adapter_kind(&self) -> AdapterKind;
64 fn insert(&mut self, table: &str, data: Row) -> DataResult<StoredRow>;
65 fn update(&mut self, table: &str, id: u64, data: Row) -> DataResult<StoredRow>;
66 fn delete(&mut self, table: &str, id: u64) -> DataResult<()>;
67 fn find(&self, table: &str, id: u64) -> DataResult<Option<StoredRow>>;
68 fn list(&self, table: &str, query: &Query) -> DataResult<Vec<StoredRow>>;
69}
70
71pub struct MemoryRepo {
72 driver: Box<dyn AdapterDriver>,
73 tables: BTreeMap<String, Vec<StoredRow>>,
74 next_id: u64,
75}
76
77impl MemoryRepo {
78 pub fn new(driver: Box<dyn AdapterDriver>) -> Self {
79 Self {
80 driver,
81 tables: BTreeMap::new(),
82 next_id: 1,
83 }
84 }
85}
86
87impl Repo for MemoryRepo {
88 fn adapter_kind(&self) -> AdapterKind {
89 self.driver.kind()
90 }
91
92 fn insert(&mut self, table: &str, data: Row) -> DataResult<StoredRow> {
93 let started_at = Instant::now();
94 let result = {
95 let entry = self.tables.entry(table.to_string()).or_default();
96 let row = StoredRow {
97 id: self.next_id,
98 data,
99 };
100 self.next_id += 1;
101 entry.push(row.clone());
102 Ok(row)
103 };
104
105 match &result {
106 Ok(row) => info!(
107 target: "shelly.data.query",
108 source = "memory_repo",
109 adapter = self.driver.kind().as_str(),
110 operation = "insert",
111 table,
112 row_id = row.id,
113 duration_ms = started_at.elapsed().as_millis() as u64,
114 "Shelly data query executed"
115 ),
116 Err(err) => warn!(
117 target: "shelly.data.query",
118 source = "memory_repo",
119 adapter = self.driver.kind().as_str(),
120 operation = "insert",
121 table,
122 duration_ms = started_at.elapsed().as_millis() as u64,
123 error = %err,
124 "Shelly data query failed"
125 ),
126 }
127
128 result
129 }
130
131 fn update(&mut self, table: &str, id: u64, data: Row) -> DataResult<StoredRow> {
132 let started_at = Instant::now();
133 let result = {
134 let rows = self.tables.entry(table.to_string()).or_default();
135 match rows.iter_mut().find(|row| row.id == id) {
136 Some(existing) => {
137 existing.data = data;
138 Ok(existing.clone())
139 }
140 None => Err(DataError::Query(format!(
141 "row id {id} not found in table `{table}`"
142 ))),
143 }
144 };
145
146 match &result {
147 Ok(row) => info!(
148 target: "shelly.data.query",
149 source = "memory_repo",
150 adapter = self.driver.kind().as_str(),
151 operation = "update",
152 table,
153 row_id = row.id,
154 duration_ms = started_at.elapsed().as_millis() as u64,
155 "Shelly data query executed"
156 ),
157 Err(err) => warn!(
158 target: "shelly.data.query",
159 source = "memory_repo",
160 adapter = self.driver.kind().as_str(),
161 operation = "update",
162 table,
163 row_id = id,
164 duration_ms = started_at.elapsed().as_millis() as u64,
165 error = %err,
166 "Shelly data query failed"
167 ),
168 }
169
170 result
171 }
172
173 fn delete(&mut self, table: &str, id: u64) -> DataResult<()> {
174 let started_at = Instant::now();
175 let result = {
176 let rows = self.tables.entry(table.to_string()).or_default();
177 let initial_len = rows.len();
178 rows.retain(|row| row.id != id);
179 if rows.len() == initial_len {
180 Err(DataError::Query(format!(
181 "row id {id} not found in table `{table}`"
182 )))
183 } else {
184 Ok(())
185 }
186 };
187
188 match &result {
189 Ok(()) => info!(
190 target: "shelly.data.query",
191 source = "memory_repo",
192 adapter = self.driver.kind().as_str(),
193 operation = "delete",
194 table,
195 row_id = id,
196 duration_ms = started_at.elapsed().as_millis() as u64,
197 "Shelly data query executed"
198 ),
199 Err(err) => warn!(
200 target: "shelly.data.query",
201 source = "memory_repo",
202 adapter = self.driver.kind().as_str(),
203 operation = "delete",
204 table,
205 row_id = id,
206 duration_ms = started_at.elapsed().as_millis() as u64,
207 error = %err,
208 "Shelly data query failed"
209 ),
210 }
211
212 result
213 }
214
215 fn find(&self, table: &str, id: u64) -> DataResult<Option<StoredRow>> {
216 let started_at = Instant::now();
217 let result = Ok(self
218 .tables
219 .get(table)
220 .and_then(|rows| rows.iter().find(|row| row.id == id))
221 .cloned());
222
223 match &result {
224 Ok(row) => info!(
225 target: "shelly.data.query",
226 source = "memory_repo",
227 adapter = self.driver.kind().as_str(),
228 operation = "find",
229 table,
230 row_id = id,
231 found = row.is_some(),
232 duration_ms = started_at.elapsed().as_millis() as u64,
233 "Shelly data query executed"
234 ),
235 Err(err) => warn!(
236 target: "shelly.data.query",
237 source = "memory_repo",
238 adapter = self.driver.kind().as_str(),
239 operation = "find",
240 table,
241 row_id = id,
242 duration_ms = started_at.elapsed().as_millis() as u64,
243 error = %err,
244 "Shelly data query failed"
245 ),
246 }
247
248 result
249 }
250
251 fn list(&self, table: &str, query: &Query) -> DataResult<Vec<StoredRow>> {
252 let started_at = Instant::now();
253 let result = {
254 let mut rows = self.tables.get(table).cloned().unwrap_or_default();
255
256 if !query.filters.is_empty() {
257 rows.retain(|row| {
258 query
259 .filters
260 .iter()
261 .all(|filter| matches_filter(row, filter))
262 });
263 }
264
265 for sort in query.sorts.iter().rev() {
266 rows.sort_by(|left, right| compare_for_sort(left, right, sort.field.as_str()));
267 if sort.direction == SortDirection::Desc {
268 rows.reverse();
269 }
270 }
271
272 if let Some(pagination) = query.pagination {
273 let offset = (pagination.page.saturating_sub(1)) * pagination.per_page;
274 rows = rows
275 .into_iter()
276 .skip(offset)
277 .take(pagination.per_page)
278 .collect();
279 }
280
281 Ok(rows)
282 };
283
284 match &result {
285 Ok(rows) => info!(
286 target: "shelly.data.query",
287 source = "memory_repo",
288 adapter = self.driver.kind().as_str(),
289 operation = "list",
290 table,
291 row_count = rows.len(),
292 filter_count = query.filters.len(),
293 sort_count = query.sorts.len(),
294 page = query.pagination.map(|value| value.page),
295 per_page = query.pagination.map(|value| value.per_page),
296 duration_ms = started_at.elapsed().as_millis() as u64,
297 "Shelly data query executed"
298 ),
299 Err(err) => warn!(
300 target: "shelly.data.query",
301 source = "memory_repo",
302 adapter = self.driver.kind().as_str(),
303 operation = "list",
304 table,
305 filter_count = query.filters.len(),
306 sort_count = query.sorts.len(),
307 page = query.pagination.map(|value| value.page),
308 per_page = query.pagination.map(|value| value.per_page),
309 duration_ms = started_at.elapsed().as_millis() as u64,
310 error = %err,
311 "Shelly data query failed"
312 ),
313 }
314
315 result
316 }
317}
318
319fn matches_filter(row: &StoredRow, filter: &crate::query::Filter) -> bool {
320 let Some(candidate) = row.data.get(&filter.field) else {
321 return false;
322 };
323 match filter.op {
324 FilterOperator::Eq => candidate == &filter.value,
325 FilterOperator::Neq => candidate != &filter.value,
326 FilterOperator::Contains => candidate
327 .as_str()
328 .zip(filter.value.as_str())
329 .is_some_and(|(left, right)| left.contains(right)),
330 FilterOperator::Gt => {
331 compare_numbers(candidate, &filter.value).is_some_and(|ord| ord == Ordering::Greater)
332 }
333 FilterOperator::Gte => compare_numbers(candidate, &filter.value)
334 .is_some_and(|ord| ord == Ordering::Greater || ord == Ordering::Equal),
335 FilterOperator::Lt => {
336 compare_numbers(candidate, &filter.value).is_some_and(|ord| ord == Ordering::Less)
337 }
338 FilterOperator::Lte => compare_numbers(candidate, &filter.value)
339 .is_some_and(|ord| ord == Ordering::Less || ord == Ordering::Equal),
340 }
341}
342
343fn compare_for_sort(left: &StoredRow, right: &StoredRow, field: &str) -> Ordering {
344 let left_value = left.data.get(field);
345 let right_value = right.data.get(field);
346 match (left_value, right_value) {
347 (Some(Value::Number(left_num)), Some(Value::Number(right_num))) => left_num
348 .as_f64()
349 .partial_cmp(&right_num.as_f64())
350 .unwrap_or(Ordering::Equal),
351 (Some(Value::String(left_text)), Some(Value::String(right_text))) => {
352 left_text.cmp(right_text)
353 }
354 _ => left.id.cmp(&right.id),
355 }
356}
357
358fn compare_numbers(left: &Value, right: &Value) -> Option<Ordering> {
359 left.as_f64()
360 .zip(right.as_f64())
361 .and_then(|(left, right)| left.partial_cmp(&right))
362}
363
364#[cfg(test)]
365mod tests {
366 use super::{adapter_for, DatabaseConfig, MemoryRepo, Repo, Row};
367 use crate::{AdapterKind, DataError, Filter, FilterOperator, Query, SortDirection};
368 use serde_json::json;
369
370 #[test]
371 fn memory_repo_works_for_adapter_selection() {
372 let mut repo = MemoryRepo::new(
373 adapter_for(&DatabaseConfig {
374 adapter: AdapterKind::Sqlite,
375 url: None,
376 url_env: None,
377 })
378 .unwrap(),
379 );
380 let mut row = Row::new();
381 row.insert("title".to_string(), json!("Alpha"));
382 row.insert("score".to_string(), json!(10));
383 repo.insert("posts", row).unwrap();
384
385 let rows = repo
386 .list(
387 "posts",
388 &Query::new()
389 .where_filter(Filter::contains("title", "Al"))
390 .order_by("score", SortDirection::Desc),
391 )
392 .unwrap();
393 assert_eq!(rows.len(), 1);
394 assert_eq!(rows[0].data.get("title"), Some(&json!("Alpha")));
395 }
396
397 #[test]
398 fn adapter_for_rejects_none_and_selects_expected_driver() {
399 let none_result = adapter_for(&DatabaseConfig {
400 adapter: AdapterKind::None,
401 url: None,
402 url_env: None,
403 });
404 assert!(matches!(none_result, Err(DataError::Adapter(_))));
405
406 for kind in [
407 AdapterKind::Postgres,
408 AdapterKind::MySql,
409 AdapterKind::Sqlite,
410 ] {
411 let driver = adapter_for(&DatabaseConfig {
412 adapter: kind,
413 url: None,
414 url_env: None,
415 })
416 .expect("driver should be created");
417 assert_eq!(driver.kind(), kind);
418 }
419 }
420
421 #[test]
422 fn update_delete_and_find_cover_missing_rows() {
423 let mut repo = MemoryRepo::new(Box::new(super::SqliteAdapter));
424
425 let mut row = Row::new();
426 row.insert("title".to_string(), json!("Draft"));
427 let inserted = repo.insert("posts", row).expect("insert should work");
428
429 assert_eq!(
430 repo.find("posts", inserted.id)
431 .expect("find should not fail")
432 .map(|it| it.id),
433 Some(inserted.id)
434 );
435 assert!(repo
436 .find("posts", 999)
437 .expect("find should not fail")
438 .is_none());
439 assert!(repo
440 .find("missing_table", inserted.id)
441 .expect("find should not fail")
442 .is_none());
443
444 let mut updated = Row::new();
445 updated.insert("title".to_string(), json!("Published"));
446 let updated_row = repo
447 .update("posts", inserted.id, updated)
448 .expect("update should work");
449 assert_eq!(updated_row.data.get("title"), Some(&json!("Published")));
450
451 let update_err = repo
452 .update("posts", 404, Row::new())
453 .expect_err("missing row should fail update");
454 assert!(matches!(update_err, DataError::Query(_)));
455
456 repo.delete("posts", inserted.id)
457 .expect("delete should remove row");
458 let delete_err = repo
459 .delete("posts", inserted.id)
460 .expect_err("deleting missing row should fail");
461 assert!(matches!(delete_err, DataError::Query(_)));
462 }
463
464 #[test]
465 fn list_applies_filters_sorts_and_pagination() {
466 let mut repo = MemoryRepo::new(Box::new(super::SqliteAdapter));
467
468 let mut alpha = Row::new();
469 alpha.insert("title".to_string(), json!("Alpha"));
470 alpha.insert("score".to_string(), json!(10));
471 alpha.insert("tag".to_string(), json!("core"));
472 repo.insert("posts", alpha).expect("insert alpha");
473
474 let mut beta = Row::new();
475 beta.insert("title".to_string(), json!("Beta"));
476 beta.insert("score".to_string(), json!(20));
477 beta.insert("tag".to_string(), json!("ops"));
478 repo.insert("posts", beta).expect("insert beta");
479
480 let mut gamma = Row::new();
481 gamma.insert("title".to_string(), json!("Gamma"));
482 gamma.insert("score".to_string(), json!(15));
483 gamma.insert("tag".to_string(), json!(123));
484 repo.insert("posts", gamma).expect("insert gamma");
485
486 let eq_rows = repo
487 .list(
488 "posts",
489 &Query::new().where_filter(Filter::eq("title", json!("Alpha"))),
490 )
491 .expect("eq filter");
492 assert_eq!(eq_rows.len(), 1);
493 assert_eq!(eq_rows[0].data.get("title"), Some(&json!("Alpha")));
494
495 let neq_rows = repo
496 .list(
497 "posts",
498 &Query::new().where_filter(crate::Filter {
499 field: "title".to_string(),
500 op: FilterOperator::Neq,
501 value: json!("Alpha"),
502 }),
503 )
504 .expect("neq filter");
505 assert_eq!(neq_rows.len(), 2);
506
507 let contains_rows = repo
508 .list(
509 "posts",
510 &Query::new().where_filter(Filter::contains("title", "mm")),
511 )
512 .expect("contains filter");
513 assert_eq!(contains_rows.len(), 1);
514 assert_eq!(contains_rows[0].data.get("title"), Some(&json!("Gamma")));
515
516 let contains_non_string_rows = repo
517 .list(
518 "posts",
519 &Query::new().where_filter(Filter::contains("tag", "2")),
520 )
521 .expect("contains on mixed type");
522 assert!(contains_non_string_rows.is_empty());
523
524 for (op, expected_titles) in [
525 (FilterOperator::Gt, vec!["Beta"]),
526 (FilterOperator::Gte, vec!["Beta", "Gamma"]),
527 (FilterOperator::Lt, vec!["Alpha"]),
528 (FilterOperator::Lte, vec!["Alpha", "Gamma"]),
529 ] {
530 let rows = repo
531 .list(
532 "posts",
533 &Query::new().where_filter(crate::Filter {
534 field: "score".to_string(),
535 op,
536 value: json!(15),
537 }),
538 )
539 .expect("numeric filter");
540 let titles: Vec<&str> = rows
541 .iter()
542 .map(|row| {
543 row.data
544 .get("title")
545 .and_then(|value| value.as_str())
546 .expect("title")
547 })
548 .collect();
549 assert_eq!(titles, expected_titles);
550 }
551
552 let unknown_field_sort = repo
553 .list(
554 "posts",
555 &Query::new()
556 .order_by("missing", SortDirection::Desc)
557 .paginate(1, 2),
558 )
559 .expect("fallback sort");
560 assert_eq!(unknown_field_sort.len(), 2);
561 assert_eq!(unknown_field_sort[0].id, 3);
562 assert_eq!(unknown_field_sort[1].id, 2);
563
564 let score_sort = repo
565 .list(
566 "posts",
567 &Query::new()
568 .order_by("score", SortDirection::Desc)
569 .order_by("title", SortDirection::Asc),
570 )
571 .expect("score sort");
572 let score_titles: Vec<&str> = score_sort
573 .iter()
574 .map(|row| {
575 row.data
576 .get("title")
577 .and_then(|value| value.as_str())
578 .expect("title")
579 })
580 .collect();
581 assert_eq!(score_titles, vec!["Beta", "Gamma", "Alpha"]);
582 }
583}