firebase_rs_sdk/firestore/remote/datastore/
in_memory.rs

1use std::collections::BTreeMap;
2
3use super::Datastore;
4use crate::firestore::api::query::{
5    Bound, FieldFilter, FilterOperator, OrderBy, OrderDirection, QueryDefinition,
6};
7use crate::firestore::api::{DocumentSnapshot, SnapshotMetadata};
8use crate::firestore::error::FirestoreResult;
9use crate::firestore::model::{DocumentKey, FieldPath};
10use crate::firestore::value::{FirestoreValue, MapValue, ValueKind};
11
12#[derive(Clone, Default)]
13pub struct InMemoryDatastore {
14    documents: std::sync::Arc<std::sync::Mutex<BTreeMap<String, MapValue>>>,
15}
16
17impl InMemoryDatastore {
18    pub fn new() -> Self {
19        Self::default()
20    }
21}
22
23impl Datastore for InMemoryDatastore {
24    fn get_document(&self, key: &DocumentKey) -> FirestoreResult<DocumentSnapshot> {
25        let store = self.documents.lock().unwrap();
26        let data = store.get(&key.path().canonical_string()).cloned();
27        Ok(DocumentSnapshot::new(
28            key.clone(),
29            data,
30            SnapshotMetadata::new(true, false),
31        ))
32    }
33
34    fn set_document(&self, key: &DocumentKey, data: MapValue, _merge: bool) -> FirestoreResult<()> {
35        let mut store = self.documents.lock().unwrap();
36        store.insert(key.path().canonical_string(), data);
37        Ok(())
38    }
39
40    fn run_query(&self, query: &QueryDefinition) -> FirestoreResult<Vec<DocumentSnapshot>> {
41        let store = self.documents.lock().unwrap();
42        let mut documents = Vec::new();
43
44        for (path, data) in store.iter() {
45            let key = DocumentKey::from_string(path)?;
46            if !query.matches_collection(&key) {
47                continue;
48            }
49
50            let snapshot =
51                DocumentSnapshot::new(key, Some(data.clone()), SnapshotMetadata::new(true, false));
52
53            if document_satisfies_filters(&snapshot, query.filters()) {
54                documents.push(snapshot);
55            }
56        }
57
58        documents.sort_by(|left, right| compare_snapshots(left, right, query.result_order_by()));
59
60        if let Some(bound) = query.result_start_at() {
61            documents.retain(|snapshot| {
62                !is_before_start_bound(snapshot, bound, query.result_order_by())
63            });
64        }
65
66        if let Some(bound) = query.result_end_at() {
67            documents
68                .retain(|snapshot| !is_after_end_bound(snapshot, bound, query.result_order_by()));
69        }
70
71        if let Some(limit) = query.limit() {
72            let limit = limit as usize;
73            match query.limit_type() {
74                crate::firestore::api::query::LimitType::First => {
75                    if documents.len() > limit {
76                        documents.truncate(limit);
77                    }
78                }
79                crate::firestore::api::query::LimitType::Last => {
80                    if documents.len() > limit {
81                        let start = documents.len() - limit;
82                        documents.drain(0..start);
83                    }
84                }
85            }
86        }
87
88        Ok(documents)
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95    use crate::firestore::value::FirestoreValue;
96
97    #[test]
98    fn in_memory_get_set() {
99        let datastore = InMemoryDatastore::new();
100        let key = DocumentKey::from_string("cities/sf").unwrap();
101        let mut map = BTreeMap::new();
102        map.insert("name".to_string(), FirestoreValue::from_string("SF"));
103        let map = MapValue::new(map);
104        datastore.set_document(&key, map.clone(), false).unwrap();
105        let snapshot = datastore.get_document(&key).unwrap();
106        assert!(snapshot.exists());
107        assert_eq!(
108            snapshot.data().unwrap().get("name"),
109            Some(&FirestoreValue::from_string("SF"))
110        );
111    }
112}
113
114fn document_satisfies_filters(snapshot: &DocumentSnapshot, filters: &[FieldFilter]) -> bool {
115    filters
116        .iter()
117        .all(|filter| match get_field_value(snapshot, filter.field()) {
118            Some(value) => evaluate_filter(filter, &value),
119            None => {
120                filter.operator() == FilterOperator::NotEqual
121                    && evaluate_filter(filter, &FirestoreValue::null())
122            }
123        })
124}
125
126fn evaluate_filter(filter: &FieldFilter, value: &FirestoreValue) -> bool {
127    match filter.operator() {
128        FilterOperator::Equal => value == filter.value(),
129        FilterOperator::NotEqual => value != filter.value(),
130        FilterOperator::LessThan => {
131            compare_values(value, filter.value()) == Some(std::cmp::Ordering::Less)
132        }
133        FilterOperator::LessThanOrEqual => match compare_values(value, filter.value()) {
134            Some(std::cmp::Ordering::Less) | Some(std::cmp::Ordering::Equal) => true,
135            _ => false,
136        },
137        FilterOperator::GreaterThan => {
138            compare_values(value, filter.value()) == Some(std::cmp::Ordering::Greater)
139        }
140        FilterOperator::GreaterThanOrEqual => match compare_values(value, filter.value()) {
141            Some(std::cmp::Ordering::Greater) | Some(std::cmp::Ordering::Equal) => true,
142            _ => false,
143        },
144        FilterOperator::NotIn
145        | FilterOperator::ArrayContains
146        | FilterOperator::ArrayContainsAny
147        | FilterOperator::In => false,
148    }
149}
150
151fn get_field_value(snapshot: &DocumentSnapshot, field: &FieldPath) -> Option<FirestoreValue> {
152    if field == &FieldPath::document_id() {
153        let key = snapshot.document_key();
154        return Some(FirestoreValue::from_string(key.path().canonical_string()));
155    }
156
157    let map = snapshot.map_value()?;
158    find_in_map(map, field.segments()).cloned()
159}
160
161fn find_in_map<'a>(map: &'a MapValue, segments: &'a [String]) -> Option<&'a FirestoreValue> {
162    let (first, rest) = segments.split_first()?;
163    let value = map.fields().get(first)?;
164    if rest.is_empty() {
165        Some(value)
166    } else if let ValueKind::Map(child) = value.kind() {
167        find_in_map(child, rest)
168    } else {
169        None
170    }
171}
172
173fn compare_snapshots(
174    left: &DocumentSnapshot,
175    right: &DocumentSnapshot,
176    order_by: &[OrderBy],
177) -> std::cmp::Ordering {
178    for order in order_by {
179        let left_value = get_field_value(left, order.field()).unwrap_or_else(FirestoreValue::null);
180        let right_value =
181            get_field_value(right, order.field()).unwrap_or_else(FirestoreValue::null);
182
183        let mut ordering =
184            compare_values(&left_value, &right_value).unwrap_or(std::cmp::Ordering::Equal);
185        if order.direction() == OrderDirection::Descending {
186            ordering = ordering.reverse();
187        }
188        if ordering != std::cmp::Ordering::Equal {
189            return ordering;
190        }
191    }
192    std::cmp::Ordering::Equal
193}
194
195fn compare_values(left: &FirestoreValue, right: &FirestoreValue) -> Option<std::cmp::Ordering> {
196    match (left.kind(), right.kind()) {
197        (ValueKind::Null, ValueKind::Null) => Some(std::cmp::Ordering::Equal),
198        (ValueKind::Boolean(a), ValueKind::Boolean(b)) => Some(a.cmp(b)),
199        (ValueKind::Integer(a), ValueKind::Integer(b)) => Some(a.cmp(b)),
200        (ValueKind::Double(a), ValueKind::Double(b)) => a.partial_cmp(b),
201        (ValueKind::Integer(a), ValueKind::Double(b)) => (*a as f64).partial_cmp(b),
202        (ValueKind::Double(a), ValueKind::Integer(b)) => a.partial_cmp(&(*b as f64)),
203        (ValueKind::String(a), ValueKind::String(b)) => Some(a.cmp(b)),
204        (ValueKind::Reference(a), ValueKind::Reference(b)) => Some(a.cmp(b)),
205        _ => None,
206    }
207}
208
209fn is_before_start_bound(snapshot: &DocumentSnapshot, bound: &Bound, order_by: &[OrderBy]) -> bool {
210    let ordering = compare_snapshot_to_bound(snapshot, bound, order_by);
211    if bound.inclusive() {
212        ordering == std::cmp::Ordering::Less
213    } else {
214        ordering != std::cmp::Ordering::Greater
215    }
216}
217
218fn is_after_end_bound(snapshot: &DocumentSnapshot, bound: &Bound, order_by: &[OrderBy]) -> bool {
219    let ordering = compare_snapshot_to_bound(snapshot, bound, order_by);
220    if bound.inclusive() {
221        ordering == std::cmp::Ordering::Greater
222    } else {
223        ordering != std::cmp::Ordering::Less
224    }
225}
226
227fn compare_snapshot_to_bound(
228    snapshot: &DocumentSnapshot,
229    bound: &Bound,
230    order_by: &[OrderBy],
231) -> std::cmp::Ordering {
232    for (index, order) in order_by.iter().enumerate() {
233        if index >= bound.values().len() {
234            break;
235        }
236
237        let bound_value = &bound.values()[index];
238        let snapshot_value =
239            get_field_value(snapshot, order.field()).unwrap_or_else(FirestoreValue::null);
240
241        let mut ordering =
242            compare_values(&snapshot_value, bound_value).unwrap_or(std::cmp::Ordering::Equal);
243        if order.direction() == OrderDirection::Descending {
244            ordering = ordering.reverse();
245        }
246
247        if ordering != std::cmp::Ordering::Equal {
248            return ordering;
249        }
250    }
251    std::cmp::Ordering::Equal
252}