firebase_rs_sdk/firestore/remote/datastore/
in_memory.rs1use std::collections::BTreeMap;
2
3use super::Datastore;
4use crate::firestore::api::query::{
5 Bound, FieldFilter, FilterOperator, OrderBy, OrderDirection, QueryDefinition,
6};
7use crate::firestore::api::{DocumentSnapshot, SnapshotMetadata};
8use crate::firestore::error::FirestoreResult;
9use crate::firestore::model::{DocumentKey, FieldPath};
10use crate::firestore::value::{FirestoreValue, MapValue, ValueKind};
11
12#[derive(Clone, Default)]
13pub struct InMemoryDatastore {
14 documents: std::sync::Arc<std::sync::Mutex<BTreeMap<String, MapValue>>>,
15}
16
17impl InMemoryDatastore {
18 pub fn new() -> Self {
19 Self::default()
20 }
21}
22
23impl Datastore for InMemoryDatastore {
24 fn get_document(&self, key: &DocumentKey) -> FirestoreResult<DocumentSnapshot> {
25 let store = self.documents.lock().unwrap();
26 let data = store.get(&key.path().canonical_string()).cloned();
27 Ok(DocumentSnapshot::new(
28 key.clone(),
29 data,
30 SnapshotMetadata::new(true, false),
31 ))
32 }
33
34 fn set_document(&self, key: &DocumentKey, data: MapValue, _merge: bool) -> FirestoreResult<()> {
35 let mut store = self.documents.lock().unwrap();
36 store.insert(key.path().canonical_string(), data);
37 Ok(())
38 }
39
40 fn run_query(&self, query: &QueryDefinition) -> FirestoreResult<Vec<DocumentSnapshot>> {
41 let store = self.documents.lock().unwrap();
42 let mut documents = Vec::new();
43
44 for (path, data) in store.iter() {
45 let key = DocumentKey::from_string(path)?;
46 if !query.matches_collection(&key) {
47 continue;
48 }
49
50 let snapshot =
51 DocumentSnapshot::new(key, Some(data.clone()), SnapshotMetadata::new(true, false));
52
53 if document_satisfies_filters(&snapshot, query.filters()) {
54 documents.push(snapshot);
55 }
56 }
57
58 documents.sort_by(|left, right| compare_snapshots(left, right, query.result_order_by()));
59
60 if let Some(bound) = query.result_start_at() {
61 documents.retain(|snapshot| {
62 !is_before_start_bound(snapshot, bound, query.result_order_by())
63 });
64 }
65
66 if let Some(bound) = query.result_end_at() {
67 documents
68 .retain(|snapshot| !is_after_end_bound(snapshot, bound, query.result_order_by()));
69 }
70
71 if let Some(limit) = query.limit() {
72 let limit = limit as usize;
73 match query.limit_type() {
74 crate::firestore::api::query::LimitType::First => {
75 if documents.len() > limit {
76 documents.truncate(limit);
77 }
78 }
79 crate::firestore::api::query::LimitType::Last => {
80 if documents.len() > limit {
81 let start = documents.len() - limit;
82 documents.drain(0..start);
83 }
84 }
85 }
86 }
87
88 Ok(documents)
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95 use crate::firestore::value::FirestoreValue;
96
97 #[test]
98 fn in_memory_get_set() {
99 let datastore = InMemoryDatastore::new();
100 let key = DocumentKey::from_string("cities/sf").unwrap();
101 let mut map = BTreeMap::new();
102 map.insert("name".to_string(), FirestoreValue::from_string("SF"));
103 let map = MapValue::new(map);
104 datastore.set_document(&key, map.clone(), false).unwrap();
105 let snapshot = datastore.get_document(&key).unwrap();
106 assert!(snapshot.exists());
107 assert_eq!(
108 snapshot.data().unwrap().get("name"),
109 Some(&FirestoreValue::from_string("SF"))
110 );
111 }
112}
113
114fn document_satisfies_filters(snapshot: &DocumentSnapshot, filters: &[FieldFilter]) -> bool {
115 filters
116 .iter()
117 .all(|filter| match get_field_value(snapshot, filter.field()) {
118 Some(value) => evaluate_filter(filter, &value),
119 None => {
120 filter.operator() == FilterOperator::NotEqual
121 && evaluate_filter(filter, &FirestoreValue::null())
122 }
123 })
124}
125
126fn evaluate_filter(filter: &FieldFilter, value: &FirestoreValue) -> bool {
127 match filter.operator() {
128 FilterOperator::Equal => value == filter.value(),
129 FilterOperator::NotEqual => value != filter.value(),
130 FilterOperator::LessThan => {
131 compare_values(value, filter.value()) == Some(std::cmp::Ordering::Less)
132 }
133 FilterOperator::LessThanOrEqual => match compare_values(value, filter.value()) {
134 Some(std::cmp::Ordering::Less) | Some(std::cmp::Ordering::Equal) => true,
135 _ => false,
136 },
137 FilterOperator::GreaterThan => {
138 compare_values(value, filter.value()) == Some(std::cmp::Ordering::Greater)
139 }
140 FilterOperator::GreaterThanOrEqual => match compare_values(value, filter.value()) {
141 Some(std::cmp::Ordering::Greater) | Some(std::cmp::Ordering::Equal) => true,
142 _ => false,
143 },
144 FilterOperator::NotIn
145 | FilterOperator::ArrayContains
146 | FilterOperator::ArrayContainsAny
147 | FilterOperator::In => false,
148 }
149}
150
151fn get_field_value(snapshot: &DocumentSnapshot, field: &FieldPath) -> Option<FirestoreValue> {
152 if field == &FieldPath::document_id() {
153 let key = snapshot.document_key();
154 return Some(FirestoreValue::from_string(key.path().canonical_string()));
155 }
156
157 let map = snapshot.map_value()?;
158 find_in_map(map, field.segments()).cloned()
159}
160
161fn find_in_map<'a>(map: &'a MapValue, segments: &'a [String]) -> Option<&'a FirestoreValue> {
162 let (first, rest) = segments.split_first()?;
163 let value = map.fields().get(first)?;
164 if rest.is_empty() {
165 Some(value)
166 } else if let ValueKind::Map(child) = value.kind() {
167 find_in_map(child, rest)
168 } else {
169 None
170 }
171}
172
173fn compare_snapshots(
174 left: &DocumentSnapshot,
175 right: &DocumentSnapshot,
176 order_by: &[OrderBy],
177) -> std::cmp::Ordering {
178 for order in order_by {
179 let left_value = get_field_value(left, order.field()).unwrap_or_else(FirestoreValue::null);
180 let right_value =
181 get_field_value(right, order.field()).unwrap_or_else(FirestoreValue::null);
182
183 let mut ordering =
184 compare_values(&left_value, &right_value).unwrap_or(std::cmp::Ordering::Equal);
185 if order.direction() == OrderDirection::Descending {
186 ordering = ordering.reverse();
187 }
188 if ordering != std::cmp::Ordering::Equal {
189 return ordering;
190 }
191 }
192 std::cmp::Ordering::Equal
193}
194
195fn compare_values(left: &FirestoreValue, right: &FirestoreValue) -> Option<std::cmp::Ordering> {
196 match (left.kind(), right.kind()) {
197 (ValueKind::Null, ValueKind::Null) => Some(std::cmp::Ordering::Equal),
198 (ValueKind::Boolean(a), ValueKind::Boolean(b)) => Some(a.cmp(b)),
199 (ValueKind::Integer(a), ValueKind::Integer(b)) => Some(a.cmp(b)),
200 (ValueKind::Double(a), ValueKind::Double(b)) => a.partial_cmp(b),
201 (ValueKind::Integer(a), ValueKind::Double(b)) => (*a as f64).partial_cmp(b),
202 (ValueKind::Double(a), ValueKind::Integer(b)) => a.partial_cmp(&(*b as f64)),
203 (ValueKind::String(a), ValueKind::String(b)) => Some(a.cmp(b)),
204 (ValueKind::Reference(a), ValueKind::Reference(b)) => Some(a.cmp(b)),
205 _ => None,
206 }
207}
208
209fn is_before_start_bound(snapshot: &DocumentSnapshot, bound: &Bound, order_by: &[OrderBy]) -> bool {
210 let ordering = compare_snapshot_to_bound(snapshot, bound, order_by);
211 if bound.inclusive() {
212 ordering == std::cmp::Ordering::Less
213 } else {
214 ordering != std::cmp::Ordering::Greater
215 }
216}
217
218fn is_after_end_bound(snapshot: &DocumentSnapshot, bound: &Bound, order_by: &[OrderBy]) -> bool {
219 let ordering = compare_snapshot_to_bound(snapshot, bound, order_by);
220 if bound.inclusive() {
221 ordering == std::cmp::Ordering::Greater
222 } else {
223 ordering != std::cmp::Ordering::Less
224 }
225}
226
227fn compare_snapshot_to_bound(
228 snapshot: &DocumentSnapshot,
229 bound: &Bound,
230 order_by: &[OrderBy],
231) -> std::cmp::Ordering {
232 for (index, order) in order_by.iter().enumerate() {
233 if index >= bound.values().len() {
234 break;
235 }
236
237 let bound_value = &bound.values()[index];
238 let snapshot_value =
239 get_field_value(snapshot, order.field()).unwrap_or_else(FirestoreValue::null);
240
241 let mut ordering =
242 compare_values(&snapshot_value, bound_value).unwrap_or(std::cmp::Ordering::Equal);
243 if order.direction() == OrderDirection::Descending {
244 ordering = ordering.reverse();
245 }
246
247 if ordering != std::cmp::Ordering::Equal {
248 return ordering;
249 }
250 }
251 std::cmp::Ordering::Equal
252}