1use std::borrow::Cow;
22use std::cell::Cell;
23
24use crate::highlighter::{Color, NoSyntaxHighlighter, SyntaxHighlighter};
25
26use datafusion::sql::parser::{DFParser, Statement};
27use datafusion::sql::sqlparser::dialect::dialect_from_str;
28use datafusion_common::config::Dialect;
29
30use rustyline::completion::{Completer, FilenameCompleter, Pair};
31use rustyline::error::ReadlineError;
32use rustyline::highlight::{CmdKind, Highlighter};
33use rustyline::hint::Hinter;
34use rustyline::validate::{ValidationContext, ValidationResult, Validator};
35use rustyline::{Context, Helper, Result};
36
37const DEFAULT_HINT_SUGGESTION: &str = " \\? for help, \\q to quit";
39
40pub struct CliHelper {
41 completer: FilenameCompleter,
42 dialect: Dialect,
43 highlighter: Box<dyn Highlighter>,
44 show_hint: Cell<bool>,
48}
49
50impl CliHelper {
51 pub fn new(dialect: &Dialect, color: bool) -> Self {
52 let highlighter: Box<dyn Highlighter> = if !color {
53 Box::new(NoSyntaxHighlighter {})
54 } else {
55 Box::new(SyntaxHighlighter::new(dialect))
56 };
57 Self {
58 completer: FilenameCompleter::new(),
59 dialect: *dialect,
60 highlighter,
61 show_hint: Cell::new(true),
62 }
63 }
64
65 pub fn set_dialect(&mut self, dialect: &Dialect) {
66 if *dialect != self.dialect {
67 self.dialect = *dialect;
68 }
69 }
70
71 pub fn reset_hint(&self) {
73 self.show_hint.set(true);
74 }
75
76 fn validate_input(&self, input: &str) -> Result<ValidationResult> {
77 if let Some(sql) = input.strip_suffix(';') {
78 let dialect = match dialect_from_str(self.dialect) {
79 Some(dialect) => dialect,
80 None => {
81 return Ok(ValidationResult::Invalid(Some(format!(
82 " 🤔 Invalid dialect: {}",
83 self.dialect
84 ))));
85 }
86 };
87 let lines = split_from_semicolon(sql);
88 for line in lines {
89 match DFParser::parse_sql_with_dialect(&line, dialect.as_ref()) {
90 Ok(statements) if statements.is_empty() => {
91 return Ok(ValidationResult::Invalid(Some(
92 " 🤔 You entered an empty statement".to_string(),
93 )));
94 }
95 Ok(_statements) => {}
96 Err(err) => {
97 return Ok(ValidationResult::Invalid(Some(format!(
98 " 🤔 Invalid statement: {err}",
99 ))));
100 }
101 }
102 }
103 Ok(ValidationResult::Valid(None))
104 } else if input.starts_with('\\') {
105 Ok(ValidationResult::Valid(None))
107 } else {
108 Ok(ValidationResult::Incomplete)
109 }
110 }
111}
112
113impl Default for CliHelper {
114 fn default() -> Self {
115 Self::new(&Dialect::Generic, false)
116 }
117}
118
119impl Highlighter for CliHelper {
120 fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> {
121 self.highlighter.highlight(line, pos)
122 }
123
124 fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool {
125 self.highlighter.highlight_char(line, pos, kind)
126 }
127}
128
129impl Hinter for CliHelper {
130 type Hint = String;
131
132 fn hint(&self, line: &str, _pos: usize, _ctx: &Context<'_>) -> Option<String> {
133 if !line.is_empty() {
134 self.show_hint.set(false);
135 }
136 (self.show_hint.get() && line.trim().is_empty())
137 .then(|| Color::gray(DEFAULT_HINT_SUGGESTION))
138 }
139}
140
141fn is_open_quote_for_location(line: &str, pos: usize) -> bool {
144 let mut sql = line[..pos].to_string();
145 sql.push('\'');
146 DFParser::parse_sql(&sql).is_ok_and(|stmts| {
147 matches!(stmts.back(), Some(Statement::CreateExternalTable(_)))
148 })
149}
150
151impl Completer for CliHelper {
152 type Candidate = Pair;
153
154 fn complete(
155 &self,
156 line: &str,
157 pos: usize,
158 ctx: &Context<'_>,
159 ) -> std::result::Result<(usize, Vec<Pair>), ReadlineError> {
160 if is_open_quote_for_location(line, pos) {
161 self.completer.complete(line, pos, ctx)
162 } else {
163 Ok((0, Vec::with_capacity(0)))
164 }
165 }
166}
167
168impl Validator for CliHelper {
169 fn validate(&self, ctx: &mut ValidationContext<'_>) -> Result<ValidationResult> {
170 let input = ctx.input().trim_end();
171 let result = self.validate_input(input);
172 self.reset_hint();
173 result
174 }
175}
176
177impl Helper for CliHelper {}
178
179pub(crate) fn split_from_semicolon(sql: &str) -> Vec<String> {
181 let mut commands = Vec::new();
182 let mut current_command = String::new();
183 let mut in_single_quote = false;
184 let mut in_double_quote = false;
185
186 for c in sql.chars() {
187 if c == '\'' && !in_double_quote {
188 in_single_quote = !in_single_quote;
189 } else if c == '"' && !in_single_quote {
190 in_double_quote = !in_double_quote;
191 }
192
193 if c == ';' && !in_single_quote && !in_double_quote {
194 if !current_command.trim().is_empty() {
195 commands.push(format!("{};", current_command.trim()));
196 current_command.clear();
197 }
198 } else {
199 current_command.push(c);
200 }
201 }
202
203 if !current_command.trim().is_empty() {
204 commands.push(format!("{};", current_command.trim()));
205 }
206
207 commands
208}
209
210#[cfg(test)]
211mod tests {
212 use std::io::{BufRead, Cursor};
213
214 use super::*;
215
216 fn readline_direct(
217 mut reader: impl BufRead,
218 validator: &CliHelper,
219 ) -> Result<ValidationResult> {
220 let mut input = String::new();
221
222 if reader.read_line(&mut input)? == 0 {
223 return Err(ReadlineError::Eof);
224 }
225
226 validator.validate_input(&input)
227 }
228
229 #[test]
230 fn unescape_readline_input() -> Result<()> {
231 let validator = CliHelper::default();
232
233 let result = readline_direct(
235 Cursor::new(
236 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',');"
237 .as_bytes(),
238 ),
239 &validator,
240 )?;
241 assert!(matches!(result, ValidationResult::Valid(None)));
242
243 let result = readline_direct(
244 Cursor::new(
245 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\0');"
246 .as_bytes()),
247 &validator,
248 )?;
249 assert!(matches!(result, ValidationResult::Valid(None)));
250
251 let result = readline_direct(
252 Cursor::new(
253 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\n');"
254 .as_bytes()),
255 &validator,
256 )?;
257 assert!(matches!(result, ValidationResult::Valid(None)));
258
259 let result = readline_direct(
260 Cursor::new(
261 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\r');"
262 .as_bytes()),
263 &validator,
264 )?;
265 assert!(matches!(result, ValidationResult::Valid(None)));
266
267 let result = readline_direct(
268 Cursor::new(
269 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\t');"
270 .as_bytes()),
271 &validator,
272 )?;
273 assert!(matches!(result, ValidationResult::Valid(None)));
274
275 let result = readline_direct(
276 Cursor::new(
277 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\\');"
278 .as_bytes()),
279 &validator,
280 )?;
281 assert!(matches!(result, ValidationResult::Valid(None)));
282
283 let result = readline_direct(
284 Cursor::new(
285 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',,');"
286 .as_bytes()),
287 &validator,
288 )?;
289 assert!(matches!(result, ValidationResult::Valid(None)));
290
291 let result = readline_direct(
292 Cursor::new(
293 r"select '\', '\\', '\\\\\', 'dsdsds\\\\', '\t', '\0', '\n';".as_bytes(),
294 ),
295 &validator,
296 )?;
297 assert!(matches!(result, ValidationResult::Valid(None)));
298
299 Ok(())
300 }
301
302 #[test]
303 fn sql_dialect() -> Result<()> {
304 let mut validator = CliHelper::default();
305
306 let result =
308 readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
309 assert!(
310 matches!(result, ValidationResult::Invalid(Some(e)) if e.contains("Invalid statement"))
311 );
312
313 validator.set_dialect(&Dialect::PostgreSQL);
315 let result =
316 readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
317 assert!(matches!(result, ValidationResult::Valid(None)));
318
319 Ok(())
320 }
321
322 #[test]
323 fn test_split_from_semicolon() {
324 let sql = "SELECT 1; SELECT 2;";
325 let expected = vec!["SELECT 1;", "SELECT 2;"];
326 assert_eq!(split_from_semicolon(sql), expected);
327
328 let sql = r#"SELECT ";";"#;
329 let expected = vec![r#"SELECT ";";"#];
330 assert_eq!(split_from_semicolon(sql), expected);
331
332 let sql = "SELECT ';';";
333 let expected = vec!["SELECT ';';"];
334 assert_eq!(split_from_semicolon(sql), expected);
335
336 let sql = r#"SELECT 1; SELECT 'value;value'; SELECT 1 as "text;text";"#;
337 let expected = vec![
338 "SELECT 1;",
339 "SELECT 'value;value';",
340 r#"SELECT 1 as "text;text";"#,
341 ];
342 assert_eq!(split_from_semicolon(sql), expected);
343
344 let sql = "";
345 let expected: Vec<String> = Vec::new();
346 assert_eq!(split_from_semicolon(sql), expected);
347
348 let sql = "SELECT 1";
349 let expected = vec!["SELECT 1;"];
350 assert_eq!(split_from_semicolon(sql), expected);
351
352 let sql = "SELECT 1; ";
353 let expected = vec!["SELECT 1;"];
354 assert_eq!(split_from_semicolon(sql), expected);
355 }
356}