Skip to main content

openauth_core/db/
memory.rs

1//! In-memory database adapter for local development and tests.
2
3use std::cmp::Ordering;
4use std::sync::Arc;
5
6use indexmap::IndexMap;
7use tokio::sync::Mutex;
8
9use super::{
10    auth_schema, run_transaction_without_native_support, AdapterCapabilities, AdapterFuture, Count,
11    Create, DbAdapter, DbRecord, DbSchema, DbValue, Delete, DeleteMany, FindMany, FindOne,
12    JoinAdapter, SchemaCreation, SortDirection, TransactionCallback, Update, UpdateMany, Where,
13    WhereMode, WhereOperator,
14};
15use crate::error::OpenAuthError;
16
17/// Async-safe in-memory adapter backed by shared state.
18#[derive(Debug, Clone, Default)]
19pub struct MemoryAdapter {
20    state: Arc<Mutex<MemoryState>>,
21}
22
23#[derive(Debug, Default)]
24struct MemoryState {
25    records: IndexMap<String, Vec<DbRecord>>,
26}
27
28impl MemoryAdapter {
29    pub fn new() -> Self {
30        Self::default()
31    }
32
33    /// Return a snapshot of all records stored for a model.
34    pub async fn records(&self, model: &str) -> Vec<DbRecord> {
35        self.state
36            .lock()
37            .await
38            .records
39            .get(model)
40            .cloned()
41            .unwrap_or_default()
42    }
43
44    /// Return the number of records stored for a model.
45    pub async fn len(&self, model: &str) -> usize {
46        self.state
47            .lock()
48            .await
49            .records
50            .get(model)
51            .map(Vec::len)
52            .unwrap_or_default()
53    }
54
55    /// Return true when no records are stored for a model.
56    pub async fn is_empty(&self, model: &str) -> bool {
57        self.len(model).await == 0
58    }
59}
60
61impl DbAdapter for MemoryAdapter {
62    fn id(&self) -> &str {
63        "memory"
64    }
65
66    fn capabilities(&self) -> AdapterCapabilities {
67        AdapterCapabilities::new(self.id())
68            .named("Memory Adapter")
69            .with_json()
70            .with_arrays()
71    }
72
73    fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
74        Box::pin(async move {
75            let mut state = self.state.lock().await;
76            state
77                .records
78                .entry(query.model)
79                .or_default()
80                .push(query.data.clone());
81            Ok(select_record(query.data, &query.select))
82        })
83    }
84
85    fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
86        Box::pin(async move {
87            if !query.joins.is_empty() {
88                let adapter = JoinAdapter::new(
89                    auth_schema(Default::default()),
90                    Arc::new(self.clone()),
91                    false,
92                );
93                return adapter.find_one(query).await;
94            }
95            let state = self.state.lock().await;
96            Ok(state.records.get(&query.model).and_then(|records| {
97                records
98                    .iter()
99                    .find(|record| matches_where(record, &query.where_clauses))
100                    .map(|record| select_record(record.clone(), &query.select))
101            }))
102        })
103    }
104
105    fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
106        Box::pin(async move {
107            if !query.joins.is_empty() {
108                let adapter = JoinAdapter::new(
109                    auth_schema(Default::default()),
110                    Arc::new(self.clone()),
111                    false,
112                );
113                return adapter.find_many(query).await;
114            }
115            let state = self.state.lock().await;
116            let mut records = state
117                .records
118                .get(&query.model)
119                .map(|records| {
120                    records
121                        .iter()
122                        .filter(|record| matches_where(record, &query.where_clauses))
123                        .cloned()
124                        .collect::<Vec<_>>()
125                })
126                .unwrap_or_default();
127
128            if let Some(sort) = &query.sort_by {
129                records.sort_by(|left, right| compare_records(left, right, &sort.field));
130                if sort.direction == SortDirection::Desc {
131                    records.reverse();
132                }
133            }
134
135            let offset = query.offset.unwrap_or(0);
136            let iter = records.into_iter().skip(offset);
137            let records: Vec<DbRecord> = match query.limit {
138                Some(limit) => iter.take(limit).collect(),
139                None => iter.collect(),
140            };
141
142            Ok(records
143                .into_iter()
144                .map(|record| select_record(record, &query.select))
145                .collect())
146        })
147    }
148
149    fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
150        Box::pin(async move {
151            let state = self.state.lock().await;
152            let count = state
153                .records
154                .get(&query.model)
155                .map(|records| {
156                    records
157                        .iter()
158                        .filter(|record| matches_where(record, &query.where_clauses))
159                        .count()
160                })
161                .unwrap_or_default();
162            Ok(count as u64)
163        })
164    }
165
166    fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
167        Box::pin(async move {
168            let mut state = self.state.lock().await;
169            let Some(records) = state.records.get_mut(&query.model) else {
170                return Ok(None);
171            };
172            let Some(record) = records
173                .iter_mut()
174                .find(|record| matches_where(record, &query.where_clauses))
175            else {
176                return Ok(None);
177            };
178            apply_update(record, query.data);
179            Ok(Some(record.clone()))
180        })
181    }
182
183    fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
184        Box::pin(async move {
185            let mut state = self.state.lock().await;
186            let Some(records) = state.records.get_mut(&query.model) else {
187                return Ok(0);
188            };
189            let mut updated = 0;
190            for record in records
191                .iter_mut()
192                .filter(|record| matches_where(record, &query.where_clauses))
193            {
194                apply_update(record, query.data.clone());
195                updated += 1;
196            }
197            Ok(updated)
198        })
199    }
200
201    fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
202        Box::pin(async move {
203            let mut state = self.state.lock().await;
204            let Some(records) = state.records.get_mut(&query.model) else {
205                return Ok(());
206            };
207            if let Some(index) = records
208                .iter()
209                .position(|record| matches_where(record, &query.where_clauses))
210            {
211                records.remove(index);
212            }
213            Ok(())
214        })
215    }
216
217    fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
218        Box::pin(async move {
219            let mut state = self.state.lock().await;
220            let Some(records) = state.records.get_mut(&query.model) else {
221                return Ok(0);
222            };
223            let before = records.len();
224            records.retain(|record| !matches_where(record, &query.where_clauses));
225            Ok((before - records.len()) as u64)
226        })
227    }
228
229    fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
230        run_transaction_without_native_support(self, callback)
231    }
232
233    fn create_schema<'a>(
234        &'a self,
235        _schema: &'a DbSchema,
236        _file: Option<&'a str>,
237    ) -> AdapterFuture<'a, Option<SchemaCreation>> {
238        Box::pin(async { Ok(None) })
239    }
240
241    fn run_migrations<'a>(&'a self, _schema: &'a DbSchema) -> AdapterFuture<'a, ()> {
242        Box::pin(async {
243            Err(OpenAuthError::InvalidConfig(
244                "MemoryAdapter does not support migrations".to_owned(),
245            ))
246        })
247    }
248}
249
250fn apply_update(record: &mut DbRecord, data: DbRecord) {
251    for (field, value) in data {
252        record.insert(field, value);
253    }
254}
255
256fn select_record(record: DbRecord, select: &[String]) -> DbRecord {
257    if select.is_empty() {
258        return record;
259    }
260    select
261        .iter()
262        .filter_map(|field| {
263            record
264                .get(field)
265                .cloned()
266                .map(|value| (field.clone(), value))
267        })
268        .collect()
269}
270
271fn matches_where(record: &DbRecord, where_clauses: &[Where]) -> bool {
272    let Some((first, rest)) = where_clauses.split_first() else {
273        return true;
274    };
275    let mut result = matches_clause(record, first);
276    for clause in rest {
277        if clause.connector == super::Connector::Or {
278            result = result || matches_clause(record, clause);
279        } else {
280            result = result && matches_clause(record, clause);
281        }
282    }
283    result
284}
285
286fn matches_clause(record: &DbRecord, clause: &Where) -> bool {
287    let Some(actual) = record.get(&clause.field) else {
288        return false;
289    };
290    match clause.operator {
291        WhereOperator::Eq => values_equal(actual, &clause.value, clause.mode),
292        WhereOperator::Ne => !values_equal(actual, &clause.value, clause.mode),
293        WhereOperator::Lt => compare_values(actual, &clause.value, clause.mode)
294            .is_some_and(|ordering| ordering == Ordering::Less),
295        WhereOperator::Lte => compare_values(actual, &clause.value, clause.mode)
296            .is_some_and(|ordering| ordering != Ordering::Greater),
297        WhereOperator::Gt => compare_values(actual, &clause.value, clause.mode)
298            .is_some_and(|ordering| ordering == Ordering::Greater),
299        WhereOperator::Gte => compare_values(actual, &clause.value, clause.mode)
300            .is_some_and(|ordering| ordering != Ordering::Less),
301        WhereOperator::In => value_in(actual, &clause.value, clause.mode),
302        WhereOperator::NotIn => !value_in(actual, &clause.value, clause.mode),
303        WhereOperator::Contains => {
304            string_predicate(actual, &clause.value, clause.mode, contains_string)
305        }
306        WhereOperator::StartsWith => {
307            string_predicate(actual, &clause.value, clause.mode, starts_with_string)
308        }
309        WhereOperator::EndsWith => {
310            string_predicate(actual, &clause.value, clause.mode, ends_with_string)
311        }
312    }
313}
314
315fn values_equal(left: &DbValue, right: &DbValue, mode: WhereMode) -> bool {
316    match (left, right) {
317        (DbValue::String(left), DbValue::String(right)) => strings_equal(left, right, mode),
318        _ => left == right,
319    }
320}
321
322fn compare_records(left: &DbRecord, right: &DbRecord, field: &str) -> Ordering {
323    match (left.get(field), right.get(field)) {
324        (Some(left), Some(right)) => {
325            compare_values(left, right, WhereMode::Sensitive).unwrap_or(Ordering::Equal)
326        }
327        (Some(_), None) => Ordering::Less,
328        (None, Some(_)) => Ordering::Greater,
329        (None, None) => Ordering::Equal,
330    }
331}
332
333fn compare_values(left: &DbValue, right: &DbValue, mode: WhereMode) -> Option<Ordering> {
334    match (left, right) {
335        (DbValue::String(left), DbValue::String(right)) => Some(compare_strings(left, right, mode)),
336        (DbValue::Number(left), DbValue::Number(right)) => Some(left.cmp(right)),
337        (DbValue::Boolean(left), DbValue::Boolean(right)) => Some(left.cmp(right)),
338        (DbValue::Timestamp(left), DbValue::Timestamp(right)) => left
339            .unix_timestamp_nanos()
340            .partial_cmp(&right.unix_timestamp_nanos()),
341        _ => None,
342    }
343}
344
345fn value_in(actual: &DbValue, expected: &DbValue, mode: WhereMode) -> bool {
346    match expected {
347        DbValue::StringArray(values) => values
348            .iter()
349            .any(|value| values_equal(actual, &DbValue::String(value.clone()), mode)),
350        DbValue::NumberArray(values) => values
351            .iter()
352            .any(|value| values_equal(actual, &DbValue::Number(*value), mode)),
353        DbValue::Json(serde_json::Value::Array(values)) => values.iter().any(|value| {
354            json_value_to_db_value(value)
355                .as_ref()
356                .is_some_and(|candidate| values_equal(actual, candidate, mode))
357        }),
358        _ => false,
359    }
360}
361
362fn json_value_to_db_value(value: &serde_json::Value) -> Option<DbValue> {
363    match value {
364        serde_json::Value::String(value) => Some(DbValue::String(value.clone())),
365        serde_json::Value::Number(value) => value.as_i64().map(DbValue::Number),
366        serde_json::Value::Bool(value) => Some(DbValue::Boolean(*value)),
367        serde_json::Value::Null => Some(DbValue::Null),
368        _ => None,
369    }
370}
371
372fn string_predicate(
373    actual: &DbValue,
374    expected: &DbValue,
375    mode: WhereMode,
376    predicate: fn(&str, &str) -> bool,
377) -> bool {
378    let (DbValue::String(actual), DbValue::String(expected)) = (actual, expected) else {
379        return false;
380    };
381    if mode == WhereMode::Insensitive {
382        return predicate(&actual.to_ascii_lowercase(), &expected.to_ascii_lowercase());
383    }
384    predicate(actual, expected)
385}
386
387fn strings_equal(left: &str, right: &str, mode: WhereMode) -> bool {
388    if mode == WhereMode::Insensitive {
389        return left.eq_ignore_ascii_case(right);
390    }
391    left == right
392}
393
394fn compare_strings(left: &str, right: &str, mode: WhereMode) -> Ordering {
395    if mode == WhereMode::Insensitive {
396        return left.to_ascii_lowercase().cmp(&right.to_ascii_lowercase());
397    }
398    left.cmp(right)
399}
400
401fn contains_string(value: &str, pattern: &str) -> bool {
402    value.contains(pattern)
403}
404
405fn starts_with_string(value: &str, pattern: &str) -> bool {
406    value.starts_with(pattern)
407}
408
409fn ends_with_string(value: &str, pattern: &str) -> bool {
410    value.ends_with(pattern)
411}