postgresql_cst_parser/
lib.rs

1#![allow(non_camel_case_types)]
2
3mod lexer;
4
5mod parser;
6
7mod cst;
8pub mod syntax_kind;
9
10pub use cst::NodeOrToken;
11pub use cst::PostgreSQLSyntax;
12pub use cst::ResolvedNode;
13pub use cst::ResolvedToken;
14pub use cst::SyntaxElement;
15pub use cst::SyntaxElementRef;
16pub use cst::SyntaxNode;
17pub use cst::SyntaxToken;
18pub use lexer::parser_error::ParserError;
19pub use lexer::parser_error::ScanReport;
20
21/// Parse SQL and construct a Complete Syntax Tree (CST).
22///
23/// # Examples
24///
25/// ```
26/// use postgresql_cst_parser::{parse, syntax_kind::SyntaxKind};
27///
28/// fn main() {
29///     // Parse SQL query and get the syntax tree
30///     let sql = "SELECT tbl.a as a, tbl.b from TBL tbl WHERE tbl.a > 0;";
31///     let root = parse(sql).unwrap();
32///
33///     // Example 1: Extract all column references from the query
34///     let column_refs: Vec<String> = root
35///         .descendants()
36///         .filter(|node| node.kind() == SyntaxKind::columnref)
37///         .map(|node| node.text().to_string())
38///         .collect();
39///
40///     println!("Column references: {:?}", column_refs); // ["tbl.a", "tbl.b", "tbl.a"]
41///
42///     // Example 2: Find the WHERE condition
43///     if let Some(where_clause) = root
44///         .descendants()
45///         .find(|node| node.kind() == SyntaxKind::where_clause)
46///     {
47///         println!("WHERE condition: {}", where_clause.text());
48///     }
49///
50///     // Example 3: Get the selected table name
51///     if let Some(relation_expr) = root
52///         .descendants()
53///         .find(|node| node.kind() == SyntaxKind::relation_expr)
54///     {
55///         if let Some(name_node) = relation_expr
56///             .descendants()
57///             .find(|node| node.kind() == SyntaxKind::ColId)
58///         {
59///             println!("Table name: {}", name_node.text());
60///         }
61///     }
62///
63///     // Example 4: Parse complex SQL and extract specific nodes
64///     let complex_sql = "WITH data AS (SELECT id, value FROM source WHERE value > 10)
65///                        SELECT d.id, d.value, COUNT(*) OVER (PARTITION BY d.id)
66///                        FROM data d JOIN other o ON d.id = o.id
67///                        ORDER BY d.value DESC LIMIT 10;";
68///
69///     let complex_root = parse(complex_sql).unwrap();
70///
71///     // Extract CTEs (Common Table Expressions)
72///     let ctes: Vec<_> = complex_root
73///         .descendants()
74///         .filter(|node| node.kind() == SyntaxKind::common_table_expr)
75///         .collect();
76///
77///     // Extract window functions
78///     let window_funcs: Vec<_> = complex_root
79///         .descendants()
80///         .filter(|node| node.kind() == SyntaxKind::over_clause)
81///         .collect();
82///
83///     println!("Number of CTEs: {}", ctes.len());
84///     println!("Number of window functions: {}", window_funcs.len());
85/// }
86/// ```
87pub fn parse(input: &str) -> Result<ResolvedNode, ParserError> {
88    cst::parse(input)
89}
90
91#[cfg(test)]
92mod tests {
93    use crate::lexer::parser_error::ScanReport;
94
95    use super::*;
96
97    #[test]
98    fn test_unterminated_hexadecimal_string_literal() {
99        let input = r#"select x'CC"#;
100        let actual = parse(input);
101
102        let expected = Err(ParserError::ScanError {
103            message: "unterminated hexadecimal string literal".to_string(),
104        });
105
106        assert_eq!(actual, expected);
107    }
108
109    #[test]
110    fn test_unterminated_unterminated_bit_string_literal() {
111        let input = r#"select b'10"#;
112        let actual = parse(input);
113
114        let expected = Err(ParserError::ScanError {
115            message: "unterminated bit string literal".to_string(),
116        });
117
118        assert_eq!(actual, expected);
119    }
120
121    #[test]
122    fn test_xeunicodefail() {
123        let input = r#"select e'\uD80"#;
124        let actual = parse(input);
125
126        let expected = Err(ParserError::ScanReport(ScanReport {
127            message: "invalid Unicode escape".to_string(),
128            detail: "Unicode escapes must be \\uXXXX or \\UXXXXXXXX.".to_string(),
129            position_in_bytes: 9,
130        }));
131
132        assert_eq!(actual, expected);
133    }
134
135    #[test]
136    fn test_invalid_unicode_surrogate_pair() {
137        let input = r#"select e'\uD800"#;
138        let actual = parse(input);
139
140        let expected = Err(ParserError::ScanError {
141            message: "invalid Unicode surrogate pair".to_string(),
142        });
143
144        assert_eq!(actual, expected);
145    }
146
147    #[test]
148    fn test_invalid_unicode_surrogate_first() {
149        let input = r#"select e'\u0000"#;
150        let actual = parse(input);
151
152        let expected = Err(ParserError::ScanError {
153            message: "invalid Unicode escape value".to_string(),
154        });
155
156        assert_eq!(actual, expected);
157    }
158
159    #[test]
160    fn test_invalid_unicode_surrogate_second() {
161        let input = r#"select e'\uD800\uD000'"#;
162        let actual = parse(input);
163
164        let expected = Err(ParserError::ScanError {
165            message: "invalid Unicode surrogate pair".to_string(),
166        });
167
168        assert_eq!(actual, expected);
169    }
170}