1use std::marker::PhantomData;
2
3use sqlx::{
4 error::BoxDynError,
5 mysql::{
6 MySqlArguments, MySqlConnectOptions, MySqlPoolOptions, MySqlQueryResult, MySqlRow,
7 MySqlValueRef,
8 },
9 Arguments, Executor, Pool, Row,
10};
11
12use besu::{DecodeValue, Driver, EncodeArgument};
13
14use crate::MySQL;
15
16#[derive(Debug, Clone)]
18pub struct Sqlx(Pool<sqlx::MySql>);
19
20impl Sqlx {
21 pub async fn new(uri: impl AsRef<str>) -> Result<Self, sqlx::Error> {
23 let uri = uri.as_ref();
24 let opts: MySqlConnectOptions = uri.parse()?;
26
27 MySqlPoolOptions::new().connect_with(opts).await.map(Self)
29 }
30
31 pub fn inner(&self) -> &Pool<sqlx::MySql> {
33 &self.0
34 }
35}
36
37impl Driver for Sqlx {
38 type Error = sqlx::Error;
39 type Dialect = MySQL;
40 type Row = MySqlRow;
41 type Value<'a> = MySqlValueRef<'a>;
42 type Output = MySqlQueryResult;
43
44 type Arguments<'a> = EncodeImpl<'a>;
45 type ValueDecoder = DecodeImpl;
46
47 fn row_len(row: &Self::Row) -> usize {
48 row.len()
49 }
50
51 fn get_value(row: &Self::Row, index: usize) -> Result<Self::Value<'_>, Self::Error> {
52 row.try_get_raw(index)
53 }
54
55 fn error_encoding_arguments(
56 err: Box<dyn std::error::Error + Send + Sync + 'static>,
57 ) -> Self::Error {
58 sqlx::Error::Encode(err)
59 }
60
61 fn error_decoding_value(
62 err: Box<dyn std::error::Error + Send + Sync + 'static>,
63 ) -> Self::Error {
64 sqlx::Error::Decode(err)
65 }
66
67 async fn query(
68 &self,
69 sql: &str,
70 args: Self::Arguments<'_>,
71 ) -> Result<Vec<Self::Row>, Self::Error> {
72 self.0.fetch_all((sql, Some(args.0))).await
73 }
74
75 async fn execute(
76 &self,
77 sql: &str,
78 args: Self::Arguments<'_>,
79 ) -> Result<Self::Output, Self::Error> {
80 self.0.execute((sql, Some(args.0))).await
81 }
82}
83
84pub struct DecodeImpl;
85
86impl<T: for<'q> sqlx::Decode<'q, sqlx::MySql>> DecodeValue<Sqlx, T> for DecodeImpl {
87 fn decode<'q>(value: MySqlValueRef<'q>) -> Result<T, BoxDynError> {
88 <T as sqlx::Decode<'q, sqlx::MySql>>::decode(value)
89 }
90}
91
92#[derive(Default)]
93pub struct EncodeImpl<'a>(MySqlArguments, PhantomData<&'a ()>);
94
95impl<'a, T: for<'q> sqlx::Encode<'q, sqlx::MySql> + sqlx::Type<sqlx::MySql>> EncodeArgument<'a, T>
96 for EncodeImpl<'a>
97{
98 fn encode(&mut self, value: T) -> Result<(), BoxDynError>
99 where
100 T: 'a,
101 {
102 self.0.add(value)
103 }
104}