besu_sqlite/
sqlx.rs

1use std::path::Path;
2
3use sqlx::{
4    error::BoxDynError,
5    sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow, SqliteValueRef},
6    Arguments, Executor, Pool, Row,
7};
8
9use besu::{DecodeValue, Driver, EncodeArgument};
10
11use crate::SQLite;
12
13/// SQLite driver implementation based on [`sqlx`].
14#[derive(Debug, Clone)]
15pub struct Sqlx(Pool<sqlx::Sqlite>);
16
17impl Sqlx {
18    /// Construct a new [`Sqlx`] driver from a path.
19    pub async fn new(path: impl AsRef<Path>) -> Result<Self, sqlx::Error> {
20        let path = path.as_ref();
21        let mut pool = sqlx::sqlite::SqlitePoolOptions::new();
22
23        // ':memory:' doesn't play nice with the connection pool - https://github.com/launchbadge/sqlx/issues/362#issuecomment-636661146
24        if path.to_str() == Some(":memory:") {
25            pool = pool
26                .max_connections(1)
27                .idle_timeout(None)
28                .max_lifetime(None);
29        }
30
31        pool.connect_with(
32            // TODO: Expose these params so the user can override them???
33            sqlx::sqlite::SqliteConnectOptions::new()
34                .create_if_missing(true)
35                .foreign_keys(true)
36                .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
37                .filename(path),
38        )
39        .await
40        .map(Self)
41    }
42
43    /// Get a reference to the inner [`Pool`] of the driver.
44    pub fn inner(&self) -> &Pool<sqlx::Sqlite> {
45        &self.0
46    }
47}
48
49impl Driver for Sqlx {
50    type Error = sqlx::Error;
51    type Dialect = SQLite;
52    type Row = SqliteRow;
53    type Value<'a> = SqliteValueRef<'a>;
54    type Output = SqliteQueryResult;
55
56    type Arguments<'a> = EncodeImpl<'a>;
57    type ValueDecoder = DecodeImpl;
58
59    fn row_len(row: &Self::Row) -> usize {
60        row.len()
61    }
62
63    fn get_value(row: &Self::Row, index: usize) -> Result<Self::Value<'_>, Self::Error> {
64        row.try_get_raw(index)
65    }
66
67    fn error_encoding_arguments(
68        err: Box<dyn std::error::Error + Send + Sync + 'static>,
69    ) -> Self::Error {
70        sqlx::Error::Encode(err)
71    }
72
73    fn error_decoding_value(
74        err: Box<dyn std::error::Error + Send + Sync + 'static>,
75    ) -> Self::Error {
76        sqlx::Error::Decode(err)
77    }
78
79    async fn query(
80        &self,
81        sql: &str,
82        args: Self::Arguments<'_>,
83    ) -> Result<Vec<Self::Row>, Self::Error> {
84        self.0.fetch_all((sql, Some(args.0))).await
85    }
86
87    async fn execute(
88        &self,
89        sql: &str,
90        args: Self::Arguments<'_>,
91    ) -> Result<Self::Output, Self::Error> {
92        self.0.execute((sql, Some(args.0))).await
93    }
94}
95
96pub struct DecodeImpl;
97
98impl<T: for<'q> sqlx::Decode<'q, sqlx::Sqlite>> DecodeValue<Sqlx, T> for DecodeImpl {
99    fn decode<'q>(value: SqliteValueRef<'q>) -> Result<T, BoxDynError> {
100        <T as sqlx::Decode<'q, sqlx::Sqlite>>::decode(value)
101    }
102}
103
104#[derive(Default)]
105pub struct EncodeImpl<'a>(SqliteArguments<'a>);
106
107impl<'a, T: for<'q> sqlx::Encode<'q, sqlx::Sqlite> + sqlx::Type<sqlx::Sqlite>> EncodeArgument<'a, T>
108    for EncodeImpl<'a>
109{
110    fn encode(&mut self, value: T) -> Result<(), BoxDynError>
111    where
112        T: 'a,
113    {
114        self.0.add(value)
115    }
116}