1use std::borrow::Cow;
22
23use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter};
24
25use datafusion::sql::parser::{DFParser, Statement};
26use datafusion::sql::sqlparser::dialect::dialect_from_str;
27
28use rustyline::completion::{Completer, FilenameCompleter, Pair};
29use rustyline::error::ReadlineError;
30use rustyline::highlight::{CmdKind, Highlighter};
31use rustyline::hint::Hinter;
32use rustyline::validate::{ValidationContext, ValidationResult, Validator};
33use rustyline::{Context, Helper, Result};
34
35pub struct CliHelper {
36 completer: FilenameCompleter,
37 dialect: String,
38 highlighter: Box<dyn Highlighter>,
39}
40
41impl CliHelper {
42 pub fn new(dialect: &str, color: bool) -> Self {
43 let highlighter: Box<dyn Highlighter> = if !color {
44 Box::new(NoSyntaxHighlighter {})
45 } else {
46 Box::new(SyntaxHighlighter::new(dialect))
47 };
48 Self {
49 completer: FilenameCompleter::new(),
50 dialect: dialect.into(),
51 highlighter,
52 }
53 }
54
55 pub fn set_dialect(&mut self, dialect: &str) {
56 if dialect != self.dialect {
57 self.dialect = dialect.to_string();
58 }
59 }
60
61 fn validate_input(&self, input: &str) -> Result<ValidationResult> {
62 if let Some(sql) = input.strip_suffix(';') {
63 let dialect = match dialect_from_str(&self.dialect) {
64 Some(dialect) => dialect,
65 None => {
66 return Ok(ValidationResult::Invalid(Some(format!(
67 " 🤔 Invalid dialect: {}",
68 self.dialect
69 ))))
70 }
71 };
72 let lines = split_from_semicolon(sql);
73 for line in lines {
74 match DFParser::parse_sql_with_dialect(&line, dialect.as_ref()) {
75 Ok(statements) if statements.is_empty() => {
76 return Ok(ValidationResult::Invalid(Some(
77 " 🤔 You entered an empty statement".to_string(),
78 )));
79 }
80 Ok(_statements) => {}
81 Err(err) => {
82 return Ok(ValidationResult::Invalid(Some(format!(
83 " 🤔 Invalid statement: {err}",
84 ))));
85 }
86 }
87 }
88 Ok(ValidationResult::Valid(None))
89 } else if input.starts_with('\\') {
90 Ok(ValidationResult::Valid(None))
92 } else {
93 Ok(ValidationResult::Incomplete)
94 }
95 }
96}
97
98impl Default for CliHelper {
99 fn default() -> Self {
100 Self::new("generic", false)
101 }
102}
103
104impl Highlighter for CliHelper {
105 fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> {
106 self.highlighter.highlight(line, pos)
107 }
108
109 fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool {
110 self.highlighter.highlight_char(line, pos, kind)
111 }
112}
113
114impl Hinter for CliHelper {
115 type Hint = String;
116}
117
118fn is_open_quote_for_location(line: &str, pos: usize) -> bool {
121 let mut sql = line[..pos].to_string();
122 sql.push('\'');
123 if let Ok(stmts) = DFParser::parse_sql(&sql) {
124 if let Some(Statement::CreateExternalTable(_)) = stmts.back() {
125 return true;
126 }
127 }
128 false
129}
130
131impl Completer for CliHelper {
132 type Candidate = Pair;
133
134 fn complete(
135 &self,
136 line: &str,
137 pos: usize,
138 ctx: &Context<'_>,
139 ) -> std::result::Result<(usize, Vec<Pair>), ReadlineError> {
140 if is_open_quote_for_location(line, pos) {
141 self.completer.complete(line, pos, ctx)
142 } else {
143 Ok((0, Vec::with_capacity(0)))
144 }
145 }
146}
147
148impl Validator for CliHelper {
149 fn validate(&self, ctx: &mut ValidationContext<'_>) -> Result<ValidationResult> {
150 let input = ctx.input().trim_end();
151 self.validate_input(input)
152 }
153}
154
155impl Helper for CliHelper {}
156
157pub(crate) fn split_from_semicolon(sql: &str) -> Vec<String> {
159 let mut commands = Vec::new();
160 let mut current_command = String::new();
161 let mut in_single_quote = false;
162 let mut in_double_quote = false;
163
164 for c in sql.chars() {
165 if c == '\'' && !in_double_quote {
166 in_single_quote = !in_single_quote;
167 } else if c == '"' && !in_single_quote {
168 in_double_quote = !in_double_quote;
169 }
170
171 if c == ';' && !in_single_quote && !in_double_quote {
172 if !current_command.trim().is_empty() {
173 commands.push(format!("{};", current_command.trim()));
174 current_command.clear();
175 }
176 } else {
177 current_command.push(c);
178 }
179 }
180
181 if !current_command.trim().is_empty() {
182 commands.push(format!("{};", current_command.trim()));
183 }
184
185 commands
186}
187
188#[cfg(test)]
189mod tests {
190 use std::io::{BufRead, Cursor};
191
192 use super::*;
193
194 fn readline_direct(
195 mut reader: impl BufRead,
196 validator: &CliHelper,
197 ) -> Result<ValidationResult> {
198 let mut input = String::new();
199
200 if reader.read_line(&mut input)? == 0 {
201 return Err(ReadlineError::Eof);
202 }
203
204 validator.validate_input(&input)
205 }
206
207 #[test]
208 fn unescape_readline_input() -> Result<()> {
209 let validator = CliHelper::default();
210
211 let result = readline_direct(
213 Cursor::new(
214 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',');"
215 .as_bytes(),
216 ),
217 &validator,
218 )?;
219 assert!(matches!(result, ValidationResult::Valid(None)));
220
221 let result = readline_direct(
222 Cursor::new(
223 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\0');"
224 .as_bytes()),
225 &validator,
226 )?;
227 assert!(matches!(result, ValidationResult::Valid(None)));
228
229 let result = readline_direct(
230 Cursor::new(
231 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\n');"
232 .as_bytes()),
233 &validator,
234 )?;
235 assert!(matches!(result, ValidationResult::Valid(None)));
236
237 let result = readline_direct(
238 Cursor::new(
239 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\r');"
240 .as_bytes()),
241 &validator,
242 )?;
243 assert!(matches!(result, ValidationResult::Valid(None)));
244
245 let result = readline_direct(
246 Cursor::new(
247 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\t');"
248 .as_bytes()),
249 &validator,
250 )?;
251 assert!(matches!(result, ValidationResult::Valid(None)));
252
253 let result = readline_direct(
254 Cursor::new(
255 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\\');"
256 .as_bytes()),
257 &validator,
258 )?;
259 assert!(matches!(result, ValidationResult::Valid(None)));
260
261 let result = readline_direct(
262 Cursor::new(
263 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',,');"
264 .as_bytes()),
265 &validator,
266 )?;
267 assert!(matches!(result, ValidationResult::Valid(None)));
268
269 let result = readline_direct(
270 Cursor::new(
271 r"select '\', '\\', '\\\\\', 'dsdsds\\\\', '\t', '\0', '\n';".as_bytes(),
272 ),
273 &validator,
274 )?;
275 assert!(matches!(result, ValidationResult::Valid(None)));
276
277 Ok(())
278 }
279
280 #[test]
281 fn sql_dialect() -> Result<()> {
282 let mut validator = CliHelper::default();
283
284 let result =
286 readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
287 assert!(
288 matches!(result, ValidationResult::Invalid(Some(e)) if e.contains("Invalid statement"))
289 );
290
291 validator.set_dialect("postgresql");
293 let result =
294 readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
295 assert!(matches!(result, ValidationResult::Valid(None)));
296
297 Ok(())
298 }
299
300 #[test]
301 fn test_split_from_semicolon() {
302 let sql = "SELECT 1; SELECT 2;";
303 let expected = vec!["SELECT 1;", "SELECT 2;"];
304 assert_eq!(split_from_semicolon(sql), expected);
305
306 let sql = r#"SELECT ";";"#;
307 let expected = vec![r#"SELECT ";";"#];
308 assert_eq!(split_from_semicolon(sql), expected);
309
310 let sql = "SELECT ';';";
311 let expected = vec!["SELECT ';';"];
312 assert_eq!(split_from_semicolon(sql), expected);
313
314 let sql = r#"SELECT 1; SELECT 'value;value'; SELECT 1 as "text;text";"#;
315 let expected = vec![
316 "SELECT 1;",
317 "SELECT 'value;value';",
318 r#"SELECT 1 as "text;text";"#,
319 ];
320 assert_eq!(split_from_semicolon(sql), expected);
321
322 let sql = "";
323 let expected: Vec<String> = Vec::new();
324 assert_eq!(split_from_semicolon(sql), expected);
325
326 let sql = "SELECT 1";
327 let expected = vec!["SELECT 1;"];
328 assert_eq!(split_from_semicolon(sql), expected);
329
330 let sql = "SELECT 1; ";
331 let expected = vec!["SELECT 1;"];
332 assert_eq!(split_from_semicolon(sql), expected);
333 }
334}