1use std::borrow::Cow;
22
23use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter};
24
25use datafusion::sql::parser::{DFParser, Statement};
26use datafusion::sql::sqlparser::dialect::dialect_from_str;
27use datafusion_common::config::Dialect;
28
29use rustyline::completion::{Completer, FilenameCompleter, Pair};
30use rustyline::error::ReadlineError;
31use rustyline::highlight::{CmdKind, Highlighter};
32use rustyline::hint::Hinter;
33use rustyline::validate::{ValidationContext, ValidationResult, Validator};
34use rustyline::{Context, Helper, Result};
35
36pub struct CliHelper {
37 completer: FilenameCompleter,
38 dialect: Dialect,
39 highlighter: Box<dyn Highlighter>,
40}
41
42impl CliHelper {
43 pub fn new(dialect: &Dialect, color: bool) -> Self {
44 let highlighter: Box<dyn Highlighter> = if !color {
45 Box::new(NoSyntaxHighlighter {})
46 } else {
47 Box::new(SyntaxHighlighter::new(dialect))
48 };
49 Self {
50 completer: FilenameCompleter::new(),
51 dialect: *dialect,
52 highlighter,
53 }
54 }
55
56 pub fn set_dialect(&mut self, dialect: &Dialect) {
57 if *dialect != self.dialect {
58 self.dialect = *dialect;
59 }
60 }
61
62 fn validate_input(&self, input: &str) -> Result<ValidationResult> {
63 if let Some(sql) = input.strip_suffix(';') {
64 let dialect = match dialect_from_str(self.dialect) {
65 Some(dialect) => dialect,
66 None => {
67 return Ok(ValidationResult::Invalid(Some(format!(
68 " 🤔 Invalid dialect: {}",
69 self.dialect
70 ))))
71 }
72 };
73 let lines = split_from_semicolon(sql);
74 for line in lines {
75 match DFParser::parse_sql_with_dialect(&line, dialect.as_ref()) {
76 Ok(statements) if statements.is_empty() => {
77 return Ok(ValidationResult::Invalid(Some(
78 " 🤔 You entered an empty statement".to_string(),
79 )));
80 }
81 Ok(_statements) => {}
82 Err(err) => {
83 return Ok(ValidationResult::Invalid(Some(format!(
84 " 🤔 Invalid statement: {err}",
85 ))));
86 }
87 }
88 }
89 Ok(ValidationResult::Valid(None))
90 } else if input.starts_with('\\') {
91 Ok(ValidationResult::Valid(None))
93 } else {
94 Ok(ValidationResult::Incomplete)
95 }
96 }
97}
98
99impl Default for CliHelper {
100 fn default() -> Self {
101 Self::new(&Dialect::Generic, false)
102 }
103}
104
105impl Highlighter for CliHelper {
106 fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> {
107 self.highlighter.highlight(line, pos)
108 }
109
110 fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool {
111 self.highlighter.highlight_char(line, pos, kind)
112 }
113}
114
115impl Hinter for CliHelper {
116 type Hint = String;
117}
118
119fn is_open_quote_for_location(line: &str, pos: usize) -> bool {
122 let mut sql = line[..pos].to_string();
123 sql.push('\'');
124 if let Ok(stmts) = DFParser::parse_sql(&sql) {
125 if let Some(Statement::CreateExternalTable(_)) = stmts.back() {
126 return true;
127 }
128 }
129 false
130}
131
132impl Completer for CliHelper {
133 type Candidate = Pair;
134
135 fn complete(
136 &self,
137 line: &str,
138 pos: usize,
139 ctx: &Context<'_>,
140 ) -> std::result::Result<(usize, Vec<Pair>), ReadlineError> {
141 if is_open_quote_for_location(line, pos) {
142 self.completer.complete(line, pos, ctx)
143 } else {
144 Ok((0, Vec::with_capacity(0)))
145 }
146 }
147}
148
149impl Validator for CliHelper {
150 fn validate(&self, ctx: &mut ValidationContext<'_>) -> Result<ValidationResult> {
151 let input = ctx.input().trim_end();
152 self.validate_input(input)
153 }
154}
155
156impl Helper for CliHelper {}
157
158pub(crate) fn split_from_semicolon(sql: &str) -> Vec<String> {
160 let mut commands = Vec::new();
161 let mut current_command = String::new();
162 let mut in_single_quote = false;
163 let mut in_double_quote = false;
164
165 for c in sql.chars() {
166 if c == '\'' && !in_double_quote {
167 in_single_quote = !in_single_quote;
168 } else if c == '"' && !in_single_quote {
169 in_double_quote = !in_double_quote;
170 }
171
172 if c == ';' && !in_single_quote && !in_double_quote {
173 if !current_command.trim().is_empty() {
174 commands.push(format!("{};", current_command.trim()));
175 current_command.clear();
176 }
177 } else {
178 current_command.push(c);
179 }
180 }
181
182 if !current_command.trim().is_empty() {
183 commands.push(format!("{};", current_command.trim()));
184 }
185
186 commands
187}
188
189#[cfg(test)]
190mod tests {
191 use std::io::{BufRead, Cursor};
192
193 use super::*;
194
195 fn readline_direct(
196 mut reader: impl BufRead,
197 validator: &CliHelper,
198 ) -> Result<ValidationResult> {
199 let mut input = String::new();
200
201 if reader.read_line(&mut input)? == 0 {
202 return Err(ReadlineError::Eof);
203 }
204
205 validator.validate_input(&input)
206 }
207
208 #[test]
209 fn unescape_readline_input() -> Result<()> {
210 let validator = CliHelper::default();
211
212 let result = readline_direct(
214 Cursor::new(
215 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',');"
216 .as_bytes(),
217 ),
218 &validator,
219 )?;
220 assert!(matches!(result, ValidationResult::Valid(None)));
221
222 let result = readline_direct(
223 Cursor::new(
224 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\0');"
225 .as_bytes()),
226 &validator,
227 )?;
228 assert!(matches!(result, ValidationResult::Valid(None)));
229
230 let result = readline_direct(
231 Cursor::new(
232 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\n');"
233 .as_bytes()),
234 &validator,
235 )?;
236 assert!(matches!(result, ValidationResult::Valid(None)));
237
238 let result = readline_direct(
239 Cursor::new(
240 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\r');"
241 .as_bytes()),
242 &validator,
243 )?;
244 assert!(matches!(result, ValidationResult::Valid(None)));
245
246 let result = readline_direct(
247 Cursor::new(
248 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\t');"
249 .as_bytes()),
250 &validator,
251 )?;
252 assert!(matches!(result, ValidationResult::Valid(None)));
253
254 let result = readline_direct(
255 Cursor::new(
256 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\\');"
257 .as_bytes()),
258 &validator,
259 )?;
260 assert!(matches!(result, ValidationResult::Valid(None)));
261
262 let result = readline_direct(
263 Cursor::new(
264 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',,');"
265 .as_bytes()),
266 &validator,
267 )?;
268 assert!(matches!(result, ValidationResult::Valid(None)));
269
270 let result = readline_direct(
271 Cursor::new(
272 r"select '\', '\\', '\\\\\', 'dsdsds\\\\', '\t', '\0', '\n';".as_bytes(),
273 ),
274 &validator,
275 )?;
276 assert!(matches!(result, ValidationResult::Valid(None)));
277
278 Ok(())
279 }
280
281 #[test]
282 fn sql_dialect() -> Result<()> {
283 let mut validator = CliHelper::default();
284
285 let result =
287 readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
288 assert!(
289 matches!(result, ValidationResult::Invalid(Some(e)) if e.contains("Invalid statement"))
290 );
291
292 validator.set_dialect(&Dialect::PostgreSQL);
294 let result =
295 readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
296 assert!(matches!(result, ValidationResult::Valid(None)));
297
298 Ok(())
299 }
300
301 #[test]
302 fn test_split_from_semicolon() {
303 let sql = "SELECT 1; SELECT 2;";
304 let expected = vec!["SELECT 1;", "SELECT 2;"];
305 assert_eq!(split_from_semicolon(sql), expected);
306
307 let sql = r#"SELECT ";";"#;
308 let expected = vec![r#"SELECT ";";"#];
309 assert_eq!(split_from_semicolon(sql), expected);
310
311 let sql = "SELECT ';';";
312 let expected = vec!["SELECT ';';"];
313 assert_eq!(split_from_semicolon(sql), expected);
314
315 let sql = r#"SELECT 1; SELECT 'value;value'; SELECT 1 as "text;text";"#;
316 let expected = vec![
317 "SELECT 1;",
318 "SELECT 'value;value';",
319 r#"SELECT 1 as "text;text";"#,
320 ];
321 assert_eq!(split_from_semicolon(sql), expected);
322
323 let sql = "";
324 let expected: Vec<String> = Vec::new();
325 assert_eq!(split_from_semicolon(sql), expected);
326
327 let sql = "SELECT 1";
328 let expected = vec!["SELECT 1;"];
329 assert_eq!(split_from_semicolon(sql), expected);
330
331 let sql = "SELECT 1; ";
332 let expected = vec!["SELECT 1;"];
333 assert_eq!(split_from_semicolon(sql), expected);
334 }
335}