Skip to main content

bool_tag_expr/
syntax_parse.rs

1//!
2//! Syntactic parsing of boolean expressions
3//!
4
5use crate::{BoolTagExprLexicalParse, ParseError, Tag, Tags, Token};
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use thiserror::Error;
8
9/// A boolean expression tree.
10///
11/// This is a simple wrapper around a [`Node`].
12#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
13pub struct BoolTagExpr(Node);
14
15/// Error that arises when attempting to use an invalid SQL identifier
16#[derive(Debug, Clone, PartialEq, Eq, Hash, Error)]
17pub enum SqlTableInfoError {
18    /// There is an invalid SQL identifier
19    #[error("Invalid identifiers: '{0}'")]
20    InvalidIdentifier(String),
21}
22
23// TODO: write a macro for compile time checking?
24/// Holds information about the table against which the SQL will be run - this
25/// is used to produce the SQL.
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27pub struct DbTableInfo {
28    /// The name of the table that holds the tag name & value columns as well as
29    /// a column for identifying the target
30    table_name: String,
31
32    /// The name of the column holding data used to ID some entity being
33    /// selected
34    id_column: String,
35
36    /// The name of the column holding the tag name
37    tag_name_column: String,
38
39    /// The name of the column holding the tag value
40    tag_value_column: String,
41}
42
43impl DbTableInfo {
44    /// Create a `DbTableInfo`, ensuring that the values are valid SQL
45    /// identifiers to protect against SQL injection attacks
46    pub fn from(
47        table_name: &str,
48        id_column: &str,
49        tag_name_column: &str,
50        tag_value_column: &str,
51    ) -> Result<Self, SqlTableInfoError> {
52        for identifier in [table_name, id_column, tag_name_column, tag_value_column] {
53            if !is_valid_sql_identifier(identifier) {
54                Err(SqlTableInfoError::InvalidIdentifier(identifier.to_string()))?;
55            }
56        }
57
58        Ok(Self {
59            table_name: table_name.to_string(),
60            id_column: id_column.to_string(),
61            tag_name_column: tag_name_column.to_string(),
62            tag_value_column: tag_value_column.to_string(),
63        })
64    }
65}
66
67/// Check that the string is a valid SQL identifier
68///
69/// This bluntly protects against SQL injection by limiting allowed chars
70fn is_valid_sql_identifier(s: &str) -> bool {
71    let mut chars = s.chars();
72    match chars.next() {
73        // Check the first char (can't be numeric)
74        Some(c) if c.is_ascii_alphabetic() || c == '_' => {
75            // Check the rest of the chars (can be numeric_)
76            chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
77        }
78        _ => false,
79    }
80}
81
82// TODO: Needs testing
83impl Serialize for BoolTagExpr {
84    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
85    where
86        S: Serializer,
87    {
88        let bool_expr = self.clone().to_boolean_expression();
89        serializer.serialize_str(&bool_expr)
90    }
91}
92
93// TODO: needs testing
94impl<'de> Deserialize<'de> for BoolTagExpr {
95    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
96    where
97        D: Deserializer<'de>,
98    {
99        let raw_expr = String::deserialize(deserializer)?;
100        let tree = Self::from(raw_expr);
101        match tree {
102            Ok(tree) => Ok(tree),
103            Err(error) => {
104                // TODO: use the error (impl into?)
105                let err_msg = format!("Boolean expressions is invalid: {error}");
106                Err(serde::de::Error::custom(err_msg))
107            }
108        }
109    }
110}
111
112/// Possible elements of a boolean expression (a boolean expression tree is made
113/// up of these)
114#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
115pub enum Node {
116    And(Box<Node>, Box<Node>),
117    Or(Box<Node>, Box<Node>),
118    Not(Box<Node>),
119    Tag(Tag),
120    Bool(bool),
121}
122
123/// Possible syntax errors
124#[derive(Debug, PartialEq, Clone, Error, Hash, Eq)]
125pub enum SyntaxParseError {
126    /// The boolean tag expression contains no tags (there must be at least 1)
127    #[error("No tags in expression")]
128    NoTags,
129
130    // TODO: store token for error msg?
131    /// The first lexical token of the boolean tag expression is invalid
132    #[error("Invalid opening token")]
133    InvalidOpeningToken,
134
135    // TODO: store token for error msg?
136    /// The last lexical token of the boolean tag expression is invalid
137    #[error("Invalid closing token")]
138    InvalidClosingToken,
139
140    // TODO: store token for error msg?
141    /// There are more closing brakets than opening brackets in the boolean tag
142    /// expression
143    #[error("Unopended brackets")]
144    UnopenedBrackets,
145
146    // TODO: store token for error msg?
147    /// There are more opening brakets than closing brackets in the boolean tag
148    /// expression
149    #[error("Unclosed brackets")]
150    UnclosedBrackets,
151
152    /// There is an invalid token order/sequence in the boolean tag expression
153    #[error("Invalid sequence of tokens: {0} -> {1}")]
154    InvalidSequence(Token, Token),
155}
156
157/// Implementing types can be (lexically and) syntactically parsed to a boolean
158/// expression tree
159pub trait BoolTagExprSyntaxParse<T: BoolTagExprLexicalParse> {
160    /// Lexically and then syntactically parse the value into a boolean
161    /// expression tree
162    fn syntax_parse(self) -> Result<BoolTagExpr, ParseError>;
163}
164
165// TODO: For anything that can be a string (eg str too)
166/// Blanket implementation of syntax parsing for any type that implements
167/// `BoolTagExprLexicalParse`
168impl<T: BoolTagExprLexicalParse> BoolTagExprSyntaxParse<T> for T {
169    fn syntax_parse(self) -> Result<BoolTagExpr, ParseError> {
170        let lexical_tokens = self.lexical_parse()?;
171        validate_token_stream(lexical_tokens.tokens().to_owned())?;
172        Ok(BoolTagExpr(syntax_parse_token_stream(
173            &mut lexical_tokens.tokens().to_owned(),
174        )))
175    }
176}
177
178impl BoolTagExpr {
179    /// Produce a [`BoolTagExpr`] from a type that implements both
180    /// `LexicalParse` and `SyntaxParse<T>`
181    pub fn from<T>(boolean_expr: T) -> Result<Self, ParseError>
182    where
183        T: BoolTagExprLexicalParse + BoolTagExprSyntaxParse<T>,
184    {
185        boolean_expr.syntax_parse()
186    }
187
188    /// Produce a logical boolean expression string from a [`BoolTagExpr`]
189    #[must_use]
190    pub fn to_boolean_expression(self) -> String {
191        boolean_expr_tree_to_logical_expr_string(self.0)
192    }
193
194    /// Produce an SQL boolean statement from a [`BoolTagExpr`] for pulling
195    /// entities out of a database
196    ///
197    /// If the SQL string returned by this function is denoted by `X`, then, as
198    /// an example, we can use it in the following way:
199    ///
200    /// ```sql
201    /// SELECT id, name
202    /// FROM entities
203    /// WHERE X
204    /// ORDER BY name
205    /// ```
206    #[must_use]
207    pub fn to_sql(self, table_info: &DbTableInfo) -> String {
208        boolean_expr_tree_to_sql(self.0, table_info)
209    }
210
211    /// Get the boolean expr tree
212    #[must_use]
213    pub fn into_node(self) -> Node {
214        self.0
215    }
216
217    /// Evaluate the expression against a list of `Tags`
218    #[must_use]
219    pub fn matches(&self, tags: &Tags) -> bool {
220        recusively_evaluate_expr_against_tags(&self.0, tags)
221    }
222}
223
224/// Evaluate a `BooleanTagExpr` tree against a list of `Tags`
225fn recusively_evaluate_expr_against_tags(expr: &Node, tags: &Tags) -> bool {
226    match expr {
227        Node::And(l, r) => {
228            recusively_evaluate_expr_against_tags(l, tags)
229                && recusively_evaluate_expr_against_tags(r, tags)
230        }
231        Node::Or(l, r) => {
232            recusively_evaluate_expr_against_tags(l, tags)
233                || recusively_evaluate_expr_against_tags(r, tags)
234        }
235        Node::Not(e) => !recusively_evaluate_expr_against_tags(e, tags),
236        Node::Tag(tag) => tags.contains(&tag),
237        Node::Bool(_) => panic!(),
238    }
239}
240
241// TODO: check the examples)
242/// Recursively produce an SQL statement from a tree of [`Node`]s
243///
244/// Examples of the SQL output:
245///
246/// - `(tag_value=XYZ AND tag_value=ABC)`
247/// - `((tag_name=QWE AND tag_value=XYZ) OR tag_value=ABC)`
248fn boolean_expr_tree_to_sql(expr: Node, table_info: &DbTableInfo) -> String {
249    match expr {
250        Node::And(l, r) => {
251            let mut sql_fragment = format!(
252                "SELECT {} FROM {} WHERE {} IN (",
253                table_info.id_column, table_info.table_name, table_info.id_column
254            );
255            sql_fragment.push_str(&boolean_expr_tree_to_sql(*l, table_info));
256            sql_fragment.push_str(&format!(") AND {} IN (", table_info.id_column));
257            sql_fragment.push_str(&boolean_expr_tree_to_sql(*r, table_info));
258            sql_fragment.push_str(&format!(") GROUP BY {}", table_info.id_column));
259            sql_fragment
260        }
261        Node::Or(l, r) => {
262            let mut sql_fragment = format!("SELECT {} FROM (", table_info.id_column);
263            sql_fragment.push_str(&boolean_expr_tree_to_sql(*l, table_info));
264            sql_fragment.push_str(" UNION ");
265            sql_fragment.push_str(&boolean_expr_tree_to_sql(*r, table_info));
266            sql_fragment.push(')');
267            sql_fragment
268        }
269        Node::Not(e) => {
270            let mut sql_fragment = format!(
271                "SELECT {} FROM {} WHERE {} NOT IN (",
272                table_info.id_column, table_info.table_name, table_info.id_column
273            );
274            sql_fragment.push_str(&boolean_expr_tree_to_sql(*e, table_info));
275            sql_fragment.push(')');
276            sql_fragment
277        }
278        Node::Tag(tag) => match tag.name {
279            None => format!(
280                "SELECT {} FROM {} WHERE {}='{}'",
281                table_info.id_column, table_info.table_name, table_info.tag_value_column, tag.value
282            ),
283            Some(tag_name) => format!(
284                "SELECT {} FROM {} WHERE {}='{}' AND {}='{}'",
285                table_info.id_column,
286                table_info.table_name,
287                table_info.tag_name_column,
288                tag_name,
289                table_info.tag_value_column,
290                tag.value
291            ),
292        },
293        Node::Bool(_) => panic!(),
294    }
295}
296
297/// Recursively produce a logical boolean expressions from a tree of
298/// [`BooleanTagExpr`]s
299fn boolean_expr_tree_to_logical_expr_string(expr: Node) -> String {
300    match expr {
301        Node::And(l, r) => {
302            let mut sql_fragment = String::from("(");
303            sql_fragment.push_str(&boolean_expr_tree_to_logical_expr_string(*l));
304            sql_fragment.push_str(" & ");
305            sql_fragment.push_str(&boolean_expr_tree_to_logical_expr_string(*r));
306            sql_fragment.push(')');
307            sql_fragment
308        }
309        Node::Or(l, r) => {
310            let mut sql_fragment = String::from("(");
311            sql_fragment.push_str(&boolean_expr_tree_to_logical_expr_string(*l));
312            sql_fragment.push_str(" | ");
313            sql_fragment.push_str(&boolean_expr_tree_to_logical_expr_string(*r));
314            sql_fragment.push(')');
315            sql_fragment
316        }
317        Node::Not(e) => {
318            let mut sql_fragment = String::from("!");
319            sql_fragment.push_str(&boolean_expr_tree_to_logical_expr_string(*e));
320            sql_fragment
321        }
322        Node::Tag(tag) => match tag.name {
323            None => format!("{}", tag.value),
324            Some(tag_name) => format!("({}={})", tag_name, tag.value),
325        },
326        Node::Bool(_) => panic!(),
327    }
328}
329
330// TODO: check token stream length?
331/// Check a stream of lexical tokens is a valid sequence (part of syntax
332/// parsing)
333fn validate_token_stream(mut tokens: Vec<Token>) -> Result<(), SyntaxParseError> {
334    let mut at_least_1_tag = false;
335    for token in tokens.clone() {
336        if let Token::Tag(_) = token {
337            at_least_1_tag = true;
338            break;
339        }
340    }
341    if !at_least_1_tag {
342        return Err(SyntaxParseError::NoTags);
343    }
344
345    let mut opening_bracket_count = 0;
346    let mut closing_bracket_count = 0;
347
348    let mut previous_token = tokens.remove(0);
349
350    // Check first token
351    match &previous_token {
352        Token::Not | Token::OpenBracket | Token::Tag(_) => Ok(()),
353        _ => return Err(SyntaxParseError::InvalidOpeningToken),
354    }?;
355
356    for token in tokens {
357        match previous_token {
358            Token::OpenBracket => {
359                opening_bracket_count += 1;
360                match token.clone() {
361                    // Valid
362                    Token::OpenBracket | Token::Not | Token::Tag(_) => Ok(()),
363
364                    // Invalid
365                    Token::CloseBracket => {
366                        return Err(SyntaxParseError::InvalidSequence(
367                            Token::OpenBracket,
368                            Token::CloseBracket,
369                        ))
370                    }
371                    Token::And => {
372                        return Err(SyntaxParseError::InvalidSequence(
373                            Token::OpenBracket,
374                            Token::And,
375                        ))
376                    }
377                    Token::Or => {
378                        return Err(SyntaxParseError::InvalidSequence(
379                            Token::OpenBracket,
380                            Token::Or,
381                        ))
382                    }
383                }
384            }
385            Token::CloseBracket => {
386                closing_bracket_count += 1;
387                match token.clone() {
388                    // Valid
389                    Token::CloseBracket | Token::And | Token::Or => Ok(()),
390
391                    // Invalid
392                    Token::OpenBracket => {
393                        return Err(SyntaxParseError::InvalidSequence(
394                            Token::CloseBracket,
395                            Token::OpenBracket,
396                        ))
397                    }
398                    Token::Not => {
399                        return Err(SyntaxParseError::InvalidSequence(
400                            Token::CloseBracket,
401                            Token::Not,
402                        ))
403                    }
404                    Token::Tag(tag) => {
405                        return Err(SyntaxParseError::InvalidSequence(
406                            Token::CloseBracket,
407                            Token::Tag(tag),
408                        ))
409                    }
410                }
411            }
412            Token::Not => match token.clone() {
413                // Valid
414                Token::Tag(_) | Token::OpenBracket => Ok(()),
415
416                // Invalid
417                Token::CloseBracket => {
418                    return Err(SyntaxParseError::InvalidSequence(
419                        Token::Not,
420                        Token::CloseBracket,
421                    ))
422                }
423                Token::Not => {
424                    return Err(SyntaxParseError::InvalidSequence(Token::Not, Token::Not))
425                }
426                Token::And => {
427                    return Err(SyntaxParseError::InvalidSequence(Token::Not, Token::And))
428                }
429                Token::Or => return Err(SyntaxParseError::InvalidSequence(Token::Not, Token::Or)),
430            },
431            Token::And => match token.clone() {
432                // Valid
433                Token::Not | Token::OpenBracket | Token::Tag(_) => Ok(()),
434
435                // Invalid
436                Token::CloseBracket => {
437                    return Err(SyntaxParseError::InvalidSequence(
438                        Token::And,
439                        Token::CloseBracket,
440                    ))
441                }
442                Token::And => {
443                    return Err(SyntaxParseError::InvalidSequence(Token::And, Token::And))
444                }
445                Token::Or => return Err(SyntaxParseError::InvalidSequence(Token::And, Token::Or)),
446            },
447            Token::Or => match token.clone() {
448                // Valid
449                Token::Not | Token::OpenBracket | Token::Tag(_) => Ok(()),
450
451                // Invalid
452                Token::CloseBracket => {
453                    return Err(SyntaxParseError::InvalidSequence(
454                        Token::Or,
455                        Token::CloseBracket,
456                    ))
457                }
458                Token::And => return Err(SyntaxParseError::InvalidSequence(Token::Or, Token::And)),
459                Token::Or => return Err(SyntaxParseError::InvalidSequence(Token::Or, Token::Or)),
460            },
461            Token::Tag(previous_tag) => match token.clone() {
462                // Valid
463                Token::CloseBracket | Token::And | Token::Or => Ok(()),
464
465                // Invalid
466                Token::OpenBracket => {
467                    return Err(SyntaxParseError::InvalidSequence(
468                        Token::Tag(previous_tag),
469                        Token::OpenBracket,
470                    ))
471                }
472                Token::Not => {
473                    return Err(SyntaxParseError::InvalidSequence(
474                        Token::Tag(previous_tag),
475                        Token::Not,
476                    ))
477                }
478                Token::Tag(this_tag) => {
479                    println!("{previous_tag} then {this_tag}, is not allowed");
480                    return Err(SyntaxParseError::InvalidSequence(
481                        Token::Tag(previous_tag),
482                        Token::Tag(this_tag),
483                    ));
484                }
485            },
486        }?;
487
488        previous_token = token;
489    }
490
491    // Check last token
492    match &previous_token {
493        Token::CloseBracket => {
494            closing_bracket_count += 1;
495            Ok(())
496        }
497        Token::Tag(_) => Ok(()),
498        _ => return Err(SyntaxParseError::InvalidClosingToken),
499    }?;
500
501    if closing_bracket_count > opening_bracket_count {
502        return Err(SyntaxParseError::UnopenedBrackets);
503    }
504
505    if closing_bracket_count < opening_bracket_count {
506        return Err(SyntaxParseError::UnclosedBrackets);
507    }
508
509    Ok(())
510}
511
512/// Parse a sequence (of valid tokens) into a boolean expression tree (wrapper
513/// around calling `recursive_syntax_parse()`)
514fn syntax_parse_token_stream(tokens: &mut Vec<Token>) -> Node {
515    let mut expr: Node = recursive_syntax_parse(tokens, None);
516    loop {
517        if tokens.is_empty() {
518            break;
519        }
520        expr = recursive_syntax_parse(tokens, Some(expr));
521    }
522    expr
523}
524
525/// Recursively parse a sequence (of valid tokens) into a boolean expression
526/// tree (called by the wrapper function `syntax_parse_token_stream()`))
527fn recursive_syntax_parse(tokens: &mut Vec<Token>, expr: Option<Node>) -> Node {
528    if tokens.is_empty() {
529        return expr.unwrap();
530    }
531    let token = tokens.remove(0);
532    match token {
533        Token::OpenBracket => {
534            // TODO: Need to pass in?
535            recursive_syntax_parse(tokens, expr)
536        }
537        Token::CloseBracket => expr.unwrap(),
538        Token::And => {
539            let result = recursive_syntax_parse(tokens, None);
540            Node::And(Box::new(expr.unwrap()), Box::new(result))
541        }
542        Token::Or => {
543            let result = recursive_syntax_parse(tokens, None);
544            Node::Or(Box::new(expr.unwrap()), Box::new(result))
545        }
546        Token::Tag(tag) => recursive_syntax_parse(tokens, Some(Node::Tag(tag))),
547        Token::Not => {
548            let result = recursive_syntax_parse(tokens, None);
549            Node::Not(Box::new(result))
550        }
551    }
552}
553
554#[cfg(test)]
555mod test {
556    use super::*;
557    use crate::{TagName, TagValue};
558
559    #[test]
560    fn syntax_parse_empty() -> anyhow::Result<()> {
561        // Must have at least 1 tag
562        let a = "";
563        assert!(a.syntax_parse().err().unwrap() == ParseError::Syntax(SyntaxParseError::NoTags));
564
565        let a = "(&)";
566        assert!(a.syntax_parse().err().unwrap() == ParseError::Syntax(SyntaxParseError::NoTags));
567
568        let a = "(& & &) | (& & &)";
569        assert!(a.syntax_parse().err().unwrap() == ParseError::Syntax(SyntaxParseError::NoTags));
570
571        Ok(())
572    }
573
574    #[test]
575    fn syntax_parse() -> anyhow::Result<()> {
576        // Should fail because of `&&` and `||`
577        let a = "((nationality=american && scientist) || (=british & scientist))  & !man && person";
578        assert!(a.syntax_parse().is_err());
579
580        // Should fail because of `&&`
581        let a = "((nationality=american & scientist) | (=british & scientist))  && !man & person";
582        assert!(a.syntax_parse().is_err());
583
584        // Should fail because of `||`
585        let a = "((nationality=american & scientist) || (=british & scientist))  & !man && person";
586        assert!(a.syntax_parse().is_err());
587
588        // Should pass
589        let a = "((nationality=american & scientist) | (=british & scientist))  & !man & person";
590        assert!(a.syntax_parse().is_ok());
591
592        // Should fail because of unmatched brackets
593        let a = "(a & b";
594        assert!(a.syntax_parse().is_err());
595
596        // Should pass
597        let a = "(a & b & c)";
598        assert!(a.syntax_parse().is_ok());
599
600        Ok(())
601    }
602
603    // TODO: improve this test (best approach, I think, will be to create an
604    // in-memory DB and execute against it)
605    //
606    // TODO: How to test
607    // This functionality should be tested by created an in-memory SQListe
608    // database with a single table with 3 columns: ID, tag name, tag value.
609    // The functionality should be tested by extracting matching IDs.
610    #[test]
611    fn to_sql() -> anyhow::Result<()> {
612        let table_info = DbTableInfo::from(
613            &"table_name",
614            &"id_column",
615            &"tag_name_column",
616            &"tag_value_column",
617        )?;
618
619        // Should pass
620        let a = "((x=a & b) | (c & b)) & !d";
621        assert!(a.syntax_parse()?.to_sql(&table_info).is_ascii());
622
623        Ok(())
624    }
625
626    // TODO: should the output string be "((x=a & b) | (c & b)) & !d" for readability?
627    #[test]
628    fn to_boolean_expression() -> anyhow::Result<()> {
629        // Should pass
630        let a = "((x=a & b) | (c & b)) & !d";
631        let parsed = a.syntax_parse()?.to_boolean_expression();
632        let parsed_again = parsed.clone().syntax_parse()?.to_boolean_expression();
633        assert_eq!(parsed, parsed_again);
634
635        Ok(())
636    }
637
638    #[test]
639    fn matches() -> anyhow::Result<()> {
640        // Shouldn't match
641        let expr = BoolTagExpr::from("!d")?;
642        let tags = Tags::from([Tag::from(None, TagValue::from("d")?)]);
643        assert!(!expr.matches(&tags));
644
645        // Should match
646        let expr = BoolTagExpr::from("d")?;
647        let tags = Tags::from([Tag::from(None, TagValue::from("d")?)]);
648        assert!(expr.matches(&tags));
649
650        // Complex
651        let expr_str = "((x=a & b) | (c & b)) & !d";
652        let expr = BoolTagExpr::from(expr_str)?;
653        {
654            // Should match
655            let tags = Tags::from([
656                Tag::from(None, TagValue::from("c")?),
657                Tag::from(None, TagValue::from("b")?),
658            ]);
659            assert!(expr.matches(&tags));
660        }
661        {
662            // Shouldn't match
663            let tags = Tags::from([
664                Tag::from(None, TagValue::from("c")?),
665                Tag::from(None, TagValue::from("b")?),
666                Tag::from(None, TagValue::from("d")?),
667            ]);
668            assert!(!expr.matches(&tags));
669        }
670        {
671            // Should match
672            let tags = Tags::from([
673                Tag::from(Some(TagName::from("x")?), TagValue::from("a")?),
674                Tag::from(None, TagValue::from("b")?),
675            ]);
676            assert!(expr.matches(&tags));
677        }
678        {
679            // Shouldn't match
680            let tags = Tags::from([
681                Tag::from(Some(TagName::from("x")?), TagValue::from("a")?),
682                Tag::from(None, TagValue::from("b")?),
683                Tag::from(None, TagValue::from("d")?),
684            ]);
685            assert!(!expr.matches(&tags));
686        }
687
688        Ok(())
689    }
690}