Skip to main content

shelly_data/
repo.rs

1use crate::{
2    adapter::{AdapterKind, DatabaseConfig},
3    error::{DataError, DataResult},
4    query::{FilterOperator, Query, SortDirection},
5};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::{cmp::Ordering, collections::BTreeMap, time::Instant};
9use tracing::{info, warn};
10
11pub type Row = BTreeMap<String, Value>;
12
13#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
14pub struct StoredRow {
15    pub id: u64,
16    pub data: Row,
17}
18
19pub trait AdapterDriver: Send + Sync {
20    fn kind(&self) -> AdapterKind;
21}
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct PostgresAdapter;
25
26impl AdapterDriver for PostgresAdapter {
27    fn kind(&self) -> AdapterKind {
28        AdapterKind::Postgres
29    }
30}
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct MySqlAdapter;
34
35impl AdapterDriver for MySqlAdapter {
36    fn kind(&self) -> AdapterKind {
37        AdapterKind::MySql
38    }
39}
40
41#[derive(Debug, Default, Clone, Copy)]
42pub struct SqliteAdapter;
43
44impl AdapterDriver for SqliteAdapter {
45    fn kind(&self) -> AdapterKind {
46        AdapterKind::Sqlite
47    }
48}
49
50pub fn adapter_for(config: &DatabaseConfig) -> DataResult<Box<dyn AdapterDriver>> {
51    match config.adapter {
52        AdapterKind::Postgres => Ok(Box::new(PostgresAdapter)),
53        AdapterKind::MySql => Ok(Box::new(MySqlAdapter)),
54        AdapterKind::Sqlite => Ok(Box::new(SqliteAdapter)),
55        AdapterKind::None => Err(DataError::Adapter(
56            "database adapter is `none`; select postgres/mysql/sqlite in shelly.data.toml"
57                .to_string(),
58        )),
59    }
60}
61
62pub trait Repo {
63    fn adapter_kind(&self) -> AdapterKind;
64    fn insert(&mut self, table: &str, data: Row) -> DataResult<StoredRow>;
65    fn update(&mut self, table: &str, id: u64, data: Row) -> DataResult<StoredRow>;
66    fn delete(&mut self, table: &str, id: u64) -> DataResult<()>;
67    fn find(&self, table: &str, id: u64) -> DataResult<Option<StoredRow>>;
68    fn list(&self, table: &str, query: &Query) -> DataResult<Vec<StoredRow>>;
69}
70
71pub struct MemoryRepo {
72    driver: Box<dyn AdapterDriver>,
73    tables: BTreeMap<String, Vec<StoredRow>>,
74    next_id: u64,
75}
76
77impl MemoryRepo {
78    pub fn new(driver: Box<dyn AdapterDriver>) -> Self {
79        Self {
80            driver,
81            tables: BTreeMap::new(),
82            next_id: 1,
83        }
84    }
85}
86
87impl Repo for MemoryRepo {
88    fn adapter_kind(&self) -> AdapterKind {
89        self.driver.kind()
90    }
91
92    fn insert(&mut self, table: &str, data: Row) -> DataResult<StoredRow> {
93        let started_at = Instant::now();
94        let result = {
95            let entry = self.tables.entry(table.to_string()).or_default();
96            let row = StoredRow {
97                id: self.next_id,
98                data,
99            };
100            self.next_id += 1;
101            entry.push(row.clone());
102            Ok(row)
103        };
104
105        match &result {
106            Ok(row) => info!(
107                target: "shelly.data.query",
108                source = "memory_repo",
109                adapter = self.driver.kind().as_str(),
110                operation = "insert",
111                table,
112                row_id = row.id,
113                duration_ms = started_at.elapsed().as_millis() as u64,
114                "Shelly data query executed"
115            ),
116            Err(err) => warn!(
117                target: "shelly.data.query",
118                source = "memory_repo",
119                adapter = self.driver.kind().as_str(),
120                operation = "insert",
121                table,
122                duration_ms = started_at.elapsed().as_millis() as u64,
123                error = %err,
124                "Shelly data query failed"
125            ),
126        }
127
128        result
129    }
130
131    fn update(&mut self, table: &str, id: u64, data: Row) -> DataResult<StoredRow> {
132        let started_at = Instant::now();
133        let result = {
134            let rows = self.tables.entry(table.to_string()).or_default();
135            match rows.iter_mut().find(|row| row.id == id) {
136                Some(existing) => {
137                    existing.data = data;
138                    Ok(existing.clone())
139                }
140                None => Err(DataError::Query(format!(
141                    "row id {id} not found in table `{table}`"
142                ))),
143            }
144        };
145
146        match &result {
147            Ok(row) => info!(
148                target: "shelly.data.query",
149                source = "memory_repo",
150                adapter = self.driver.kind().as_str(),
151                operation = "update",
152                table,
153                row_id = row.id,
154                duration_ms = started_at.elapsed().as_millis() as u64,
155                "Shelly data query executed"
156            ),
157            Err(err) => warn!(
158                target: "shelly.data.query",
159                source = "memory_repo",
160                adapter = self.driver.kind().as_str(),
161                operation = "update",
162                table,
163                row_id = id,
164                duration_ms = started_at.elapsed().as_millis() as u64,
165                error = %err,
166                "Shelly data query failed"
167            ),
168        }
169
170        result
171    }
172
173    fn delete(&mut self, table: &str, id: u64) -> DataResult<()> {
174        let started_at = Instant::now();
175        let result = {
176            let rows = self.tables.entry(table.to_string()).or_default();
177            let initial_len = rows.len();
178            rows.retain(|row| row.id != id);
179            if rows.len() == initial_len {
180                Err(DataError::Query(format!(
181                    "row id {id} not found in table `{table}`"
182                )))
183            } else {
184                Ok(())
185            }
186        };
187
188        match &result {
189            Ok(()) => info!(
190                target: "shelly.data.query",
191                source = "memory_repo",
192                adapter = self.driver.kind().as_str(),
193                operation = "delete",
194                table,
195                row_id = id,
196                duration_ms = started_at.elapsed().as_millis() as u64,
197                "Shelly data query executed"
198            ),
199            Err(err) => warn!(
200                target: "shelly.data.query",
201                source = "memory_repo",
202                adapter = self.driver.kind().as_str(),
203                operation = "delete",
204                table,
205                row_id = id,
206                duration_ms = started_at.elapsed().as_millis() as u64,
207                error = %err,
208                "Shelly data query failed"
209            ),
210        }
211
212        result
213    }
214
215    fn find(&self, table: &str, id: u64) -> DataResult<Option<StoredRow>> {
216        let started_at = Instant::now();
217        let result = Ok(self
218            .tables
219            .get(table)
220            .and_then(|rows| rows.iter().find(|row| row.id == id))
221            .cloned());
222
223        match &result {
224            Ok(row) => info!(
225                target: "shelly.data.query",
226                source = "memory_repo",
227                adapter = self.driver.kind().as_str(),
228                operation = "find",
229                table,
230                row_id = id,
231                found = row.is_some(),
232                duration_ms = started_at.elapsed().as_millis() as u64,
233                "Shelly data query executed"
234            ),
235            Err(err) => warn!(
236                target: "shelly.data.query",
237                source = "memory_repo",
238                adapter = self.driver.kind().as_str(),
239                operation = "find",
240                table,
241                row_id = id,
242                duration_ms = started_at.elapsed().as_millis() as u64,
243                error = %err,
244                "Shelly data query failed"
245            ),
246        }
247
248        result
249    }
250
251    fn list(&self, table: &str, query: &Query) -> DataResult<Vec<StoredRow>> {
252        let started_at = Instant::now();
253        let result = {
254            let mut rows = self.tables.get(table).cloned().unwrap_or_default();
255
256            if !query.filters.is_empty() {
257                rows.retain(|row| {
258                    query
259                        .filters
260                        .iter()
261                        .all(|filter| matches_filter(row, filter))
262                });
263            }
264
265            for sort in query.sorts.iter().rev() {
266                rows.sort_by(|left, right| compare_for_sort(left, right, sort.field.as_str()));
267                if sort.direction == SortDirection::Desc {
268                    rows.reverse();
269                }
270            }
271
272            if let Some(pagination) = query.pagination {
273                let offset = (pagination.page.saturating_sub(1)) * pagination.per_page;
274                rows = rows
275                    .into_iter()
276                    .skip(offset)
277                    .take(pagination.per_page)
278                    .collect();
279            }
280
281            Ok(rows)
282        };
283
284        match &result {
285            Ok(rows) => info!(
286                target: "shelly.data.query",
287                source = "memory_repo",
288                adapter = self.driver.kind().as_str(),
289                operation = "list",
290                table,
291                row_count = rows.len(),
292                filter_count = query.filters.len(),
293                sort_count = query.sorts.len(),
294                page = query.pagination.map(|value| value.page),
295                per_page = query.pagination.map(|value| value.per_page),
296                duration_ms = started_at.elapsed().as_millis() as u64,
297                "Shelly data query executed"
298            ),
299            Err(err) => warn!(
300                target: "shelly.data.query",
301                source = "memory_repo",
302                adapter = self.driver.kind().as_str(),
303                operation = "list",
304                table,
305                filter_count = query.filters.len(),
306                sort_count = query.sorts.len(),
307                page = query.pagination.map(|value| value.page),
308                per_page = query.pagination.map(|value| value.per_page),
309                duration_ms = started_at.elapsed().as_millis() as u64,
310                error = %err,
311                "Shelly data query failed"
312            ),
313        }
314
315        result
316    }
317}
318
319fn matches_filter(row: &StoredRow, filter: &crate::query::Filter) -> bool {
320    let Some(candidate) = row.data.get(&filter.field) else {
321        return false;
322    };
323    match filter.op {
324        FilterOperator::Eq => candidate == &filter.value,
325        FilterOperator::Neq => candidate != &filter.value,
326        FilterOperator::Contains => candidate
327            .as_str()
328            .zip(filter.value.as_str())
329            .is_some_and(|(left, right)| left.contains(right)),
330        FilterOperator::Gt => {
331            compare_numbers(candidate, &filter.value).is_some_and(|ord| ord == Ordering::Greater)
332        }
333        FilterOperator::Gte => compare_numbers(candidate, &filter.value)
334            .is_some_and(|ord| ord == Ordering::Greater || ord == Ordering::Equal),
335        FilterOperator::Lt => {
336            compare_numbers(candidate, &filter.value).is_some_and(|ord| ord == Ordering::Less)
337        }
338        FilterOperator::Lte => compare_numbers(candidate, &filter.value)
339            .is_some_and(|ord| ord == Ordering::Less || ord == Ordering::Equal),
340    }
341}
342
343fn compare_for_sort(left: &StoredRow, right: &StoredRow, field: &str) -> Ordering {
344    let left_value = left.data.get(field);
345    let right_value = right.data.get(field);
346    match (left_value, right_value) {
347        (Some(Value::Number(left_num)), Some(Value::Number(right_num))) => left_num
348            .as_f64()
349            .partial_cmp(&right_num.as_f64())
350            .unwrap_or(Ordering::Equal),
351        (Some(Value::String(left_text)), Some(Value::String(right_text))) => {
352            left_text.cmp(right_text)
353        }
354        _ => left.id.cmp(&right.id),
355    }
356}
357
358fn compare_numbers(left: &Value, right: &Value) -> Option<Ordering> {
359    left.as_f64()
360        .zip(right.as_f64())
361        .and_then(|(left, right)| left.partial_cmp(&right))
362}
363
364#[cfg(test)]
365mod tests {
366    use super::{adapter_for, DatabaseConfig, MemoryRepo, Repo, Row};
367    use crate::{AdapterKind, DataError, Filter, FilterOperator, Query, SortDirection};
368    use serde_json::json;
369
370    #[test]
371    fn memory_repo_works_for_adapter_selection() {
372        let mut repo = MemoryRepo::new(
373            adapter_for(&DatabaseConfig {
374                adapter: AdapterKind::Sqlite,
375                url: None,
376                url_env: None,
377            })
378            .unwrap(),
379        );
380        let mut row = Row::new();
381        row.insert("title".to_string(), json!("Alpha"));
382        row.insert("score".to_string(), json!(10));
383        repo.insert("posts", row).unwrap();
384
385        let rows = repo
386            .list(
387                "posts",
388                &Query::new()
389                    .where_filter(Filter::contains("title", "Al"))
390                    .order_by("score", SortDirection::Desc),
391            )
392            .unwrap();
393        assert_eq!(rows.len(), 1);
394        assert_eq!(rows[0].data.get("title"), Some(&json!("Alpha")));
395    }
396
397    #[test]
398    fn adapter_for_rejects_none_and_selects_expected_driver() {
399        let none_result = adapter_for(&DatabaseConfig {
400            adapter: AdapterKind::None,
401            url: None,
402            url_env: None,
403        });
404        assert!(matches!(none_result, Err(DataError::Adapter(_))));
405
406        for kind in [
407            AdapterKind::Postgres,
408            AdapterKind::MySql,
409            AdapterKind::Sqlite,
410        ] {
411            let driver = adapter_for(&DatabaseConfig {
412                adapter: kind,
413                url: None,
414                url_env: None,
415            })
416            .expect("driver should be created");
417            assert_eq!(driver.kind(), kind);
418        }
419    }
420
421    #[test]
422    fn update_delete_and_find_cover_missing_rows() {
423        let mut repo = MemoryRepo::new(Box::new(super::SqliteAdapter));
424
425        let mut row = Row::new();
426        row.insert("title".to_string(), json!("Draft"));
427        let inserted = repo.insert("posts", row).expect("insert should work");
428
429        assert_eq!(
430            repo.find("posts", inserted.id)
431                .expect("find should not fail")
432                .map(|it| it.id),
433            Some(inserted.id)
434        );
435        assert!(repo
436            .find("posts", 999)
437            .expect("find should not fail")
438            .is_none());
439        assert!(repo
440            .find("missing_table", inserted.id)
441            .expect("find should not fail")
442            .is_none());
443
444        let mut updated = Row::new();
445        updated.insert("title".to_string(), json!("Published"));
446        let updated_row = repo
447            .update("posts", inserted.id, updated)
448            .expect("update should work");
449        assert_eq!(updated_row.data.get("title"), Some(&json!("Published")));
450
451        let update_err = repo
452            .update("posts", 404, Row::new())
453            .expect_err("missing row should fail update");
454        assert!(matches!(update_err, DataError::Query(_)));
455
456        repo.delete("posts", inserted.id)
457            .expect("delete should remove row");
458        let delete_err = repo
459            .delete("posts", inserted.id)
460            .expect_err("deleting missing row should fail");
461        assert!(matches!(delete_err, DataError::Query(_)));
462    }
463
464    #[test]
465    fn list_applies_filters_sorts_and_pagination() {
466        let mut repo = MemoryRepo::new(Box::new(super::SqliteAdapter));
467
468        let mut alpha = Row::new();
469        alpha.insert("title".to_string(), json!("Alpha"));
470        alpha.insert("score".to_string(), json!(10));
471        alpha.insert("tag".to_string(), json!("core"));
472        repo.insert("posts", alpha).expect("insert alpha");
473
474        let mut beta = Row::new();
475        beta.insert("title".to_string(), json!("Beta"));
476        beta.insert("score".to_string(), json!(20));
477        beta.insert("tag".to_string(), json!("ops"));
478        repo.insert("posts", beta).expect("insert beta");
479
480        let mut gamma = Row::new();
481        gamma.insert("title".to_string(), json!("Gamma"));
482        gamma.insert("score".to_string(), json!(15));
483        gamma.insert("tag".to_string(), json!(123));
484        repo.insert("posts", gamma).expect("insert gamma");
485
486        let eq_rows = repo
487            .list(
488                "posts",
489                &Query::new().where_filter(Filter::eq("title", json!("Alpha"))),
490            )
491            .expect("eq filter");
492        assert_eq!(eq_rows.len(), 1);
493        assert_eq!(eq_rows[0].data.get("title"), Some(&json!("Alpha")));
494
495        let neq_rows = repo
496            .list(
497                "posts",
498                &Query::new().where_filter(crate::Filter {
499                    field: "title".to_string(),
500                    op: FilterOperator::Neq,
501                    value: json!("Alpha"),
502                }),
503            )
504            .expect("neq filter");
505        assert_eq!(neq_rows.len(), 2);
506
507        let contains_rows = repo
508            .list(
509                "posts",
510                &Query::new().where_filter(Filter::contains("title", "mm")),
511            )
512            .expect("contains filter");
513        assert_eq!(contains_rows.len(), 1);
514        assert_eq!(contains_rows[0].data.get("title"), Some(&json!("Gamma")));
515
516        let contains_non_string_rows = repo
517            .list(
518                "posts",
519                &Query::new().where_filter(Filter::contains("tag", "2")),
520            )
521            .expect("contains on mixed type");
522        assert!(contains_non_string_rows.is_empty());
523
524        for (op, expected_titles) in [
525            (FilterOperator::Gt, vec!["Beta"]),
526            (FilterOperator::Gte, vec!["Beta", "Gamma"]),
527            (FilterOperator::Lt, vec!["Alpha"]),
528            (FilterOperator::Lte, vec!["Alpha", "Gamma"]),
529        ] {
530            let rows = repo
531                .list(
532                    "posts",
533                    &Query::new().where_filter(crate::Filter {
534                        field: "score".to_string(),
535                        op,
536                        value: json!(15),
537                    }),
538                )
539                .expect("numeric filter");
540            let titles: Vec<&str> = rows
541                .iter()
542                .map(|row| {
543                    row.data
544                        .get("title")
545                        .and_then(|value| value.as_str())
546                        .expect("title")
547                })
548                .collect();
549            assert_eq!(titles, expected_titles);
550        }
551
552        let unknown_field_sort = repo
553            .list(
554                "posts",
555                &Query::new()
556                    .order_by("missing", SortDirection::Desc)
557                    .paginate(1, 2),
558            )
559            .expect("fallback sort");
560        assert_eq!(unknown_field_sort.len(), 2);
561        assert_eq!(unknown_field_sort[0].id, 3);
562        assert_eq!(unknown_field_sort[1].id, 2);
563
564        let score_sort = repo
565            .list(
566                "posts",
567                &Query::new()
568                    .order_by("score", SortDirection::Desc)
569                    .order_by("title", SortDirection::Asc),
570            )
571            .expect("score sort");
572        let score_titles: Vec<&str> = score_sort
573            .iter()
574            .map(|row| {
575                row.data
576                    .get("title")
577                    .and_then(|value| value.as_str())
578                    .expect("title")
579            })
580            .collect();
581        assert_eq!(score_titles, vec!["Beta", "Gamma", "Alpha"]);
582    }
583}