datafusion_cli/
helper.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Helper that helps with interactive editing, including multi-line parsing and validation,
19//! and auto-completion for file name during creating external table.
20
21use std::borrow::Cow;
22
23use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter};
24
25use datafusion::sql::parser::{DFParser, Statement};
26use datafusion::sql::sqlparser::dialect::dialect_from_str;
27
28use rustyline::completion::{Completer, FilenameCompleter, Pair};
29use rustyline::error::ReadlineError;
30use rustyline::highlight::{CmdKind, Highlighter};
31use rustyline::hint::Hinter;
32use rustyline::validate::{ValidationContext, ValidationResult, Validator};
33use rustyline::{Context, Helper, Result};
34
35pub struct CliHelper {
36    completer: FilenameCompleter,
37    dialect: String,
38    highlighter: Box<dyn Highlighter>,
39}
40
41impl CliHelper {
42    pub fn new(dialect: &str, color: bool) -> Self {
43        let highlighter: Box<dyn Highlighter> = if !color {
44            Box::new(NoSyntaxHighlighter {})
45        } else {
46            Box::new(SyntaxHighlighter::new(dialect))
47        };
48        Self {
49            completer: FilenameCompleter::new(),
50            dialect: dialect.into(),
51            highlighter,
52        }
53    }
54
55    pub fn set_dialect(&mut self, dialect: &str) {
56        if dialect != self.dialect {
57            self.dialect = dialect.to_string();
58        }
59    }
60
61    fn validate_input(&self, input: &str) -> Result<ValidationResult> {
62        if let Some(sql) = input.strip_suffix(';') {
63            let dialect = match dialect_from_str(&self.dialect) {
64                Some(dialect) => dialect,
65                None => {
66                    return Ok(ValidationResult::Invalid(Some(format!(
67                        "  🤔 Invalid dialect: {}",
68                        self.dialect
69                    ))))
70                }
71            };
72            let lines = split_from_semicolon(sql);
73            for line in lines {
74                match DFParser::parse_sql_with_dialect(&line, dialect.as_ref()) {
75                    Ok(statements) if statements.is_empty() => {
76                        return Ok(ValidationResult::Invalid(Some(
77                            "  🤔 You entered an empty statement".to_string(),
78                        )));
79                    }
80                    Ok(_statements) => {}
81                    Err(err) => {
82                        return Ok(ValidationResult::Invalid(Some(format!(
83                            "  🤔 Invalid statement: {err}",
84                        ))));
85                    }
86                }
87            }
88            Ok(ValidationResult::Valid(None))
89        } else if input.starts_with('\\') {
90            // command
91            Ok(ValidationResult::Valid(None))
92        } else {
93            Ok(ValidationResult::Incomplete)
94        }
95    }
96}
97
98impl Default for CliHelper {
99    fn default() -> Self {
100        Self::new("generic", false)
101    }
102}
103
104impl Highlighter for CliHelper {
105    fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> {
106        self.highlighter.highlight(line, pos)
107    }
108
109    fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool {
110        self.highlighter.highlight_char(line, pos, kind)
111    }
112}
113
114impl Hinter for CliHelper {
115    type Hint = String;
116}
117
118/// returns true if the current position is after the open quote for
119/// creating an external table.
120fn is_open_quote_for_location(line: &str, pos: usize) -> bool {
121    let mut sql = line[..pos].to_string();
122    sql.push('\'');
123    if let Ok(stmts) = DFParser::parse_sql(&sql) {
124        if let Some(Statement::CreateExternalTable(_)) = stmts.back() {
125            return true;
126        }
127    }
128    false
129}
130
131impl Completer for CliHelper {
132    type Candidate = Pair;
133
134    fn complete(
135        &self,
136        line: &str,
137        pos: usize,
138        ctx: &Context<'_>,
139    ) -> std::result::Result<(usize, Vec<Pair>), ReadlineError> {
140        if is_open_quote_for_location(line, pos) {
141            self.completer.complete(line, pos, ctx)
142        } else {
143            Ok((0, Vec::with_capacity(0)))
144        }
145    }
146}
147
148impl Validator for CliHelper {
149    fn validate(&self, ctx: &mut ValidationContext<'_>) -> Result<ValidationResult> {
150        let input = ctx.input().trim_end();
151        self.validate_input(input)
152    }
153}
154
155impl Helper for CliHelper {}
156
157/// Splits a string which consists of multiple queries.
158pub(crate) fn split_from_semicolon(sql: &str) -> Vec<String> {
159    let mut commands = Vec::new();
160    let mut current_command = String::new();
161    let mut in_single_quote = false;
162    let mut in_double_quote = false;
163
164    for c in sql.chars() {
165        if c == '\'' && !in_double_quote {
166            in_single_quote = !in_single_quote;
167        } else if c == '"' && !in_single_quote {
168            in_double_quote = !in_double_quote;
169        }
170
171        if c == ';' && !in_single_quote && !in_double_quote {
172            if !current_command.trim().is_empty() {
173                commands.push(format!("{};", current_command.trim()));
174                current_command.clear();
175            }
176        } else {
177            current_command.push(c);
178        }
179    }
180
181    if !current_command.trim().is_empty() {
182        commands.push(format!("{};", current_command.trim()));
183    }
184
185    commands
186}
187
188#[cfg(test)]
189mod tests {
190    use std::io::{BufRead, Cursor};
191
192    use super::*;
193
194    fn readline_direct(
195        mut reader: impl BufRead,
196        validator: &CliHelper,
197    ) -> Result<ValidationResult> {
198        let mut input = String::new();
199
200        if reader.read_line(&mut input)? == 0 {
201            return Err(ReadlineError::Eof);
202        }
203
204        validator.validate_input(&input)
205    }
206
207    #[test]
208    fn unescape_readline_input() -> Result<()> {
209        let validator = CliHelper::default();
210
211        // should be valid
212        let result = readline_direct(
213             Cursor::new(
214                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',');"
215                     .as_bytes(),
216             ),
217             &validator,
218         )?;
219        assert!(matches!(result, ValidationResult::Valid(None)));
220
221        let result = readline_direct(
222             Cursor::new(
223                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\0');"
224                     .as_bytes()),
225             &validator,
226         )?;
227        assert!(matches!(result, ValidationResult::Valid(None)));
228
229        let result = readline_direct(
230             Cursor::new(
231                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\n');"
232                     .as_bytes()),
233             &validator,
234         )?;
235        assert!(matches!(result, ValidationResult::Valid(None)));
236
237        let result = readline_direct(
238             Cursor::new(
239                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\r');"
240                     .as_bytes()),
241             &validator,
242         )?;
243        assert!(matches!(result, ValidationResult::Valid(None)));
244
245        let result = readline_direct(
246             Cursor::new(
247                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\t');"
248                     .as_bytes()),
249             &validator,
250         )?;
251        assert!(matches!(result, ValidationResult::Valid(None)));
252
253        let result = readline_direct(
254             Cursor::new(
255                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\\');"
256                     .as_bytes()),
257             &validator,
258         )?;
259        assert!(matches!(result, ValidationResult::Valid(None)));
260
261        let result = readline_direct(
262             Cursor::new(
263                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',,');"
264                     .as_bytes()),
265             &validator,
266         )?;
267        assert!(matches!(result, ValidationResult::Valid(None)));
268
269        let result = readline_direct(
270            Cursor::new(
271                r"select '\', '\\', '\\\\\', 'dsdsds\\\\', '\t', '\0', '\n';".as_bytes(),
272            ),
273            &validator,
274        )?;
275        assert!(matches!(result, ValidationResult::Valid(None)));
276
277        Ok(())
278    }
279
280    #[test]
281    fn sql_dialect() -> Result<()> {
282        let mut validator = CliHelper::default();
283
284        // should be invalid in generic dialect
285        let result =
286            readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
287        assert!(
288            matches!(result, ValidationResult::Invalid(Some(e)) if e.contains("Invalid statement"))
289        );
290
291        // valid in postgresql dialect
292        validator.set_dialect("postgresql");
293        let result =
294            readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
295        assert!(matches!(result, ValidationResult::Valid(None)));
296
297        Ok(())
298    }
299
300    #[test]
301    fn test_split_from_semicolon() {
302        let sql = "SELECT 1; SELECT 2;";
303        let expected = vec!["SELECT 1;", "SELECT 2;"];
304        assert_eq!(split_from_semicolon(sql), expected);
305
306        let sql = r#"SELECT ";";"#;
307        let expected = vec![r#"SELECT ";";"#];
308        assert_eq!(split_from_semicolon(sql), expected);
309
310        let sql = "SELECT ';';";
311        let expected = vec!["SELECT ';';"];
312        assert_eq!(split_from_semicolon(sql), expected);
313
314        let sql = r#"SELECT 1; SELECT 'value;value'; SELECT 1 as "text;text";"#;
315        let expected = vec![
316            "SELECT 1;",
317            "SELECT 'value;value';",
318            r#"SELECT 1 as "text;text";"#,
319        ];
320        assert_eq!(split_from_semicolon(sql), expected);
321
322        let sql = "";
323        let expected: Vec<String> = Vec::new();
324        assert_eq!(split_from_semicolon(sql), expected);
325
326        let sql = "SELECT 1";
327        let expected = vec!["SELECT 1;"];
328        assert_eq!(split_from_semicolon(sql), expected);
329
330        let sql = "SELECT 1;   ";
331        let expected = vec!["SELECT 1;"];
332        assert_eq!(split_from_semicolon(sql), expected);
333    }
334}