besu_mysql/
sqlx.rs

1use std::marker::PhantomData;
2
3use sqlx::{
4    error::BoxDynError,
5    mysql::{
6        MySqlArguments, MySqlConnectOptions, MySqlPoolOptions, MySqlQueryResult, MySqlRow,
7        MySqlValueRef,
8    },
9    Arguments, Executor, Pool, Row,
10};
11
12use besu::{DecodeValue, Driver, EncodeArgument};
13
14use crate::MySQL;
15
16/// MySQl driver implementation based on [`sqlx`].
17#[derive(Debug, Clone)]
18pub struct Sqlx(Pool<sqlx::MySql>);
19
20impl Sqlx {
21    /// Construct a new [`Sqlx`] driver from a URI.
22    pub async fn new(uri: impl AsRef<str>) -> Result<Self, sqlx::Error> {
23        let uri = uri.as_ref();
24        // TODO: Expose these params so the user can override them???
25        let opts: MySqlConnectOptions = uri.parse()?;
26
27        // TODO: Expose these params so the user can override them???
28        MySqlPoolOptions::new().connect_with(opts).await.map(Self)
29    }
30
31    /// Get a reference to the inner [`Pool`] of the driver.
32    pub fn inner(&self) -> &Pool<sqlx::MySql> {
33        &self.0
34    }
35}
36
37impl Driver for Sqlx {
38    type Error = sqlx::Error;
39    type Dialect = MySQL;
40    type Row = MySqlRow;
41    type Value<'a> = MySqlValueRef<'a>;
42    type Output = MySqlQueryResult;
43
44    type Arguments<'a> = EncodeImpl<'a>;
45    type ValueDecoder = DecodeImpl;
46
47    fn row_len(row: &Self::Row) -> usize {
48        row.len()
49    }
50
51    fn get_value(row: &Self::Row, index: usize) -> Result<Self::Value<'_>, Self::Error> {
52        row.try_get_raw(index)
53    }
54
55    fn error_encoding_arguments(
56        err: Box<dyn std::error::Error + Send + Sync + 'static>,
57    ) -> Self::Error {
58        sqlx::Error::Encode(err)
59    }
60
61    fn error_decoding_value(
62        err: Box<dyn std::error::Error + Send + Sync + 'static>,
63    ) -> Self::Error {
64        sqlx::Error::Decode(err)
65    }
66
67    async fn query(
68        &self,
69        sql: &str,
70        args: Self::Arguments<'_>,
71    ) -> Result<Vec<Self::Row>, Self::Error> {
72        self.0.fetch_all((sql, Some(args.0))).await
73    }
74
75    async fn execute(
76        &self,
77        sql: &str,
78        args: Self::Arguments<'_>,
79    ) -> Result<Self::Output, Self::Error> {
80        self.0.execute((sql, Some(args.0))).await
81    }
82}
83
84pub struct DecodeImpl;
85
86impl<T: for<'q> sqlx::Decode<'q, sqlx::MySql>> DecodeValue<Sqlx, T> for DecodeImpl {
87    fn decode<'q>(value: MySqlValueRef<'q>) -> Result<T, BoxDynError> {
88        <T as sqlx::Decode<'q, sqlx::MySql>>::decode(value)
89    }
90}
91
92#[derive(Default)]
93pub struct EncodeImpl<'a>(MySqlArguments, PhantomData<&'a ()>);
94
95impl<'a, T: for<'q> sqlx::Encode<'q, sqlx::MySql> + sqlx::Type<sqlx::MySql>> EncodeArgument<'a, T>
96    for EncodeImpl<'a>
97{
98    fn encode(&mut self, value: T) -> Result<(), BoxDynError>
99    where
100        T: 'a,
101    {
102        self.0.add(value)
103    }
104}