bool_tag_expr/
syntax_parse.rs

1//!
2//! Syntactic parsing of boolean expressions
3//!
4
5use crate::{BoolTagExprLexicalParse, ParseError, Tag, Token};
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use thiserror::Error;
8
9/// A boolean expression tree.
10///
11/// This is a simple wrapper around a [`Node`].
12#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
13pub struct BoolTagExpr(Node);
14
15/// Error that arises when attempting to use an invalid SQL identifier
16#[derive(Debug, Clone, PartialEq, Eq, Hash, Error)]
17pub enum SqlTableInfoError {
18    /// There is an invalid SQL identifier
19    #[error("Invalid identifiers: '{0}'")]
20    InvalidIdentifier(String),
21}
22
23// TODO: write a macro for compile time checking?
24/// Holds information about the table against which the SQL will be run - this
25/// is used to produce the SQL.
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27pub struct DbTableInfo {
28    /// The name of the table that holds the tag name & value columns as well as
29    /// a column for identifying the target
30    table_name: String,
31
32    /// The name of the column holding data used to ID some entity being
33    /// selected
34    id_column: String,
35
36    /// The name of the column holding the tag name
37    tag_name_column: String,
38
39    /// The name of the column holding the tag value
40    tag_value_column: String,
41}
42
43impl DbTableInfo {
44    /// Create a `DbTableInfo`, ensuring that the values are valid SQL
45    /// identifiers to protect against SQL injection attacks
46    pub fn from(
47        table_name: &str,
48        id_column: &str,
49        tag_name_column: &str,
50        tag_value_column: &str,
51    ) -> Result<Self, SqlTableInfoError> {
52        for identifier in [table_name, id_column, tag_name_column, tag_value_column] {
53            if !is_valid_sql_identifier(identifier) {
54                Err(SqlTableInfoError::InvalidIdentifier(identifier.to_string()))?;
55            }
56        }
57
58        Ok(Self {
59            table_name: table_name.to_string(),
60            id_column: id_column.to_string(),
61            tag_name_column: tag_name_column.to_string(),
62            tag_value_column: tag_value_column.to_string(),
63        })
64    }
65}
66
67/// Check that the string is a valid SQL identifier
68///
69/// This bluntly protects against SQL injection by limiting allowed chars
70fn is_valid_sql_identifier(s: &str) -> bool {
71    let mut chars = s.chars();
72    match chars.next() {
73        // Check the first char (can't be numeric)
74        Some(c) if c.is_ascii_alphabetic() || c == '_' => {
75            // Check the rest of the chars (can be numeric_)
76            chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
77        }
78        _ => false,
79    }
80}
81
82// TODO: Needs testing
83impl Serialize for BoolTagExpr {
84    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
85    where
86        S: Serializer,
87    {
88        let bool_expr = self.clone().to_boolean_expression();
89        serializer.serialize_str(&bool_expr)
90    }
91}
92
93// TODO: needs testing
94impl<'de> Deserialize<'de> for BoolTagExpr {
95    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
96    where
97        D: Deserializer<'de>,
98    {
99        let raw_expr = String::deserialize(deserializer)?;
100        let tree = Self::from(raw_expr);
101        match tree {
102            Ok(tree) => Ok(tree),
103            Err(error) => {
104                // TODO: use the error (impl into?)
105                let err_msg = format!("Boolean expressions is invalid: {error}");
106                Err(serde::de::Error::custom(err_msg))
107            }
108        }
109    }
110}
111
112/// Possible elements of a boolean expression (a boolean expression tree is made
113/// up of these)
114#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
115pub enum Node {
116    And(Box<Node>, Box<Node>),
117    Or(Box<Node>, Box<Node>),
118    Not(Box<Node>),
119    Tag(Tag),
120    Bool(bool),
121}
122
123/// Possible syntax errors
124#[derive(Debug, PartialEq, Clone, Error, Hash, Eq)]
125pub enum SyntaxParseError {
126    /// The boolean tag expression contains no tags (there must be at least 1)
127    #[error("No tags in expression")]
128    NoTags,
129
130    // TODO: store token for error msg?
131    /// The first lexical token of the boolean tag expression is invalid
132    #[error("Invalid opening token")]
133    InvalidOpeningToken,
134
135    // TODO: store token for error msg?
136    /// The last lexical token of the boolean tag expression is invalid
137    #[error("Invalid closing token")]
138    InvalidClosingToken,
139
140    // TODO: store token for error msg?
141    /// There are more closing brakets than opening brackets in the boolean tag
142    /// expression
143    #[error("Unopended brackets")]
144    UnopenedBrackets,
145
146    // TODO: store token for error msg?
147    /// There are more opening brakets than closing brackets in the boolean tag
148    /// expression
149    #[error("Unclosed brackets")]
150    UnclosedBrackets,
151
152    /// There is an invalid token order/sequence in the boolean tag expression
153    #[error("Invalid sequence of tokens: {0} -> {1}")]
154    InvalidSequence(Token, Token),
155}
156
157/// Implementing types can be (lexically and) syntactically parsed to a boolean
158/// expression tree
159pub trait BoolTagExprSyntaxParse<T: BoolTagExprLexicalParse> {
160    /// Lexically and then syntactically parse the value into a boolean
161    /// expression tree
162    fn syntax_parse(self) -> Result<BoolTagExpr, ParseError>;
163}
164
165// TODO: For anything that can be a string (eg str too)
166/// Blanket implementation of syntax parsing for any type that implements
167/// `BoolTagExprLexicalParse`
168impl<T: BoolTagExprLexicalParse> BoolTagExprSyntaxParse<T> for T {
169    fn syntax_parse(self) -> Result<BoolTagExpr, ParseError> {
170        let lexical_tokens = self.lexical_parse()?;
171        validate_token_stream(lexical_tokens.tokens().to_owned())?;
172        Ok(BoolTagExpr(syntax_parse_token_stream(
173            &mut lexical_tokens.tokens().to_owned(),
174        )))
175    }
176}
177
178impl BoolTagExpr {
179    /// Produce a [`BoolTagExpr`] from a type that implements both
180    /// `LexicalParse` and `SyntaxParse<T>`
181    pub fn from<T>(boolean_expr: T) -> Result<Self, ParseError>
182    where
183        T: BoolTagExprLexicalParse + BoolTagExprSyntaxParse<T>,
184    {
185        boolean_expr.syntax_parse()
186    }
187
188    /// Produce a logical boolean expression string from a [`BoolTagExpr`]
189    #[must_use]
190    pub fn to_boolean_expression(self) -> String {
191        boolean_expr_tree_to_logical_expr_string(self.0)
192    }
193
194    /// Produce an SQL boolean statement from a [`BoolTagExpr`] for pulling
195    /// entities out of a database
196    ///
197    /// If the SQL string returned by this function is denoted by `X`, then, as
198    /// an example, we can use it in the following way:
199    ///
200    /// ```sql
201    /// SELECT id, name
202    /// FROM entities
203    /// WHERE X
204    /// ORDER BY name
205    /// ```
206    #[must_use]
207    pub fn to_sql(self, table_info: &DbTableInfo) -> String {
208        boolean_expr_tree_to_sql(self.0, table_info)
209    }
210
211    /// Get the boolean expr tree
212    #[must_use]
213    pub fn into_node(self) -> Node {
214        self.0
215    }
216}
217
218// TODO: check the examples)
219/// Recursively produce an SQL statement from a tree of [`Node`]s
220///
221/// Examples of the SQL output:
222///
223/// - `(tag_value=XYZ AND tag_value=ABC)`
224/// - `((tag_name=QWE AND tag_value=XYZ) OR tag_value=ABC)`
225fn boolean_expr_tree_to_sql(expr: Node, table_info: &DbTableInfo) -> String {
226    match expr {
227        Node::And(l, r) => {
228            let mut sql_fragment = format!(
229                "SELECT {} FROM {} WHERE {} IN (",
230                table_info.id_column, table_info.table_name, table_info.id_column
231            );
232            sql_fragment.push_str(&boolean_expr_tree_to_sql(*l, table_info));
233            sql_fragment.push_str(&format!(") AND {} IN (", table_info.id_column));
234            sql_fragment.push_str(&boolean_expr_tree_to_sql(*r, table_info));
235            sql_fragment.push_str(&format!(") GROUP BY {}", table_info.id_column));
236            sql_fragment
237        }
238        Node::Or(l, r) => {
239            let mut sql_fragment = format!("SELECT {} FROM (", table_info.id_column);
240            sql_fragment.push_str(&boolean_expr_tree_to_sql(*l, table_info));
241            sql_fragment.push_str(" UNION ");
242            sql_fragment.push_str(&boolean_expr_tree_to_sql(*r, table_info));
243            sql_fragment.push(')');
244            sql_fragment
245        }
246        Node::Not(e) => {
247            let mut sql_fragment = format!(
248                "SELECT {} FROM {} WHERE {} NOT IN (",
249                table_info.id_column, table_info.table_name, table_info.id_column
250            );
251            sql_fragment.push_str(&boolean_expr_tree_to_sql(*e, table_info));
252            sql_fragment.push(')');
253            sql_fragment
254        }
255        Node::Tag(tag) => match tag.name {
256            None => format!(
257                "SELECT {} FROM {} WHERE {}='{}'",
258                table_info.id_column, table_info.table_name, table_info.tag_value_column, tag.value
259            ),
260            Some(tag_name) => format!(
261                "SELECT {} FROM {} WHERE {}='{}' AND {}='{}'",
262                table_info.id_column,
263                table_info.table_name,
264                table_info.tag_name_column,
265                tag_name,
266                table_info.tag_value_column,
267                tag.value
268            ),
269        },
270        Node::Bool(_) => panic!(),
271    }
272}
273
274/// Recursively produce a logical boolean expressions from a tree of
275/// [`BooleanTagExpr`]s
276fn boolean_expr_tree_to_logical_expr_string(expr: Node) -> String {
277    match expr {
278        Node::And(l, r) => {
279            let mut sql_fragment = String::from("(");
280            sql_fragment.push_str(&boolean_expr_tree_to_logical_expr_string(*l));
281            sql_fragment.push_str(" & ");
282            sql_fragment.push_str(&boolean_expr_tree_to_logical_expr_string(*r));
283            sql_fragment.push(')');
284            sql_fragment
285        }
286        Node::Or(l, r) => {
287            let mut sql_fragment = String::from("(");
288            sql_fragment.push_str(&boolean_expr_tree_to_logical_expr_string(*l));
289            sql_fragment.push_str(" | ");
290            sql_fragment.push_str(&boolean_expr_tree_to_logical_expr_string(*r));
291            sql_fragment.push(')');
292            sql_fragment
293        }
294        Node::Not(e) => {
295            let mut sql_fragment = String::from("!");
296            sql_fragment.push_str(&boolean_expr_tree_to_logical_expr_string(*e));
297            sql_fragment
298        }
299        Node::Tag(tag) => match tag.name {
300            None => format!("{}", tag.value),
301            Some(tag_name) => format!("({}={})", tag_name, tag.value),
302        },
303        Node::Bool(_) => panic!(),
304    }
305}
306
307// TODO: check token stream length?
308/// Check a stream of lexical tokens is a valid sequence (part of syntax
309/// parsing)
310fn validate_token_stream(mut tokens: Vec<Token>) -> Result<(), SyntaxParseError> {
311    let mut at_least_1_tag = false;
312    for token in tokens.clone() {
313        if let Token::Tag(_) = token {
314            at_least_1_tag = true;
315            break;
316        }
317    }
318    if !at_least_1_tag {
319        return Err(SyntaxParseError::NoTags);
320    }
321
322    let mut opening_bracket_count = 0;
323    let mut closing_bracket_count = 0;
324
325    let mut previous_token = tokens.remove(0);
326
327    // Check first token
328    match &previous_token {
329        Token::Not | Token::OpenBracket | Token::Tag(_) => Ok(()),
330        _ => return Err(SyntaxParseError::InvalidOpeningToken),
331    }?;
332
333    for token in tokens {
334        match previous_token {
335            Token::OpenBracket => {
336                opening_bracket_count += 1;
337                match token.clone() {
338                    // Valid
339                    Token::OpenBracket | Token::Not | Token::Tag(_) => Ok(()),
340
341                    // Invalid
342                    Token::CloseBracket => {
343                        return Err(SyntaxParseError::InvalidSequence(
344                            Token::OpenBracket,
345                            Token::CloseBracket,
346                        ))
347                    }
348                    Token::And => {
349                        return Err(SyntaxParseError::InvalidSequence(
350                            Token::OpenBracket,
351                            Token::And,
352                        ))
353                    }
354                    Token::Or => {
355                        return Err(SyntaxParseError::InvalidSequence(
356                            Token::OpenBracket,
357                            Token::Or,
358                        ))
359                    }
360                }
361            }
362            Token::CloseBracket => {
363                closing_bracket_count += 1;
364                match token.clone() {
365                    // Valid
366                    Token::CloseBracket | Token::And | Token::Or => Ok(()),
367
368                    // Invalid
369                    Token::OpenBracket => {
370                        return Err(SyntaxParseError::InvalidSequence(
371                            Token::CloseBracket,
372                            Token::OpenBracket,
373                        ))
374                    }
375                    Token::Not => {
376                        return Err(SyntaxParseError::InvalidSequence(
377                            Token::CloseBracket,
378                            Token::Not,
379                        ))
380                    }
381                    Token::Tag(tag) => {
382                        return Err(SyntaxParseError::InvalidSequence(
383                            Token::CloseBracket,
384                            Token::Tag(tag),
385                        ))
386                    }
387                }
388            }
389            Token::Not => match token.clone() {
390                // Valid
391                Token::Tag(_) | Token::OpenBracket => Ok(()),
392
393                // Invalid
394                Token::CloseBracket => {
395                    return Err(SyntaxParseError::InvalidSequence(
396                        Token::Not,
397                        Token::CloseBracket,
398                    ))
399                }
400                Token::Not => {
401                    return Err(SyntaxParseError::InvalidSequence(Token::Not, Token::Not))
402                }
403                Token::And => {
404                    return Err(SyntaxParseError::InvalidSequence(Token::Not, Token::And))
405                }
406                Token::Or => return Err(SyntaxParseError::InvalidSequence(Token::Not, Token::Or)),
407            },
408            Token::And => match token.clone() {
409                // Valid
410                Token::Not | Token::OpenBracket | Token::Tag(_) => Ok(()),
411
412                // Invalid
413                Token::CloseBracket => {
414                    return Err(SyntaxParseError::InvalidSequence(
415                        Token::And,
416                        Token::CloseBracket,
417                    ))
418                }
419                Token::And => {
420                    return Err(SyntaxParseError::InvalidSequence(Token::And, Token::And))
421                }
422                Token::Or => return Err(SyntaxParseError::InvalidSequence(Token::And, Token::Or)),
423            },
424            Token::Or => match token.clone() {
425                // Valid
426                Token::Not | Token::OpenBracket | Token::Tag(_) => Ok(()),
427
428                // Invalid
429                Token::CloseBracket => {
430                    return Err(SyntaxParseError::InvalidSequence(
431                        Token::Or,
432                        Token::CloseBracket,
433                    ))
434                }
435                Token::And => return Err(SyntaxParseError::InvalidSequence(Token::Or, Token::And)),
436                Token::Or => return Err(SyntaxParseError::InvalidSequence(Token::Or, Token::Or)),
437            },
438            Token::Tag(previous_tag) => match token.clone() {
439                // Valid
440                Token::CloseBracket | Token::And | Token::Or => Ok(()),
441
442                // Invalid
443                Token::OpenBracket => {
444                    return Err(SyntaxParseError::InvalidSequence(
445                        Token::Tag(previous_tag),
446                        Token::OpenBracket,
447                    ))
448                }
449                Token::Not => {
450                    return Err(SyntaxParseError::InvalidSequence(
451                        Token::Tag(previous_tag),
452                        Token::Not,
453                    ))
454                }
455                Token::Tag(this_tag) => {
456                    println!("{previous_tag} then {this_tag}, is not allowed");
457                    return Err(SyntaxParseError::InvalidSequence(
458                        Token::Tag(previous_tag),
459                        Token::Tag(this_tag),
460                    ));
461                }
462            },
463        }?;
464
465        previous_token = token;
466    }
467
468    // Check last token
469    match &previous_token {
470        Token::CloseBracket => {
471            closing_bracket_count += 1;
472            Ok(())
473        }
474        Token::Tag(_) => Ok(()),
475        _ => return Err(SyntaxParseError::InvalidClosingToken),
476    }?;
477
478    if closing_bracket_count > opening_bracket_count {
479        return Err(SyntaxParseError::UnopenedBrackets);
480    }
481
482    if closing_bracket_count < opening_bracket_count {
483        return Err(SyntaxParseError::UnclosedBrackets);
484    }
485
486    Ok(())
487}
488
489/// Parse a sequence (of valid tokens) into a boolean expression tree (wrapper
490/// around calling `recursive_syntax_parse()`)
491fn syntax_parse_token_stream(tokens: &mut Vec<Token>) -> Node {
492    let mut expr: Node = recursive_syntax_parse(tokens, None);
493    loop {
494        if tokens.is_empty() {
495            break;
496        }
497        expr = recursive_syntax_parse(tokens, Some(expr));
498    }
499    expr
500}
501
502/// Recursively parse a sequence (of valid tokens) into a boolean expression
503/// tree (called by the wrapper function `syntax_parse_token_stream()`))
504fn recursive_syntax_parse(tokens: &mut Vec<Token>, expr: Option<Node>) -> Node {
505    if tokens.is_empty() {
506        return expr.unwrap();
507    }
508    let token = tokens.remove(0);
509    match token {
510        Token::OpenBracket => {
511            // TODO: Need to pass in?
512            recursive_syntax_parse(tokens, expr)
513        }
514        Token::CloseBracket => expr.unwrap(),
515        Token::And => {
516            let result = recursive_syntax_parse(tokens, None);
517            Node::And(Box::new(expr.unwrap()), Box::new(result))
518        }
519        Token::Or => {
520            let result = recursive_syntax_parse(tokens, None);
521            Node::Or(Box::new(expr.unwrap()), Box::new(result))
522        }
523        Token::Tag(tag) => recursive_syntax_parse(tokens, Some(Node::Tag(tag))),
524        Token::Not => {
525            let result = recursive_syntax_parse(tokens, None);
526            Node::Not(Box::new(result))
527        }
528    }
529}
530
531#[cfg(test)]
532mod test {
533    use super::*;
534
535    #[test]
536    fn syntax_parse_empty() {
537        // Must have at least 1 tag
538        let a = "";
539        assert!(a.syntax_parse().err().unwrap() == ParseError::Syntax(SyntaxParseError::NoTags));
540
541        let a = "(&)";
542        assert!(a.syntax_parse().err().unwrap() == ParseError::Syntax(SyntaxParseError::NoTags));
543
544        let a = "(& & &) | (& & &)";
545        assert!(a.syntax_parse().err().unwrap() == ParseError::Syntax(SyntaxParseError::NoTags));
546    }
547
548    #[test]
549    fn syntax_parse() {
550        // Should fail because of `&&` and `||`
551        let a = "((nationality=american && scientist) || (=british & scientist))  & !man && person";
552        assert!(a.syntax_parse().is_err());
553
554        // Should fail because of `&&`
555        let a = "((nationality=american & scientist) | (=british & scientist))  && !man & person";
556        assert!(a.syntax_parse().is_err());
557
558        // Should fail because of `||`
559        let a = "((nationality=american & scientist) || (=british & scientist))  & !man && person";
560        assert!(a.syntax_parse().is_err());
561
562        // Should pass
563        let a = "((nationality=american & scientist) | (=british & scientist))  & !man & person";
564        assert!(a.syntax_parse().is_ok());
565
566        // Should fail because of unmatched brackets
567        let a = "(a & b";
568        assert!(a.syntax_parse().is_err());
569
570        // Should pass
571        let a = "(a & b & c)";
572        assert!(a.syntax_parse().is_ok());
573    }
574
575    // TODO: improve this test (best approach, I think, will be to create an
576    // in-memory DB and execute against it)
577    //
578    // TODO: How to test
579    // This functionality should be tested by created an in-memory SQListe
580    // database with a single table with 3 columns: ID, tag name, tag value.
581    // The functionality should be tested by extracting matching IDs.
582    #[test]
583    fn to_sql() {
584        let table_info = DbTableInfo::from(
585            &"table_name",
586            &"id_column",
587            &"tag_name_column",
588            &"tag_value_column",
589        )
590        .unwrap();
591
592        // Should pass
593        let a = "((x=a & b) | (c & b)) & !d";
594        assert!(a.syntax_parse().unwrap().to_sql(&table_info).is_ascii());
595    }
596
597    // TODO: should the output string be "((x=a & b) | (c & b)) & !d" for readability?
598    #[test]
599    fn to_boolean_expression() {
600        // Should pass
601        let a = "((x=a & b) | (c & b)) & !d";
602        let parsed = a.syntax_parse().unwrap().to_boolean_expression();
603        let parsed_again = parsed
604            .clone()
605            .syntax_parse()
606            .unwrap()
607            .to_boolean_expression();
608        assert_eq!(parsed, parsed_again);
609    }
610}