1use std::borrow::Cow;
22
23use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter};
24use crate::iceberg::transform_iceberg_input;
25
26use datafusion::common::sql_datafusion_err;
27use datafusion::error::DataFusionError;
28use datafusion::sql::parser::{DFParser, Statement};
29use datafusion::sql::sqlparser::dialect::dialect_from_str;
30use datafusion::sql::sqlparser::parser::ParserError;
31
32use rustyline::completion::{Completer, FilenameCompleter, Pair};
33use rustyline::error::ReadlineError;
34use rustyline::highlight::Highlighter;
35use rustyline::hint::Hinter;
36use rustyline::validate::{ValidationContext, ValidationResult, Validator};
37use rustyline::{Context, Helper, Result};
38
39pub struct CliHelper {
40 completer: FilenameCompleter,
41 dialect: String,
42 highlighter: Box<dyn Highlighter>,
43}
44
45impl CliHelper {
46 pub fn new(dialect: &str, color: bool) -> Self {
47 let highlighter: Box<dyn Highlighter> = if !color {
48 Box::new(NoSyntaxHighlighter {})
49 } else {
50 Box::new(SyntaxHighlighter::new(dialect))
51 };
52 Self {
53 completer: FilenameCompleter::new(),
54 dialect: dialect.into(),
55 highlighter,
56 }
57 }
58
59 pub fn set_dialect(&mut self, dialect: &str) {
60 if dialect != self.dialect {
61 self.dialect = dialect.to_string();
62 }
63 }
64
65 fn validate_input(&self, input: &str) -> Result<ValidationResult> {
66 if let Some(sql) = input.strip_suffix(';') {
67 let sql = match unescape_input(sql) {
68 Ok(sql) => sql,
69 Err(err) => {
70 return Ok(ValidationResult::Invalid(Some(format!(
71 " 🤔 Invalid statement: {err}",
72 ))))
73 }
74 };
75
76 let dialect = match dialect_from_str(&self.dialect) {
77 Some(dialect) => dialect,
78 None => {
79 return Ok(ValidationResult::Invalid(Some(format!(
80 " 🤔 Invalid dialect: {}",
81 self.dialect
82 ))))
83 }
84 };
85 let lines = split_from_semicolon(sql);
86 for line in lines {
87 match DFParser::parse_sql_with_dialect(
88 &transform_iceberg_input(&line),
89 dialect.as_ref(),
90 ) {
91 Ok(statements) if statements.is_empty() => {
92 return Ok(ValidationResult::Invalid(Some(
93 " 🤔 You entered an empty statement".to_string(),
94 )));
95 }
96 Ok(_statements) => {}
97 Err(err) => {
98 return Ok(ValidationResult::Invalid(Some(format!(
99 " 🤔 Invalid statement: {err}",
100 ))));
101 }
102 }
103 }
104 Ok(ValidationResult::Valid(None))
105 } else if input.starts_with('\\') {
106 Ok(ValidationResult::Valid(None))
108 } else {
109 Ok(ValidationResult::Incomplete)
110 }
111 }
112}
113
114impl Default for CliHelper {
115 fn default() -> Self {
116 Self::new("generic", false)
117 }
118}
119
120impl Highlighter for CliHelper {
121 fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> {
122 self.highlighter.highlight(line, pos)
123 }
124
125 fn highlight_char(&self, line: &str, pos: usize, forced: bool) -> bool {
126 self.highlighter.highlight_char(line, pos, forced)
127 }
128}
129
130impl Hinter for CliHelper {
131 type Hint = String;
132}
133
134fn is_open_quote_for_location(line: &str, pos: usize) -> bool {
137 let mut sql = line[..pos].to_string();
138 sql.push('\'');
139 if let Ok(stmts) = DFParser::parse_sql(&sql) {
140 if let Some(Statement::CreateExternalTable(_)) = stmts.back() {
141 return true;
142 }
143 }
144 false
145}
146
147impl Completer for CliHelper {
148 type Candidate = Pair;
149
150 fn complete(
151 &self,
152 line: &str,
153 pos: usize,
154 ctx: &Context<'_>,
155 ) -> std::result::Result<(usize, Vec<Pair>), ReadlineError> {
156 if is_open_quote_for_location(line, pos) {
157 self.completer.complete(line, pos, ctx)
158 } else {
159 Ok((0, Vec::with_capacity(0)))
160 }
161 }
162}
163
164impl Validator for CliHelper {
165 fn validate(&self, ctx: &mut ValidationContext<'_>) -> Result<ValidationResult> {
166 let input = ctx.input().trim_end();
167 self.validate_input(input)
168 }
169}
170
171impl Helper for CliHelper {}
172
173pub fn unescape_input(input: &str) -> datafusion::error::Result<String> {
177 let mut chars = input.chars();
178
179 let mut result = String::with_capacity(input.len());
180 while let Some(char) = chars.next() {
181 if char == '\\' {
182 if let Some(next_char) = chars.next() {
183 result.push(match next_char {
185 '0' => '\0',
186 'n' => '\n',
187 'r' => '\r',
188 't' => '\t',
189 '\\' => '\\',
190 _ => {
191 return Err(sql_datafusion_err!(ParserError::TokenizerError(
192 format!("unsupported escape char: '\\{}'", next_char)
193 )))
194 }
195 });
196 }
197 } else {
198 result.push(char);
199 }
200 }
201
202 Ok(result)
203}
204
205pub(crate) fn split_from_semicolon(sql: String) -> Vec<String> {
207 let mut commands = Vec::new();
208 let mut current_command = String::new();
209 let mut in_single_quote = false;
210 let mut in_double_quote = false;
211
212 for c in sql.chars() {
213 if c == '\'' && !in_double_quote {
214 in_single_quote = !in_single_quote;
215 } else if c == '"' && !in_single_quote {
216 in_double_quote = !in_double_quote;
217 }
218
219 if c == ';' && !in_single_quote && !in_double_quote {
220 if !current_command.trim().is_empty() {
221 commands.push(format!("{};", current_command.trim()));
222 current_command.clear();
223 }
224 } else {
225 current_command.push(c);
226 }
227 }
228
229 if !current_command.trim().is_empty() {
230 commands.push(format!("{};", current_command.trim()));
231 }
232
233 commands
234}
235
236#[cfg(test)]
237mod tests {
238 use std::io::{BufRead, Cursor};
239
240 use super::*;
241
242 fn readline_direct(
243 mut reader: impl BufRead,
244 validator: &CliHelper,
245 ) -> Result<ValidationResult> {
246 let mut input = String::new();
247
248 if reader.read_line(&mut input)? == 0 {
249 return Err(ReadlineError::Eof);
250 }
251
252 validator.validate_input(&input)
253 }
254
255 #[test]
256 fn unescape_readline_input() -> Result<()> {
257 let validator = CliHelper::default();
258
259 let result = readline_direct(
261 Cursor::new(
262 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',');"
263 .as_bytes(),
264 ),
265 &validator,
266 )?;
267 assert!(matches!(result, ValidationResult::Valid(None)));
268
269 let result = readline_direct(
270 Cursor::new(
271 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\0');"
272 .as_bytes()),
273 &validator,
274 )?;
275 assert!(matches!(result, ValidationResult::Valid(None)));
276
277 let result = readline_direct(
278 Cursor::new(
279 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\n');"
280 .as_bytes()),
281 &validator,
282 )?;
283 assert!(matches!(result, ValidationResult::Valid(None)));
284
285 let result = readline_direct(
286 Cursor::new(
287 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\r');"
288 .as_bytes()),
289 &validator,
290 )?;
291 assert!(matches!(result, ValidationResult::Valid(None)));
292
293 let result = readline_direct(
294 Cursor::new(
295 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\t');"
296 .as_bytes()),
297 &validator,
298 )?;
299 assert!(matches!(result, ValidationResult::Valid(None)));
300
301 let result = readline_direct(
302 Cursor::new(
303 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\\');"
304 .as_bytes()),
305 &validator,
306 )?;
307 assert!(matches!(result, ValidationResult::Valid(None)));
308
309 let result = readline_direct(
310 Cursor::new(
311 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',,');"
312 .as_bytes()),
313 &validator,
314 )?;
315 assert!(matches!(result, ValidationResult::Valid(None)));
316
317 let result = readline_direct(
319 Cursor::new(
320 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\u{07}');"
321 .as_bytes()),
322 &validator,
323 )?;
324 assert!(matches!(result, ValidationResult::Invalid(Some(_))));
325
326 Ok(())
327 }
328
329 #[test]
330 fn sql_dialect() -> Result<()> {
331 let mut validator = CliHelper::default();
332
333 let result =
335 readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
336 assert!(
337 matches!(result, ValidationResult::Invalid(Some(e)) if e.contains("Invalid statement"))
338 );
339
340 validator.set_dialect("postgresql");
342 let result =
343 readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
344 assert!(matches!(result, ValidationResult::Valid(None)));
345
346 Ok(())
347 }
348
349 #[test]
350 fn test_split_from_semicolon() {
351 let sql = "SELECT 1; SELECT 2;";
352 let expected = vec!["SELECT 1;", "SELECT 2;"];
353 assert_eq!(split_from_semicolon(sql.to_string()), expected);
354
355 let sql = r#"SELECT ";";"#;
356 let expected = vec![r#"SELECT ";";"#];
357 assert_eq!(split_from_semicolon(sql.to_string()), expected);
358
359 let sql = "SELECT ';';";
360 let expected = vec!["SELECT ';';"];
361 assert_eq!(split_from_semicolon(sql.to_string()), expected);
362
363 let sql = r#"SELECT 1; SELECT 'value;value'; SELECT 1 as "text;text";"#;
364 let expected = vec![
365 "SELECT 1;",
366 "SELECT 'value;value';",
367 r#"SELECT 1 as "text;text";"#,
368 ];
369 assert_eq!(split_from_semicolon(sql.to_string()), expected);
370
371 let sql = "";
372 let expected: Vec<String> = Vec::new();
373 assert_eq!(split_from_semicolon(sql.to_string()), expected);
374
375 let sql = "SELECT 1";
376 let expected = vec!["SELECT 1;"];
377 assert_eq!(split_from_semicolon(sql.to_string()), expected);
378
379 let sql = "SELECT 1; ";
380 let expected = vec!["SELECT 1;"];
381 assert_eq!(split_from_semicolon(sql.to_string()), expected);
382 }
383}