1#![deny(missing_docs)]
2#![allow(clippy::pedantic)]
3#![doc = include_str!("lib.md")]
4
5pub mod sem;
6pub mod syn;
7
8mod schema_dump_file;
9mod sql_file;
10mod trie_map;
11
12use miette::LabeledSpan;
13use sql_fun_core::IVec;
14
15use self::sql_file::read_sql_file;
16pub use self::trie_map::TrieMap;
17
18use std::{io::SeekFrom, string::String as StdString};
19
20use crate::{
21 sem::{ParseContext, SemAst},
22 syn::ParseResult,
23};
24
25pub struct AstAndContextPair<TParseContext>(SemAst, TParseContext)
27where
28 TParseContext: ParseContext;
29
30impl<TParseContext> AstAndContextPair<TParseContext>
31where
32 TParseContext: ParseContext,
33{
34 pub fn new(ast: SemAst, context: TParseContext) -> Self {
36 Self(ast, context)
37 }
38}
39
40pub fn parse(sql: &str) -> Result<ParseResult, pg_query::Error> {
47 let parse_result = ::pg_query::parse(sql)?;
48 Ok(ParseResult::from(parse_result.protobuf))
49}
50
51pub fn scan(sql: &str) -> Result<IVec<crate::syn::ScanToken>, pg_query::Error> {
53 let scan_result = ::pg_query::scan(sql)?;
54 let mut tokens = Vec::new();
55 for token in scan_result.tokens {
56 tokens.push(crate::syn::ScanToken::from(token));
57 }
58 Ok(tokens.into())
59}
60
61fn offset_in_string(haystack: &StdString, needle: &str) -> Option<usize> {
62 let base = haystack.as_ptr() as usize;
63 let ptr = needle.as_ptr() as usize;
64
65 if ptr >= base && ptr <= base + haystack.len() {
66 Some(ptr - base)
67 } else {
68 None }
70}
71
72#[derive(Debug, Clone, Copy, Default)]
74pub struct StringSpan {
75 offset: usize,
76 len: usize,
77}
78
79impl StringSpan {
80 #[must_use]
86 pub fn from_str_in_str(container: &StdString, content: &str) -> Self {
87 let offset = offset_in_string(container, content).expect("content in container");
88 let len = content.len();
89 Self { offset, len }
90 }
91
92 pub fn seek_pos(&self) -> SeekFrom {
94 SeekFrom::Start(self.offset as u64)
95 }
96
97 pub fn len(&self) -> usize {
99 self.len
100 }
101
102 pub fn is_empty(&self) -> bool {
104 self.len == 0
105 }
106
107 fn new_labeled_span(&self, label: &str) -> LabeledSpan {
108 LabeledSpan::new(Some(String::from(label)), self.offset, self.len)
109 }
110
111 fn from_scan_token(tok: &crate::syn::ScanToken) -> Self {
112 Self {
113 offset: tok.get_start() as usize,
114 len: (tok.get_end() - tok.get_start()) as usize,
115 }
116 }
117
118 pub fn end_pos(&self) -> usize {
120 self.offset + self.len
121 }
122
123 pub fn extend(&mut self, other: &Self) {
125 self.offset = std::cmp::min(self.offset, other.offset);
126 self.len = std::cmp::max(self.end_pos(), other.end_pos()) - self.offset;
127 }
128}
129
130pub use self::schema_dump_file::{ParseDumpFileError, ParsedSchemaDump, parse_schema_file};
131
132#[cfg(test)]
133pub mod test_helpers;
134
135#[cfg(test)]
136mod tests {
137
138 use std::path::PathBuf;
139
140 use crate::{
141 sem::{BaseContext, BaseParseContext},
142 syn::{ListOpt, Opt},
143 };
144
145 use super::{parse, parse_schema_file};
146 use clap::Parser;
147 use sql_fun_core::SqlFunArgs;
148 use testresult::TestResult;
149
150 pub struct EnableStackOverflowBacktrace {}
151
152 #[rstest::fixture]
153 pub fn enable_stack_overflow_backtrace() -> EnableStackOverflowBacktrace {
154 #[expect(unsafe_code)]
155 unsafe {
156 backtrace_on_stack_overflow::enable()
157 };
158 EnableStackOverflowBacktrace {}
159 }
160
161 #[rstest::fixture]
162 pub fn context_args() -> SqlFunArgs {
163 SqlFunArgs::try_parse_from(vec![
164 "sqlfun",
165 "--metadata-file",
166 "sql_fun.metadata.toml",
167 "--sql-fun-home",
168 env!("SQL_FUN_HOME"),
169 "subcmd",
170 ])
171 .unwrap()
172 }
173
174 #[ctor::ctor]
175 fn init_tracing() {
176 use tracing_subscriber::{
177 EnvFilter,
178 fmt::format::{FmtSpan, debug_fn},
179 };
180
181 let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("debug"));
182
183 let _ = tracing_subscriber::fmt()
184 .with_env_filter(filter)
185 .with_test_writer()
186 .with_span_events(FmtSpan::ACTIVE)
187 .fmt_fields(debug_fn(|writer, field, value| {
188 if field.name() == "message" {
189 use core::fmt::Write as _;
190 write!(writer, "{value:?}")
191 } else {
192 Ok(())
193 }
194 }))
195 .try_init();
196 }
197
198 #[test]
199 fn test_parse_simple_query() -> TestResult {
200 let result = parse("select * from users where id=0")?;
201 let Some(relname) = result
202 .get_stmts()
203 .get(0)
204 .get_stmt()
205 .as_select_stmt()
206 .get_from_clause()
207 .get(0)
208 .as_range_var()
209 .get_relname()
210 else {
211 panic!("select relname failed returns None")
212 };
213
214 assert_eq!(&relname, "users");
215
216 Ok(())
217 }
218
219 #[test]
220 fn test_create_composit_type() -> TestResult {
221 let result = parse("CREATE TYPE compfoo AS (f1 int, f2 text);")?;
222 let stmt = result.get_stmts().get(0).get_stmt();
223
224 let Some(ct) = stmt.as_composite_type_stmt().as_inner() else {
225 eprintln!("{stmt:?}");
226 panic!();
227 };
228
229 eprintln!("{ct:?}");
230 Ok(())
231 }
232
233 #[test]
234 fn test_create_range_type() -> TestResult {
235 let result = parse(
236 "CREATE TYPE float8_range AS RANGE (subtype = float8, subtype_diff = float8mi);",
237 )?;
238 let stmt = result.get_stmts().get(0).get_stmt();
239
240 let Some(ct) = stmt.as_create_range_stmt().as_inner() else {
241 eprintln!("{stmt:?}");
242 panic!();
243 };
244
245 eprintln!("{ct:?}");
246 Ok(())
247 }
248
249 #[test]
250 fn test_create_base_type() -> TestResult {
251 let result = parse(
252 "CREATE TYPE box (
253 INTERNALLENGTH = 16,
254 INPUT = my_box_in_function,
255 OUTPUT = my_box_out_function
256);",
257 )?;
258 let stmt = result.get_stmts().get(0).get_stmt();
259
260 let Some(ct) = stmt.as_define_stmt().as_inner() else {
261 eprintln!("{stmt:?}");
262 panic!();
263 };
264
265 eprintln!("{ct:?}");
266 Ok(())
267 }
268
269 #[ignore]
270 #[rstest::rstest]
271 fn parse_adventure_works_schema(
272 context_args: SqlFunArgs,
273 _enable_stack_overflow_backtrace: EnableStackOverflowBacktrace,
274 ) -> TestResult {
275 let mut analyze_context = BaseContext::new(context_args.clone())?;
276
277 let builtin_context = BaseContext::new(context_args.clone())?;
278 let home_path = PathBuf::from(env!("SQL_FUN_HOME"));
279 let builtin = home_path.join("postgres/17/schema.sql");
280 let builtin = parse_schema_file(&builtin, builtin_context)?;
281 analyze_context.extend(Box::new(builtin));
282
283 let tablefunc_ext = home_path.join("postgres/17/extension/tablefunc--1.0.sql");
284 let input_context = BaseContext::new(context_args.clone())?;
285 let tablefunc_ext = parse_schema_file(&tablefunc_ext, input_context)?;
286
287 analyze_context.extend(Box::new(tablefunc_ext));
288 let file = PathBuf::from("../examples/adventure-works/schema.develop.sql");
289 let _dump = parse_schema_file(&file, analyze_context)?;
290 Ok(())
291 }
292}