Skip to main content

zql_cli/db/
driver.rs

1use crate::db::context::Context;
2use crate::db::database::Plugin;
3use crate::db::plugin::basic::BasicPlugin;
4use crate::db::plugin::mysql::MySQLPlugin;
5use crate::db::plugin::sqlite::SQLitePlugin;
6use crate::db::plugin::tsql::TSQLPlugin;
7use crate::error::{EngineError, MyError, MyResult};
8use crate::{regex, regex_insensitive};
9use indexmap::IndexMap;
10use itertools::Itertools;
11use regex::Match;
12use serde::{Deserialize, Serialize};
13use std::mem;
14
15#[derive(Debug, Default, Deserialize, PartialEq, Serialize)]
16pub struct Driver {
17    pub default: bool,
18    pub odbc: IndexMap<String, String>,
19}
20
21impl Driver {
22    pub fn new(text: &str) -> MyResult<Self> {
23        let odbc = Self::parse_odbc(text)?;
24        if !odbc.is_empty() {
25            let driver = Self { odbc, default: false };
26            Ok(driver)
27        } else {
28            Err(MyError::Engine(EngineError::ParseDriver(text.to_string())))
29        }
30    }
31
32    pub fn with_default(mut self, default: bool) -> Self {
33        self.default = default;
34        self
35    }
36
37    pub fn with_password(mut self) -> Self {
38        for (key, value) in self.odbc.iter_mut() {
39            // According to www.connectionstrings.com, every database
40            // system which accepts a password in its ODBC string (with
41            // one exception) uses some case variant of "Password" or
42            // "Pwd".
43            if key.eq_ignore_ascii_case("Password") || key.eq_ignore_ascii_case("Pwd") {
44                if value.is_empty() {
45                    if let Ok(mut password) = rpassword::prompt_password("Password? ") {
46                        mem::swap(value, &mut password);
47                    }
48                }
49                break;
50            }
51        }
52        self
53    }
54
55    #[cfg(debug_assertions)]
56    pub fn is_memory(&self) -> bool {
57        if let Some(driver) = self.odbc.get("Driver") {
58            if let Some(database) = self.odbc.get("Database") {
59                return driver == "SQLite3" && database == ":memory:";
60            }
61        }
62        false
63    }
64
65    fn parse_odbc(odbc: &str) -> MyResult<IndexMap<String, String>> {
66        let regex = regex!(r"^(\w+)=(.*)$");
67        let convert = |m: Option<Match>| {
68            m.as_ref().map(Match::as_str).map(String::from).unwrap_or_default()
69        };
70        let split = |token: &str| {
71            match regex.captures(token) {
72                Some(c) => Ok((convert(c.get(1)), convert(c.get(2)))),
73                None => Err(MyError::Engine(EngineError::ParseDriver(odbc.to_string()))),
74            }
75        };
76        odbc.split(';').into_iter().filter(|x| !x.is_empty()).map(split).collect()
77    }
78
79    pub fn format_odbc(&self) -> String {
80        self.odbc.iter().map(|(k, v)| format!("{}={}", k, v)).join(";")
81    }
82
83    // Driver=SQLite3;...
84    // Driver={MySQL ODBC 9.4 Unicode Driver};...
85    // Driver={ODBC Driver 17 for SQL Server};...
86
87    pub fn create_plugin(&self) -> Box<dyn Plugin> {
88        let sqlite_regex = regex_insensitive!(r"\bSQLite3\b");
89        let mysql_regex = regex_insensitive!(r"\bMySQL\b");
90        let tsql_regex = regex_insensitive!(r"\bSQL Server\b");
91        if let Some(driver) = self.odbc.get("Driver") {
92            if sqlite_regex.is_match(driver) {
93                return Box::new(SQLitePlugin::new());
94            } else if mysql_regex.is_match(driver) {
95                return Box::new(MySQLPlugin::new());
96            } else if tsql_regex.is_match(driver) {
97                return Box::new(TSQLPlugin::new());
98            }
99        }
100        Box::new(BasicPlugin::new(";"))
101    }
102
103    pub fn create_prompt(&self, context: Context) -> String {
104        let mut database = None;
105        let mut server = None;
106        for (key, value) in self.odbc.iter() {
107            if key.eq_ignore_ascii_case("Database") {
108                database = Some(value.clone());
109            }
110            if key.eq_ignore_ascii_case("Server") {
111                server = Some(value.clone());
112            }
113        }
114        if let Some(database) = context.database && !database.is_empty() {
115            if let Some(server) = server {
116                format!("zql {}@{}> ", database, server)
117            } else {
118                format!("zql {}> ", database)
119            }
120        } else if let Some(mut database) = database {
121            if let Some(index) = database.rfind(|c| c == '/' || c == '\\') {
122                let index = index + 1;
123                database = database[index..].to_string();
124            }
125            if let Some(server) = server {
126                format!("zql {}@{}> ", database, server)
127            } else {
128                format!("zql {}> ", database)
129            }
130        } else {
131            if let Some(server) = server {
132                format!("zql {}> ", server)
133            } else {
134                String::from("zql> ")
135            }
136        }
137    }
138}
139
140// noinspection DuplicatedCode
141#[cfg(test)]
142mod tests {
143    use crate::db::context::Context;
144    use crate::db::driver::Driver;
145    use crate::error::{MyError, MyResult};
146    use indexmap::IndexMap;
147    use pretty_assertions::assert_eq;
148
149    #[test]
150    fn test_connection_string_is_parsed_no_semicolons() -> MyResult<()> {
151        let expected = IndexMap::<String, String>::from([
152            ("DSN".to_string(), "datasource".to_string()),
153            ("Uid".to_string(), "hwalters".to_string()),
154            ("Pwd".to_string(), "password".to_string()),
155        ]);
156        let driver = Driver::new("DSN=datasource;Uid=hwalters;Pwd=password")?;
157        assert_eq!(driver.odbc, expected);
158        Ok(())
159    }
160
161    #[test]
162    fn test_connection_string_is_parsed_with_semicolons() -> MyResult<()> {
163        let expected = IndexMap::<String, String>::from([
164            ("DSN".to_string(), "datasource".to_string()),
165            ("Uid".to_string(), "hwalters".to_string()),
166            ("Pwd".to_string(), "password".to_string()),
167        ]);
168        let driver = Driver::new(";;DSN=datasource;;Uid=hwalters;;Pwd=password;;")?;
169        assert_eq!(driver.odbc, expected);
170        Ok(())
171    }
172
173    #[test]
174    fn test_connection_string_not_parsed_if_empty() {
175        let driver = Driver::new("");
176        let error = driver.err().as_ref().map(MyError::to_string);
177        assert_eq!(error, Some("Invalid ODBC string \"\"".to_string()));
178    }
179
180    #[test]
181    fn test_connection_string_not_parsed_if_invalid() {
182        let driver = Driver::new("DSN=datasource;Uid");
183        let error = driver.err().as_ref().map(MyError::to_string);
184        assert_eq!(error, Some("Invalid ODBC string \"DSN=datasource;Uid\"".to_string()));
185    }
186
187    #[test]
188    fn test_connection_string_is_formatted() {
189        let odbc = IndexMap::<String, String>::from([
190            ("DSN".to_string(), "datasource".to_string()),
191            ("Uid".to_string(), "hwalters".to_string()),
192            ("Pwd".to_string(), "password".to_string()),
193        ]);
194        let driver = Driver { odbc, default: false };
195        assert_eq!(driver.format_odbc(), "DSN=datasource;Uid=hwalters;Pwd=password");
196    }
197
198    #[test]
199    fn test_prompt_contains_database_server() -> MyResult<()> {
200        let driver = Driver::new("Driver=MySQL;Server=localhost;Database=hugos")?;
201        assert_eq!(create_prompt(&driver, None), "zql hugos@localhost> ");
202        assert_eq!(create_prompt(&driver, Some("")), "zql hugos@localhost> ");
203        assert_eq!(create_prompt(&driver, Some("sys")), "zql sys@localhost> ");
204        Ok(())
205    }
206
207    #[test]
208    fn test_prompt_contains_server_only() -> MyResult<()> {
209        let driver = Driver::new("DRIVER=MySQL;SERVER=localhost")?;
210        assert_eq!(create_prompt(&driver, None), "zql localhost> ");
211        assert_eq!(create_prompt(&driver, Some("")), "zql localhost> ");
212        assert_eq!(create_prompt(&driver, Some("sys")), "zql sys@localhost> ");
213        Ok(())
214    }
215
216    #[test]
217    fn test_prompt_contains_database_only() -> MyResult<()> {
218        let driver = Driver::new("DRIVER=MySQL;DATABASE=hugos")?;
219        assert_eq!(create_prompt(&driver, None), "zql hugos> ");
220        assert_eq!(create_prompt(&driver, Some("")), "zql hugos> ");
221        assert_eq!(create_prompt(&driver, Some("sys")), "zql sys> ");
222        Ok(())
223    }
224
225    #[test]
226    fn test_prompt_contains_no_extra_fields() -> MyResult<()> {
227        let driver = Driver::new("DRIVER=MySQL")?;
228        assert_eq!(create_prompt(&driver, None), "zql> ");
229        assert_eq!(create_prompt(&driver, Some("")), "zql> ");
230        assert_eq!(create_prompt(&driver, Some("sys")), "zql sys> ");
231        Ok(())
232    }
233
234    #[test]
235    fn test_prompt_strips_linux_directory() -> MyResult<()> {
236        let driver = Driver::new("Driver=SQLite3;Database=/path/to/hugos.db")?;
237        assert_eq!(create_prompt(&driver, None), "zql hugos.db> ");
238        Ok(())
239    }
240
241    #[test]
242    fn test_prompt_strips_windows_directory() -> MyResult<()> {
243        let driver = Driver::new(r"Driver=SQLite3;Database=C:\Path\To\hugos.db")?;
244        assert_eq!(create_prompt(&driver, None), "zql hugos.db> ");
245        Ok(())
246    }
247
248    fn create_prompt(driver: &Driver, database: Option<&str>) -> String {
249        let context = Context::new(database, None);
250        driver.create_prompt(context)
251    }
252}