frostbow_cli/
helper.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Helper that helps with interactive editing, including multi-line parsing and validation,
19//! and auto-completion for file name during creating external table.
20
21use std::borrow::Cow;
22
23use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter};
24use crate::iceberg::transform_iceberg_input;
25
26use datafusion::common::sql_datafusion_err;
27use datafusion::error::DataFusionError;
28use datafusion::sql::parser::{DFParser, Statement};
29use datafusion::sql::sqlparser::dialect::dialect_from_str;
30use datafusion::sql::sqlparser::parser::ParserError;
31
32use rustyline::completion::{Completer, FilenameCompleter, Pair};
33use rustyline::error::ReadlineError;
34use rustyline::highlight::Highlighter;
35use rustyline::hint::Hinter;
36use rustyline::validate::{ValidationContext, ValidationResult, Validator};
37use rustyline::{Context, Helper, Result};
38
39pub struct CliHelper {
40    completer: FilenameCompleter,
41    dialect: String,
42    highlighter: Box<dyn Highlighter>,
43}
44
45impl CliHelper {
46    pub fn new(dialect: &str, color: bool) -> Self {
47        let highlighter: Box<dyn Highlighter> = if !color {
48            Box::new(NoSyntaxHighlighter {})
49        } else {
50            Box::new(SyntaxHighlighter::new(dialect))
51        };
52        Self {
53            completer: FilenameCompleter::new(),
54            dialect: dialect.into(),
55            highlighter,
56        }
57    }
58
59    pub fn set_dialect(&mut self, dialect: &str) {
60        if dialect != self.dialect {
61            self.dialect = dialect.to_string();
62        }
63    }
64
65    fn validate_input(&self, input: &str) -> Result<ValidationResult> {
66        if let Some(sql) = input.strip_suffix(';') {
67            let sql = match unescape_input(sql) {
68                Ok(sql) => sql,
69                Err(err) => {
70                    return Ok(ValidationResult::Invalid(Some(format!(
71                        "  🤔 Invalid statement: {err}",
72                    ))))
73                }
74            };
75
76            let dialect = match dialect_from_str(&self.dialect) {
77                Some(dialect) => dialect,
78                None => {
79                    return Ok(ValidationResult::Invalid(Some(format!(
80                        "  🤔 Invalid dialect: {}",
81                        self.dialect
82                    ))))
83                }
84            };
85            let lines = split_from_semicolon(sql);
86            for line in lines {
87                match DFParser::parse_sql_with_dialect(
88                    &transform_iceberg_input(&line),
89                    dialect.as_ref(),
90                ) {
91                    Ok(statements) if statements.is_empty() => {
92                        return Ok(ValidationResult::Invalid(Some(
93                            "  🤔 You entered an empty statement".to_string(),
94                        )));
95                    }
96                    Ok(_statements) => {}
97                    Err(err) => {
98                        return Ok(ValidationResult::Invalid(Some(format!(
99                            "  🤔 Invalid statement: {err}",
100                        ))));
101                    }
102                }
103            }
104            Ok(ValidationResult::Valid(None))
105        } else if input.starts_with('\\') {
106            // command
107            Ok(ValidationResult::Valid(None))
108        } else {
109            Ok(ValidationResult::Incomplete)
110        }
111    }
112}
113
114impl Default for CliHelper {
115    fn default() -> Self {
116        Self::new("generic", false)
117    }
118}
119
120impl Highlighter for CliHelper {
121    fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> {
122        self.highlighter.highlight(line, pos)
123    }
124
125    fn highlight_char(&self, line: &str, pos: usize, forced: bool) -> bool {
126        self.highlighter.highlight_char(line, pos, forced)
127    }
128}
129
130impl Hinter for CliHelper {
131    type Hint = String;
132}
133
134/// returns true if the current position is after the open quote for
135/// creating an external table.
136fn is_open_quote_for_location(line: &str, pos: usize) -> bool {
137    let mut sql = line[..pos].to_string();
138    sql.push('\'');
139    if let Ok(stmts) = DFParser::parse_sql(&sql) {
140        if let Some(Statement::CreateExternalTable(_)) = stmts.back() {
141            return true;
142        }
143    }
144    false
145}
146
147impl Completer for CliHelper {
148    type Candidate = Pair;
149
150    fn complete(
151        &self,
152        line: &str,
153        pos: usize,
154        ctx: &Context<'_>,
155    ) -> std::result::Result<(usize, Vec<Pair>), ReadlineError> {
156        if is_open_quote_for_location(line, pos) {
157            self.completer.complete(line, pos, ctx)
158        } else {
159            Ok((0, Vec::with_capacity(0)))
160        }
161    }
162}
163
164impl Validator for CliHelper {
165    fn validate(&self, ctx: &mut ValidationContext<'_>) -> Result<ValidationResult> {
166        let input = ctx.input().trim_end();
167        self.validate_input(input)
168    }
169}
170
171impl Helper for CliHelper {}
172
173/// Unescape input string from readline.
174///
175/// The data read from stdio will be escaped, so we need to unescape the input before executing the input
176pub fn unescape_input(input: &str) -> datafusion::error::Result<String> {
177    let mut chars = input.chars();
178
179    let mut result = String::with_capacity(input.len());
180    while let Some(char) = chars.next() {
181        if char == '\\' {
182            if let Some(next_char) = chars.next() {
183                // https://static.rust-lang.org/doc/master/reference.html#literals
184                result.push(match next_char {
185                    '0' => '\0',
186                    'n' => '\n',
187                    'r' => '\r',
188                    't' => '\t',
189                    '\\' => '\\',
190                    _ => {
191                        return Err(sql_datafusion_err!(ParserError::TokenizerError(
192                            format!("unsupported escape char: '\\{}'", next_char)
193                        )))
194                    }
195                });
196            }
197        } else {
198            result.push(char);
199        }
200    }
201
202    Ok(result)
203}
204
205/// Splits a string which consists of multiple queries.
206pub(crate) fn split_from_semicolon(sql: String) -> Vec<String> {
207    let mut commands = Vec::new();
208    let mut current_command = String::new();
209    let mut in_single_quote = false;
210    let mut in_double_quote = false;
211
212    for c in sql.chars() {
213        if c == '\'' && !in_double_quote {
214            in_single_quote = !in_single_quote;
215        } else if c == '"' && !in_single_quote {
216            in_double_quote = !in_double_quote;
217        }
218
219        if c == ';' && !in_single_quote && !in_double_quote {
220            if !current_command.trim().is_empty() {
221                commands.push(format!("{};", current_command.trim()));
222                current_command.clear();
223            }
224        } else {
225            current_command.push(c);
226        }
227    }
228
229    if !current_command.trim().is_empty() {
230        commands.push(format!("{};", current_command.trim()));
231    }
232
233    commands
234}
235
236#[cfg(test)]
237mod tests {
238    use std::io::{BufRead, Cursor};
239
240    use super::*;
241
242    fn readline_direct(
243        mut reader: impl BufRead,
244        validator: &CliHelper,
245    ) -> Result<ValidationResult> {
246        let mut input = String::new();
247
248        if reader.read_line(&mut input)? == 0 {
249            return Err(ReadlineError::Eof);
250        }
251
252        validator.validate_input(&input)
253    }
254
255    #[test]
256    fn unescape_readline_input() -> Result<()> {
257        let validator = CliHelper::default();
258
259        // should be valid
260        let result = readline_direct(
261             Cursor::new(
262                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',');"
263                     .as_bytes(),
264             ),
265             &validator,
266         )?;
267        assert!(matches!(result, ValidationResult::Valid(None)));
268
269        let result = readline_direct(
270             Cursor::new(
271                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\0');"
272                     .as_bytes()),
273             &validator,
274         )?;
275        assert!(matches!(result, ValidationResult::Valid(None)));
276
277        let result = readline_direct(
278             Cursor::new(
279                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\n');"
280                     .as_bytes()),
281             &validator,
282         )?;
283        assert!(matches!(result, ValidationResult::Valid(None)));
284
285        let result = readline_direct(
286             Cursor::new(
287                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\r');"
288                     .as_bytes()),
289             &validator,
290         )?;
291        assert!(matches!(result, ValidationResult::Valid(None)));
292
293        let result = readline_direct(
294             Cursor::new(
295                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\t');"
296                     .as_bytes()),
297             &validator,
298         )?;
299        assert!(matches!(result, ValidationResult::Valid(None)));
300
301        let result = readline_direct(
302             Cursor::new(
303                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\\');"
304                     .as_bytes()),
305             &validator,
306         )?;
307        assert!(matches!(result, ValidationResult::Valid(None)));
308
309        let result = readline_direct(
310             Cursor::new(
311                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',,');"
312                     .as_bytes()),
313             &validator,
314         )?;
315        assert!(matches!(result, ValidationResult::Valid(None)));
316
317        // should be invalid
318        let result = readline_direct(
319             Cursor::new(
320                 r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\u{07}');"
321                     .as_bytes()),
322             &validator,
323         )?;
324        assert!(matches!(result, ValidationResult::Invalid(Some(_))));
325
326        Ok(())
327    }
328
329    #[test]
330    fn sql_dialect() -> Result<()> {
331        let mut validator = CliHelper::default();
332
333        // should be invalid in generic dialect
334        let result =
335            readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
336        assert!(
337            matches!(result, ValidationResult::Invalid(Some(e)) if e.contains("Invalid statement"))
338        );
339
340        // valid in postgresql dialect
341        validator.set_dialect("postgresql");
342        let result =
343            readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
344        assert!(matches!(result, ValidationResult::Valid(None)));
345
346        Ok(())
347    }
348
349    #[test]
350    fn test_split_from_semicolon() {
351        let sql = "SELECT 1; SELECT 2;";
352        let expected = vec!["SELECT 1;", "SELECT 2;"];
353        assert_eq!(split_from_semicolon(sql.to_string()), expected);
354
355        let sql = r#"SELECT ";";"#;
356        let expected = vec![r#"SELECT ";";"#];
357        assert_eq!(split_from_semicolon(sql.to_string()), expected);
358
359        let sql = "SELECT ';';";
360        let expected = vec!["SELECT ';';"];
361        assert_eq!(split_from_semicolon(sql.to_string()), expected);
362
363        let sql = r#"SELECT 1; SELECT 'value;value'; SELECT 1 as "text;text";"#;
364        let expected = vec![
365            "SELECT 1;",
366            "SELECT 'value;value';",
367            r#"SELECT 1 as "text;text";"#,
368        ];
369        assert_eq!(split_from_semicolon(sql.to_string()), expected);
370
371        let sql = "";
372        let expected: Vec<String> = Vec::new();
373        assert_eq!(split_from_semicolon(sql.to_string()), expected);
374
375        let sql = "SELECT 1";
376        let expected = vec!["SELECT 1;"];
377        assert_eq!(split_from_semicolon(sql.to_string()), expected);
378
379        let sql = "SELECT 1;   ";
380        let expected = vec!["SELECT 1;"];
381        assert_eq!(split_from_semicolon(sql.to_string()), expected);
382    }
383}