datafusion_cli/
helper.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Helper that helps with interactive editing, including multi-line parsing and validation,
19//! and auto-completion for file name during creating external table.
20
21use std::borrow::Cow;
22
23use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter};
24
25use datafusion::sql::parser::{DFParser, Statement};
26use datafusion::sql::sqlparser::dialect::dialect_from_str;
27use datafusion_common::config::Dialect;
28
29use rustyline::completion::{Completer, FilenameCompleter, Pair};
30use rustyline::error::ReadlineError;
31use rustyline::highlight::{CmdKind, Highlighter};
32use rustyline::hint::Hinter;
33use rustyline::validate::{ValidationContext, ValidationResult, Validator};
34use rustyline::{Context, Helper, Result};
35
36pub struct CliHelper {
37    completer: FilenameCompleter,
38    dialect: Dialect,
39    highlighter: Box<dyn Highlighter>,
40}
41
42impl CliHelper {
43    pub fn new(dialect: &Dialect, color: bool) -> Self {
44        let highlighter: Box<dyn Highlighter> = if !color {
45            Box::new(NoSyntaxHighlighter {})
46        } else {
47            Box::new(SyntaxHighlighter::new(dialect))
48        };
49        Self {
50            completer: FilenameCompleter::new(),
51            dialect: *dialect,
52            highlighter,
53        }
54    }
55
56    pub fn set_dialect(&mut self, dialect: &Dialect) {
57        if *dialect != self.dialect {
58            self.dialect = *dialect;
59        }
60    }
61
62    fn validate_input(&self, input: &str) -> Result<ValidationResult> {
63        if let Some(sql) = input.strip_suffix(';') {
64            let dialect = match dialect_from_str(self.dialect) {
65                Some(dialect) => dialect,
66                None => {
67                    return Ok(ValidationResult::Invalid(Some(format!(
68                        "  🤔 Invalid dialect: {}",
69                        self.dialect
70                    ))))
71                }
72            };
73            let lines = split_from_semicolon(sql);
74            for line in lines {
75                match DFParser::parse_sql_with_dialect(&line, dialect.as_ref()) {
76                    Ok(statements) if statements.is_empty() => {
77                        return Ok(ValidationResult::Invalid(Some(
78                            "  🤔 You entered an empty statement".to_string(),
79                        )));
80                    }
81                    Ok(_statements) => {}
82                    Err(err) => {
83                        return Ok(ValidationResult::Invalid(Some(format!(
84                            "  🤔 Invalid statement: {err}",
85                        ))));
86                    }
87                }
88            }
89            Ok(ValidationResult::Valid(None))
90        } else if input.starts_with('\\') {
91            // command
92            Ok(ValidationResult::Valid(None))
93        } else {
94            Ok(ValidationResult::Incomplete)
95        }
96    }
97}
98
99impl Default for CliHelper {
100    fn default() -> Self {
101        Self::new(&Dialect::Generic, false)
102    }
103}
104
105impl Highlighter for CliHelper {
106    fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> {
107        self.highlighter.highlight(line, pos)
108    }
109
110    fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool {
111        self.highlighter.highlight_char(line, pos, kind)
112    }
113}
114
115impl Hinter for CliHelper {
116    type Hint = String;
117}
118
119/// returns true if the current position is after the open quote for
120/// creating an external table.
121fn is_open_quote_for_location(line: &str, pos: usize) -> bool {
122    let mut sql = line[..pos].to_string();
123    sql.push('\'');
124    if let Ok(stmts) = DFParser::parse_sql(&sql) {
125        if let Some(Statement::CreateExternalTable(_)) = stmts.back() {
126            return true;
127        }
128    }
129    false
130}
131
132impl Completer for CliHelper {
133    type Candidate = Pair;
134
135    fn complete(
136        &self,
137        line: &str,
138        pos: usize,
139        ctx: &Context<'_>,
140    ) -> std::result::Result<(usize, Vec<Pair>), ReadlineError> {
141        if is_open_quote_for_location(line, pos) {
142            self.completer.complete(line, pos, ctx)
143        } else {
144            Ok((0, Vec::with_capacity(0)))
145        }
146    }
147}
148
149impl Validator for CliHelper {
150    fn validate(&self, ctx: &mut ValidationContext<'_>) -> Result<ValidationResult> {
151        let input = ctx.input().trim_end();
152        self.validate_input(input)
153    }
154}
155
156impl Helper for CliHelper {}
157
158/// Splits a string which consists of multiple queries.
159pub(crate) fn split_from_semicolon(sql: &str) -> Vec<String> {
160    let mut commands = Vec::new();
161    let mut current_command = String::new();
162    let mut in_single_quote = false;
163    let mut in_double_quote = false;
164
165    for c in sql.chars() {
166        if c == '\'' && !in_double_quote {
167            in_single_quote = !in_single_quote;
168        } else if c == '"' && !in_single_quote {
169            in_double_quote = !in_double_quote;
170        }
171
172        if c == ';' && !in_single_quote && !in_double_quote {
173            if !current_command.trim().is_empty() {
174                commands.push(format!("{};", current_command.trim()));
175                current_command.clear();
176            }
177        } else {
178            current_command.push(c);
179        }
180    }
181
182    if !current_command.trim().is_empty() {
183        commands.push(format!("{};", current_command.trim()));
184    }
185
186    commands
187}
188
189#[cfg(test)]
190mod tests {
191    use std::io::{BufRead, Cursor};
192
193    use super::*;
194
195    fn readline_direct(
196        mut reader: impl BufRead,
197        validator: &CliHelper,
198    ) -> Result<ValidationResult> {
199        let mut input = String::new();
200
201        if reader.read_line(&mut input)? == 0 {
202            return Err(ReadlineError::Eof);
203        }
204
205        validator.validate_input(&input)
206    }
207
208    #[test]
209    fn unescape_readline_input() -> Result<()> {
210        let validator = CliHelper::default();
211
212        // should be valid
213        let result = readline_direct(
214             Cursor::new(
215                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',');"
216                     .as_bytes(),
217             ),
218             &validator,
219         )?;
220        assert!(matches!(result, ValidationResult::Valid(None)));
221
222        let result = readline_direct(
223             Cursor::new(
224                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\0');"
225                     .as_bytes()),
226             &validator,
227         )?;
228        assert!(matches!(result, ValidationResult::Valid(None)));
229
230        let result = readline_direct(
231             Cursor::new(
232                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\n');"
233                     .as_bytes()),
234             &validator,
235         )?;
236        assert!(matches!(result, ValidationResult::Valid(None)));
237
238        let result = readline_direct(
239             Cursor::new(
240                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\r');"
241                     .as_bytes()),
242             &validator,
243         )?;
244        assert!(matches!(result, ValidationResult::Valid(None)));
245
246        let result = readline_direct(
247             Cursor::new(
248                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\t');"
249                     .as_bytes()),
250             &validator,
251         )?;
252        assert!(matches!(result, ValidationResult::Valid(None)));
253
254        let result = readline_direct(
255             Cursor::new(
256                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\\');"
257                     .as_bytes()),
258             &validator,
259         )?;
260        assert!(matches!(result, ValidationResult::Valid(None)));
261
262        let result = readline_direct(
263             Cursor::new(
264                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',,');"
265                     .as_bytes()),
266             &validator,
267         )?;
268        assert!(matches!(result, ValidationResult::Valid(None)));
269
270        let result = readline_direct(
271            Cursor::new(
272                r"select '\', '\\', '\\\\\', 'dsdsds\\\\', '\t', '\0', '\n';".as_bytes(),
273            ),
274            &validator,
275        )?;
276        assert!(matches!(result, ValidationResult::Valid(None)));
277
278        Ok(())
279    }
280
281    #[test]
282    fn sql_dialect() -> Result<()> {
283        let mut validator = CliHelper::default();
284
285        // should be invalid in generic dialect
286        let result =
287            readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
288        assert!(
289            matches!(result, ValidationResult::Invalid(Some(e)) if e.contains("Invalid statement"))
290        );
291
292        // valid in postgresql dialect
293        validator.set_dialect(&Dialect::PostgreSQL);
294        let result =
295            readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
296        assert!(matches!(result, ValidationResult::Valid(None)));
297
298        Ok(())
299    }
300
301    #[test]
302    fn test_split_from_semicolon() {
303        let sql = "SELECT 1; SELECT 2;";
304        let expected = vec!["SELECT 1;", "SELECT 2;"];
305        assert_eq!(split_from_semicolon(sql), expected);
306
307        let sql = r#"SELECT ";";"#;
308        let expected = vec![r#"SELECT ";";"#];
309        assert_eq!(split_from_semicolon(sql), expected);
310
311        let sql = "SELECT ';';";
312        let expected = vec!["SELECT ';';"];
313        assert_eq!(split_from_semicolon(sql), expected);
314
315        let sql = r#"SELECT 1; SELECT 'value;value'; SELECT 1 as "text;text";"#;
316        let expected = vec![
317            "SELECT 1;",
318            "SELECT 'value;value';",
319            r#"SELECT 1 as "text;text";"#,
320        ];
321        assert_eq!(split_from_semicolon(sql), expected);
322
323        let sql = "";
324        let expected: Vec<String> = Vec::new();
325        assert_eq!(split_from_semicolon(sql), expected);
326
327        let sql = "SELECT 1";
328        let expected = vec!["SELECT 1;"];
329        assert_eq!(split_from_semicolon(sql), expected);
330
331        let sql = "SELECT 1;   ";
332        let expected = vec!["SELECT 1;"];
333        assert_eq!(split_from_semicolon(sql), expected);
334    }
335}