rsv_lib/utils/
filter.rs

1use super::{
2    math_expr_parser::{CompiledExpr, AST},
3    row_split::CsvRowSplitter,
4};
5use crate::utils::util::werr_exit;
6use regex::Regex;
7use std::{
8    fs::File,
9    io::{BufRead, BufReader},
10    path::Path,
11};
12
13#[derive(Debug)]
14enum Op {
15    Equal,
16    NotEqual,
17    Gt,
18    Ge,
19    Lt,
20    Le,
21}
22
23impl Op {
24    fn evaluate(&self, a: f64, b: f64) -> bool {
25        match self {
26            Op::Equal => a == b,
27            Op::NotEqual => a != b,
28            Op::Gt => a > b,
29            Op::Ge => a >= b,
30            Op::Lt => a < b,
31            Op::Le => a <= b,
32        }
33    }
34}
35
36struct FilterItem {
37    col: usize,
38    is_numeric: bool,
39    op: Op,
40    f64_value: f64,
41    str_value: String,
42    f64_values: Vec<f64>,
43    str_values: Vec<String>,
44    is_math_expr: bool,
45    ast: Box<CompiledExpr>,
46}
47
48pub struct Filter<'a> {
49    raw: &'a str,
50    total: Option<usize>,
51    path: Option<&'a Path>,
52    sep: char,
53    quote: char,
54    filters: Vec<FilterItem>,
55    pub parsed: bool,
56}
57
58fn parse_col_usize(col: &str) -> usize {
59    col.parse().unwrap_or_else(|_| {
60        werr_exit!(
61            "{}",
62            "Column syntax error: can be something like 0 (first column), -1 (last column)."
63        );
64    })
65}
66
67fn parse_i32(col: &str) -> i32 {
68    col.parse().unwrap_or_else(|_| {
69        werr_exit!(
70            "{}",
71            "Column syntax error: can be something like 0 (first column), -1 (last column)."
72        );
73    })
74}
75
76impl<'a> Filter<'a> {
77    pub fn new(raw: &str) -> Filter {
78        Filter {
79            raw,
80            total: None,
81            path: None,
82            sep: ',',
83            quote: '"',
84            filters: vec![],
85            parsed: false,
86        }
87    }
88
89    pub fn is_empty(&self) -> bool {
90        self.filters.is_empty()
91    }
92
93    pub fn total_col(mut self, total: usize) -> Self {
94        self.total = Some(total);
95        self
96    }
97
98    pub fn total_col_of(mut self, path: &'a Path, sep: char, quote: char) -> Self {
99        self.path = Some(path);
100        self.sep = sep;
101        self.quote = quote;
102        self
103    }
104
105    fn true_col(&mut self, col: &str) -> usize {
106        if col.starts_with('-') {
107            if self.total.is_none() {
108                let mut first_line = String::new();
109                let f = File::open(self.path.unwrap()).expect("unable to open file.");
110                BufReader::new(f)
111                    .read_line(&mut first_line)
112                    .expect("read error.");
113                self.total = Some(CsvRowSplitter::new(&first_line, self.sep, self.quote).count());
114            }
115            let i = (self.total.unwrap() as i32) + parse_i32(col);
116            if i < 0 {
117                werr_exit!("Column {} does not exist.", col);
118            }
119            i as usize
120        } else {
121            parse_col_usize(col)
122        }
123    }
124
125    pub fn parse(mut self) -> Self {
126        self.parsed = true;
127
128        if self.raw.is_empty() {
129            return self;
130        }
131
132        self.raw
133            .split('&')
134            .filter(|&i| !i.is_empty())
135            .for_each(|one| self.parse_one(one));
136
137        self
138    }
139
140    fn parse_one(&mut self, one: &str) {
141        // matching order is important
142        let re = Regex::new("!=|>=|<=|=|>|<").unwrap();
143        let v = re.split(one).collect::<Vec<_>>();
144
145        if v.len() != 2 {
146            werr_exit!("Error: Filter syntax is wrong, run <rsv select -h> for help.");
147        }
148
149        // parse column
150        let mut col = v[0].to_owned();
151        let is_numeric = col.ends_with(['n', 'N']);
152        if is_numeric {
153            col.pop();
154        }
155        let col = { self.true_col(&col) };
156
157        // check whether rhs is a math expr
158        // @1 or c1 represents first column
159        let is_math_expr = v[1].contains(['+', '*', '/', '%', '^', '(', '@', 'c'])
160            || v[1].rfind('-').unwrap_or(0) > 0;
161
162        let mut item = FilterItem {
163            col,
164            is_numeric,
165            op: Op::NotEqual,
166            f64_value: 0.0,
167            f64_values: vec![],
168            str_value: String::new(),
169            str_values: vec![],
170            is_math_expr,
171            ast: Box::new(AST::parse("")),
172        };
173
174        // parse filter, matching order is important
175        match (is_math_expr, is_numeric) {
176            (true, _) => {
177                item.op = if one.contains("!=") {
178                    Op::NotEqual
179                } else if one.contains(">=") {
180                    Op::Ge
181                } else if one.contains("<=") {
182                    Op::Le
183                } else if one.contains('=') {
184                    Op::Equal
185                } else if one.contains('>') {
186                    Op::Gt
187                } else if one.contains('<') {
188                    Op::Lt
189                } else {
190                    Op::NotEqual
191                };
192                item.ast = Box::new(AST::parse(v[1]));
193            }
194            (false, true) => {
195                if one.contains("!=") {
196                    item.op = Op::NotEqual;
197                    item.f64_values = parse_f64_vec(v[1]);
198                } else if one.contains(">=") {
199                    item.op = Op::Ge;
200                    item.f64_value = parse_f64(v[1]);
201                } else if one.contains("<=") {
202                    item.op = Op::Le;
203                    item.f64_value = parse_f64(v[1]);
204                } else if one.contains('=') {
205                    item.op = Op::Equal;
206                    item.f64_values = parse_f64_vec(v[1]);
207                } else if one.contains('>') {
208                    item.op = Op::Gt;
209                    item.f64_value = parse_f64(v[1]);
210                } else if one.contains('<') {
211                    item.op = Op::Lt;
212                    item.f64_value = parse_f64(v[1]);
213                }
214            }
215            (false, false) => {
216                if one.contains("!=") {
217                    item.op = Op::NotEqual;
218                    item.str_values = v[1].split(',').map(String::from).collect::<Vec<_>>();
219                } else if one.contains(">=") {
220                    item.op = Op::Ge;
221                    item.str_value = v[1].to_owned();
222                } else if one.contains("<=") {
223                    item.op = Op::Le;
224                    item.str_value = v[1].to_owned();
225                } else if one.contains('=') {
226                    item.op = Op::Equal;
227                    item.str_values = v[1].split(',').map(String::from).collect::<Vec<_>>();
228                } else if one.contains('>') {
229                    item.op = Op::Gt;
230                    item.str_value = v[1].to_owned();
231                } else if one.contains('<') {
232                    item.op = Op::Lt;
233                    item.str_value = v[1].to_owned()
234                }
235            }
236        }
237
238        self.filters.push(item);
239    }
240
241    // todo
242    pub fn record_is_valid<T: AsRef<str>>(&self, row: &[T]) -> bool {
243        self.filters.iter().all(|item| item.record_is_valid(row))
244    }
245
246    pub fn record_valid_map<'b>(
247        &self,
248        row: &'b str,
249        sep: char,
250        quote: char,
251    ) -> Option<(Option<&'b str>, Option<Vec<&'b str>>)> {
252        if self.is_empty() {
253            return Some((Some(row), None));
254        }
255
256        let v = CsvRowSplitter::new(row, sep, quote).collect::<Vec<_>>();
257        if self.record_is_valid(&v) {
258            Some((Some(row), Some(v)))
259        } else {
260            None
261        }
262    }
263
264    pub fn excel_record_is_valid<T: AsRef<str>>(&self, row: &[T]) -> bool {
265        if self.is_empty() {
266            return true;
267        }
268        self.filters.iter().all(|item| item.record_is_valid(row))
269    }
270}
271
272pub fn parse_f64(s: &str) -> f64 {
273    s.parse::<f64>().unwrap_or_else(|_| {
274        werr_exit!("Error: <{s}> is not a valid number, run <rsv select -h> for help.");
275    })
276}
277
278fn parse_f64_vec(s: &str) -> Vec<f64> {
279    s.split(',')
280        .map(|i| {
281            i.parse::<f64>().unwrap_or_else(|_| {
282                werr_exit!("Error: <{i}> is not a number, run <rsv select -h> for help.")
283            })
284        })
285        .collect()
286}
287
288impl FilterItem {
289    fn record_is_valid<T: AsRef<str>>(&self, row: &[T]) -> bool {
290        match (self.is_math_expr, self.is_numeric, &self.op) {
291            (true, _, _) => match row[self.col].as_ref().parse::<f64>() {
292                Ok(v) => match self.ast.max_column() {
293                    0 => self.op.evaluate(v, self.ast.evaluate(None)),
294                    _ => {
295                        let f64_vec: Vec<f64> = (0..=self.ast.max_column())
296                            .map(|i| match self.ast.contains_column(&i) {
297                                true => parse_f64(row[i].as_ref()),
298                                false => 0.0,
299                            })
300                            .collect();
301                        self.op.evaluate(v, self.ast.evaluate(Some(&f64_vec)))
302                    }
303                },
304                Err(_) => false,
305            },
306            (false, true, Op::Equal) => match row[self.col].as_ref().parse::<f64>() {
307                Ok(v) => self.f64_values.contains(&v),
308                Err(_) => false,
309            },
310            (false, true, Op::NotEqual) => match row[self.col].as_ref().parse::<f64>() {
311                Ok(v) => !self.f64_values.contains(&v),
312                Err(_) => true,
313            },
314            (false, true, _) => match row[self.col].as_ref().parse::<f64>() {
315                Ok(v) => self.op.evaluate(v, self.f64_value),
316                Err(_) => false,
317            },
318            (false, false, Op::Equal) => self
319                .str_values
320                .iter()
321                .any(|i| i.as_str() == row[self.col].as_ref()),
322            (false, false, Op::NotEqual) => !self
323                .str_values
324                .iter()
325                .any(|i| i.as_str() == row[self.col].as_ref()),
326            (false, false, Op::Ge) => row[self.col].as_ref() >= &self.str_value,
327            (false, false, Op::Gt) => row[self.col].as_ref() > &self.str_value,
328            (false, false, Op::Le) => row[self.col].as_ref() <= &self.str_value,
329            (false, false, Op::Lt) => row[self.col].as_ref() < &self.str_value,
330        }
331    }
332}