1use crate::db::context::Context;
2use crate::db::database::Plugin;
3use crate::db::plugin::basic::BasicPlugin;
4use crate::db::plugin::mysql::MySQLPlugin;
5use crate::db::plugin::sqlite::SQLitePlugin;
6use crate::db::plugin::tsql::TSQLPlugin;
7use crate::error::{EngineError, MyError, MyResult};
8use crate::{regex, regex_insensitive};
9use indexmap::IndexMap;
10use itertools::Itertools;
11use regex::Match;
12use serde::{Deserialize, Serialize};
13use std::mem;
14
15#[derive(Debug, Default, Deserialize, PartialEq, Serialize)]
16pub struct Driver {
17 pub default: bool,
18 pub odbc: IndexMap<String, String>,
19}
20
21impl Driver {
22 pub fn new(text: &str) -> MyResult<Self> {
23 let odbc = Self::parse_odbc(text)?;
24 if !odbc.is_empty() {
25 let driver = Self { odbc, default: false };
26 Ok(driver)
27 } else {
28 Err(MyError::Engine(EngineError::ParseDriver(text.to_string())))
29 }
30 }
31
32 pub fn with_default(mut self, default: bool) -> Self {
33 self.default = default;
34 self
35 }
36
37 pub fn with_password(mut self) -> Self {
38 for (key, value) in self.odbc.iter_mut() {
39 if key.eq_ignore_ascii_case("Password") || key.eq_ignore_ascii_case("Pwd") {
44 if value.is_empty() {
45 if let Ok(mut password) = rpassword::prompt_password("Password? ") {
46 mem::swap(value, &mut password);
47 }
48 }
49 break;
50 }
51 }
52 self
53 }
54
55 #[cfg(debug_assertions)]
56 pub fn is_memory(&self) -> bool {
57 if let Some(driver) = self.odbc.get("Driver") {
58 if let Some(database) = self.odbc.get("Database") {
59 return driver == "SQLite3" && database == ":memory:";
60 }
61 }
62 false
63 }
64
65 fn parse_odbc(odbc: &str) -> MyResult<IndexMap<String, String>> {
66 let regex = regex!(r"^(\w+)=(.*)$");
67 let convert = |m: Option<Match>| {
68 m.as_ref().map(Match::as_str).map(String::from).unwrap_or_default()
69 };
70 let split = |token: &str| {
71 match regex.captures(token) {
72 Some(c) => Ok((convert(c.get(1)), convert(c.get(2)))),
73 None => Err(MyError::Engine(EngineError::ParseDriver(odbc.to_string()))),
74 }
75 };
76 odbc.split(';').into_iter().filter(|x| !x.is_empty()).map(split).collect()
77 }
78
79 pub fn format_odbc(&self) -> String {
80 self.odbc.iter().map(|(k, v)| format!("{}={}", k, v)).join(";")
81 }
82
83 pub fn create_plugin(&self) -> Box<dyn Plugin> {
88 let sqlite_regex = regex_insensitive!(r"\bSQLite3\b");
89 let mysql_regex = regex_insensitive!(r"\bMySQL\b");
90 let tsql_regex = regex_insensitive!(r"\bSQL Server\b");
91 if let Some(driver) = self.odbc.get("Driver") {
92 if sqlite_regex.is_match(driver) {
93 return Box::new(SQLitePlugin::new());
94 } else if mysql_regex.is_match(driver) {
95 return Box::new(MySQLPlugin::new());
96 } else if tsql_regex.is_match(driver) {
97 return Box::new(TSQLPlugin::new());
98 }
99 }
100 Box::new(BasicPlugin::new(";"))
101 }
102
103 pub fn create_prompt(&self, context: Context) -> String {
104 let mut database = None;
105 let mut server = None;
106 for (key, value) in self.odbc.iter() {
107 if key.eq_ignore_ascii_case("Database") {
108 database = Some(value.clone());
109 }
110 if key.eq_ignore_ascii_case("Server") {
111 server = Some(value.clone());
112 }
113 }
114 if let Some(database) = context.database && !database.is_empty() {
115 if let Some(server) = server {
116 format!("zql {}@{}> ", database, server)
117 } else {
118 format!("zql {}> ", database)
119 }
120 } else if let Some(mut database) = database {
121 if let Some(index) = database.rfind(|c| c == '/' || c == '\\') {
122 let index = index + 1;
123 database = database[index..].to_string();
124 }
125 if let Some(server) = server {
126 format!("zql {}@{}> ", database, server)
127 } else {
128 format!("zql {}> ", database)
129 }
130 } else {
131 if let Some(server) = server {
132 format!("zql {}> ", server)
133 } else {
134 String::from("zql> ")
135 }
136 }
137 }
138}
139
140#[cfg(test)]
142mod tests {
143 use crate::db::context::Context;
144 use crate::db::driver::Driver;
145 use crate::error::{MyError, MyResult};
146 use indexmap::IndexMap;
147 use pretty_assertions::assert_eq;
148
149 #[test]
150 fn test_connection_string_is_parsed_no_semicolons() -> MyResult<()> {
151 let expected = IndexMap::<String, String>::from([
152 ("DSN".to_string(), "datasource".to_string()),
153 ("Uid".to_string(), "hwalters".to_string()),
154 ("Pwd".to_string(), "password".to_string()),
155 ]);
156 let driver = Driver::new("DSN=datasource;Uid=hwalters;Pwd=password")?;
157 assert_eq!(driver.odbc, expected);
158 Ok(())
159 }
160
161 #[test]
162 fn test_connection_string_is_parsed_with_semicolons() -> MyResult<()> {
163 let expected = IndexMap::<String, String>::from([
164 ("DSN".to_string(), "datasource".to_string()),
165 ("Uid".to_string(), "hwalters".to_string()),
166 ("Pwd".to_string(), "password".to_string()),
167 ]);
168 let driver = Driver::new(";;DSN=datasource;;Uid=hwalters;;Pwd=password;;")?;
169 assert_eq!(driver.odbc, expected);
170 Ok(())
171 }
172
173 #[test]
174 fn test_connection_string_not_parsed_if_empty() {
175 let driver = Driver::new("");
176 let error = driver.err().as_ref().map(MyError::to_string);
177 assert_eq!(error, Some("Invalid ODBC string \"\"".to_string()));
178 }
179
180 #[test]
181 fn test_connection_string_not_parsed_if_invalid() {
182 let driver = Driver::new("DSN=datasource;Uid");
183 let error = driver.err().as_ref().map(MyError::to_string);
184 assert_eq!(error, Some("Invalid ODBC string \"DSN=datasource;Uid\"".to_string()));
185 }
186
187 #[test]
188 fn test_connection_string_is_formatted() {
189 let odbc = IndexMap::<String, String>::from([
190 ("DSN".to_string(), "datasource".to_string()),
191 ("Uid".to_string(), "hwalters".to_string()),
192 ("Pwd".to_string(), "password".to_string()),
193 ]);
194 let driver = Driver { odbc, default: false };
195 assert_eq!(driver.format_odbc(), "DSN=datasource;Uid=hwalters;Pwd=password");
196 }
197
198 #[test]
199 fn test_prompt_contains_database_server() -> MyResult<()> {
200 let driver = Driver::new("Driver=MySQL;Server=localhost;Database=hugos")?;
201 assert_eq!(create_prompt(&driver, None), "zql hugos@localhost> ");
202 assert_eq!(create_prompt(&driver, Some("")), "zql hugos@localhost> ");
203 assert_eq!(create_prompt(&driver, Some("sys")), "zql sys@localhost> ");
204 Ok(())
205 }
206
207 #[test]
208 fn test_prompt_contains_server_only() -> MyResult<()> {
209 let driver = Driver::new("DRIVER=MySQL;SERVER=localhost")?;
210 assert_eq!(create_prompt(&driver, None), "zql localhost> ");
211 assert_eq!(create_prompt(&driver, Some("")), "zql localhost> ");
212 assert_eq!(create_prompt(&driver, Some("sys")), "zql sys@localhost> ");
213 Ok(())
214 }
215
216 #[test]
217 fn test_prompt_contains_database_only() -> MyResult<()> {
218 let driver = Driver::new("DRIVER=MySQL;DATABASE=hugos")?;
219 assert_eq!(create_prompt(&driver, None), "zql hugos> ");
220 assert_eq!(create_prompt(&driver, Some("")), "zql hugos> ");
221 assert_eq!(create_prompt(&driver, Some("sys")), "zql sys> ");
222 Ok(())
223 }
224
225 #[test]
226 fn test_prompt_contains_no_extra_fields() -> MyResult<()> {
227 let driver = Driver::new("DRIVER=MySQL")?;
228 assert_eq!(create_prompt(&driver, None), "zql> ");
229 assert_eq!(create_prompt(&driver, Some("")), "zql> ");
230 assert_eq!(create_prompt(&driver, Some("sys")), "zql sys> ");
231 Ok(())
232 }
233
234 #[test]
235 fn test_prompt_strips_linux_directory() -> MyResult<()> {
236 let driver = Driver::new("Driver=SQLite3;Database=/path/to/hugos.db")?;
237 assert_eq!(create_prompt(&driver, None), "zql hugos.db> ");
238 Ok(())
239 }
240
241 #[test]
242 fn test_prompt_strips_windows_directory() -> MyResult<()> {
243 let driver = Driver::new(r"Driver=SQLite3;Database=C:\Path\To\hugos.db")?;
244 assert_eq!(create_prompt(&driver, None), "zql hugos.db> ");
245 Ok(())
246 }
247
248 fn create_prompt(driver: &Driver, database: Option<&str>) -> String {
249 let context = Context::new(database, None);
250 driver.create_prompt(context)
251 }
252}