1use std::convert::Infallible;
2use std::fmt::{Debug, Write as _};
3
4use crate::{Comparator, Condition, Error, ErrorKind, Query, QueryKind};
5
6use async_trait::async_trait;
7use datastore::{DataDescriptor, DataQuery, Reader, Store, StoreData, TypeWriter, Write, Writer};
8use futures::TryStreamExt;
9use sqlx::{mysql::MySqlRow, MySql, Pool, Row};
10
11#[derive(Clone, Debug)]
13pub struct MySqlStore {
14 pool: Pool<MySql>,
15}
16
17#[async_trait]
18impl Store for MySqlStore {
19 type DataStore = Self;
20 type Error = Error;
21
22 async fn connect(uri: &str) -> Result<Self, Self::Error> {
23 let pool = Pool::connect(uri)
24 .await
25 .map_err(|err| Error(ErrorKind::Sqlx(err)))?;
26
27 Ok(Self { pool })
28 }
29
30 async fn create<T, D>(&self, descriptor: D) -> Result<(), Self::Error>
31 where
32 T: StoreData<Self> + Send + Sync + 'static,
33 D: DataDescriptor<T, Self> + Send + Sync,
34 {
35 let table = descriptor.ident();
36 let mut writer = MySqlTypeWriter::new(table, QueryKind::Create);
37 descriptor.write(&mut writer).unwrap();
38
39 let sql = writer.sql();
40 log::debug!("Executing sql CREATE query: \"{}\"", sql);
41
42 sqlx::query(&sql)
43 .execute(&self.pool)
44 .await
45 .map_err(|err| Error(ErrorKind::Sqlx(err)))?;
46 Ok(())
47 }
48
49 async fn delete<T, D, Q>(&self, descriptor: D, query: Q) -> Result<(), Self::Error>
50 where
51 T: StoreData<Self::DataStore> + Send + Sync + 'static,
52 D: DataDescriptor<T, Self::DataStore> + Send,
53 Q: DataQuery<T, Self::DataStore> + Send,
54 {
55 let table = descriptor.ident();
56 let mut writer = MySqlWriter::new(table, QueryKind::Delete);
57 writer.write_conditions = true;
58 query.write(&mut writer).unwrap();
59
60 let sql = writer.sql();
61 log::debug!("Executing sql DELETE query: \"{}\"", sql);
62
63 sqlx::query(&sql)
64 .execute(&self.pool)
65 .await
66 .map_err(|err| Error(ErrorKind::Sqlx(err)))?;
67 Ok(())
68 }
69
70 async fn get<T, D, Q>(&self, descriptor: D, query: Q) -> Result<Vec<T>, Self::Error>
71 where
72 T: StoreData<Self::DataStore> + Send + Sync + 'static,
73 D: DataDescriptor<T, Self::DataStore> + Send,
74 Q: DataQuery<T, Self::DataStore> + Send,
75 {
76 let table = descriptor.ident();
77
78 let mut writer = MySqlWriter::new(table, QueryKind::Select);
79 descriptor.write(&mut writer).unwrap();
80
81 writer.write_conditions = true;
82 query.write(&mut writer).unwrap();
83
84 let sql = writer.sql();
85 log::debug!("Executing sql SELECT query: \"{}\"", sql);
86
87 let mut rows = sqlx::query(&sql).fetch(&self.pool);
88
89 let mut entries = Vec::new();
90 while let Some(row) = rows
91 .try_next()
92 .await
93 .map_err(|err| Error(ErrorKind::Sqlx(err)))?
94 {
95 let mut reader = MySqlReader::new(row);
96 let data = T::read(&mut reader).unwrap();
97
98 entries.push(data);
99 }
100
101 Ok(entries)
102 }
103
104 async fn get_all<T, D>(&self, descriptor: D) -> Result<Vec<T>, Self::Error>
105 where
106 T: StoreData<Self::DataStore> + Send + Sync + 'static,
107 D: DataDescriptor<T, Self::DataStore> + Send + Sync,
108 {
109 let table = descriptor.ident();
110 let mut writer = MySqlTypeWriter::new(table, QueryKind::Select);
111 descriptor.write(&mut writer).unwrap();
112
113 let sql = writer.sql();
114 log::debug!("Executing sql SELECT query: \"{}\"", sql);
115
116 let mut rows = sqlx::query(&sql).fetch(&self.pool);
117
118 let mut entries = Vec::new();
119 while let Some(row) = rows
120 .try_next()
121 .await
122 .map_err(|err| Error(ErrorKind::Sqlx(err)))?
123 {
124 let mut reader = MySqlReader::new(row);
125 let data = T::read(&mut reader).map_err(|err| Error(ErrorKind::Sqlx(err)))?;
126
127 entries.push(data);
128 }
129
130 Ok(entries)
131 }
132
133 async fn get_one<T, D, Q>(&self, descriptor: D, query: Q) -> Result<Option<T>, Self::Error>
134 where
135 T: StoreData<Self::DataStore> + Send + Sync + 'static,
136 D: DataDescriptor<T, Self::DataStore> + Send,
137 Q: DataQuery<T, Self::DataStore> + Send,
138 {
139 let table = descriptor.ident();
140
141 let mut writer = MySqlWriter::new(table, QueryKind::Select);
142 descriptor.write(&mut writer).unwrap();
143
144 writer.write_conditions = true;
145 query.write(&mut writer).unwrap();
146
147 let sql = writer.sql();
148 log::debug!("Executing sql SELECT query: \"{}\"", sql);
149
150 let row = match sqlx::query(&sql).fetch_one(&self.pool).await {
151 Ok(row) => row,
152 Err(sqlx::Error::RowNotFound) => return Ok(None),
153 Err(err) => return Err(Error(ErrorKind::Sqlx(err))),
154 };
155
156 let mut reader = MySqlReader::new(row);
157 let data = T::read(&mut reader).map_err(|err| Error(ErrorKind::Sqlx(err)))?;
158
159 Ok(Some(data))
160 }
161
162 async fn insert<T, D>(&self, descriptor: D, data: T) -> Result<(), Self::Error>
163 where
164 T: StoreData<Self::DataStore> + Send + Sync + 'static,
165 D: DataDescriptor<T, Self::DataStore> + Send,
166 {
167 let table = descriptor.ident();
168
169 let mut writer = MySqlWriter::new(table, QueryKind::Insert);
170 data.write(&mut writer).unwrap();
171
172 let sql = writer.sql();
173 log::debug!("Executing sql INSERT query: \"{}\"", sql);
174
175 sqlx::query(&sql)
176 .execute(&self.pool)
177 .await
178 .map_err(|err| Error(ErrorKind::Sqlx(err)))?;
179 Ok(())
180 }
181}
182
183#[derive(Debug)]
184struct MySqlWriter<'a> {
185 query: Query<'a>,
186 key: &'static str,
187 write_conditions: bool,
188}
189
190impl<'a> MySqlWriter<'a> {
191 fn new(table: &'a str, kind: QueryKind) -> Self {
192 Self {
193 query: Query::new(table, kind),
194 key: "",
195 write_conditions: false,
196 }
197 }
198
199 fn sql(&self) -> String {
200 self.query.to_string()
201 }
202
203 fn write<T>(&mut self, val: T) -> Result<(), <Self as Writer<MySqlStore>>::Error>
204 where
205 T: ToString,
206 {
207 if self.write_conditions {
208 self.query.push_condition(Condition::new(
209 self.key.to_owned(),
210 val.to_string(),
211 Comparator::Eq,
212 ));
213 } else {
214 self.query.push(self.key.to_owned(), val.to_string());
215 }
216 Ok(())
217 }
218}
219
220impl<'a> Writer<MySqlStore> for MySqlWriter<'a> {
221 type Error = Infallible;
222
223 fn write_bool(&mut self, v: bool) -> Result<(), Self::Error> {
224 self.write(match v {
225 false => "FALSE",
226 true => "TRUE",
227 })
228 }
229
230 fn write_i8(&mut self, v: i8) -> Result<(), Self::Error> {
231 self.write(v)
232 }
233
234 fn write_i16(&mut self, v: i16) -> Result<(), Self::Error> {
235 self.write(v)
236 }
237
238 fn write_i32(&mut self, v: i32) -> Result<(), Self::Error> {
239 self.write(v)
240 }
241
242 fn write_i64(&mut self, v: i64) -> Result<(), Self::Error> {
243 self.write(v)
244 }
245
246 fn write_u8(&mut self, v: u8) -> Result<(), Self::Error> {
247 self.write(v)
248 }
249
250 fn write_u16(&mut self, v: u16) -> Result<(), Self::Error> {
251 self.write(v)
252 }
253
254 fn write_u32(&mut self, v: u32) -> Result<(), Self::Error> {
255 self.write(v)
256 }
257
258 fn write_u64(&mut self, v: u64) -> Result<(), Self::Error> {
259 self.write(v)
260 }
261
262 fn write_f32(&mut self, v: f32) -> Result<(), Self::Error> {
263 self.write(v)
264 }
265
266 fn write_f64(&mut self, v: f64) -> Result<(), Self::Error> {
267 self.write(v)
268 }
269
270 fn write_bytes(&mut self, v: &[u8]) -> Result<(), Self::Error> {
271 let mut string = String::with_capacity(2 * v.len() + "0x".len());
272 string.push_str("0x");
273 for byte in v {
274 let _ = write!(string, "{:02x}", byte);
275 }
276
277 self.write(string)
278 }
279
280 fn write_str(&mut self, v: &str) -> Result<(), Self::Error> {
281 self.write(format!("'{}'", v.replace('\'', "\'")))
282 }
283
284 fn write_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error>
285 where
286 T: ?Sized + Write<MySqlStore>,
287 {
288 self.key = key;
289 value.write(self)
290 }
291}
292
293impl<'a> TypeWriter<MySqlStore> for MySqlWriter<'a> {
294 type Error = Infallible;
295
296 fn write_bool(&mut self) -> Result<(), Self::Error> {
297 self.write("BOOLEAN")
298 }
299
300 fn write_i8(&mut self) -> Result<(), Self::Error> {
301 self.write("TINYINT")
302 }
303
304 fn write_i16(&mut self) -> Result<(), Self::Error> {
305 self.write("SMALLINT")
306 }
307
308 fn write_i32(&mut self) -> Result<(), Self::Error> {
309 self.write("INT")
310 }
311
312 fn write_i64(&mut self) -> Result<(), Self::Error> {
313 self.write("BIGINT")
314 }
315
316 fn write_u8(&mut self) -> Result<(), Self::Error> {
317 self.write("TINYINT UNSIGNED")
318 }
319
320 fn write_u16(&mut self) -> Result<(), Self::Error> {
321 self.write("SMALLINT UNSIGNED")
322 }
323
324 fn write_u32(&mut self) -> Result<(), Self::Error> {
325 self.write("INT UNSIGNED")
326 }
327
328 fn write_u64(&mut self) -> Result<(), Self::Error> {
329 self.write("BIGINT UNSIGNED")
330 }
331
332 fn write_f32(&mut self) -> Result<(), Self::Error> {
333 self.write("FLOAT")
334 }
335
336 fn write_f64(&mut self) -> Result<(), Self::Error> {
337 self.write("DOUBLE")
338 }
339
340 fn write_bytes(&mut self) -> Result<(), Self::Error> {
341 self.write("BLOB")
342 }
343
344 fn write_str(&mut self) -> Result<(), Self::Error> {
345 self.write("TEXT")
346 }
347
348 fn write_field<T>(&mut self, key: &'static str) -> Result<(), Self::Error>
349 where
350 T: ?Sized + Write<MySqlStore>,
351 {
352 self.key = key;
353 T::write_type(self)
354 }
355}
356
357struct MySqlTypeWriter<'a> {
358 query: Query<'a>,
359 key: &'static str,
360 write_conditions: bool,
361}
362
363impl<'a> MySqlTypeWriter<'a> {
364 fn new(table: &'a str, kind: QueryKind) -> Self {
365 Self {
366 query: Query::new(table, kind),
367 key: "",
368 write_conditions: false,
369 }
370 }
371
372 fn sql(&self) -> String {
373 self.query.to_string()
374 }
375
376 fn write<T>(&mut self, value: T) -> Result<(), <Self as TypeWriter<MySqlStore>>::Error>
377 where
378 T: ToString,
379 {
380 if !self.write_conditions {
381 self.query.push(self.key.to_owned(), value.to_string());
382 } else {
383 self.query.push_condition(Condition::new(
384 self.key.to_owned(),
385 value.to_string(),
386 Comparator::Eq,
387 ));
388 }
389 Ok(())
390 }
391}
392
393impl<'a> TypeWriter<MySqlStore> for MySqlTypeWriter<'a> {
394 type Error = Infallible;
395
396 fn write_bool(&mut self) -> Result<(), Self::Error> {
397 self.write("BOOLEAN")
398 }
399
400 fn write_i8(&mut self) -> Result<(), Self::Error> {
401 self.write("TINYINT")
402 }
403
404 fn write_i16(&mut self) -> Result<(), Self::Error> {
405 self.write("SMALLINT")
406 }
407
408 fn write_i32(&mut self) -> Result<(), Self::Error> {
409 self.write("INT")
410 }
411
412 fn write_i64(&mut self) -> Result<(), Self::Error> {
413 self.write("BIGINT")
414 }
415
416 fn write_u8(&mut self) -> Result<(), Self::Error> {
417 self.write("TINYINT UNSIGNED")
418 }
419
420 fn write_u16(&mut self) -> Result<(), Self::Error> {
421 self.write("SMALLINT UNSIGNED")
422 }
423
424 fn write_u32(&mut self) -> Result<(), Self::Error> {
425 self.write("INT UNSIGNED")
426 }
427
428 fn write_u64(&mut self) -> Result<(), Self::Error> {
429 self.write("BIGINT UNSIGNED")
430 }
431
432 fn write_f32(&mut self) -> Result<(), Self::Error> {
433 self.write("FLOAT")
434 }
435
436 fn write_f64(&mut self) -> Result<(), Self::Error> {
437 self.write("DOUBLE")
438 }
439
440 fn write_bytes(&mut self) -> Result<(), Self::Error> {
441 self.write("BLOB")
442 }
443
444 fn write_str(&mut self) -> Result<(), Self::Error> {
445 self.write("TEXT")
446 }
447
448 fn write_field<T>(&mut self, key: &'static str) -> Result<(), Self::Error>
449 where
450 T: ?Sized + Write<MySqlStore>,
451 {
452 self.key = key;
453 T::write_type(self)
454 }
455}
456
457struct MySqlReader {
458 row: MySqlRow,
459 column: Option<&'static str>,
460}
461
462impl MySqlReader {
463 fn new(row: MySqlRow) -> Self {
464 Self { row, column: None }
465 }
466
467 fn read<'r, T>(&'r mut self) -> Result<T, <Self as Reader<MySqlStore>>::Error>
468 where
469 T: sqlx::Decode<'r, MySql> + sqlx::Type<MySql>,
470 {
471 self.row.try_get(self.column.unwrap())
472 }
473}
474
475impl Reader<MySqlStore> for MySqlReader {
476 type Error = sqlx::Error;
477
478 fn read_bool(&mut self) -> Result<bool, Self::Error> {
479 self.read()
480 }
481
482 fn read_i8(&mut self) -> Result<i8, Self::Error> {
483 self.read()
484 }
485
486 fn read_i16(&mut self) -> Result<i16, Self::Error> {
487 self.read()
488 }
489
490 fn read_i32(&mut self) -> Result<i32, Self::Error> {
491 self.read()
492 }
493
494 fn read_i64(&mut self) -> Result<i64, Self::Error> {
495 self.read()
496 }
497
498 fn read_u8(&mut self) -> Result<u8, Self::Error> {
499 self.read()
500 }
501
502 fn read_u16(&mut self) -> Result<u16, Self::Error> {
503 self.read()
504 }
505
506 fn read_u32(&mut self) -> Result<u32, Self::Error> {
507 self.read()
508 }
509
510 fn read_u64(&mut self) -> Result<u64, Self::Error> {
511 self.read()
512 }
513
514 fn read_f32(&mut self) -> Result<f32, Self::Error> {
515 self.read()
516 }
517
518 fn read_f64(&mut self) -> Result<f64, Self::Error> {
519 self.read()
520 }
521
522 fn read_byte_buf(&mut self) -> Result<Vec<u8>, Self::Error> {
523 self.read()
524 }
525
526 fn read_string(&mut self) -> Result<String, Self::Error> {
527 self.read()
528 }
529
530 fn read_field<T>(&mut self, key: &'static str) -> Result<T, Self::Error>
531 where
532 T: Sized + datastore::Read<MySqlStore>,
533 {
534 self.column = Some(key);
535 T::read(self)
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::{MySqlStore, MySqlWriter};
542 use crate::{mysql::MySqlTypeWriter, QueryKind};
543
544 use datastore::{TypeWriter, Writer};
545
546 macro_rules! write {
547 ($writer:expr, $key:expr, $val:expr) => {
548 <MySqlWriter as Writer<MySqlStore>>::write_field(&mut $writer, $key, $val).unwrap();
549 };
550 }
551
552 macro_rules! write_type {
553 ($writer:expr, $key:expr, $val:ty) => {
554 <MySqlWriter as TypeWriter<MySqlStore>>::write_field::<$val>(&mut $writer, $key)
555 .unwrap();
556 };
557 }
558
559 #[test]
560 fn test_writer_create() {
561 let mut writer = MySqlTypeWriter::new("test", QueryKind::Create);
562 writer.write_field::<i32>("id").unwrap();
563
564 assert_eq!(writer.sql(), "CREATE TABLE IF NOT EXISTS test (id INT)");
565
566 let mut writer = MySqlTypeWriter::new("test", QueryKind::Create);
567 writer.write_field::<i32>("id").unwrap();
568 writer.write_field::<str>("name").unwrap();
569
570 assert_eq!(
571 writer.sql(),
572 "CREATE TABLE IF NOT EXISTS test (id INT,name TEXT)"
573 );
574 }
575
576 #[test]
577 fn test_writer_delete() {
578 let mut writer = MySqlWriter::new("test", QueryKind::Delete);
579 writer.write_conditions = true;
580 write!(writer, "id", &3_i32);
581
582 assert_eq!(writer.sql(), "DELETE FROM test WHERE id = 3");
583
584 let mut writer = MySqlWriter::new("test", QueryKind::Delete);
585 writer.write_conditions = true;
586 write!(writer, "id", &3_i32);
587 write!(writer, "name", "hello");
588
589 assert_eq!(
590 writer.sql(),
591 "DELETE FROM test WHERE id = 3 AND name = 'hello'"
592 );
593 }
594
595 #[test]
596 fn test_writer_insert() {
597 let mut writer = MySqlWriter::new("test", QueryKind::Insert);
598 write!(writer, "id", &3_i32);
599
600 assert_eq!(writer.sql(), "INSERT INTO test (id) VALUES (3)");
601
602 let mut writer = MySqlWriter::new("test", QueryKind::Insert);
603 write!(writer, "id", &3_i32);
604 write!(writer, "name", "hello");
605
606 assert_eq!(
607 writer.sql(),
608 "INSERT INTO test (id,name) VALUES (3,'hello')"
609 );
610 }
611
612 #[test]
613 fn test_writer_select() {
614 let mut writer = MySqlWriter::new("test", QueryKind::Select);
615 write_type!(writer, "id", i32);
616
617 assert_eq!(writer.sql(), "SELECT id FROM test");
618
619 let mut writer = MySqlWriter::new("test", QueryKind::Select);
620 write_type!(writer, "id", i32);
621 write_type!(writer, "name", str);
622
623 assert_eq!(writer.sql(), "SELECT id,name FROM test");
624
625 let mut writer = MySqlWriter::new("test", QueryKind::Select);
626 write_type!(writer, "id", i32);
627 write_type!(writer, "name", str);
628 writer.write_conditions = true;
629 write!(writer, "id", &3_i32);
630
631 assert_eq!(writer.sql(), "SELECT id,name FROM test WHERE id = 3");
632 }
633}