Skip to main content

teaql_runtime/
memory.rs

1use std::cmp::Ordering;
2use std::collections::BTreeMap;
3use std::sync::{Arc, Mutex};
4
5use rust_decimal::Decimal;
6use rust_decimal::prelude::ToPrimitive;
7use teaql_core::{
8    Aggregate, AggregateFunction, BinaryOp, DeleteCommand, Entity, Expr, ExprFunction,
9    InsertCommand, Record, RecoverCommand, RelationAggregate, SelectQuery, SmartList, SortDirection,
10    UpdateCommand, Value,
11};
12
13use crate::{InMemoryMetadataStore, MetadataStore, RepositoryError, RuntimeError};
14
15#[derive(Debug)]
16pub enum MemoryRepositoryError {
17    Poisoned,
18    UnsupportedExpression(String),
19    UnsupportedAggregate(String),
20}
21
22impl std::fmt::Display for MemoryRepositoryError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            Self::Poisoned => write!(f, "memory repository lock poisoned"),
26            Self::UnsupportedExpression(message) => {
27                write!(f, "unsupported memory expression: {message}")
28            }
29            Self::UnsupportedAggregate(message) => {
30                write!(f, "unsupported memory aggregate: {message}")
31            }
32        }
33    }
34}
35
36impl std::error::Error for MemoryRepositoryError {}
37
38#[derive(Debug, Clone)]
39pub struct MemoryRepository<M = InMemoryMetadataStore> {
40    metadata: M,
41    data: Arc<Mutex<BTreeMap<String, Vec<Record>>>>,
42}
43
44impl<M> MemoryRepository<M>
45where
46    M: MetadataStore,
47{
48    pub fn new(metadata: M) -> Self {
49        Self {
50            metadata,
51            data: Arc::new(Mutex::new(BTreeMap::new())),
52        }
53    }
54
55    pub fn with_rows(mut self, entity: impl Into<String>, rows: Vec<Record>) -> Self {
56        self.seed(entity, rows);
57        self
58    }
59
60    pub fn seed(&mut self, entity: impl Into<String>, rows: Vec<Record>) {
61        if let Ok(mut data) = self.data.lock() {
62            data.insert(entity.into(), rows);
63        }
64    }
65
66    pub fn fetch_all(
67        &self,
68        query: &SelectQuery,
69    ) -> Result<Vec<Record>, RepositoryError<MemoryRepositoryError>> {
70        self.require_entity(&query.entity)?;
71        let data = self
72            .data
73            .lock()
74            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
75        let mut rows = data.get(&query.entity).cloned().unwrap_or_default();
76        drop(data);
77
78        if let Some(filter) = &query.filter {
79            rows = rows
80                .into_iter()
81                .filter_map(|row| match eval_filter(filter, &row) {
82                    Ok(true) => Some(Ok(row)),
83                    Ok(false) => None,
84                    Err(err) => Some(Err(err)),
85                })
86                .collect::<Result<Vec<_>, _>>()
87                .map_err(RepositoryError::Executor)?;
88        }
89
90        if !query.aggregates.is_empty() {
91            return aggregate_rows(query, &rows).map_err(RepositoryError::Executor);
92        }
93
94        apply_ordering(&mut rows, query);
95        rows = apply_slice(rows, query);
96        if !query.projection.is_empty() || !query.expr_projection.is_empty() {
97            rows = rows
98                .into_iter()
99                .map(|row| project_row(row, query))
100                .collect::<Result<Vec<_>, _>>()
101                .map_err(RepositoryError::Executor)?;
102        }
103        Ok(rows)
104    }
105
106    pub fn fetch_smart_list(
107        &self,
108        query: &SelectQuery,
109    ) -> Result<SmartList<Record>, RepositoryError<MemoryRepositoryError>> {
110        self.fetch_all(query).map(SmartList::from)
111    }
112
113    pub fn fetch_entities<T>(
114        &self,
115        query: &SelectQuery,
116    ) -> Result<SmartList<T>, RepositoryError<MemoryRepositoryError>>
117    where
118        T: Entity,
119    {
120        self.fetch_all(query)?
121            .into_iter()
122            .map(T::from_record)
123            .collect::<Result<Vec<_>, _>>()
124            .map(SmartList::from)
125            .map_err(RepositoryError::Entity)
126    }
127
128    pub fn fetch_all_with_relation_aggregates(
129        &self,
130        query: &SelectQuery,
131        relation_aggregates: &[RelationAggregate],
132    ) -> Result<Vec<Record>, RepositoryError<MemoryRepositoryError>> {
133        let mut rows = self.fetch_all(query)?;
134        self.enhance_relation_aggregates(&query.entity, &mut rows, relation_aggregates)?;
135        Ok(rows)
136    }
137
138    pub fn fetch_smart_list_with_relation_aggregates(
139        &self,
140        query: &SelectQuery,
141        relation_aggregates: &[RelationAggregate],
142    ) -> Result<SmartList<Record>, RepositoryError<MemoryRepositoryError>> {
143        self.fetch_all_with_relation_aggregates(query, relation_aggregates)
144            .map(SmartList::from)
145    }
146
147    pub fn fetch_entities_with_relation_aggregates<T>(
148        &self,
149        query: &SelectQuery,
150        relation_aggregates: &[RelationAggregate],
151    ) -> Result<SmartList<T>, RepositoryError<MemoryRepositoryError>>
152    where
153        T: Entity,
154    {
155        self.fetch_all_with_relation_aggregates(query, relation_aggregates)?
156            .into_iter()
157            .map(T::from_record)
158            .collect::<Result<Vec<_>, _>>()
159            .map(SmartList::from)
160            .map_err(RepositoryError::Entity)
161    }
162
163    pub fn enhance_relation_aggregates(
164        &self,
165        parent_entity: &str,
166        parent_rows: &mut [Record],
167        relation_aggregates: &[RelationAggregate],
168    ) -> Result<(), RepositoryError<MemoryRepositoryError>> {
169        for aggregate in relation_aggregates {
170            self.enhance_relation_aggregate(parent_entity, parent_rows, aggregate)?;
171        }
172        Ok(())
173    }
174
175    fn enhance_relation_aggregate(
176        &self,
177        parent_entity: &str,
178        parent_rows: &mut [Record],
179        aggregate: &RelationAggregate,
180    ) -> Result<(), RepositoryError<MemoryRepositoryError>> {
181        let descriptor = self
182            .metadata
183            .entity(parent_entity)
184            .ok_or_else(|| {
185                RepositoryError::Runtime(RuntimeError::MissingEntity(parent_entity.to_owned()))
186            })?;
187
188        let relation = descriptor
189            .relation_by_name(&aggregate.relation_name)
190            .ok_or_else(|| {
191                RepositoryError::Runtime(RuntimeError::MissingRelation {
192                    entity: parent_entity.to_owned(),
193                    relation: aggregate.relation_name.clone(),
194                })
195            })?;
196
197        let ids = parent_rows
198            .iter()
199            .filter_map(|row| row.get(&relation.local_key).cloned())
200            .collect::<Vec<_>>();
201
202        if ids.is_empty() {
203            let value = if aggregate.single_result {
204                Value::U64(0)
205            } else {
206                Value::List(Vec::new())
207            };
208            for parent in parent_rows.iter_mut() {
209                parent.insert(aggregate.alias.clone(), value.clone());
210            }
211            return Ok(());
212        }
213
214        let mut query = aggregate.query.clone();
215        query.entity = relation.target_entity.clone();
216        query.projection.clear();
217        query.expr_projection.clear();
218        query.order_by.clear();
219        query.slice = None;
220        query.relations.clear();
221        if query.aggregates.is_empty() {
222            let alias = if aggregate.single_result {
223                aggregate.alias.clone()
224            } else {
225                "count".to_owned()
226            };
227            query = query.aggregate(Aggregate::count(alias));
228        }
229        if !query
230            .group_by
231            .iter()
232            .any(|field| field == &relation.foreign_key)
233        {
234            query = query.group_by(relation.foreign_key.clone());
235        }
236        query = query.and_filter(Expr::in_list(relation.foreign_key.clone(), ids));
237
238        let aggregate_rows = self.fetch_all(&query)?;
239
240        let mut buckets: BTreeMap<String, Vec<Record>> = BTreeMap::new();
241        for mut row in aggregate_rows {
242            if let Some(key) = row.remove(&relation.foreign_key) {
243                let bucket_key = local_graph_identity_key(&key);
244                buckets
245                    .entry(bucket_key)
246                    .or_default()
247                    .push(row);
248            }
249        }
250
251        for parent in parent_rows {
252            let value = parent
253                .get(&relation.local_key)
254                .and_then(|local_value| buckets.get(&local_graph_identity_key(local_value)))
255                .map(|rows| {
256                    if aggregate.single_result {
257                        rows.first()
258                            .map(|row| {
259                                if row.len() == 1 {
260                                    row.values().next().cloned().unwrap_or(Value::Null)
261                                } else {
262                                    Value::object(row.clone())
263                                }
264                            })
265                            .unwrap_or(Value::U64(0))
266                    } else {
267                        Value::List(rows.iter().cloned().map(Value::object).collect())
268                    }
269                })
270                .unwrap_or_else(|| {
271                    if aggregate.single_result {
272                        Value::U64(0)
273                    } else {
274                        Value::List(Vec::new())
275                    }
276                });
277            parent.insert(aggregate.alias.clone(), value);
278        }
279
280        Ok(())
281    }
282
283    pub fn insert(
284        &self,
285        command: &InsertCommand,
286    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
287        self.require_entity(&command.entity)?;
288        let mut data = self
289            .data
290            .lock()
291            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
292        data.entry(command.entity.clone())
293            .or_default()
294            .push(command.values.clone());
295        Ok(1)
296    }
297
298    pub fn update(
299        &self,
300        command: &UpdateCommand,
301    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
302        let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
303        let mut data = self
304            .data
305            .lock()
306            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
307        let rows = data.entry(command.entity.clone()).or_default();
308        let Some(row) = rows
309            .iter_mut()
310            .find(|row| row.get(id_property) == Some(&command.id))
311        else {
312            return self.maybe_optimistic_conflict(
313                command.expected_version,
314                &command.entity,
315                &command.id,
316            );
317        };
318
319        if let Some(expected) = command.expected_version {
320            if row.get(version_property) != Some(&Value::I64(expected)) {
321                println!("OptimisticLockConflict in memory.rs update! entity={}, id={:?}, expected={}, existing={:?}", command.entity, command.id, expected, row.get(version_property));
322                return Err(RepositoryError::Runtime(
323                    RuntimeError::OptimisticLockConflict {
324                        entity: command.entity.clone(),
325                        id: format!("{:?}", command.id),
326                    },
327                ));
328            }
329            row.insert(
330                version_property.to_owned(),
331                Value::I64(expected + 1),
332            );
333        }
334
335        for (key, value) in &command.values {
336            row.insert(key.clone(), value.clone());
337        }
338        Ok(1)
339    }
340
341    pub fn delete(
342        &self,
343        command: &DeleteCommand,
344    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
345        let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
346        let mut data = self
347            .data
348            .lock()
349            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
350        let rows = data.entry(command.entity.clone()).or_default();
351        let Some(index) = rows
352            .iter()
353            .position(|row| row.get(id_property) == Some(&command.id))
354        else {
355            return self.maybe_optimistic_conflict(
356                command.expected_version,
357                &command.entity,
358                &command.id,
359            );
360        };
361
362        if let Some(expected_version) = command.expected_version {
363            if rows[index].get(version_property) != Some(&Value::I64(expected_version)) {
364                return Err(RepositoryError::Runtime(
365                    RuntimeError::OptimisticLockConflict {
366                        entity: command.entity.clone(),
367                        id: format!("{:?}", command.id),
368                    },
369                ));
370            }
371        }
372
373        if command.soft_delete {
374            let next_version = command
375                .expected_version
376                .or_else(|| read_i64(rows[index].get(version_property)))
377                .map(|version| -(version.abs() + 1))
378                .unwrap_or(-1);
379            rows[index].insert(version_property.to_owned(), Value::I64(next_version));
380        } else {
381            rows.remove(index);
382        }
383        Ok(1)
384    }
385
386    pub fn recover(
387        &self,
388        command: &RecoverCommand,
389    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
390        let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
391        let mut data = self
392            .data
393            .lock()
394            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
395        let rows = data.entry(command.entity.clone()).or_default();
396        let Some(row) = rows
397            .iter_mut()
398            .find(|row| row.get(id_property) == Some(&command.id))
399        else {
400            return Err(RepositoryError::Runtime(
401                RuntimeError::OptimisticLockConflict {
402                    entity: command.entity.clone(),
403                    id: format!("{:?}", command.id),
404                },
405            ));
406        };
407
408        if row.get(version_property) != Some(&Value::I64(command.expected_version)) {
409            return Err(RepositoryError::Runtime(
410                RuntimeError::OptimisticLockConflict {
411                    entity: command.entity.clone(),
412                    id: format!("{:?}", command.id),
413                },
414            ));
415        }
416
417        row.insert(
418            version_property.to_owned(),
419            Value::I64(command.expected_version.abs() + 1),
420        );
421        Ok(1)
422    }
423
424    fn require_entity(&self, entity: &str) -> Result<(), RepositoryError<MemoryRepositoryError>> {
425        self.metadata
426            .entity(entity)
427            .map(|_| ())
428            .ok_or_else(|| RepositoryError::Runtime(RuntimeError::MissingEntity(entity.to_owned())))
429    }
430
431    fn id_and_version_properties(
432        &self,
433        entity: &str,
434    ) -> Result<(&str, &str), RepositoryError<MemoryRepositoryError>> {
435        let descriptor = self.metadata.entity(entity).ok_or_else(|| {
436            RepositoryError::Runtime(RuntimeError::MissingEntity(entity.to_owned()))
437        })?;
438        let id = descriptor
439            .id_property()
440            .map(|property| property.name.as_str())
441            .unwrap_or("id");
442        let version = descriptor
443            .version_property()
444            .map(|property| property.name.as_str())
445            .unwrap_or("version");
446        Ok((id, version))
447    }
448
449    fn maybe_optimistic_conflict(
450        &self,
451        expected_version: Option<i64>,
452        entity: &str,
453        id: &Value,
454    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
455        if expected_version.is_some() {
456            Err(RepositoryError::Runtime(
457                RuntimeError::OptimisticLockConflict {
458                    entity: entity.to_owned(),
459                    id: format!("{id:?}"),
460                },
461            ))
462        } else {
463            Ok(0)
464        }
465    }
466}
467
468fn eval_filter(expr: &Expr, row: &Record) -> Result<bool, MemoryRepositoryError> {
469    match expr {
470        Expr::Column(_) | Expr::Value(_) | Expr::Function { .. } => {
471            value_truthy(&eval_value(expr, row)?)
472        }
473        Expr::Binary { left, op, right } => {
474            let left = eval_value(left, row)?;
475            let right = eval_value(right, row)?;
476            eval_binary(&left, *op, &right)
477        }
478        Expr::SubQuery { .. } => Err(MemoryRepositoryError::UnsupportedExpression(
479            "subquery filters require a SQL executor".to_owned(),
480        )),
481        Expr::Between { expr, lower, upper } => {
482            let value = eval_value(expr, row)?;
483            let lower = eval_value(lower, row)?;
484            let upper = eval_value(upper, row)?;
485            Ok(compare_values(&value, &lower) != Some(Ordering::Less)
486                && compare_values(&value, &upper) != Some(Ordering::Greater))
487        }
488        Expr::IsNull(expr) => Ok(matches!(eval_value(expr, row)?, Value::Null)),
489        Expr::IsNotNull(expr) => Ok(!matches!(eval_value(expr, row)?, Value::Null)),
490        Expr::And(parts) => {
491            for part in parts {
492                if !eval_filter(part, row)? {
493                    return Ok(false);
494                }
495            }
496            Ok(true)
497        }
498        Expr::Or(parts) => {
499            for part in parts {
500                if eval_filter(part, row)? {
501                    return Ok(true);
502                }
503            }
504            Ok(false)
505        }
506        Expr::Not(expr) => Ok(!eval_filter(expr, row)?),
507    }
508}
509
510fn eval_value(expr: &Expr, row: &Record) -> Result<Value, MemoryRepositoryError> {
511    match expr {
512        Expr::Column(column) => Ok(row.get(column).cloned().unwrap_or(Value::Null)),
513        Expr::Value(value) => Ok(value.clone()),
514        Expr::Function { function, args } => eval_function(*function, args, row),
515        other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
516            "cannot evaluate {other:?} as a scalar value"
517        ))),
518    }
519}
520
521fn eval_function(
522    function: ExprFunction,
523    args: &[Expr],
524    row: &Record,
525) -> Result<Value, MemoryRepositoryError> {
526    match function {
527        ExprFunction::Soundex => {
528            let [arg] = args else {
529                return Err(MemoryRepositoryError::UnsupportedExpression(
530                    "SOUNDEX expects exactly one argument".to_owned(),
531                ));
532            };
533            match eval_value(arg, row)? {
534                Value::Text(value) => Ok(Value::Text(soundex(&value))),
535                Value::Null => Ok(Value::Null),
536                other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
537                    "SOUNDEX expects text, got {other:?}"
538                ))),
539            }
540        }
541        ExprFunction::Gbk => {
542            let [arg] = args else {
543                return Err(MemoryRepositoryError::UnsupportedExpression(
544                    "GBK expects exactly one argument".to_owned(),
545                ));
546            };
547            eval_value(arg, row)
548        }
549        other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
550            "function {other:?} is only supported by SQL execution"
551        ))),
552    }
553}
554
555fn eval_binary(left: &Value, op: BinaryOp, right: &Value) -> Result<bool, MemoryRepositoryError> {
556    match op {
557        BinaryOp::Eq => Ok(values_equal(left, right)),
558        BinaryOp::Ne => Ok(!values_equal(left, right)),
559        BinaryOp::Gt => Ok(compare_values(left, right) == Some(Ordering::Greater)),
560        BinaryOp::Gte => Ok(matches!(
561            compare_values(left, right),
562            Some(Ordering::Greater | Ordering::Equal)
563        )),
564        BinaryOp::Lt => Ok(compare_values(left, right) == Some(Ordering::Less)),
565        BinaryOp::Lte => Ok(matches!(
566            compare_values(left, right),
567            Some(Ordering::Less | Ordering::Equal)
568        )),
569        BinaryOp::Like => match (left, right) {
570            (Value::Text(value), Value::Text(pattern)) => Ok(like_matches(value, pattern)),
571            _ => Ok(false),
572        },
573        BinaryOp::NotLike => match (left, right) {
574            (Value::Text(value), Value::Text(pattern)) => Ok(!like_matches(value, pattern)),
575            _ => Ok(true),
576        },
577        BinaryOp::In | BinaryOp::InLarge => match right {
578            Value::List(values) => Ok(values.iter().any(|value| values_equal(left, value))),
579            _ => Err(MemoryRepositoryError::UnsupportedExpression(
580                "IN expects a list value".to_owned(),
581            )),
582        },
583        BinaryOp::NotIn | BinaryOp::NotInLarge => match right {
584            Value::List(values) => Ok(!values.iter().any(|value| values_equal(left, value))),
585            _ => Err(MemoryRepositoryError::UnsupportedExpression(
586                "NOT IN expects a list value".to_owned(),
587            )),
588        },
589    }
590}
591
592fn value_truthy(value: &Value) -> Result<bool, MemoryRepositoryError> {
593    match value {
594        Value::Bool(value) => Ok(*value),
595        Value::Null => Ok(false),
596        other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
597            "non-boolean expression result: {other:?}"
598        ))),
599    }
600}
601
602fn values_equal(left: &Value, right: &Value) -> bool {
603    match (left, right) {
604        (Value::I64(left), Value::U64(right)) if *left >= 0 => *left as u64 == *right,
605        (Value::U64(left), Value::I64(right)) if *right >= 0 => *left == *right as u64,
606        _ => left == right,
607    }
608}
609
610fn compare_values(left: &Value, right: &Value) -> Option<Ordering> {
611    match (left, right) {
612        (Value::I64(left), Value::I64(right)) => Some(left.cmp(right)),
613        (Value::U64(left), Value::U64(right)) => Some(left.cmp(right)),
614        (Value::I64(left), Value::U64(right)) if *left >= 0 => Some((*left as u64).cmp(right)),
615        (Value::U64(left), Value::I64(right)) if *right >= 0 => Some(left.cmp(&(*right as u64))),
616        (Value::F64(left), Value::F64(right)) => left.partial_cmp(right),
617        (Value::Decimal(left), Value::Decimal(right)) => Some(left.cmp(right)),
618        (Value::Text(left), Value::Text(right)) => Some(left.cmp(right)),
619        (Value::Date(left), Value::Date(right)) => Some(left.cmp(right)),
620        (Value::Timestamp(left), Value::Timestamp(right)) => Some(left.cmp(right)),
621        _ => None,
622    }
623}
624
625fn like_matches(value: &str, pattern: &str) -> bool {
626    if pattern == "%" {
627        return true;
628    }
629    match (pattern.strip_prefix('%'), pattern.strip_suffix('%')) {
630        (Some(inner), Some(_)) if pattern.len() >= 2 => value.contains(&inner[..inner.len() - 1]),
631        (Some(suffix), None) => value.ends_with(suffix),
632        (None, Some(prefix)) => value.starts_with(prefix),
633        _ => value == pattern,
634    }
635}
636
637fn soundex(value: &str) -> String {
638    let mut letters = value
639        .chars()
640        .filter(|ch| ch.is_ascii_alphabetic())
641        .map(|ch| ch.to_ascii_uppercase());
642    let Some(first) = letters.next() else {
643        return "0000".to_owned();
644    };
645
646    let mut output = String::with_capacity(4);
647    output.push(first);
648    let mut previous_code = soundex_code(first);
649
650    for ch in letters {
651        let code = soundex_code(ch);
652        if code != '0' && code != previous_code {
653            output.push(code);
654            if output.len() == 4 {
655                return output;
656            }
657        }
658        previous_code = code;
659    }
660
661    while output.len() < 4 {
662        output.push('0');
663    }
664    output
665}
666
667fn soundex_code(ch: char) -> char {
668    match ch {
669        'B' | 'F' | 'P' | 'V' => '1',
670        'C' | 'G' | 'J' | 'K' | 'Q' | 'S' | 'X' | 'Z' => '2',
671        'D' | 'T' => '3',
672        'L' => '4',
673        'M' | 'N' => '5',
674        'R' => '6',
675        _ => '0',
676    }
677}
678
679fn apply_ordering(rows: &mut [Record], query: &SelectQuery) {
680    for order in query.order_by.iter().rev() {
681        rows.sort_by(|left, right| {
682            let left_value = if let Some(expr) = &order.expr {
683                eval_value(expr, left).ok()
684            } else {
685                left.get(&order.field).cloned()
686            };
687            let right_value = if let Some(expr) = &order.expr {
688                eval_value(expr, right).ok()
689            } else {
690                right.get(&order.field).cloned()
691            };
692            let ordering = match (left_value.as_ref(), right_value.as_ref()) {
693                (Some(left), Some(right)) => compare_values(left, right).unwrap_or(Ordering::Equal),
694                (None, Some(_)) => Ordering::Less,
695                (Some(_), None) => Ordering::Greater,
696                (None, None) => Ordering::Equal,
697            };
698            match order.direction {
699                SortDirection::Asc => ordering,
700                SortDirection::Desc => ordering.reverse(),
701            }
702        });
703    }
704}
705
706fn apply_slice(rows: Vec<Record>, query: &SelectQuery) -> Vec<Record> {
707    let Some(slice) = query.slice else {
708        return rows;
709    };
710    let offset = usize::try_from(slice.offset).unwrap_or(usize::MAX);
711    let limit = slice
712        .limit
713        .and_then(|limit| usize::try_from(limit).ok())
714        .unwrap_or(usize::MAX);
715    rows.into_iter().skip(offset).take(limit).collect()
716}
717
718fn project_row(row: Record, query: &SelectQuery) -> Result<Record, MemoryRepositoryError> {
719    let mut output: Record = query
720        .projection
721        .iter()
722        .filter_map(|field| row.get(field).cloned().map(|value| (field.clone(), value)))
723        .collect();
724    for projection in &query.expr_projection {
725        output.insert(
726            projection.alias.clone(),
727            eval_value(&projection.expr, &row)?,
728        );
729    }
730    Ok(output)
731}
732
733fn aggregate_rows(
734    query: &SelectQuery,
735    rows: &[Record],
736) -> Result<Vec<Record>, MemoryRepositoryError> {
737    let mut groups: BTreeMap<Vec<String>, Vec<&Record>> = BTreeMap::new();
738    if query.group_by.is_empty() {
739        groups.insert(Vec::new(), rows.iter().collect());
740    } else {
741        for row in rows {
742            let key = query
743                .group_by
744                .iter()
745                .map(|field| row.get(field).map(value_key).unwrap_or_default())
746                .collect::<Vec<_>>();
747            groups.entry(key).or_default().push(row);
748        }
749    }
750
751    let rows = groups
752        .into_values()
753        .map(|rows| {
754            let mut output = Record::new();
755            if let Some(first) = rows.first() {
756                for field in &query.group_by {
757                    if let Some(value) = first.get(field) {
758                        output.insert(field.clone(), value.clone());
759                    }
760                }
761            }
762            for aggregate in &query.aggregates {
763                let value = match aggregate.function {
764                    AggregateFunction::Count => {
765                        if aggregate.field == "*" {
766                            Value::U64(rows.len() as u64)
767                        } else {
768                            Value::U64(
769                                rows.iter()
770                                    .filter(|row| {
771                                        !matches!(
772                                            row.get(&aggregate.field),
773                                            None | Some(Value::Null)
774                                        )
775                                    })
776                                    .count() as u64,
777                            )
778                        }
779                    }
780                    AggregateFunction::Sum => numeric_sum(&rows, &aggregate.field)?,
781                    AggregateFunction::Avg => numeric_avg(&rows, &aggregate.field)?,
782                    AggregateFunction::Min => min_max(&rows, &aggregate.field, false)?,
783                    AggregateFunction::Max => min_max(&rows, &aggregate.field, true)?,
784                    AggregateFunction::Stddev => numeric_stddev(&rows, &aggregate.field, true)?,
785                    AggregateFunction::StddevPop => numeric_stddev(&rows, &aggregate.field, false)?,
786                    AggregateFunction::VarSamp => numeric_variance(&rows, &aggregate.field, true)?,
787                    AggregateFunction::VarPop => numeric_variance(&rows, &aggregate.field, false)?,
788                    AggregateFunction::BitAnd => {
789                        bit_aggregate(&rows, &aggregate.field, BitOp::And)?
790                    }
791                    AggregateFunction::BitOr => bit_aggregate(&rows, &aggregate.field, BitOp::Or)?,
792                    AggregateFunction::BitXor => {
793                        bit_aggregate(&rows, &aggregate.field, BitOp::Xor)?
794                    }
795                };
796                output.insert(aggregate.alias.clone(), value);
797            }
798            for projection in &query.expr_projection {
799                output.insert(
800                    projection.alias.clone(),
801                    eval_value(&projection.expr, &output)?,
802                );
803            }
804            Ok(output)
805        })
806        .collect::<Result<Vec<_>, _>>()?;
807    if let Some(having) = &query.having {
808        rows.into_iter()
809            .filter_map(|row| match eval_filter(having, &row) {
810                Ok(true) => Some(Ok(row)),
811                Ok(false) => None,
812                Err(err) => Some(Err(err)),
813            })
814            .collect()
815    } else {
816        Ok(rows)
817    }
818}
819
820fn numeric_sum(rows: &[&Record], field: &str) -> Result<Value, MemoryRepositoryError> {
821    let mut decimal_sum = Decimal::ZERO;
822    let mut integer_sum: i128 = 0;
823    let mut saw_decimal = false;
824    for value in rows.iter().filter_map(|row| row.get(field)) {
825        match value {
826            Value::I64(value) => {
827                integer_sum += i128::from(*value);
828                decimal_sum += Decimal::from(*value);
829            }
830            Value::U64(value) => {
831                integer_sum += i128::from(*value);
832                decimal_sum += Decimal::from(*value);
833            }
834            Value::F64(value) => {
835                saw_decimal = true;
836                decimal_sum += decimal_from_f64(*value);
837            }
838            Value::Decimal(value) => {
839                saw_decimal = true;
840                decimal_sum += *value;
841            }
842            Value::Null => {}
843            other => {
844                return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
845                    "SUM does not support {other:?}"
846                )));
847            }
848        }
849    }
850    if saw_decimal {
851        Ok(Value::Decimal(decimal_sum))
852    } else if integer_sum >= 0 {
853        Ok(Value::U64(integer_sum as u64))
854    } else {
855        Ok(Value::I64(integer_sum as i64))
856    }
857}
858
859fn numeric_avg(rows: &[&Record], field: &str) -> Result<Value, MemoryRepositoryError> {
860    let mut sum = Decimal::ZERO;
861    let mut count: u64 = 0;
862    for value in rows.iter().filter_map(|row| row.get(field)) {
863        match value {
864            Value::I64(value) => {
865                sum += Decimal::from(*value);
866                count += 1;
867            }
868            Value::U64(value) => {
869                sum += Decimal::from(*value);
870                count += 1;
871            }
872            Value::F64(value) => {
873                sum += decimal_from_f64(*value);
874                count += 1;
875            }
876            Value::Decimal(value) => {
877                sum += *value;
878                count += 1;
879            }
880            Value::Null => {}
881            other => {
882                return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
883                    "AVG does not support {other:?}"
884                )));
885            }
886        }
887    }
888    Ok(if count == 0 {
889        Value::Null
890    } else {
891        Value::Decimal(sum / Decimal::from(count))
892    })
893}
894
895fn decimal_from_f64(value: f64) -> Decimal {
896    Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
897}
898
899fn numeric_values(rows: &[&Record], field: &str) -> Result<Vec<f64>, MemoryRepositoryError> {
900    rows.iter()
901        .filter_map(|row| row.get(field))
902        .filter(|value| !matches!(value, Value::Null))
903        .map(|value| match value {
904            Value::I64(value) => Ok(*value as f64),
905            Value::U64(value) => Ok(*value as f64),
906            Value::F64(value) => Ok(*value),
907            Value::Decimal(value) => value.to_f64().ok_or_else(|| {
908                MemoryRepositoryError::UnsupportedAggregate(format!(
909                    "cannot convert decimal {value} to f64 for statistical aggregate"
910                ))
911            }),
912            other => Err(MemoryRepositoryError::UnsupportedAggregate(format!(
913                "numeric aggregate does not support {other:?}"
914            ))),
915        })
916        .collect()
917}
918
919fn numeric_variance(
920    rows: &[&Record],
921    field: &str,
922    sample: bool,
923) -> Result<Value, MemoryRepositoryError> {
924    let values = numeric_values(rows, field)?;
925    let count = values.len();
926    if count == 0 || (sample && count < 2) {
927        return Ok(Value::Null);
928    }
929    let mean = values.iter().sum::<f64>() / count as f64;
930    let sum = values
931        .iter()
932        .map(|value| {
933            let diff = value - mean;
934            diff * diff
935        })
936        .sum::<f64>();
937    let denominator = if sample { count - 1 } else { count } as f64;
938    Ok(Value::Decimal(decimal_from_f64(sum / denominator)))
939}
940
941fn numeric_stddev(
942    rows: &[&Record],
943    field: &str,
944    sample: bool,
945) -> Result<Value, MemoryRepositoryError> {
946    Ok(match numeric_variance(rows, field, sample)? {
947        Value::Decimal(value) => {
948            Value::Decimal(decimal_from_f64(value.to_f64().unwrap_or(0.0).sqrt()))
949        }
950        Value::Null => Value::Null,
951        other => other,
952    })
953}
954
955#[derive(Debug, Clone, Copy)]
956enum BitOp {
957    And,
958    Or,
959    Xor,
960}
961
962fn bit_aggregate(rows: &[&Record], field: &str, op: BitOp) -> Result<Value, MemoryRepositoryError> {
963    let mut selected: Option<i64> = None;
964    for value in rows.iter().filter_map(|row| row.get(field)) {
965        let value = match value {
966            Value::I64(value) => *value,
967            Value::U64(value) => i64::try_from(*value).map_err(|_| {
968                MemoryRepositoryError::UnsupportedAggregate(format!(
969                    "BIT aggregate u64 {value} exceeds i64 range"
970                ))
971            })?,
972            Value::Null => continue,
973            other => {
974                return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
975                    "BIT aggregate does not support {other:?}"
976                )));
977            }
978        };
979        selected = Some(match (selected, op) {
980            (None, _) => value,
981            (Some(current), BitOp::And) => current & value,
982            (Some(current), BitOp::Or) => current | value,
983            (Some(current), BitOp::Xor) => current ^ value,
984        });
985    }
986    Ok(selected.map(Value::I64).unwrap_or(Value::Null))
987}
988
989fn min_max(rows: &[&Record], field: &str, max: bool) -> Result<Value, MemoryRepositoryError> {
990    let mut selected: Option<Value> = None;
991    for value in rows.iter().filter_map(|row| row.get(field)) {
992        if matches!(value, Value::Null) {
993            continue;
994        }
995        match &selected {
996            None => selected = Some(value.clone()),
997            Some(current) => {
998                let Some(ordering) = compare_values(value, current) else {
999                    return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
1000                        "MIN/MAX does not support {value:?}"
1001                    )));
1002                };
1003                if (max && ordering == Ordering::Greater) || (!max && ordering == Ordering::Less) {
1004                    selected = Some(value.clone());
1005                }
1006            }
1007        }
1008    }
1009    Ok(selected.unwrap_or(Value::Null))
1010}
1011
1012fn value_key(value: &Value) -> String {
1013    match value {
1014        Value::Null => "null".to_owned(),
1015        Value::Bool(value) => format!("b:{value}"),
1016        Value::I64(value) => format!("i:{value}"),
1017        Value::U64(value) => format!("u:{value}"),
1018        Value::F64(value) => format!("f:{value}"),
1019        Value::Decimal(value) => format!("d:{value}"),
1020        Value::Text(value) => format!("t:{value}"),
1021        Value::Json(value) => format!("j:{value}"),
1022        Value::Date(value) => format!("d:{value}"),
1023        Value::Timestamp(value) => format!("ts:{}", value.to_rfc3339()),
1024        Value::Object(_) => "object".to_owned(),
1025        Value::List(_) => "list".to_owned(),
1026    }
1027}
1028
1029fn read_i64(value: Option<&Value>) -> Option<i64> {
1030    match value {
1031        Some(Value::I64(value)) => Some(*value),
1032        _ => None,
1033    }
1034}
1035
1036fn local_graph_identity_key(value: &Value) -> String {
1037    match value {
1038        Value::I64(val) if *val >= 0 => format!("u:{}", *val as u64),
1039        Value::U64(val) => format!("u:{val}"),
1040        Value::Null => "null".to_owned(),
1041        Value::Bool(v) => format!("b:{v}"),
1042        Value::I64(v) => format!("i:{v}"),
1043        Value::F64(v) => format!("f:{v}"),
1044        Value::Decimal(v) => format!("d:{v}"),
1045        Value::Text(v) => format!("t:{v}"),
1046        Value::Json(v) => format!("j:{v}"),
1047        Value::Date(v) => format!("d:{v}"),
1048        Value::Timestamp(v) => format!("ts:{}", v.to_rfc3339()),
1049        Value::Object(_) => "o".to_owned(),
1050        Value::List(_) => "l".to_owned(),
1051    }
1052}
1053