1use std::cell::RefCell;
23use std::rc::Rc;
24
25use mysql as my;
26use mysql_common::constants::ColumnType;
27
28use rdbc;
29
30use sqlparser::dialect::MySqlDialect;
31use sqlparser::tokenizer::{Token, Tokenizer, Word};
32
33fn to_rdbc_err(e: &my::error::Error) -> rdbc::Error {
35 rdbc::Error::General(format!("{:?}", e))
36}
37
38pub struct MySQLDriver {}
39
40impl MySQLDriver {
41 pub fn new() -> Self {
42 MySQLDriver {}
43 }
44}
45
46impl rdbc::Driver for MySQLDriver {
47 fn connect(&self, url: &str) -> rdbc::Result<Rc<RefCell<dyn rdbc::Connection + 'static>>> {
48 let opts = my::Opts::from_url(&url).expect("DATABASE_URL invalid");
49 my::Conn::new(opts)
50 .map_err(|e| to_rdbc_err(&e))
51 .map(|conn| {
52 Rc::new(RefCell::new(MySQLConnection { conn })) as Rc<RefCell<dyn rdbc::Connection>>
53 })
54 }
55}
56
57struct MySQLConnection {
58 conn: my::Conn,
59}
60
61impl rdbc::Connection for MySQLConnection {
62 fn create(&mut self, sql: &str) -> rdbc::Result<Rc<RefCell<dyn rdbc::Statement + '_>>> {
63 Ok(Rc::new(RefCell::new(MySQLStatement {
64 conn: &mut self.conn,
65 sql: sql.to_owned(),
66 })) as Rc<RefCell<dyn rdbc::Statement>>)
67 }
68
69 fn prepare(&mut self, sql: &str) -> rdbc::Result<Rc<RefCell<dyn rdbc::Statement + '_>>> {
70 self.conn
71 .prepare(&sql)
72 .and_then(|stmt| {
73 Ok(Rc::new(RefCell::new(MySQLPreparedStatement { stmt }))
74 as Rc<RefCell<dyn rdbc::Statement>>)
75 })
76 .map_err(|e| to_rdbc_err(&e))
77 }
78}
79
80struct MySQLStatement<'a> {
81 conn: &'a mut my::Conn,
82 sql: String,
83}
84
85impl<'a> rdbc::Statement for MySQLStatement<'a> {
86 fn execute_query(
87 &mut self,
88 params: &[rdbc::Value],
89 ) -> rdbc::Result<Rc<RefCell<dyn rdbc::ResultSet + '_>>> {
90 let sql = rewrite(&self.sql, params)?;
91 self.conn
92 .query(&sql)
93 .map_err(|e| to_rdbc_err(&e))
94 .map(|result| {
95 Rc::new(RefCell::new(MySQLResultSet { result, row: None }))
96 as Rc<RefCell<dyn rdbc::ResultSet>>
97 })
98 }
99
100 fn execute_update(&mut self, params: &[rdbc::Value]) -> rdbc::Result<u64> {
101 let sql = rewrite(&self.sql, params)?;
102 self.conn
103 .query(&sql)
104 .map_err(|e| to_rdbc_err(&e))
105 .map(|result| result.affected_rows())
106 }
107}
108
109struct MySQLPreparedStatement<'a> {
110 stmt: my::Stmt<'a>,
111}
112
113impl<'a> rdbc::Statement for MySQLPreparedStatement<'a> {
114 fn execute_query(
115 &mut self,
116 params: &[rdbc::Value],
117 ) -> rdbc::Result<Rc<RefCell<dyn rdbc::ResultSet + '_>>> {
118 self.stmt
119 .execute(to_my_params(params))
120 .map_err(|e| to_rdbc_err(&e))
121 .map(|result| {
122 Rc::new(RefCell::new(MySQLResultSet { result, row: None }))
123 as Rc<RefCell<dyn rdbc::ResultSet>>
124 })
125 }
126
127 fn execute_update(&mut self, params: &[rdbc::Value]) -> rdbc::Result<u64> {
128 self.stmt
129 .execute(to_my_params(params))
130 .map_err(|e| to_rdbc_err(&e))
131 .map(|result| result.affected_rows())
132 }
133}
134
135pub struct MySQLResultSet<'a> {
136 result: my::QueryResult<'a>,
137 row: Option<my::Result<my::Row>>,
138}
139
140impl<'a> rdbc::ResultSet for MySQLResultSet<'a> {
141 fn meta_data(&self) -> rdbc::Result<Rc<dyn rdbc::ResultSetMetaData>> {
142 let meta: Vec<rdbc::Column> = self
143 .result
144 .columns_ref()
145 .iter()
146 .map(|c| rdbc::Column::new(&c.name_str(), to_rdbc_type(&c.column_type())))
147 .collect();
148 Ok(Rc::new(meta))
149 }
150
151 fn next(&mut self) -> bool {
152 self.row = self.result.next();
153 self.row.is_some()
154 }
155
156 fn get_i8(&self, i: u64) -> rdbc::Result<Option<i8>> {
157 match &self.row {
158 Some(Ok(row)) => Ok(row.get(i as usize)),
159 _ => Ok(None),
160 }
161 }
162
163 fn get_i16(&self, i: u64) -> rdbc::Result<Option<i16>> {
164 match &self.row {
165 Some(Ok(row)) => Ok(row.get(i as usize)),
166 _ => Ok(None),
167 }
168 }
169
170 fn get_i32(&self, i: u64) -> rdbc::Result<Option<i32>> {
171 match &self.row {
172 Some(Ok(row)) => Ok(row.get(i as usize)),
173 _ => Ok(None),
174 }
175 }
176
177 fn get_i64(&self, i: u64) -> rdbc::Result<Option<i64>> {
178 match &self.row {
179 Some(Ok(row)) => Ok(row.get(i as usize)),
180 _ => Ok(None),
181 }
182 }
183
184 fn get_f32(&self, i: u64) -> rdbc::Result<Option<f32>> {
185 match &self.row {
186 Some(Ok(row)) => Ok(row.get(i as usize)),
187 _ => Ok(None),
188 }
189 }
190
191 fn get_f64(&self, i: u64) -> rdbc::Result<Option<f64>> {
192 match &self.row {
193 Some(Ok(row)) => Ok(row.get(i as usize)),
194 _ => Ok(None),
195 }
196 }
197
198 fn get_string(&self, i: u64) -> rdbc::Result<Option<String>> {
199 match &self.row {
200 Some(Ok(row)) => Ok(row.get(i as usize)),
201 _ => Ok(None),
202 }
203 }
204
205 fn get_bytes(&self, i: u64) -> rdbc::Result<Option<Vec<u8>>> {
206 match &self.row {
207 Some(Ok(row)) => Ok(row.get(i as usize)),
208 _ => Ok(None),
209 }
210 }
211}
212
213fn to_rdbc_type(t: &ColumnType) -> rdbc::DataType {
214 match t {
215 ColumnType::MYSQL_TYPE_FLOAT => rdbc::DataType::Float,
216 ColumnType::MYSQL_TYPE_DOUBLE => rdbc::DataType::Double,
217 _ => rdbc::DataType::Utf8,
219 }
220}
221
222fn to_my_value(v: &rdbc::Value) -> my::Value {
223 match v {
224 rdbc::Value::Int32(n) => my::Value::Int(*n as i64),
225 rdbc::Value::UInt32(n) => my::Value::Int(*n as i64),
226 rdbc::Value::String(s) => my::Value::from(s),
227 }
229}
230
231fn to_my_params(params: &[rdbc::Value]) -> my::Params {
233 my::Params::Positional(params.iter().map(|v| to_my_value(v)).collect())
234}
235
236fn rewrite(sql: &str, params: &[rdbc::Value]) -> rdbc::Result<String> {
237 let dialect = MySqlDialect {};
238 let mut tokenizer = Tokenizer::new(&dialect, sql);
239 tokenizer
240 .tokenize()
241 .and_then(|tokens| {
242 let mut i = 0;
243
244 let tokens: Vec<Token> = tokens
245 .iter()
246 .map(|t| match t {
247 Token::Char(c) if *c == '?' => {
248 let param = ¶ms[i];
249 i += 1;
250 Token::Word(Word {
251 value: param.to_string(),
252 quote_style: None,
253 keyword: "".to_owned(),
254 })
255 }
256 _ => t.clone(),
257 })
258 .collect();
259
260 let sql = tokens
261 .iter()
262 .map(|t| format!("{}", t))
263 .collect::<Vec<String>>()
264 .join("");
265
266 Ok(sql)
267 })
268 .map_err(|e| rdbc::Error::General(format!("{:?}", e)))
269}
270
271#[cfg(test)]
272mod tests {
273
274 use super::*;
275 use std::sync::Arc;
276
277 #[test]
278 fn execute_query() -> rdbc::Result<()> {
279 execute("DROP TABLE IF EXISTS test", &vec![])?;
280 execute("CREATE TABLE test (a INT NOT NULL)", &vec![])?;
281 execute(
282 "INSERT INTO test (a) VALUES (?)",
283 &vec![rdbc::Value::Int32(123)],
284 )?;
285
286 let driver: Arc<dyn rdbc::Driver> = Arc::new(MySQLDriver::new());
287 let conn = driver.connect("mysql://root:secret@127.0.0.1:3307/mysql")?;
288 let mut conn = conn.as_ref().borrow_mut();
289 let stmt = conn.prepare("SELECT a FROM test")?;
290 let mut stmt = stmt.borrow_mut();
291 let rs = stmt.execute_query(&vec![])?;
292
293 let mut rs = rs.as_ref().borrow_mut();
294
295 assert!(rs.next());
296 assert_eq!(Some(123), rs.get_i32(0)?);
297 assert!(!rs.next());
298
299 Ok(())
300 }
301
302 fn execute(sql: &str, values: &Vec<rdbc::Value>) -> rdbc::Result<u64> {
303 println!("Executing '{}' with {} params", sql, values.len());
304 let driver: Arc<dyn rdbc::Driver> = Arc::new(MySQLDriver::new());
305 let conn = driver.connect("mysql://root:secret@127.0.0.1:3307/mysql")?;
306 let mut conn = conn.as_ref().borrow_mut();
307 let stmt = conn.create(sql)?;
308 let mut stmt = stmt.borrow_mut();
309 stmt.execute_update(values)
310 }
311}