Skip to main content

acts_next/store/db/mem/
collect.rs

1use serde::de::DeserializeOwned;
2use tracing::debug;
3
4use crate::store::query::CondType;
5use crate::store::{Cond, Expr, ExprOp, PageData, map_db_err};
6use crate::{ActError, DbCollection, Result, ShareLock, store::query::*};
7use serde_json::Value as JsonValue;
8use std::cmp::Ordering;
9use std::collections::{BTreeMap, HashMap, HashSet};
10use std::fmt::Debug;
11use std::marker::PhantomData;
12use std::sync::{Arc, RwLock};
13
14use super::DbDocument;
15
16#[derive(Debug)]
17pub struct Collect<T> {
18    name: String,
19    db: ShareLock<BTreeMap<String, HashMap<String, JsonValue>>>,
20    _t: PhantomData<T>,
21}
22
23impl<T> Collect<T> {
24    pub fn new(name: &str) -> Self {
25        Self {
26            name: name.to_string(),
27            db: Arc::new(RwLock::new(BTreeMap::new())),
28            _t: PhantomData,
29        }
30    }
31}
32
33impl<T> DbCollection for Collect<T>
34where
35    T: DbDocument + Send + Sync + Clone + Debug,
36{
37    type Item = T;
38
39    fn exists(&self, id: &str) -> crate::Result<bool> {
40        debug!("mem::{}.exists({:?})", self.name, id);
41        Ok(self.db.read().unwrap().contains_key(id))
42    }
43
44    fn find(&self, id: &str) -> Result<Self::Item> {
45        debug!("mem::{}.find({:?})", self.name, id);
46        self.db
47            .read()
48            .unwrap()
49            .get(id)
50            .map(|iter| map_to_model::<Self::Item>(iter).unwrap())
51            .ok_or(ActError::Store(format!(
52                "cannot find {} by '{}'",
53                self.name, id
54            )))
55    }
56
57    fn query(&self, q: &Query) -> crate::Result<PageData<Self::Item>> {
58        debug!("mem::{}.query({:?})", self.name, q);
59        let db = self.db.read().unwrap();
60        #[allow(unused_assignments)]
61        let mut rows = vec![];
62        if !q.is_cond() {
63            rows = db.values().collect::<Vec<_>>();
64        } else {
65            let mut q = q.clone();
66            for cond in q.queries_mut() {
67                for expr in cond.conds().clone().iter() {
68                    let mut result = HashSet::new();
69                    for (k, v) in db.iter() {
70                        let prop_value = v.get(expr.key()).ok_or(ActError::Store(format!(
71                            "cannot find key `{}` in {}",
72                            expr.key(),
73                            self.name
74                        )))?;
75                        let cond_value = expr.value();
76
77                        if expr.op(prop_value, cond_value) {
78                            result.insert(k.as_bytes().to_vec().into_boxed_slice());
79                        }
80                    }
81                    cond.calc(&result);
82                }
83            }
84
85            let items = q.calc();
86            #[allow(unused_assignments)]
87            {
88                rows = db
89                    .iter()
90                    .filter_map(|(k, v)| {
91                        if items.contains(&k.as_bytes().to_vec().into_boxed_slice()) {
92                            return Some(v);
93                        }
94                        None
95                    })
96                    .collect::<Vec<_>>();
97            }
98        }
99
100        // order the rows
101        if !q.order_by().is_empty() {
102            rows.sort_by(|a, b| {
103                let mut ret = Ordering::Equal;
104                for (order, rev) in q.order_by() {
105                    if *rev {
106                        ret = ret.then(
107                            b.get(order)
108                                .unwrap()
109                                .to_string()
110                                .cmp(&a.get(order).unwrap().to_string()),
111                        );
112                    } else {
113                        ret = ret.then(
114                            a.get(order)
115                                .unwrap()
116                                .to_string()
117                                .cmp(&b.get(order).unwrap().to_string()),
118                        );
119                    }
120                }
121
122                ret
123            });
124        }
125
126        let count = rows.len();
127        let page_count = count.div_ceil(q.limit());
128        let page_num = q.offset() / q.limit() + 1;
129        let data = PageData {
130            count,
131            page_size: q.limit(),
132            page_num,
133            page_count,
134            rows: rows
135                .iter()
136                .skip(q.offset())
137                .take(q.limit())
138                .map(|row| map_to_model::<Self::Item>(row).unwrap())
139                .collect::<Vec<_>>(),
140        };
141        Ok(data)
142    }
143
144    fn create(&self, data: &Self::Item) -> Result<bool> {
145        debug!("mem::{}.create({:?})", self.name, data);
146        self.db
147            .write()
148            .unwrap()
149            .insert(data.id().to_string(), data.doc()?);
150        Ok(true)
151    }
152
153    fn update(&self, data: &Self::Item) -> Result<bool> {
154        debug!("mem::{}.update({:?})", self.name, data);
155        self.db
156            .write()
157            .unwrap()
158            .entry(data.id().to_string())
159            .and_modify(|iter| *iter = data.doc().unwrap());
160        Ok(true)
161    }
162
163    fn delete(&self, id: &str) -> crate::Result<bool> {
164        debug!("mem::{}.delete({:?})", self.name, id);
165        self.db.write().unwrap().remove(id);
166        Ok(true)
167    }
168}
169
170impl Cond {
171    pub fn calc(&mut self, v: &HashSet<Box<[u8]>>) {
172        match self.r#type {
173            CondType::And => {
174                if self.result.is_empty() {
175                    self.result = v.clone();
176                } else {
177                    self.result = self.result.intersection(v).cloned().collect::<HashSet<_>>()
178                }
179            }
180            CondType::Or => {
181                if self.result.is_empty() {
182                    self.result = v.clone();
183                } else {
184                    self.result = self.result.union(v).cloned().collect::<HashSet<_>>()
185                }
186            }
187        }
188    }
189}
190
191impl Expr {
192    pub fn op(&self, l: &serde_json::Value, r: &serde_json::Value) -> bool {
193        debug!("Expr.op op={:?}, l={l}, r={r}", self.op);
194        match &self.op {
195            ExprOp::EQ => l == r,
196            ExprOp::NE => l != r,
197            ExprOp::LT => {
198                if let (serde_json::Value::Number(v1), serde_json::Value::Number(v2)) = (l, r) {
199                    if v1.is_f64() {
200                        return v1.as_f64().unwrap() < v2.as_f64().unwrap_or_default();
201                    } else if v1.is_i64() {
202                        return v1.as_i64().unwrap() < v2.as_i64().unwrap_or_default();
203                    } else if v1.is_u64() {
204                        return v1.as_u64().unwrap() < v2.as_u64().unwrap_or_default();
205                    }
206                }
207                false
208            }
209            ExprOp::LE => {
210                if let (serde_json::Value::Number(v1), serde_json::Value::Number(v2)) = (l, r) {
211                    if v1.is_f64() {
212                        return v1.as_f64().unwrap() <= v2.as_f64().unwrap_or_default();
213                    } else if v1.is_i64() {
214                        return v1.as_i64().unwrap() <= v2.as_i64().unwrap_or_default();
215                    } else if v1.is_u64() {
216                        return v1.as_u64().unwrap() <= v2.as_u64().unwrap_or_default();
217                    }
218                }
219                false
220            }
221            ExprOp::GT => {
222                if let (serde_json::Value::Number(v1), serde_json::Value::Number(v2)) = (l, r) {
223                    if v1.is_f64() {
224                        return v1.as_f64().unwrap() > v2.as_f64().unwrap_or_default();
225                    } else if v1.is_i64() {
226                        return v1.as_i64().unwrap() > v2.as_i64().unwrap_or_default();
227                    } else if v1.is_u64() {
228                        return v1.as_u64().unwrap() > v2.as_u64().unwrap_or_default();
229                    }
230                }
231                false
232            }
233            ExprOp::GE => {
234                if let (serde_json::Value::Number(v1), serde_json::Value::Number(v2)) = (l, r) {
235                    if v1.is_f64() {
236                        return v1.as_f64().unwrap() >= v2.as_f64().unwrap_or_default();
237                    } else if v1.is_i64() {
238                        return v1.as_i64().unwrap() >= v2.as_i64().unwrap_or_default();
239                    } else if v1.is_u64() {
240                        return v1.as_u64().unwrap() >= v2.as_u64().unwrap_or_default();
241                    }
242                }
243                false
244            }
245        }
246    }
247}
248
249fn map_to_model<T>(map: &HashMap<String, JsonValue>) -> Result<T>
250where
251    T: DeserializeOwned,
252{
253    let mut value = serde_json::Map::new();
254    for (k, v) in map {
255        value.insert(k.to_string(), v.clone());
256    }
257    serde_json::from_value(JsonValue::Object(value)).map_err(map_db_err)
258}