Skip to main content

datafusion_cli/
helper.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Helper that helps with interactive editing, including multi-line parsing and validation,
19//! and auto-completion for file name during creating external table.
20
21use std::borrow::Cow;
22use std::cell::Cell;
23
24use crate::highlighter::{Color, NoSyntaxHighlighter, SyntaxHighlighter};
25
26use datafusion::sql::parser::{DFParser, Statement};
27use datafusion::sql::sqlparser::dialect::dialect_from_str;
28use datafusion_common::config::Dialect;
29
30use rustyline::completion::{Completer, FilenameCompleter, Pair};
31use rustyline::error::ReadlineError;
32use rustyline::highlight::{CmdKind, Highlighter};
33use rustyline::hint::Hinter;
34use rustyline::validate::{ValidationContext, ValidationResult, Validator};
35use rustyline::{Context, Helper, Result};
36
37/// Default suggestion shown when the input line is empty.
38const DEFAULT_HINT_SUGGESTION: &str = " \\? for help, \\q to quit";
39
40pub struct CliHelper {
41    completer: FilenameCompleter,
42    dialect: Dialect,
43    highlighter: Box<dyn Highlighter>,
44    /// Tracks whether to show the default hint. Set to `false` once the user
45    /// types anything, so the hint doesn't reappear after deleting back to
46    /// an empty line. Reset to `true` when the line is submitted.
47    show_hint: Cell<bool>,
48}
49
50impl CliHelper {
51    pub fn new(dialect: &Dialect, color: bool) -> Self {
52        let highlighter: Box<dyn Highlighter> = if !color {
53            Box::new(NoSyntaxHighlighter {})
54        } else {
55            Box::new(SyntaxHighlighter::new(dialect))
56        };
57        Self {
58            completer: FilenameCompleter::new(),
59            dialect: *dialect,
60            highlighter,
61            show_hint: Cell::new(true),
62        }
63    }
64
65    pub fn set_dialect(&mut self, dialect: &Dialect) {
66        if *dialect != self.dialect {
67            self.dialect = *dialect;
68        }
69    }
70
71    /// Re-enable the default hint for the next prompt.
72    pub fn reset_hint(&self) {
73        self.show_hint.set(true);
74    }
75
76    fn validate_input(&self, input: &str) -> Result<ValidationResult> {
77        if let Some(sql) = input.strip_suffix(';') {
78            let dialect = match dialect_from_str(self.dialect) {
79                Some(dialect) => dialect,
80                None => {
81                    return Ok(ValidationResult::Invalid(Some(format!(
82                        "  🤔 Invalid dialect: {}",
83                        self.dialect
84                    ))));
85                }
86            };
87            let lines = split_from_semicolon(sql);
88            for line in lines {
89                match DFParser::parse_sql_with_dialect(&line, dialect.as_ref()) {
90                    Ok(statements) if statements.is_empty() => {
91                        return Ok(ValidationResult::Invalid(Some(
92                            "  🤔 You entered an empty statement".to_string(),
93                        )));
94                    }
95                    Ok(_statements) => {}
96                    Err(err) => {
97                        return Ok(ValidationResult::Invalid(Some(format!(
98                            "  🤔 Invalid statement: {err}",
99                        ))));
100                    }
101                }
102            }
103            Ok(ValidationResult::Valid(None))
104        } else if input.starts_with('\\') {
105            // command
106            Ok(ValidationResult::Valid(None))
107        } else {
108            Ok(ValidationResult::Incomplete)
109        }
110    }
111}
112
113impl Default for CliHelper {
114    fn default() -> Self {
115        Self::new(&Dialect::Generic, false)
116    }
117}
118
119impl Highlighter for CliHelper {
120    fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> {
121        self.highlighter.highlight(line, pos)
122    }
123
124    fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool {
125        self.highlighter.highlight_char(line, pos, kind)
126    }
127}
128
129impl Hinter for CliHelper {
130    type Hint = String;
131
132    fn hint(&self, line: &str, _pos: usize, _ctx: &Context<'_>) -> Option<String> {
133        if !line.is_empty() {
134            self.show_hint.set(false);
135        }
136        (self.show_hint.get() && line.trim().is_empty())
137            .then(|| Color::gray(DEFAULT_HINT_SUGGESTION))
138    }
139}
140
141/// returns true if the current position is after the open quote for
142/// creating an external table.
143fn is_open_quote_for_location(line: &str, pos: usize) -> bool {
144    let mut sql = line[..pos].to_string();
145    sql.push('\'');
146    DFParser::parse_sql(&sql).is_ok_and(|stmts| {
147        matches!(stmts.back(), Some(Statement::CreateExternalTable(_)))
148    })
149}
150
151impl Completer for CliHelper {
152    type Candidate = Pair;
153
154    fn complete(
155        &self,
156        line: &str,
157        pos: usize,
158        ctx: &Context<'_>,
159    ) -> std::result::Result<(usize, Vec<Pair>), ReadlineError> {
160        if is_open_quote_for_location(line, pos) {
161            self.completer.complete(line, pos, ctx)
162        } else {
163            Ok((0, Vec::with_capacity(0)))
164        }
165    }
166}
167
168impl Validator for CliHelper {
169    fn validate(&self, ctx: &mut ValidationContext<'_>) -> Result<ValidationResult> {
170        let input = ctx.input().trim_end();
171        let result = self.validate_input(input);
172        self.reset_hint();
173        result
174    }
175}
176
177impl Helper for CliHelper {}
178
179/// Splits a string which consists of multiple queries.
180pub(crate) fn split_from_semicolon(sql: &str) -> Vec<String> {
181    let mut commands = Vec::new();
182    let mut current_command = String::new();
183    let mut in_single_quote = false;
184    let mut in_double_quote = false;
185
186    for c in sql.chars() {
187        if c == '\'' && !in_double_quote {
188            in_single_quote = !in_single_quote;
189        } else if c == '"' && !in_single_quote {
190            in_double_quote = !in_double_quote;
191        }
192
193        if c == ';' && !in_single_quote && !in_double_quote {
194            if !current_command.trim().is_empty() {
195                commands.push(format!("{};", current_command.trim()));
196                current_command.clear();
197            }
198        } else {
199            current_command.push(c);
200        }
201    }
202
203    if !current_command.trim().is_empty() {
204        commands.push(format!("{};", current_command.trim()));
205    }
206
207    commands
208}
209
210#[cfg(test)]
211mod tests {
212    use std::io::{BufRead, Cursor};
213
214    use super::*;
215
216    fn readline_direct(
217        mut reader: impl BufRead,
218        validator: &CliHelper,
219    ) -> Result<ValidationResult> {
220        let mut input = String::new();
221
222        if reader.read_line(&mut input)? == 0 {
223            return Err(ReadlineError::Eof);
224        }
225
226        validator.validate_input(&input)
227    }
228
229    #[test]
230    fn unescape_readline_input() -> Result<()> {
231        let validator = CliHelper::default();
232
233        // should be valid
234        let result = readline_direct(
235             Cursor::new(
236                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',');"
237                     .as_bytes(),
238             ),
239             &validator,
240         )?;
241        assert!(matches!(result, ValidationResult::Valid(None)));
242
243        let result = readline_direct(
244             Cursor::new(
245                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\0');"
246                     .as_bytes()),
247             &validator,
248         )?;
249        assert!(matches!(result, ValidationResult::Valid(None)));
250
251        let result = readline_direct(
252             Cursor::new(
253                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\n');"
254                     .as_bytes()),
255             &validator,
256         )?;
257        assert!(matches!(result, ValidationResult::Valid(None)));
258
259        let result = readline_direct(
260             Cursor::new(
261                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\r');"
262                     .as_bytes()),
263             &validator,
264         )?;
265        assert!(matches!(result, ValidationResult::Valid(None)));
266
267        let result = readline_direct(
268             Cursor::new(
269                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\t');"
270                     .as_bytes()),
271             &validator,
272         )?;
273        assert!(matches!(result, ValidationResult::Valid(None)));
274
275        let result = readline_direct(
276             Cursor::new(
277                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\\');"
278                     .as_bytes()),
279             &validator,
280         )?;
281        assert!(matches!(result, ValidationResult::Valid(None)));
282
283        let result = readline_direct(
284             Cursor::new(
285                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',,');"
286                     .as_bytes()),
287             &validator,
288         )?;
289        assert!(matches!(result, ValidationResult::Valid(None)));
290
291        let result = readline_direct(
292            Cursor::new(
293                r"select '\', '\\', '\\\\\', 'dsdsds\\\\', '\t', '\0', '\n';".as_bytes(),
294            ),
295            &validator,
296        )?;
297        assert!(matches!(result, ValidationResult::Valid(None)));
298
299        Ok(())
300    }
301
302    #[test]
303    fn sql_dialect() -> Result<()> {
304        let mut validator = CliHelper::default();
305
306        // should be invalid in generic dialect
307        let result =
308            readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
309        assert!(
310            matches!(result, ValidationResult::Invalid(Some(e)) if e.contains("Invalid statement"))
311        );
312
313        // valid in postgresql dialect
314        validator.set_dialect(&Dialect::PostgreSQL);
315        let result =
316            readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
317        assert!(matches!(result, ValidationResult::Valid(None)));
318
319        Ok(())
320    }
321
322    #[test]
323    fn test_split_from_semicolon() {
324        let sql = "SELECT 1; SELECT 2;";
325        let expected = vec!["SELECT 1;", "SELECT 2;"];
326        assert_eq!(split_from_semicolon(sql), expected);
327
328        let sql = r#"SELECT ";";"#;
329        let expected = vec![r#"SELECT ";";"#];
330        assert_eq!(split_from_semicolon(sql), expected);
331
332        let sql = "SELECT ';';";
333        let expected = vec!["SELECT ';';"];
334        assert_eq!(split_from_semicolon(sql), expected);
335
336        let sql = r#"SELECT 1; SELECT 'value;value'; SELECT 1 as "text;text";"#;
337        let expected = vec![
338            "SELECT 1;",
339            "SELECT 'value;value';",
340            r#"SELECT 1 as "text;text";"#,
341        ];
342        assert_eq!(split_from_semicolon(sql), expected);
343
344        let sql = "";
345        let expected: Vec<String> = Vec::new();
346        assert_eq!(split_from_semicolon(sql), expected);
347
348        let sql = "SELECT 1";
349        let expected = vec!["SELECT 1;"];
350        assert_eq!(split_from_semicolon(sql), expected);
351
352        let sql = "SELECT 1;   ";
353        let expected = vec!["SELECT 1;"];
354        assert_eq!(split_from_semicolon(sql), expected);
355    }
356}