1use std::path::Path;
2
3use sqlx::{
4 error::BoxDynError,
5 sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow, SqliteValueRef},
6 Arguments, Executor, Pool, Row,
7};
8
9use besu::{DecodeValue, Driver, EncodeArgument};
10
11use crate::SQLite;
12
13#[derive(Debug, Clone)]
15pub struct Sqlx(Pool<sqlx::Sqlite>);
16
17impl Sqlx {
18 pub async fn new(path: impl AsRef<Path>) -> Result<Self, sqlx::Error> {
20 let path = path.as_ref();
21 let mut pool = sqlx::sqlite::SqlitePoolOptions::new();
22
23 if path.to_str() == Some(":memory:") {
25 pool = pool
26 .max_connections(1)
27 .idle_timeout(None)
28 .max_lifetime(None);
29 }
30
31 pool.connect_with(
32 sqlx::sqlite::SqliteConnectOptions::new()
34 .create_if_missing(true)
35 .foreign_keys(true)
36 .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
37 .filename(path),
38 )
39 .await
40 .map(Self)
41 }
42
43 pub fn inner(&self) -> &Pool<sqlx::Sqlite> {
45 &self.0
46 }
47}
48
49impl Driver for Sqlx {
50 type Error = sqlx::Error;
51 type Dialect = SQLite;
52 type Row = SqliteRow;
53 type Value<'a> = SqliteValueRef<'a>;
54 type Output = SqliteQueryResult;
55
56 type Arguments<'a> = EncodeImpl<'a>;
57 type ValueDecoder = DecodeImpl;
58
59 fn row_len(row: &Self::Row) -> usize {
60 row.len()
61 }
62
63 fn get_value(row: &Self::Row, index: usize) -> Result<Self::Value<'_>, Self::Error> {
64 row.try_get_raw(index)
65 }
66
67 fn error_encoding_arguments(
68 err: Box<dyn std::error::Error + Send + Sync + 'static>,
69 ) -> Self::Error {
70 sqlx::Error::Encode(err)
71 }
72
73 fn error_decoding_value(
74 err: Box<dyn std::error::Error + Send + Sync + 'static>,
75 ) -> Self::Error {
76 sqlx::Error::Decode(err)
77 }
78
79 async fn query(
80 &self,
81 sql: &str,
82 args: Self::Arguments<'_>,
83 ) -> Result<Vec<Self::Row>, Self::Error> {
84 self.0.fetch_all((sql, Some(args.0))).await
85 }
86
87 async fn execute(
88 &self,
89 sql: &str,
90 args: Self::Arguments<'_>,
91 ) -> Result<Self::Output, Self::Error> {
92 self.0.execute((sql, Some(args.0))).await
93 }
94}
95
96pub struct DecodeImpl;
97
98impl<T: for<'q> sqlx::Decode<'q, sqlx::Sqlite>> DecodeValue<Sqlx, T> for DecodeImpl {
99 fn decode<'q>(value: SqliteValueRef<'q>) -> Result<T, BoxDynError> {
100 <T as sqlx::Decode<'q, sqlx::Sqlite>>::decode(value)
101 }
102}
103
104#[derive(Default)]
105pub struct EncodeImpl<'a>(SqliteArguments<'a>);
106
107impl<'a, T: for<'q> sqlx::Encode<'q, sqlx::Sqlite> + sqlx::Type<sqlx::Sqlite>> EncodeArgument<'a, T>
108 for EncodeImpl<'a>
109{
110 fn encode(&mut self, value: T) -> Result<(), BoxDynError>
111 where
112 T: 'a,
113 {
114 self.0.add(value)
115 }
116}