1use std::cmp::Ordering;
2use std::collections::BTreeMap;
3use std::sync::{Arc, Mutex};
4
5use rust_decimal::Decimal;
6use rust_decimal::prelude::ToPrimitive;
7use teaql_core::{
8 Aggregate, AggregateFunction, BinaryOp, DeleteCommand, Entity, Expr, ExprFunction,
9 InsertCommand, Record, RecoverCommand, RelationAggregate, SelectQuery, SmartList, SortDirection,
10 UpdateCommand, Value,
11};
12
13use crate::{InMemoryMetadataStore, MetadataStore, RepositoryError, RuntimeError};
14
15#[derive(Debug)]
16pub enum MemoryRepositoryError {
17 Poisoned,
18 UnsupportedExpression(String),
19 UnsupportedAggregate(String),
20}
21
22impl std::fmt::Display for MemoryRepositoryError {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match self {
25 Self::Poisoned => write!(f, "memory repository lock poisoned"),
26 Self::UnsupportedExpression(message) => {
27 write!(f, "unsupported memory expression: {message}")
28 }
29 Self::UnsupportedAggregate(message) => {
30 write!(f, "unsupported memory aggregate: {message}")
31 }
32 }
33 }
34}
35
36impl std::error::Error for MemoryRepositoryError {}
37
38#[derive(Debug, Clone)]
39pub struct MemoryRepository<M = InMemoryMetadataStore> {
40 metadata: M,
41 data: Arc<Mutex<BTreeMap<String, Vec<Record>>>>,
42}
43
44impl<M> MemoryRepository<M>
45where
46 M: MetadataStore,
47{
48 pub fn new(metadata: M) -> Self {
49 Self {
50 metadata,
51 data: Arc::new(Mutex::new(BTreeMap::new())),
52 }
53 }
54
55 pub fn with_rows(mut self, entity: impl Into<String>, rows: Vec<Record>) -> Self {
56 self.seed(entity, rows);
57 self
58 }
59
60 pub fn seed(&mut self, entity: impl Into<String>, rows: Vec<Record>) {
61 if let Ok(mut data) = self.data.lock() {
62 data.insert(entity.into(), rows);
63 }
64 }
65
66 pub fn fetch_all(
67 &self,
68 query: &SelectQuery,
69 ) -> Result<Vec<Record>, RepositoryError<MemoryRepositoryError>> {
70 self.require_entity(&query.entity)?;
71 let data = self
72 .data
73 .lock()
74 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
75 let mut rows = data.get(&query.entity).cloned().unwrap_or_default();
76 drop(data);
77
78 if let Some(filter) = &query.filter {
79 rows = rows
80 .into_iter()
81 .filter_map(|row| match eval_filter(filter, &row) {
82 Ok(true) => Some(Ok(row)),
83 Ok(false) => None,
84 Err(err) => Some(Err(err)),
85 })
86 .collect::<Result<Vec<_>, _>>()
87 .map_err(RepositoryError::Executor)?;
88 }
89
90 if !query.aggregates.is_empty() {
91 return aggregate_rows(query, &rows).map_err(RepositoryError::Executor);
92 }
93
94 apply_ordering(&mut rows, query);
95 rows = apply_slice(rows, query);
96 if !query.projection.is_empty() || !query.expr_projection.is_empty() {
97 rows = rows
98 .into_iter()
99 .map(|row| project_row(row, query))
100 .collect::<Result<Vec<_>, _>>()
101 .map_err(RepositoryError::Executor)?;
102 }
103 Ok(rows)
104 }
105
106 pub fn fetch_smart_list(
107 &self,
108 query: &SelectQuery,
109 ) -> Result<SmartList<Record>, RepositoryError<MemoryRepositoryError>> {
110 self.fetch_all(query).map(SmartList::from)
111 }
112
113 pub fn fetch_entities<T>(
114 &self,
115 query: &SelectQuery,
116 ) -> Result<SmartList<T>, RepositoryError<MemoryRepositoryError>>
117 where
118 T: Entity,
119 {
120 self.fetch_all(query)?
121 .into_iter()
122 .map(T::from_record)
123 .collect::<Result<Vec<_>, _>>()
124 .map(SmartList::from)
125 .map_err(RepositoryError::Entity)
126 }
127
128 pub fn fetch_all_with_relation_aggregates(
129 &self,
130 query: &SelectQuery,
131 relation_aggregates: &[RelationAggregate],
132 ) -> Result<Vec<Record>, RepositoryError<MemoryRepositoryError>> {
133 let mut rows = self.fetch_all(query)?;
134 self.enhance_relation_aggregates(&query.entity, &mut rows, relation_aggregates)?;
135 Ok(rows)
136 }
137
138 pub fn fetch_smart_list_with_relation_aggregates(
139 &self,
140 query: &SelectQuery,
141 relation_aggregates: &[RelationAggregate],
142 ) -> Result<SmartList<Record>, RepositoryError<MemoryRepositoryError>> {
143 self.fetch_all_with_relation_aggregates(query, relation_aggregates)
144 .map(SmartList::from)
145 }
146
147 pub fn fetch_entities_with_relation_aggregates<T>(
148 &self,
149 query: &SelectQuery,
150 relation_aggregates: &[RelationAggregate],
151 ) -> Result<SmartList<T>, RepositoryError<MemoryRepositoryError>>
152 where
153 T: Entity,
154 {
155 self.fetch_all_with_relation_aggregates(query, relation_aggregates)?
156 .into_iter()
157 .map(T::from_record)
158 .collect::<Result<Vec<_>, _>>()
159 .map(SmartList::from)
160 .map_err(RepositoryError::Entity)
161 }
162
163 pub fn enhance_relation_aggregates(
164 &self,
165 parent_entity: &str,
166 parent_rows: &mut [Record],
167 relation_aggregates: &[RelationAggregate],
168 ) -> Result<(), RepositoryError<MemoryRepositoryError>> {
169 for aggregate in relation_aggregates {
170 self.enhance_relation_aggregate(parent_entity, parent_rows, aggregate)?;
171 }
172 Ok(())
173 }
174
175 fn enhance_relation_aggregate(
176 &self,
177 parent_entity: &str,
178 parent_rows: &mut [Record],
179 aggregate: &RelationAggregate,
180 ) -> Result<(), RepositoryError<MemoryRepositoryError>> {
181 let descriptor = self
182 .metadata
183 .entity(parent_entity)
184 .ok_or_else(|| {
185 RepositoryError::Runtime(RuntimeError::MissingEntity(parent_entity.to_owned()))
186 })?;
187
188 let relation = descriptor
189 .relation_by_name(&aggregate.relation_name)
190 .ok_or_else(|| {
191 RepositoryError::Runtime(RuntimeError::MissingRelation {
192 entity: parent_entity.to_owned(),
193 relation: aggregate.relation_name.clone(),
194 })
195 })?;
196
197 let ids = parent_rows
198 .iter()
199 .filter_map(|row| row.get(&relation.local_key).cloned())
200 .collect::<Vec<_>>();
201
202 if ids.is_empty() {
203 let value = if aggregate.single_result {
204 Value::U64(0)
205 } else {
206 Value::List(Vec::new())
207 };
208 for parent in parent_rows.iter_mut() {
209 parent.insert(aggregate.alias.clone(), value.clone());
210 }
211 return Ok(());
212 }
213
214 let mut query = aggregate.query.clone();
215 query.entity = relation.target_entity.clone();
216 query.projection.clear();
217 query.expr_projection.clear();
218 query.order_by.clear();
219 query.slice = None;
220 query.relations.clear();
221 if query.aggregates.is_empty() {
222 let alias = if aggregate.single_result {
223 aggregate.alias.clone()
224 } else {
225 "count".to_owned()
226 };
227 query = query.aggregate(Aggregate::count(alias));
228 }
229 if !query
230 .group_by
231 .iter()
232 .any(|field| field == &relation.foreign_key)
233 {
234 query = query.group_by(relation.foreign_key.clone());
235 }
236 query = query.and_filter(Expr::in_list(relation.foreign_key.clone(), ids));
237
238 let aggregate_rows = self.fetch_all(&query)?;
239
240 let mut buckets: BTreeMap<String, Vec<Record>> = BTreeMap::new();
241 for mut row in aggregate_rows {
242 if let Some(key) = row.remove(&relation.foreign_key) {
243 let bucket_key = local_graph_identity_key(&key);
244 buckets
245 .entry(bucket_key)
246 .or_default()
247 .push(row);
248 }
249 }
250
251 for parent in parent_rows {
252 let value = parent
253 .get(&relation.local_key)
254 .and_then(|local_value| buckets.get(&local_graph_identity_key(local_value)))
255 .map(|rows| {
256 if aggregate.single_result {
257 rows.first()
258 .map(|row| {
259 if row.len() == 1 {
260 row.values().next().cloned().unwrap_or(Value::Null)
261 } else {
262 Value::object(row.clone())
263 }
264 })
265 .unwrap_or(Value::U64(0))
266 } else {
267 Value::List(rows.iter().cloned().map(Value::object).collect())
268 }
269 })
270 .unwrap_or_else(|| {
271 if aggregate.single_result {
272 Value::U64(0)
273 } else {
274 Value::List(Vec::new())
275 }
276 });
277 parent.insert(aggregate.alias.clone(), value);
278 }
279
280 Ok(())
281 }
282
283 pub fn insert(
284 &self,
285 command: &InsertCommand,
286 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
287 self.require_entity(&command.entity)?;
288 let mut data = self
289 .data
290 .lock()
291 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
292 data.entry(command.entity.clone())
293 .or_default()
294 .push(command.values.clone());
295 Ok(1)
296 }
297
298 pub fn update(
299 &self,
300 command: &UpdateCommand,
301 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
302 let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
303 let mut data = self
304 .data
305 .lock()
306 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
307 let rows = data.entry(command.entity.clone()).or_default();
308 let Some(row) = rows
309 .iter_mut()
310 .find(|row| row.get(id_property) == Some(&command.id))
311 else {
312 return self.maybe_optimistic_conflict(
313 command.expected_version,
314 &command.entity,
315 &command.id,
316 );
317 };
318
319 if let Some(expected) = command.expected_version {
320 if row.get(version_property) != Some(&Value::I64(expected)) {
321 println!("OptimisticLockConflict in memory.rs update! entity={}, id={:?}, expected={}, existing={:?}", command.entity, command.id, expected, row.get(version_property));
322 return Err(RepositoryError::Runtime(
323 RuntimeError::OptimisticLockConflict {
324 entity: command.entity.clone(),
325 id: format!("{:?}", command.id),
326 },
327 ));
328 }
329 row.insert(
330 version_property.to_owned(),
331 Value::I64(expected + 1),
332 );
333 }
334
335 for (key, value) in &command.values {
336 row.insert(key.clone(), value.clone());
337 }
338 Ok(1)
339 }
340
341 pub fn delete(
342 &self,
343 command: &DeleteCommand,
344 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
345 let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
346 let mut data = self
347 .data
348 .lock()
349 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
350 let rows = data.entry(command.entity.clone()).or_default();
351 let Some(index) = rows
352 .iter()
353 .position(|row| row.get(id_property) == Some(&command.id))
354 else {
355 return self.maybe_optimistic_conflict(
356 command.expected_version,
357 &command.entity,
358 &command.id,
359 );
360 };
361
362 if let Some(expected_version) = command.expected_version {
363 if rows[index].get(version_property) != Some(&Value::I64(expected_version)) {
364 return Err(RepositoryError::Runtime(
365 RuntimeError::OptimisticLockConflict {
366 entity: command.entity.clone(),
367 id: format!("{:?}", command.id),
368 },
369 ));
370 }
371 }
372
373 if command.soft_delete {
374 let next_version = command
375 .expected_version
376 .or_else(|| read_i64(rows[index].get(version_property)))
377 .map(|version| -(version.abs() + 1))
378 .unwrap_or(-1);
379 rows[index].insert(version_property.to_owned(), Value::I64(next_version));
380 } else {
381 rows.remove(index);
382 }
383 Ok(1)
384 }
385
386 pub fn recover(
387 &self,
388 command: &RecoverCommand,
389 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
390 let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
391 let mut data = self
392 .data
393 .lock()
394 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
395 let rows = data.entry(command.entity.clone()).or_default();
396 let Some(row) = rows
397 .iter_mut()
398 .find(|row| row.get(id_property) == Some(&command.id))
399 else {
400 return Err(RepositoryError::Runtime(
401 RuntimeError::OptimisticLockConflict {
402 entity: command.entity.clone(),
403 id: format!("{:?}", command.id),
404 },
405 ));
406 };
407
408 if row.get(version_property) != Some(&Value::I64(command.expected_version)) {
409 return Err(RepositoryError::Runtime(
410 RuntimeError::OptimisticLockConflict {
411 entity: command.entity.clone(),
412 id: format!("{:?}", command.id),
413 },
414 ));
415 }
416
417 row.insert(
418 version_property.to_owned(),
419 Value::I64(command.expected_version.abs() + 1),
420 );
421 Ok(1)
422 }
423
424 fn require_entity(&self, entity: &str) -> Result<(), RepositoryError<MemoryRepositoryError>> {
425 self.metadata
426 .entity(entity)
427 .map(|_| ())
428 .ok_or_else(|| RepositoryError::Runtime(RuntimeError::MissingEntity(entity.to_owned())))
429 }
430
431 fn id_and_version_properties(
432 &self,
433 entity: &str,
434 ) -> Result<(&str, &str), RepositoryError<MemoryRepositoryError>> {
435 let descriptor = self.metadata.entity(entity).ok_or_else(|| {
436 RepositoryError::Runtime(RuntimeError::MissingEntity(entity.to_owned()))
437 })?;
438 let id = descriptor
439 .id_property()
440 .map(|property| property.name.as_str())
441 .unwrap_or("id");
442 let version = descriptor
443 .version_property()
444 .map(|property| property.name.as_str())
445 .unwrap_or("version");
446 Ok((id, version))
447 }
448
449 fn maybe_optimistic_conflict(
450 &self,
451 expected_version: Option<i64>,
452 entity: &str,
453 id: &Value,
454 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
455 if expected_version.is_some() {
456 Err(RepositoryError::Runtime(
457 RuntimeError::OptimisticLockConflict {
458 entity: entity.to_owned(),
459 id: format!("{id:?}"),
460 },
461 ))
462 } else {
463 Ok(0)
464 }
465 }
466}
467
468fn eval_filter(expr: &Expr, row: &Record) -> Result<bool, MemoryRepositoryError> {
469 match expr {
470 Expr::Column(_) | Expr::Value(_) | Expr::Function { .. } => {
471 value_truthy(&eval_value(expr, row)?)
472 }
473 Expr::Binary { left, op, right } => {
474 let left = eval_value(left, row)?;
475 let right = eval_value(right, row)?;
476 eval_binary(&left, *op, &right)
477 }
478 Expr::SubQuery { .. } => Err(MemoryRepositoryError::UnsupportedExpression(
479 "subquery filters require a SQL executor".to_owned(),
480 )),
481 Expr::Between { expr, lower, upper } => {
482 let value = eval_value(expr, row)?;
483 let lower = eval_value(lower, row)?;
484 let upper = eval_value(upper, row)?;
485 Ok(compare_values(&value, &lower) != Some(Ordering::Less)
486 && compare_values(&value, &upper) != Some(Ordering::Greater))
487 }
488 Expr::IsNull(expr) => Ok(matches!(eval_value(expr, row)?, Value::Null)),
489 Expr::IsNotNull(expr) => Ok(!matches!(eval_value(expr, row)?, Value::Null)),
490 Expr::And(parts) => {
491 for part in parts {
492 if !eval_filter(part, row)? {
493 return Ok(false);
494 }
495 }
496 Ok(true)
497 }
498 Expr::Or(parts) => {
499 for part in parts {
500 if eval_filter(part, row)? {
501 return Ok(true);
502 }
503 }
504 Ok(false)
505 }
506 Expr::Not(expr) => Ok(!eval_filter(expr, row)?),
507 }
508}
509
510fn eval_value(expr: &Expr, row: &Record) -> Result<Value, MemoryRepositoryError> {
511 match expr {
512 Expr::Column(column) => Ok(row.get(column).cloned().unwrap_or(Value::Null)),
513 Expr::Value(value) => Ok(value.clone()),
514 Expr::Function { function, args } => eval_function(*function, args, row),
515 other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
516 "cannot evaluate {other:?} as a scalar value"
517 ))),
518 }
519}
520
521fn eval_function(
522 function: ExprFunction,
523 args: &[Expr],
524 row: &Record,
525) -> Result<Value, MemoryRepositoryError> {
526 match function {
527 ExprFunction::Soundex => {
528 let [arg] = args else {
529 return Err(MemoryRepositoryError::UnsupportedExpression(
530 "SOUNDEX expects exactly one argument".to_owned(),
531 ));
532 };
533 match eval_value(arg, row)? {
534 Value::Text(value) => Ok(Value::Text(soundex(&value))),
535 Value::Null => Ok(Value::Null),
536 other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
537 "SOUNDEX expects text, got {other:?}"
538 ))),
539 }
540 }
541 ExprFunction::Gbk => {
542 let [arg] = args else {
543 return Err(MemoryRepositoryError::UnsupportedExpression(
544 "GBK expects exactly one argument".to_owned(),
545 ));
546 };
547 eval_value(arg, row)
548 }
549 other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
550 "function {other:?} is only supported by SQL execution"
551 ))),
552 }
553}
554
555fn eval_binary(left: &Value, op: BinaryOp, right: &Value) -> Result<bool, MemoryRepositoryError> {
556 match op {
557 BinaryOp::Eq => Ok(values_equal(left, right)),
558 BinaryOp::Ne => Ok(!values_equal(left, right)),
559 BinaryOp::Gt => Ok(compare_values(left, right) == Some(Ordering::Greater)),
560 BinaryOp::Gte => Ok(matches!(
561 compare_values(left, right),
562 Some(Ordering::Greater | Ordering::Equal)
563 )),
564 BinaryOp::Lt => Ok(compare_values(left, right) == Some(Ordering::Less)),
565 BinaryOp::Lte => Ok(matches!(
566 compare_values(left, right),
567 Some(Ordering::Less | Ordering::Equal)
568 )),
569 BinaryOp::Like => match (left, right) {
570 (Value::Text(value), Value::Text(pattern)) => Ok(like_matches(value, pattern)),
571 _ => Ok(false),
572 },
573 BinaryOp::NotLike => match (left, right) {
574 (Value::Text(value), Value::Text(pattern)) => Ok(!like_matches(value, pattern)),
575 _ => Ok(true),
576 },
577 BinaryOp::In | BinaryOp::InLarge => match right {
578 Value::List(values) => Ok(values.iter().any(|value| values_equal(left, value))),
579 _ => Err(MemoryRepositoryError::UnsupportedExpression(
580 "IN expects a list value".to_owned(),
581 )),
582 },
583 BinaryOp::NotIn | BinaryOp::NotInLarge => match right {
584 Value::List(values) => Ok(!values.iter().any(|value| values_equal(left, value))),
585 _ => Err(MemoryRepositoryError::UnsupportedExpression(
586 "NOT IN expects a list value".to_owned(),
587 )),
588 },
589 }
590}
591
592fn value_truthy(value: &Value) -> Result<bool, MemoryRepositoryError> {
593 match value {
594 Value::Bool(value) => Ok(*value),
595 Value::Null => Ok(false),
596 other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
597 "non-boolean expression result: {other:?}"
598 ))),
599 }
600}
601
602fn values_equal(left: &Value, right: &Value) -> bool {
603 match (left, right) {
604 (Value::I64(left), Value::U64(right)) if *left >= 0 => *left as u64 == *right,
605 (Value::U64(left), Value::I64(right)) if *right >= 0 => *left == *right as u64,
606 _ => left == right,
607 }
608}
609
610fn compare_values(left: &Value, right: &Value) -> Option<Ordering> {
611 match (left, right) {
612 (Value::I64(left), Value::I64(right)) => Some(left.cmp(right)),
613 (Value::U64(left), Value::U64(right)) => Some(left.cmp(right)),
614 (Value::I64(left), Value::U64(right)) if *left >= 0 => Some((*left as u64).cmp(right)),
615 (Value::U64(left), Value::I64(right)) if *right >= 0 => Some(left.cmp(&(*right as u64))),
616 (Value::F64(left), Value::F64(right)) => left.partial_cmp(right),
617 (Value::Decimal(left), Value::Decimal(right)) => Some(left.cmp(right)),
618 (Value::Text(left), Value::Text(right)) => Some(left.cmp(right)),
619 (Value::Date(left), Value::Date(right)) => Some(left.cmp(right)),
620 (Value::Timestamp(left), Value::Timestamp(right)) => Some(left.cmp(right)),
621 _ => None,
622 }
623}
624
625fn like_matches(value: &str, pattern: &str) -> bool {
626 if pattern == "%" {
627 return true;
628 }
629 match (pattern.strip_prefix('%'), pattern.strip_suffix('%')) {
630 (Some(inner), Some(_)) if pattern.len() >= 2 => value.contains(&inner[..inner.len() - 1]),
631 (Some(suffix), None) => value.ends_with(suffix),
632 (None, Some(prefix)) => value.starts_with(prefix),
633 _ => value == pattern,
634 }
635}
636
637fn soundex(value: &str) -> String {
638 let mut letters = value
639 .chars()
640 .filter(|ch| ch.is_ascii_alphabetic())
641 .map(|ch| ch.to_ascii_uppercase());
642 let Some(first) = letters.next() else {
643 return "0000".to_owned();
644 };
645
646 let mut output = String::with_capacity(4);
647 output.push(first);
648 let mut previous_code = soundex_code(first);
649
650 for ch in letters {
651 let code = soundex_code(ch);
652 if code != '0' && code != previous_code {
653 output.push(code);
654 if output.len() == 4 {
655 return output;
656 }
657 }
658 previous_code = code;
659 }
660
661 while output.len() < 4 {
662 output.push('0');
663 }
664 output
665}
666
667fn soundex_code(ch: char) -> char {
668 match ch {
669 'B' | 'F' | 'P' | 'V' => '1',
670 'C' | 'G' | 'J' | 'K' | 'Q' | 'S' | 'X' | 'Z' => '2',
671 'D' | 'T' => '3',
672 'L' => '4',
673 'M' | 'N' => '5',
674 'R' => '6',
675 _ => '0',
676 }
677}
678
679fn apply_ordering(rows: &mut [Record], query: &SelectQuery) {
680 for order in query.order_by.iter().rev() {
681 rows.sort_by(|left, right| {
682 let left_value = if let Some(expr) = &order.expr {
683 eval_value(expr, left).ok()
684 } else {
685 left.get(&order.field).cloned()
686 };
687 let right_value = if let Some(expr) = &order.expr {
688 eval_value(expr, right).ok()
689 } else {
690 right.get(&order.field).cloned()
691 };
692 let ordering = match (left_value.as_ref(), right_value.as_ref()) {
693 (Some(left), Some(right)) => compare_values(left, right).unwrap_or(Ordering::Equal),
694 (None, Some(_)) => Ordering::Less,
695 (Some(_), None) => Ordering::Greater,
696 (None, None) => Ordering::Equal,
697 };
698 match order.direction {
699 SortDirection::Asc => ordering,
700 SortDirection::Desc => ordering.reverse(),
701 }
702 });
703 }
704}
705
706fn apply_slice(rows: Vec<Record>, query: &SelectQuery) -> Vec<Record> {
707 let Some(slice) = query.slice else {
708 return rows;
709 };
710 let offset = usize::try_from(slice.offset).unwrap_or(usize::MAX);
711 let limit = slice
712 .limit
713 .and_then(|limit| usize::try_from(limit).ok())
714 .unwrap_or(usize::MAX);
715 rows.into_iter().skip(offset).take(limit).collect()
716}
717
718fn project_row(row: Record, query: &SelectQuery) -> Result<Record, MemoryRepositoryError> {
719 let mut output: Record = query
720 .projection
721 .iter()
722 .filter_map(|field| row.get(field).cloned().map(|value| (field.clone(), value)))
723 .collect();
724 for projection in &query.expr_projection {
725 output.insert(
726 projection.alias.clone(),
727 eval_value(&projection.expr, &row)?,
728 );
729 }
730 Ok(output)
731}
732
733fn aggregate_rows(
734 query: &SelectQuery,
735 rows: &[Record],
736) -> Result<Vec<Record>, MemoryRepositoryError> {
737 let mut groups: BTreeMap<Vec<String>, Vec<&Record>> = BTreeMap::new();
738 if query.group_by.is_empty() {
739 groups.insert(Vec::new(), rows.iter().collect());
740 } else {
741 for row in rows {
742 let key = query
743 .group_by
744 .iter()
745 .map(|field| row.get(field).map(value_key).unwrap_or_default())
746 .collect::<Vec<_>>();
747 groups.entry(key).or_default().push(row);
748 }
749 }
750
751 let rows = groups
752 .into_values()
753 .map(|rows| {
754 let mut output = Record::new();
755 if let Some(first) = rows.first() {
756 for field in &query.group_by {
757 if let Some(value) = first.get(field) {
758 output.insert(field.clone(), value.clone());
759 }
760 }
761 }
762 for aggregate in &query.aggregates {
763 let value = match aggregate.function {
764 AggregateFunction::Count => {
765 if aggregate.field == "*" {
766 Value::U64(rows.len() as u64)
767 } else {
768 Value::U64(
769 rows.iter()
770 .filter(|row| {
771 !matches!(
772 row.get(&aggregate.field),
773 None | Some(Value::Null)
774 )
775 })
776 .count() as u64,
777 )
778 }
779 }
780 AggregateFunction::Sum => numeric_sum(&rows, &aggregate.field)?,
781 AggregateFunction::Avg => numeric_avg(&rows, &aggregate.field)?,
782 AggregateFunction::Min => min_max(&rows, &aggregate.field, false)?,
783 AggregateFunction::Max => min_max(&rows, &aggregate.field, true)?,
784 AggregateFunction::Stddev => numeric_stddev(&rows, &aggregate.field, true)?,
785 AggregateFunction::StddevPop => numeric_stddev(&rows, &aggregate.field, false)?,
786 AggregateFunction::VarSamp => numeric_variance(&rows, &aggregate.field, true)?,
787 AggregateFunction::VarPop => numeric_variance(&rows, &aggregate.field, false)?,
788 AggregateFunction::BitAnd => {
789 bit_aggregate(&rows, &aggregate.field, BitOp::And)?
790 }
791 AggregateFunction::BitOr => bit_aggregate(&rows, &aggregate.field, BitOp::Or)?,
792 AggregateFunction::BitXor => {
793 bit_aggregate(&rows, &aggregate.field, BitOp::Xor)?
794 }
795 };
796 output.insert(aggregate.alias.clone(), value);
797 }
798 for projection in &query.expr_projection {
799 output.insert(
800 projection.alias.clone(),
801 eval_value(&projection.expr, &output)?,
802 );
803 }
804 Ok(output)
805 })
806 .collect::<Result<Vec<_>, _>>()?;
807 if let Some(having) = &query.having {
808 rows.into_iter()
809 .filter_map(|row| match eval_filter(having, &row) {
810 Ok(true) => Some(Ok(row)),
811 Ok(false) => None,
812 Err(err) => Some(Err(err)),
813 })
814 .collect()
815 } else {
816 Ok(rows)
817 }
818}
819
820fn numeric_sum(rows: &[&Record], field: &str) -> Result<Value, MemoryRepositoryError> {
821 let mut decimal_sum = Decimal::ZERO;
822 let mut integer_sum: i128 = 0;
823 let mut saw_decimal = false;
824 for value in rows.iter().filter_map(|row| row.get(field)) {
825 match value {
826 Value::I64(value) => {
827 integer_sum += i128::from(*value);
828 decimal_sum += Decimal::from(*value);
829 }
830 Value::U64(value) => {
831 integer_sum += i128::from(*value);
832 decimal_sum += Decimal::from(*value);
833 }
834 Value::F64(value) => {
835 saw_decimal = true;
836 decimal_sum += decimal_from_f64(*value);
837 }
838 Value::Decimal(value) => {
839 saw_decimal = true;
840 decimal_sum += *value;
841 }
842 Value::Null => {}
843 other => {
844 return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
845 "SUM does not support {other:?}"
846 )));
847 }
848 }
849 }
850 if saw_decimal {
851 Ok(Value::Decimal(decimal_sum))
852 } else if integer_sum >= 0 {
853 Ok(Value::U64(integer_sum as u64))
854 } else {
855 Ok(Value::I64(integer_sum as i64))
856 }
857}
858
859fn numeric_avg(rows: &[&Record], field: &str) -> Result<Value, MemoryRepositoryError> {
860 let mut sum = Decimal::ZERO;
861 let mut count: u64 = 0;
862 for value in rows.iter().filter_map(|row| row.get(field)) {
863 match value {
864 Value::I64(value) => {
865 sum += Decimal::from(*value);
866 count += 1;
867 }
868 Value::U64(value) => {
869 sum += Decimal::from(*value);
870 count += 1;
871 }
872 Value::F64(value) => {
873 sum += decimal_from_f64(*value);
874 count += 1;
875 }
876 Value::Decimal(value) => {
877 sum += *value;
878 count += 1;
879 }
880 Value::Null => {}
881 other => {
882 return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
883 "AVG does not support {other:?}"
884 )));
885 }
886 }
887 }
888 Ok(if count == 0 {
889 Value::Null
890 } else {
891 Value::Decimal(sum / Decimal::from(count))
892 })
893}
894
895fn decimal_from_f64(value: f64) -> Decimal {
896 Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
897}
898
899fn numeric_values(rows: &[&Record], field: &str) -> Result<Vec<f64>, MemoryRepositoryError> {
900 rows.iter()
901 .filter_map(|row| row.get(field))
902 .filter(|value| !matches!(value, Value::Null))
903 .map(|value| match value {
904 Value::I64(value) => Ok(*value as f64),
905 Value::U64(value) => Ok(*value as f64),
906 Value::F64(value) => Ok(*value),
907 Value::Decimal(value) => value.to_f64().ok_or_else(|| {
908 MemoryRepositoryError::UnsupportedAggregate(format!(
909 "cannot convert decimal {value} to f64 for statistical aggregate"
910 ))
911 }),
912 other => Err(MemoryRepositoryError::UnsupportedAggregate(format!(
913 "numeric aggregate does not support {other:?}"
914 ))),
915 })
916 .collect()
917}
918
919fn numeric_variance(
920 rows: &[&Record],
921 field: &str,
922 sample: bool,
923) -> Result<Value, MemoryRepositoryError> {
924 let values = numeric_values(rows, field)?;
925 let count = values.len();
926 if count == 0 || (sample && count < 2) {
927 return Ok(Value::Null);
928 }
929 let mean = values.iter().sum::<f64>() / count as f64;
930 let sum = values
931 .iter()
932 .map(|value| {
933 let diff = value - mean;
934 diff * diff
935 })
936 .sum::<f64>();
937 let denominator = if sample { count - 1 } else { count } as f64;
938 Ok(Value::Decimal(decimal_from_f64(sum / denominator)))
939}
940
941fn numeric_stddev(
942 rows: &[&Record],
943 field: &str,
944 sample: bool,
945) -> Result<Value, MemoryRepositoryError> {
946 Ok(match numeric_variance(rows, field, sample)? {
947 Value::Decimal(value) => {
948 Value::Decimal(decimal_from_f64(value.to_f64().unwrap_or(0.0).sqrt()))
949 }
950 Value::Null => Value::Null,
951 other => other,
952 })
953}
954
955#[derive(Debug, Clone, Copy)]
956enum BitOp {
957 And,
958 Or,
959 Xor,
960}
961
962fn bit_aggregate(rows: &[&Record], field: &str, op: BitOp) -> Result<Value, MemoryRepositoryError> {
963 let mut selected: Option<i64> = None;
964 for value in rows.iter().filter_map(|row| row.get(field)) {
965 let value = match value {
966 Value::I64(value) => *value,
967 Value::U64(value) => i64::try_from(*value).map_err(|_| {
968 MemoryRepositoryError::UnsupportedAggregate(format!(
969 "BIT aggregate u64 {value} exceeds i64 range"
970 ))
971 })?,
972 Value::Null => continue,
973 other => {
974 return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
975 "BIT aggregate does not support {other:?}"
976 )));
977 }
978 };
979 selected = Some(match (selected, op) {
980 (None, _) => value,
981 (Some(current), BitOp::And) => current & value,
982 (Some(current), BitOp::Or) => current | value,
983 (Some(current), BitOp::Xor) => current ^ value,
984 });
985 }
986 Ok(selected.map(Value::I64).unwrap_or(Value::Null))
987}
988
989fn min_max(rows: &[&Record], field: &str, max: bool) -> Result<Value, MemoryRepositoryError> {
990 let mut selected: Option<Value> = None;
991 for value in rows.iter().filter_map(|row| row.get(field)) {
992 if matches!(value, Value::Null) {
993 continue;
994 }
995 match &selected {
996 None => selected = Some(value.clone()),
997 Some(current) => {
998 let Some(ordering) = compare_values(value, current) else {
999 return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
1000 "MIN/MAX does not support {value:?}"
1001 )));
1002 };
1003 if (max && ordering == Ordering::Greater) || (!max && ordering == Ordering::Less) {
1004 selected = Some(value.clone());
1005 }
1006 }
1007 }
1008 }
1009 Ok(selected.unwrap_or(Value::Null))
1010}
1011
1012fn value_key(value: &Value) -> String {
1013 match value {
1014 Value::Null => "null".to_owned(),
1015 Value::Bool(value) => format!("b:{value}"),
1016 Value::I64(value) => format!("i:{value}"),
1017 Value::U64(value) => format!("u:{value}"),
1018 Value::F64(value) => format!("f:{value}"),
1019 Value::Decimal(value) => format!("d:{value}"),
1020 Value::Text(value) => format!("t:{value}"),
1021 Value::Json(value) => format!("j:{value}"),
1022 Value::Date(value) => format!("d:{value}"),
1023 Value::Timestamp(value) => format!("ts:{}", value.to_rfc3339()),
1024 Value::Object(_) => "object".to_owned(),
1025 Value::List(_) => "list".to_owned(),
1026 }
1027}
1028
1029fn read_i64(value: Option<&Value>) -> Option<i64> {
1030 match value {
1031 Some(Value::I64(value)) => Some(*value),
1032 _ => None,
1033 }
1034}
1035
1036fn local_graph_identity_key(value: &Value) -> String {
1037 match value {
1038 Value::I64(val) if *val >= 0 => format!("u:{}", *val as u64),
1039 Value::U64(val) => format!("u:{val}"),
1040 Value::Null => "null".to_owned(),
1041 Value::Bool(v) => format!("b:{v}"),
1042 Value::I64(v) => format!("i:{v}"),
1043 Value::F64(v) => format!("f:{v}"),
1044 Value::Decimal(v) => format!("d:{v}"),
1045 Value::Text(v) => format!("t:{v}"),
1046 Value::Json(v) => format!("j:{v}"),
1047 Value::Date(v) => format!("d:{v}"),
1048 Value::Timestamp(v) => format!("ts:{}", v.to_rfc3339()),
1049 Value::Object(_) => "o".to_owned(),
1050 Value::List(_) => "l".to_owned(),
1051 }
1052}
1053