postgresql_cst_parser/lib.rs
1#![allow(non_camel_case_types)]
2
3mod lexer;
4
5mod parser;
6
7mod cst;
8pub mod syntax_kind;
9
10pub use cst::NodeOrToken;
11pub use cst::PostgreSQLSyntax;
12pub use cst::ResolvedNode;
13pub use cst::ResolvedToken;
14pub use cst::SyntaxElement;
15pub use cst::SyntaxElementRef;
16pub use cst::SyntaxNode;
17pub use cst::SyntaxToken;
18pub use lexer::parser_error::ParserError;
19pub use lexer::parser_error::ScanReport;
20
21/// Parse SQL and construct a Complete Syntax Tree (CST).
22///
23/// # Examples
24///
25/// ```
26/// use postgresql_cst_parser::{parse, syntax_kind::SyntaxKind};
27///
28/// fn main() {
29/// // Parse SQL query and get the syntax tree
30/// let sql = "SELECT tbl.a as a, tbl.b from TBL tbl WHERE tbl.a > 0;";
31/// let root = parse(sql).unwrap();
32///
33/// // Example 1: Extract all column references from the query
34/// let column_refs: Vec<String> = root
35/// .descendants()
36/// .filter(|node| node.kind() == SyntaxKind::columnref)
37/// .map(|node| node.text().to_string())
38/// .collect();
39///
40/// println!("Column references: {:?}", column_refs); // ["tbl.a", "tbl.b", "tbl.a"]
41///
42/// // Example 2: Find the WHERE condition
43/// if let Some(where_clause) = root
44/// .descendants()
45/// .find(|node| node.kind() == SyntaxKind::where_clause)
46/// {
47/// println!("WHERE condition: {}", where_clause.text());
48/// }
49///
50/// // Example 3: Get the selected table name
51/// if let Some(relation_expr) = root
52/// .descendants()
53/// .find(|node| node.kind() == SyntaxKind::relation_expr)
54/// {
55/// if let Some(name_node) = relation_expr
56/// .descendants()
57/// .find(|node| node.kind() == SyntaxKind::ColId)
58/// {
59/// println!("Table name: {}", name_node.text());
60/// }
61/// }
62///
63/// // Example 4: Parse complex SQL and extract specific nodes
64/// let complex_sql = "WITH data AS (SELECT id, value FROM source WHERE value > 10)
65/// SELECT d.id, d.value, COUNT(*) OVER (PARTITION BY d.id)
66/// FROM data d JOIN other o ON d.id = o.id
67/// ORDER BY d.value DESC LIMIT 10;";
68///
69/// let complex_root = parse(complex_sql).unwrap();
70///
71/// // Extract CTEs (Common Table Expressions)
72/// let ctes: Vec<_> = complex_root
73/// .descendants()
74/// .filter(|node| node.kind() == SyntaxKind::common_table_expr)
75/// .collect();
76///
77/// // Extract window functions
78/// let window_funcs: Vec<_> = complex_root
79/// .descendants()
80/// .filter(|node| node.kind() == SyntaxKind::over_clause)
81/// .collect();
82///
83/// println!("Number of CTEs: {}", ctes.len());
84/// println!("Number of window functions: {}", window_funcs.len());
85/// }
86/// ```
87pub fn parse(input: &str) -> Result<ResolvedNode, ParserError> {
88 cst::parse(input)
89}
90
91#[cfg(test)]
92mod tests {
93 use crate::lexer::parser_error::ScanReport;
94
95 use super::*;
96
97 #[test]
98 fn test_unterminated_hexadecimal_string_literal() {
99 let input = r#"select x'CC"#;
100 let actual = parse(input);
101
102 let expected = Err(ParserError::ScanError {
103 message: "unterminated hexadecimal string literal".to_string(),
104 });
105
106 assert_eq!(actual, expected);
107 }
108
109 #[test]
110 fn test_unterminated_unterminated_bit_string_literal() {
111 let input = r#"select b'10"#;
112 let actual = parse(input);
113
114 let expected = Err(ParserError::ScanError {
115 message: "unterminated bit string literal".to_string(),
116 });
117
118 assert_eq!(actual, expected);
119 }
120
121 #[test]
122 fn test_xeunicodefail() {
123 let input = r#"select e'\uD80"#;
124 let actual = parse(input);
125
126 let expected = Err(ParserError::ScanReport(ScanReport {
127 message: "invalid Unicode escape".to_string(),
128 detail: "Unicode escapes must be \\uXXXX or \\UXXXXXXXX.".to_string(),
129 position_in_bytes: 9,
130 }));
131
132 assert_eq!(actual, expected);
133 }
134
135 #[test]
136 fn test_invalid_unicode_surrogate_pair() {
137 let input = r#"select e'\uD800"#;
138 let actual = parse(input);
139
140 let expected = Err(ParserError::ScanError {
141 message: "invalid Unicode surrogate pair".to_string(),
142 });
143
144 assert_eq!(actual, expected);
145 }
146
147 #[test]
148 fn test_invalid_unicode_surrogate_first() {
149 let input = r#"select e'\u0000"#;
150 let actual = parse(input);
151
152 let expected = Err(ParserError::ScanError {
153 message: "invalid Unicode escape value".to_string(),
154 });
155
156 assert_eq!(actual, expected);
157 }
158
159 #[test]
160 fn test_invalid_unicode_surrogate_second() {
161 let input = r#"select e'\uD800\uD000'"#;
162 let actual = parse(input);
163
164 let expected = Err(ParserError::ScanError {
165 message: "invalid Unicode surrogate pair".to_string(),
166 });
167
168 assert_eq!(actual, expected);
169 }
170}