besu_postgres/
sqlx.rs

1use std::marker::PhantomData;
2
3use sqlx::{
4    error::BoxDynError,
5    postgres::{PgArguments, PgConnectOptions, PgPoolOptions, PgQueryResult, PgRow, PgValueRef},
6    Arguments, Executor, Pool, Row,
7};
8
9use besu::{DecodeValue, Driver, EncodeArgument};
10
11use crate::Postgres;
12
13/// Postgres driver implementation based on [`sqlx`].
14#[derive(Debug, Clone)]
15pub struct Sqlx(Pool<sqlx::Postgres>);
16
17impl Sqlx {
18    /// Construct a new [`Sqlx`] driver from a URI.
19    pub async fn new(uri: impl AsRef<str>) -> Result<Self, sqlx::Error> {
20        let uri = uri.as_ref();
21        // TODO: Expose these params so the user can override them???
22        let opts: PgConnectOptions = uri.parse()?;
23
24        // TODO: Expose these params so the user can override them???
25        PgPoolOptions::new().connect_with(opts).await.map(Self)
26    }
27
28    /// Get a reference to the inner [`Pool`] of the driver.
29    pub fn inner(&self) -> &Pool<sqlx::Postgres> {
30        &self.0
31    }
32}
33
34impl Driver for Sqlx {
35    type Error = sqlx::Error;
36    type Dialect = Postgres;
37    type Row = PgRow;
38    type Value<'a> = PgValueRef<'a>;
39    type Output = PgQueryResult;
40
41    type Arguments<'a> = EncodeImpl<'a>;
42    type ValueDecoder = DecodeImpl;
43
44    fn row_len(row: &Self::Row) -> usize {
45        row.len()
46    }
47
48    fn get_value(row: &Self::Row, index: usize) -> Result<Self::Value<'_>, Self::Error> {
49        row.try_get_raw(index)
50    }
51
52    fn error_encoding_arguments(
53        err: Box<dyn std::error::Error + Send + Sync + 'static>,
54    ) -> Self::Error {
55        sqlx::Error::Encode(err)
56    }
57
58    fn error_decoding_value(
59        err: Box<dyn std::error::Error + Send + Sync + 'static>,
60    ) -> Self::Error {
61        sqlx::Error::Decode(err)
62    }
63
64    async fn query(
65        &self,
66        sql: &str,
67        args: Self::Arguments<'_>,
68    ) -> Result<Vec<Self::Row>, Self::Error> {
69        self.0.fetch_all((sql, Some(args.0))).await
70    }
71
72    async fn execute(
73        &self,
74        sql: &str,
75        args: Self::Arguments<'_>,
76    ) -> Result<Self::Output, Self::Error> {
77        self.0.execute((sql, Some(args.0))).await
78    }
79}
80
81pub struct DecodeImpl;
82
83impl<T: for<'q> sqlx::Decode<'q, sqlx::Postgres>> DecodeValue<Sqlx, T> for DecodeImpl {
84    fn decode<'q>(value: PgValueRef<'q>) -> Result<T, BoxDynError> {
85        <T as sqlx::Decode<'q, sqlx::Postgres>>::decode(value)
86    }
87}
88
89#[derive(Default)]
90pub struct EncodeImpl<'a>(PgArguments, PhantomData<&'a ()>);
91
92impl<'a, T: for<'q> sqlx::Encode<'q, sqlx::Postgres> + sqlx::Type<sqlx::Postgres>>
93    EncodeArgument<'a, T> for EncodeImpl<'a>
94{
95    fn encode(&mut self, value: T) -> Result<(), BoxDynError>
96    where
97        T: 'a,
98    {
99        self.0.add(value)
100    }
101}