datastore_mysql/
mysql.rs

1use std::convert::Infallible;
2use std::fmt::{Debug, Write as _};
3
4use crate::{Comparator, Condition, Error, ErrorKind, Query, QueryKind};
5
6use async_trait::async_trait;
7use datastore::{DataDescriptor, DataQuery, Reader, Store, StoreData, TypeWriter, Write, Writer};
8use futures::TryStreamExt;
9use sqlx::{mysql::MySqlRow, MySql, Pool, Row};
10
11/// A pooled [`Store`] for the MySQL database.
12#[derive(Clone, Debug)]
13pub struct MySqlStore {
14    pool: Pool<MySql>,
15}
16
17#[async_trait]
18impl Store for MySqlStore {
19    type DataStore = Self;
20    type Error = Error;
21
22    async fn connect(uri: &str) -> Result<Self, Self::Error> {
23        let pool = Pool::connect(uri)
24            .await
25            .map_err(|err| Error(ErrorKind::Sqlx(err)))?;
26
27        Ok(Self { pool })
28    }
29
30    async fn create<T, D>(&self, descriptor: D) -> Result<(), Self::Error>
31    where
32        T: StoreData<Self> + Send + Sync + 'static,
33        D: DataDescriptor<T, Self> + Send + Sync,
34    {
35        let table = descriptor.ident();
36        let mut writer = MySqlTypeWriter::new(table, QueryKind::Create);
37        descriptor.write(&mut writer).unwrap();
38
39        let sql = writer.sql();
40        log::debug!("Executing sql CREATE query: \"{}\"", sql);
41
42        sqlx::query(&sql)
43            .execute(&self.pool)
44            .await
45            .map_err(|err| Error(ErrorKind::Sqlx(err)))?;
46        Ok(())
47    }
48
49    async fn delete<T, D, Q>(&self, descriptor: D, query: Q) -> Result<(), Self::Error>
50    where
51        T: StoreData<Self::DataStore> + Send + Sync + 'static,
52        D: DataDescriptor<T, Self::DataStore> + Send,
53        Q: DataQuery<T, Self::DataStore> + Send,
54    {
55        let table = descriptor.ident();
56        let mut writer = MySqlWriter::new(table, QueryKind::Delete);
57        writer.write_conditions = true;
58        query.write(&mut writer).unwrap();
59
60        let sql = writer.sql();
61        log::debug!("Executing sql DELETE query: \"{}\"", sql);
62
63        sqlx::query(&sql)
64            .execute(&self.pool)
65            .await
66            .map_err(|err| Error(ErrorKind::Sqlx(err)))?;
67        Ok(())
68    }
69
70    async fn get<T, D, Q>(&self, descriptor: D, query: Q) -> Result<Vec<T>, Self::Error>
71    where
72        T: StoreData<Self::DataStore> + Send + Sync + 'static,
73        D: DataDescriptor<T, Self::DataStore> + Send,
74        Q: DataQuery<T, Self::DataStore> + Send,
75    {
76        let table = descriptor.ident();
77
78        let mut writer = MySqlWriter::new(table, QueryKind::Select);
79        descriptor.write(&mut writer).unwrap();
80
81        writer.write_conditions = true;
82        query.write(&mut writer).unwrap();
83
84        let sql = writer.sql();
85        log::debug!("Executing sql SELECT query: \"{}\"", sql);
86
87        let mut rows = sqlx::query(&sql).fetch(&self.pool);
88
89        let mut entries = Vec::new();
90        while let Some(row) = rows
91            .try_next()
92            .await
93            .map_err(|err| Error(ErrorKind::Sqlx(err)))?
94        {
95            let mut reader = MySqlReader::new(row);
96            let data = T::read(&mut reader).unwrap();
97
98            entries.push(data);
99        }
100
101        Ok(entries)
102    }
103
104    async fn get_all<T, D>(&self, descriptor: D) -> Result<Vec<T>, Self::Error>
105    where
106        T: StoreData<Self::DataStore> + Send + Sync + 'static,
107        D: DataDescriptor<T, Self::DataStore> + Send + Sync,
108    {
109        let table = descriptor.ident();
110        let mut writer = MySqlTypeWriter::new(table, QueryKind::Select);
111        descriptor.write(&mut writer).unwrap();
112
113        let sql = writer.sql();
114        log::debug!("Executing sql SELECT query: \"{}\"", sql);
115
116        let mut rows = sqlx::query(&sql).fetch(&self.pool);
117
118        let mut entries = Vec::new();
119        while let Some(row) = rows
120            .try_next()
121            .await
122            .map_err(|err| Error(ErrorKind::Sqlx(err)))?
123        {
124            let mut reader = MySqlReader::new(row);
125            let data = T::read(&mut reader).map_err(|err| Error(ErrorKind::Sqlx(err)))?;
126
127            entries.push(data);
128        }
129
130        Ok(entries)
131    }
132
133    async fn get_one<T, D, Q>(&self, descriptor: D, query: Q) -> Result<Option<T>, Self::Error>
134    where
135        T: StoreData<Self::DataStore> + Send + Sync + 'static,
136        D: DataDescriptor<T, Self::DataStore> + Send,
137        Q: DataQuery<T, Self::DataStore> + Send,
138    {
139        let table = descriptor.ident();
140
141        let mut writer = MySqlWriter::new(table, QueryKind::Select);
142        descriptor.write(&mut writer).unwrap();
143
144        writer.write_conditions = true;
145        query.write(&mut writer).unwrap();
146
147        let sql = writer.sql();
148        log::debug!("Executing sql SELECT query: \"{}\"", sql);
149
150        let row = match sqlx::query(&sql).fetch_one(&self.pool).await {
151            Ok(row) => row,
152            Err(sqlx::Error::RowNotFound) => return Ok(None),
153            Err(err) => return Err(Error(ErrorKind::Sqlx(err))),
154        };
155
156        let mut reader = MySqlReader::new(row);
157        let data = T::read(&mut reader).map_err(|err| Error(ErrorKind::Sqlx(err)))?;
158
159        Ok(Some(data))
160    }
161
162    async fn insert<T, D>(&self, descriptor: D, data: T) -> Result<(), Self::Error>
163    where
164        T: StoreData<Self::DataStore> + Send + Sync + 'static,
165        D: DataDescriptor<T, Self::DataStore> + Send,
166    {
167        let table = descriptor.ident();
168
169        let mut writer = MySqlWriter::new(table, QueryKind::Insert);
170        data.write(&mut writer).unwrap();
171
172        let sql = writer.sql();
173        log::debug!("Executing sql INSERT query: \"{}\"", sql);
174
175        sqlx::query(&sql)
176            .execute(&self.pool)
177            .await
178            .map_err(|err| Error(ErrorKind::Sqlx(err)))?;
179        Ok(())
180    }
181}
182
183#[derive(Debug)]
184struct MySqlWriter<'a> {
185    query: Query<'a>,
186    key: &'static str,
187    write_conditions: bool,
188}
189
190impl<'a> MySqlWriter<'a> {
191    fn new(table: &'a str, kind: QueryKind) -> Self {
192        Self {
193            query: Query::new(table, kind),
194            key: "",
195            write_conditions: false,
196        }
197    }
198
199    fn sql(&self) -> String {
200        self.query.to_string()
201    }
202
203    fn write<T>(&mut self, val: T) -> Result<(), <Self as Writer<MySqlStore>>::Error>
204    where
205        T: ToString,
206    {
207        if self.write_conditions {
208            self.query.push_condition(Condition::new(
209                self.key.to_owned(),
210                val.to_string(),
211                Comparator::Eq,
212            ));
213        } else {
214            self.query.push(self.key.to_owned(), val.to_string());
215        }
216        Ok(())
217    }
218}
219
220impl<'a> Writer<MySqlStore> for MySqlWriter<'a> {
221    type Error = Infallible;
222
223    fn write_bool(&mut self, v: bool) -> Result<(), Self::Error> {
224        self.write(match v {
225            false => "FALSE",
226            true => "TRUE",
227        })
228    }
229
230    fn write_i8(&mut self, v: i8) -> Result<(), Self::Error> {
231        self.write(v)
232    }
233
234    fn write_i16(&mut self, v: i16) -> Result<(), Self::Error> {
235        self.write(v)
236    }
237
238    fn write_i32(&mut self, v: i32) -> Result<(), Self::Error> {
239        self.write(v)
240    }
241
242    fn write_i64(&mut self, v: i64) -> Result<(), Self::Error> {
243        self.write(v)
244    }
245
246    fn write_u8(&mut self, v: u8) -> Result<(), Self::Error> {
247        self.write(v)
248    }
249
250    fn write_u16(&mut self, v: u16) -> Result<(), Self::Error> {
251        self.write(v)
252    }
253
254    fn write_u32(&mut self, v: u32) -> Result<(), Self::Error> {
255        self.write(v)
256    }
257
258    fn write_u64(&mut self, v: u64) -> Result<(), Self::Error> {
259        self.write(v)
260    }
261
262    fn write_f32(&mut self, v: f32) -> Result<(), Self::Error> {
263        self.write(v)
264    }
265
266    fn write_f64(&mut self, v: f64) -> Result<(), Self::Error> {
267        self.write(v)
268    }
269
270    fn write_bytes(&mut self, v: &[u8]) -> Result<(), Self::Error> {
271        let mut string = String::with_capacity(2 * v.len() + "0x".len());
272        string.push_str("0x");
273        for byte in v {
274            let _ = write!(string, "{:02x}", byte);
275        }
276
277        self.write(string)
278    }
279
280    fn write_str(&mut self, v: &str) -> Result<(), Self::Error> {
281        self.write(format!("'{}'", v.replace('\'', "\'")))
282    }
283
284    fn write_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error>
285    where
286        T: ?Sized + Write<MySqlStore>,
287    {
288        self.key = key;
289        value.write(self)
290    }
291}
292
293impl<'a> TypeWriter<MySqlStore> for MySqlWriter<'a> {
294    type Error = Infallible;
295
296    fn write_bool(&mut self) -> Result<(), Self::Error> {
297        self.write("BOOLEAN")
298    }
299
300    fn write_i8(&mut self) -> Result<(), Self::Error> {
301        self.write("TINYINT")
302    }
303
304    fn write_i16(&mut self) -> Result<(), Self::Error> {
305        self.write("SMALLINT")
306    }
307
308    fn write_i32(&mut self) -> Result<(), Self::Error> {
309        self.write("INT")
310    }
311
312    fn write_i64(&mut self) -> Result<(), Self::Error> {
313        self.write("BIGINT")
314    }
315
316    fn write_u8(&mut self) -> Result<(), Self::Error> {
317        self.write("TINYINT UNSIGNED")
318    }
319
320    fn write_u16(&mut self) -> Result<(), Self::Error> {
321        self.write("SMALLINT UNSIGNED")
322    }
323
324    fn write_u32(&mut self) -> Result<(), Self::Error> {
325        self.write("INT UNSIGNED")
326    }
327
328    fn write_u64(&mut self) -> Result<(), Self::Error> {
329        self.write("BIGINT UNSIGNED")
330    }
331
332    fn write_f32(&mut self) -> Result<(), Self::Error> {
333        self.write("FLOAT")
334    }
335
336    fn write_f64(&mut self) -> Result<(), Self::Error> {
337        self.write("DOUBLE")
338    }
339
340    fn write_bytes(&mut self) -> Result<(), Self::Error> {
341        self.write("BLOB")
342    }
343
344    fn write_str(&mut self) -> Result<(), Self::Error> {
345        self.write("TEXT")
346    }
347
348    fn write_field<T>(&mut self, key: &'static str) -> Result<(), Self::Error>
349    where
350        T: ?Sized + Write<MySqlStore>,
351    {
352        self.key = key;
353        T::write_type(self)
354    }
355}
356
357struct MySqlTypeWriter<'a> {
358    query: Query<'a>,
359    key: &'static str,
360    write_conditions: bool,
361}
362
363impl<'a> MySqlTypeWriter<'a> {
364    fn new(table: &'a str, kind: QueryKind) -> Self {
365        Self {
366            query: Query::new(table, kind),
367            key: "",
368            write_conditions: false,
369        }
370    }
371
372    fn sql(&self) -> String {
373        self.query.to_string()
374    }
375
376    fn write<T>(&mut self, value: T) -> Result<(), <Self as TypeWriter<MySqlStore>>::Error>
377    where
378        T: ToString,
379    {
380        if !self.write_conditions {
381            self.query.push(self.key.to_owned(), value.to_string());
382        } else {
383            self.query.push_condition(Condition::new(
384                self.key.to_owned(),
385                value.to_string(),
386                Comparator::Eq,
387            ));
388        }
389        Ok(())
390    }
391}
392
393impl<'a> TypeWriter<MySqlStore> for MySqlTypeWriter<'a> {
394    type Error = Infallible;
395
396    fn write_bool(&mut self) -> Result<(), Self::Error> {
397        self.write("BOOLEAN")
398    }
399
400    fn write_i8(&mut self) -> Result<(), Self::Error> {
401        self.write("TINYINT")
402    }
403
404    fn write_i16(&mut self) -> Result<(), Self::Error> {
405        self.write("SMALLINT")
406    }
407
408    fn write_i32(&mut self) -> Result<(), Self::Error> {
409        self.write("INT")
410    }
411
412    fn write_i64(&mut self) -> Result<(), Self::Error> {
413        self.write("BIGINT")
414    }
415
416    fn write_u8(&mut self) -> Result<(), Self::Error> {
417        self.write("TINYINT UNSIGNED")
418    }
419
420    fn write_u16(&mut self) -> Result<(), Self::Error> {
421        self.write("SMALLINT UNSIGNED")
422    }
423
424    fn write_u32(&mut self) -> Result<(), Self::Error> {
425        self.write("INT UNSIGNED")
426    }
427
428    fn write_u64(&mut self) -> Result<(), Self::Error> {
429        self.write("BIGINT UNSIGNED")
430    }
431
432    fn write_f32(&mut self) -> Result<(), Self::Error> {
433        self.write("FLOAT")
434    }
435
436    fn write_f64(&mut self) -> Result<(), Self::Error> {
437        self.write("DOUBLE")
438    }
439
440    fn write_bytes(&mut self) -> Result<(), Self::Error> {
441        self.write("BLOB")
442    }
443
444    fn write_str(&mut self) -> Result<(), Self::Error> {
445        self.write("TEXT")
446    }
447
448    fn write_field<T>(&mut self, key: &'static str) -> Result<(), Self::Error>
449    where
450        T: ?Sized + Write<MySqlStore>,
451    {
452        self.key = key;
453        T::write_type(self)
454    }
455}
456
457struct MySqlReader {
458    row: MySqlRow,
459    column: Option<&'static str>,
460}
461
462impl MySqlReader {
463    fn new(row: MySqlRow) -> Self {
464        Self { row, column: None }
465    }
466
467    fn read<'r, T>(&'r mut self) -> Result<T, <Self as Reader<MySqlStore>>::Error>
468    where
469        T: sqlx::Decode<'r, MySql> + sqlx::Type<MySql>,
470    {
471        self.row.try_get(self.column.unwrap())
472    }
473}
474
475impl Reader<MySqlStore> for MySqlReader {
476    type Error = sqlx::Error;
477
478    fn read_bool(&mut self) -> Result<bool, Self::Error> {
479        self.read()
480    }
481
482    fn read_i8(&mut self) -> Result<i8, Self::Error> {
483        self.read()
484    }
485
486    fn read_i16(&mut self) -> Result<i16, Self::Error> {
487        self.read()
488    }
489
490    fn read_i32(&mut self) -> Result<i32, Self::Error> {
491        self.read()
492    }
493
494    fn read_i64(&mut self) -> Result<i64, Self::Error> {
495        self.read()
496    }
497
498    fn read_u8(&mut self) -> Result<u8, Self::Error> {
499        self.read()
500    }
501
502    fn read_u16(&mut self) -> Result<u16, Self::Error> {
503        self.read()
504    }
505
506    fn read_u32(&mut self) -> Result<u32, Self::Error> {
507        self.read()
508    }
509
510    fn read_u64(&mut self) -> Result<u64, Self::Error> {
511        self.read()
512    }
513
514    fn read_f32(&mut self) -> Result<f32, Self::Error> {
515        self.read()
516    }
517
518    fn read_f64(&mut self) -> Result<f64, Self::Error> {
519        self.read()
520    }
521
522    fn read_byte_buf(&mut self) -> Result<Vec<u8>, Self::Error> {
523        self.read()
524    }
525
526    fn read_string(&mut self) -> Result<String, Self::Error> {
527        self.read()
528    }
529
530    fn read_field<T>(&mut self, key: &'static str) -> Result<T, Self::Error>
531    where
532        T: Sized + datastore::Read<MySqlStore>,
533    {
534        self.column = Some(key);
535        T::read(self)
536    }
537}
538
539#[cfg(test)]
540mod tests {
541    use super::{MySqlStore, MySqlWriter};
542    use crate::{mysql::MySqlTypeWriter, QueryKind};
543
544    use datastore::{TypeWriter, Writer};
545
546    macro_rules! write {
547        ($writer:expr, $key:expr, $val:expr) => {
548            <MySqlWriter as Writer<MySqlStore>>::write_field(&mut $writer, $key, $val).unwrap();
549        };
550    }
551
552    macro_rules! write_type {
553        ($writer:expr, $key:expr, $val:ty) => {
554            <MySqlWriter as TypeWriter<MySqlStore>>::write_field::<$val>(&mut $writer, $key)
555                .unwrap();
556        };
557    }
558
559    #[test]
560    fn test_writer_create() {
561        let mut writer = MySqlTypeWriter::new("test", QueryKind::Create);
562        writer.write_field::<i32>("id").unwrap();
563
564        assert_eq!(writer.sql(), "CREATE TABLE IF NOT EXISTS test (id INT)");
565
566        let mut writer = MySqlTypeWriter::new("test", QueryKind::Create);
567        writer.write_field::<i32>("id").unwrap();
568        writer.write_field::<str>("name").unwrap();
569
570        assert_eq!(
571            writer.sql(),
572            "CREATE TABLE IF NOT EXISTS test (id INT,name TEXT)"
573        );
574    }
575
576    #[test]
577    fn test_writer_delete() {
578        let mut writer = MySqlWriter::new("test", QueryKind::Delete);
579        writer.write_conditions = true;
580        write!(writer, "id", &3_i32);
581
582        assert_eq!(writer.sql(), "DELETE FROM test WHERE id = 3");
583
584        let mut writer = MySqlWriter::new("test", QueryKind::Delete);
585        writer.write_conditions = true;
586        write!(writer, "id", &3_i32);
587        write!(writer, "name", "hello");
588
589        assert_eq!(
590            writer.sql(),
591            "DELETE FROM test WHERE id = 3 AND name = 'hello'"
592        );
593    }
594
595    #[test]
596    fn test_writer_insert() {
597        let mut writer = MySqlWriter::new("test", QueryKind::Insert);
598        write!(writer, "id", &3_i32);
599
600        assert_eq!(writer.sql(), "INSERT INTO test (id) VALUES (3)");
601
602        let mut writer = MySqlWriter::new("test", QueryKind::Insert);
603        write!(writer, "id", &3_i32);
604        write!(writer, "name", "hello");
605
606        assert_eq!(
607            writer.sql(),
608            "INSERT INTO test (id,name) VALUES (3,'hello')"
609        );
610    }
611
612    #[test]
613    fn test_writer_select() {
614        let mut writer = MySqlWriter::new("test", QueryKind::Select);
615        write_type!(writer, "id", i32);
616
617        assert_eq!(writer.sql(), "SELECT id FROM test");
618
619        let mut writer = MySqlWriter::new("test", QueryKind::Select);
620        write_type!(writer, "id", i32);
621        write_type!(writer, "name", str);
622
623        assert_eq!(writer.sql(), "SELECT id,name FROM test");
624
625        let mut writer = MySqlWriter::new("test", QueryKind::Select);
626        write_type!(writer, "id", i32);
627        write_type!(writer, "name", str);
628        writer.write_conditions = true;
629        write!(writer, "id", &3_i32);
630
631        assert_eq!(writer.sql(), "SELECT id,name FROM test WHERE id = 3");
632    }
633}