1use std::cmp::Ordering;
4use std::sync::Arc;
5
6use indexmap::IndexMap;
7use tokio::sync::Mutex;
8
9use super::{
10 auth_schema, run_transaction_without_native_support, AdapterCapabilities, AdapterFuture, Count,
11 Create, DbAdapter, DbRecord, DbSchema, DbValue, Delete, DeleteMany, FindMany, FindOne,
12 JoinAdapter, SchemaCreation, SortDirection, TransactionCallback, Update, UpdateMany, Where,
13 WhereMode, WhereOperator,
14};
15use crate::error::OpenAuthError;
16
17#[derive(Debug, Clone, Default)]
19pub struct MemoryAdapter {
20 state: Arc<Mutex<MemoryState>>,
21}
22
23#[derive(Debug, Default)]
24struct MemoryState {
25 records: IndexMap<String, Vec<DbRecord>>,
26}
27
28impl MemoryAdapter {
29 pub fn new() -> Self {
30 Self::default()
31 }
32
33 pub async fn records(&self, model: &str) -> Vec<DbRecord> {
35 self.state
36 .lock()
37 .await
38 .records
39 .get(model)
40 .cloned()
41 .unwrap_or_default()
42 }
43
44 pub async fn len(&self, model: &str) -> usize {
46 self.state
47 .lock()
48 .await
49 .records
50 .get(model)
51 .map(Vec::len)
52 .unwrap_or_default()
53 }
54
55 pub async fn is_empty(&self, model: &str) -> bool {
57 self.len(model).await == 0
58 }
59}
60
61impl DbAdapter for MemoryAdapter {
62 fn id(&self) -> &str {
63 "memory"
64 }
65
66 fn capabilities(&self) -> AdapterCapabilities {
67 AdapterCapabilities::new(self.id())
68 .named("Memory Adapter")
69 .with_json()
70 .with_arrays()
71 }
72
73 fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
74 Box::pin(async move {
75 let mut state = self.state.lock().await;
76 state
77 .records
78 .entry(query.model)
79 .or_default()
80 .push(query.data.clone());
81 Ok(select_record(query.data, &query.select))
82 })
83 }
84
85 fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
86 Box::pin(async move {
87 if !query.joins.is_empty() {
88 let adapter = JoinAdapter::new(
89 auth_schema(Default::default()),
90 Arc::new(self.clone()),
91 false,
92 );
93 return adapter.find_one(query).await;
94 }
95 let state = self.state.lock().await;
96 Ok(state.records.get(&query.model).and_then(|records| {
97 records
98 .iter()
99 .find(|record| matches_where(record, &query.where_clauses))
100 .map(|record| select_record(record.clone(), &query.select))
101 }))
102 })
103 }
104
105 fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
106 Box::pin(async move {
107 if !query.joins.is_empty() {
108 let adapter = JoinAdapter::new(
109 auth_schema(Default::default()),
110 Arc::new(self.clone()),
111 false,
112 );
113 return adapter.find_many(query).await;
114 }
115 let state = self.state.lock().await;
116 let mut records = state
117 .records
118 .get(&query.model)
119 .map(|records| {
120 records
121 .iter()
122 .filter(|record| matches_where(record, &query.where_clauses))
123 .cloned()
124 .collect::<Vec<_>>()
125 })
126 .unwrap_or_default();
127
128 if let Some(sort) = &query.sort_by {
129 records.sort_by(|left, right| compare_records(left, right, &sort.field));
130 if sort.direction == SortDirection::Desc {
131 records.reverse();
132 }
133 }
134
135 let offset = query.offset.unwrap_or(0);
136 let iter = records.into_iter().skip(offset);
137 let records: Vec<DbRecord> = match query.limit {
138 Some(limit) => iter.take(limit).collect(),
139 None => iter.collect(),
140 };
141
142 Ok(records
143 .into_iter()
144 .map(|record| select_record(record, &query.select))
145 .collect())
146 })
147 }
148
149 fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
150 Box::pin(async move {
151 let state = self.state.lock().await;
152 let count = state
153 .records
154 .get(&query.model)
155 .map(|records| {
156 records
157 .iter()
158 .filter(|record| matches_where(record, &query.where_clauses))
159 .count()
160 })
161 .unwrap_or_default();
162 Ok(count as u64)
163 })
164 }
165
166 fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
167 Box::pin(async move {
168 let mut state = self.state.lock().await;
169 let Some(records) = state.records.get_mut(&query.model) else {
170 return Ok(None);
171 };
172 let Some(record) = records
173 .iter_mut()
174 .find(|record| matches_where(record, &query.where_clauses))
175 else {
176 return Ok(None);
177 };
178 apply_update(record, query.data);
179 Ok(Some(record.clone()))
180 })
181 }
182
183 fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
184 Box::pin(async move {
185 let mut state = self.state.lock().await;
186 let Some(records) = state.records.get_mut(&query.model) else {
187 return Ok(0);
188 };
189 let mut updated = 0;
190 for record in records
191 .iter_mut()
192 .filter(|record| matches_where(record, &query.where_clauses))
193 {
194 apply_update(record, query.data.clone());
195 updated += 1;
196 }
197 Ok(updated)
198 })
199 }
200
201 fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
202 Box::pin(async move {
203 let mut state = self.state.lock().await;
204 let Some(records) = state.records.get_mut(&query.model) else {
205 return Ok(());
206 };
207 if let Some(index) = records
208 .iter()
209 .position(|record| matches_where(record, &query.where_clauses))
210 {
211 records.remove(index);
212 }
213 Ok(())
214 })
215 }
216
217 fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
218 Box::pin(async move {
219 let mut state = self.state.lock().await;
220 let Some(records) = state.records.get_mut(&query.model) else {
221 return Ok(0);
222 };
223 let before = records.len();
224 records.retain(|record| !matches_where(record, &query.where_clauses));
225 Ok((before - records.len()) as u64)
226 })
227 }
228
229 fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
230 run_transaction_without_native_support(self, callback)
231 }
232
233 fn create_schema<'a>(
234 &'a self,
235 _schema: &'a DbSchema,
236 _file: Option<&'a str>,
237 ) -> AdapterFuture<'a, Option<SchemaCreation>> {
238 Box::pin(async { Ok(None) })
239 }
240
241 fn run_migrations<'a>(&'a self, _schema: &'a DbSchema) -> AdapterFuture<'a, ()> {
242 Box::pin(async {
243 Err(OpenAuthError::InvalidConfig(
244 "MemoryAdapter does not support migrations".to_owned(),
245 ))
246 })
247 }
248}
249
250fn apply_update(record: &mut DbRecord, data: DbRecord) {
251 for (field, value) in data {
252 record.insert(field, value);
253 }
254}
255
256fn select_record(record: DbRecord, select: &[String]) -> DbRecord {
257 if select.is_empty() {
258 return record;
259 }
260 select
261 .iter()
262 .filter_map(|field| {
263 record
264 .get(field)
265 .cloned()
266 .map(|value| (field.clone(), value))
267 })
268 .collect()
269}
270
271fn matches_where(record: &DbRecord, where_clauses: &[Where]) -> bool {
272 let Some((first, rest)) = where_clauses.split_first() else {
273 return true;
274 };
275 let mut result = matches_clause(record, first);
276 for clause in rest {
277 if clause.connector == super::Connector::Or {
278 result = result || matches_clause(record, clause);
279 } else {
280 result = result && matches_clause(record, clause);
281 }
282 }
283 result
284}
285
286fn matches_clause(record: &DbRecord, clause: &Where) -> bool {
287 let Some(actual) = record.get(&clause.field) else {
288 return false;
289 };
290 match clause.operator {
291 WhereOperator::Eq => values_equal(actual, &clause.value, clause.mode),
292 WhereOperator::Ne => !values_equal(actual, &clause.value, clause.mode),
293 WhereOperator::Lt => compare_values(actual, &clause.value, clause.mode)
294 .is_some_and(|ordering| ordering == Ordering::Less),
295 WhereOperator::Lte => compare_values(actual, &clause.value, clause.mode)
296 .is_some_and(|ordering| ordering != Ordering::Greater),
297 WhereOperator::Gt => compare_values(actual, &clause.value, clause.mode)
298 .is_some_and(|ordering| ordering == Ordering::Greater),
299 WhereOperator::Gte => compare_values(actual, &clause.value, clause.mode)
300 .is_some_and(|ordering| ordering != Ordering::Less),
301 WhereOperator::In => value_in(actual, &clause.value, clause.mode),
302 WhereOperator::NotIn => !value_in(actual, &clause.value, clause.mode),
303 WhereOperator::Contains => {
304 string_predicate(actual, &clause.value, clause.mode, contains_string)
305 }
306 WhereOperator::StartsWith => {
307 string_predicate(actual, &clause.value, clause.mode, starts_with_string)
308 }
309 WhereOperator::EndsWith => {
310 string_predicate(actual, &clause.value, clause.mode, ends_with_string)
311 }
312 }
313}
314
315fn values_equal(left: &DbValue, right: &DbValue, mode: WhereMode) -> bool {
316 match (left, right) {
317 (DbValue::String(left), DbValue::String(right)) => strings_equal(left, right, mode),
318 _ => left == right,
319 }
320}
321
322fn compare_records(left: &DbRecord, right: &DbRecord, field: &str) -> Ordering {
323 match (left.get(field), right.get(field)) {
324 (Some(left), Some(right)) => {
325 compare_values(left, right, WhereMode::Sensitive).unwrap_or(Ordering::Equal)
326 }
327 (Some(_), None) => Ordering::Less,
328 (None, Some(_)) => Ordering::Greater,
329 (None, None) => Ordering::Equal,
330 }
331}
332
333fn compare_values(left: &DbValue, right: &DbValue, mode: WhereMode) -> Option<Ordering> {
334 match (left, right) {
335 (DbValue::String(left), DbValue::String(right)) => Some(compare_strings(left, right, mode)),
336 (DbValue::Number(left), DbValue::Number(right)) => Some(left.cmp(right)),
337 (DbValue::Boolean(left), DbValue::Boolean(right)) => Some(left.cmp(right)),
338 (DbValue::Timestamp(left), DbValue::Timestamp(right)) => left
339 .unix_timestamp_nanos()
340 .partial_cmp(&right.unix_timestamp_nanos()),
341 _ => None,
342 }
343}
344
345fn value_in(actual: &DbValue, expected: &DbValue, mode: WhereMode) -> bool {
346 match expected {
347 DbValue::StringArray(values) => values
348 .iter()
349 .any(|value| values_equal(actual, &DbValue::String(value.clone()), mode)),
350 DbValue::NumberArray(values) => values
351 .iter()
352 .any(|value| values_equal(actual, &DbValue::Number(*value), mode)),
353 DbValue::Json(serde_json::Value::Array(values)) => values.iter().any(|value| {
354 json_value_to_db_value(value)
355 .as_ref()
356 .is_some_and(|candidate| values_equal(actual, candidate, mode))
357 }),
358 _ => false,
359 }
360}
361
362fn json_value_to_db_value(value: &serde_json::Value) -> Option<DbValue> {
363 match value {
364 serde_json::Value::String(value) => Some(DbValue::String(value.clone())),
365 serde_json::Value::Number(value) => value.as_i64().map(DbValue::Number),
366 serde_json::Value::Bool(value) => Some(DbValue::Boolean(*value)),
367 serde_json::Value::Null => Some(DbValue::Null),
368 _ => None,
369 }
370}
371
372fn string_predicate(
373 actual: &DbValue,
374 expected: &DbValue,
375 mode: WhereMode,
376 predicate: fn(&str, &str) -> bool,
377) -> bool {
378 let (DbValue::String(actual), DbValue::String(expected)) = (actual, expected) else {
379 return false;
380 };
381 if mode == WhereMode::Insensitive {
382 return predicate(&actual.to_ascii_lowercase(), &expected.to_ascii_lowercase());
383 }
384 predicate(actual, expected)
385}
386
387fn strings_equal(left: &str, right: &str, mode: WhereMode) -> bool {
388 if mode == WhereMode::Insensitive {
389 return left.eq_ignore_ascii_case(right);
390 }
391 left == right
392}
393
394fn compare_strings(left: &str, right: &str, mode: WhereMode) -> Ordering {
395 if mode == WhereMode::Insensitive {
396 return left.to_ascii_lowercase().cmp(&right.to_ascii_lowercase());
397 }
398 left.cmp(right)
399}
400
401fn contains_string(value: &str, pattern: &str) -> bool {
402 value.contains(pattern)
403}
404
405fn starts_with_string(value: &str, pattern: &str) -> bool {
406 value.starts_with(pattern)
407}
408
409fn ends_with_string(value: &str, pattern: &str) -> bool {
410 value.ends_with(pattern)
411}