1use std::marker::PhantomData;
2
3use sqlx::{
4 error::BoxDynError,
5 postgres::{PgArguments, PgConnectOptions, PgPoolOptions, PgQueryResult, PgRow, PgValueRef},
6 Arguments, Executor, Pool, Row,
7};
8
9use besu::{DecodeValue, Driver, EncodeArgument};
10
11use crate::Postgres;
12
13#[derive(Debug, Clone)]
15pub struct Sqlx(Pool<sqlx::Postgres>);
16
17impl Sqlx {
18 pub async fn new(uri: impl AsRef<str>) -> Result<Self, sqlx::Error> {
20 let uri = uri.as_ref();
21 let opts: PgConnectOptions = uri.parse()?;
23
24 PgPoolOptions::new().connect_with(opts).await.map(Self)
26 }
27
28 pub fn inner(&self) -> &Pool<sqlx::Postgres> {
30 &self.0
31 }
32}
33
34impl Driver for Sqlx {
35 type Error = sqlx::Error;
36 type Dialect = Postgres;
37 type Row = PgRow;
38 type Value<'a> = PgValueRef<'a>;
39 type Output = PgQueryResult;
40
41 type Arguments<'a> = EncodeImpl<'a>;
42 type ValueDecoder = DecodeImpl;
43
44 fn row_len(row: &Self::Row) -> usize {
45 row.len()
46 }
47
48 fn get_value(row: &Self::Row, index: usize) -> Result<Self::Value<'_>, Self::Error> {
49 row.try_get_raw(index)
50 }
51
52 fn error_encoding_arguments(
53 err: Box<dyn std::error::Error + Send + Sync + 'static>,
54 ) -> Self::Error {
55 sqlx::Error::Encode(err)
56 }
57
58 fn error_decoding_value(
59 err: Box<dyn std::error::Error + Send + Sync + 'static>,
60 ) -> Self::Error {
61 sqlx::Error::Decode(err)
62 }
63
64 async fn query(
65 &self,
66 sql: &str,
67 args: Self::Arguments<'_>,
68 ) -> Result<Vec<Self::Row>, Self::Error> {
69 self.0.fetch_all((sql, Some(args.0))).await
70 }
71
72 async fn execute(
73 &self,
74 sql: &str,
75 args: Self::Arguments<'_>,
76 ) -> Result<Self::Output, Self::Error> {
77 self.0.execute((sql, Some(args.0))).await
78 }
79}
80
81pub struct DecodeImpl;
82
83impl<T: for<'q> sqlx::Decode<'q, sqlx::Postgres>> DecodeValue<Sqlx, T> for DecodeImpl {
84 fn decode<'q>(value: PgValueRef<'q>) -> Result<T, BoxDynError> {
85 <T as sqlx::Decode<'q, sqlx::Postgres>>::decode(value)
86 }
87}
88
89#[derive(Default)]
90pub struct EncodeImpl<'a>(PgArguments, PhantomData<&'a ()>);
91
92impl<'a, T: for<'q> sqlx::Encode<'q, sqlx::Postgres> + sqlx::Type<sqlx::Postgres>>
93 EncodeArgument<'a, T> for EncodeImpl<'a>
94{
95 fn encode(&mut self, value: T) -> Result<(), BoxDynError>
96 where
97 T: 'a,
98 {
99 self.0.add(value)
100 }
101}