Skip to main content

eva_common/
db.rs

1/// Database functions (sqlx-based) and dynamic wrapper
2/// Currently supported sqlx 0.6 only
3///
4/// Supported databases: Sqlite, PostgresSQL
5///
6/// For Value type use JSONB only
7/// For OID use VARCHAR(1024)
8///
9/// For Time (feature "time" enabled) type use INTEGER for Sqlite and TIMESTAMP/TIMESTAMPTZ for
10/// Postgres
11#[cfg(feature = "acl")]
12use crate::acl::OIDMask;
13use crate::{EResult, Error};
14use crate::{OID, value::Value};
15use sqlx::ConnectOptions;
16use sqlx::encode::IsNull;
17use sqlx::error::BoxDynError;
18use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgValueRef};
19use sqlx::sqlite::SqliteValueRef;
20use sqlx::sqlite::{self, SqliteConnectOptions, SqlitePoolOptions};
21use sqlx::{Decode, Encode};
22use sqlx::{PgPool, SqlitePool, postgres};
23use sqlx::{Postgres, Sqlite, Type};
24use std::borrow::Cow;
25use std::str::FromStr as _;
26use std::sync::OnceLock;
27use std::time::Duration;
28
29pub mod prelude {
30    pub use super::{DbKind, DbPool, Transaction, db_init, db_pool};
31}
32
33type ResultIsNull = Result<IsNull, Box<dyn std::error::Error + Send + Sync + 'static>>;
34
35static DB_POOL: OnceLock<DbPool> = OnceLock::new();
36
37impl Type<Sqlite> for OID {
38    fn type_info() -> sqlite::SqliteTypeInfo {
39        <str as Type<Sqlite>>::type_info()
40    }
41}
42
43impl Type<Postgres> for OID {
44    fn type_info() -> postgres::PgTypeInfo {
45        <str as Type<Postgres>>::type_info()
46    }
47    fn compatible(ty: &postgres::PgTypeInfo) -> bool {
48        *ty == postgres::PgTypeInfo::with_name("VARCHAR")
49            || *ty == postgres::PgTypeInfo::with_name("TEXT")
50    }
51}
52
53impl postgres::PgHasArrayType for OID {
54    fn array_type_info() -> postgres::PgTypeInfo {
55        postgres::PgTypeInfo::with_name("_TEXT")
56    }
57
58    fn array_compatible(ty: &postgres::PgTypeInfo) -> bool {
59        *ty == postgres::PgTypeInfo::with_name("_TEXT")
60            || *ty == postgres::PgTypeInfo::with_name("_VARCHAR")
61    }
62}
63
64impl<'r> Decode<'r, Sqlite> for OID {
65    fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
66        let value = <&str as Decode<Sqlite>>::decode(value)?;
67        value.parse().map_err(Into::into)
68    }
69}
70
71impl<'r> Decode<'r, Postgres> for OID {
72    fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
73        let value = <&str as Decode<Postgres>>::decode(value)?;
74        value.parse().map_err(Into::into)
75    }
76}
77
78impl<'q> Encode<'q, Sqlite> for OID {
79    fn encode(self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> ResultIsNull {
80        args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
81            self.to_string(),
82        )));
83
84        Ok(IsNull::No)
85    }
86    fn encode_by_ref(&self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> ResultIsNull {
87        args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
88            self.to_string(),
89        )));
90        Ok(IsNull::No)
91    }
92
93    fn size_hint(&self) -> usize {
94        self.as_str().len()
95    }
96}
97
98impl Encode<'_, Postgres> for OID {
99    fn encode_by_ref(&self, buf: &mut postgres::PgArgumentBuffer) -> ResultIsNull {
100        <&str as Encode<Postgres>>::encode(self.as_str(), buf)
101    }
102    fn size_hint(&self) -> usize {
103        self.as_str().len()
104    }
105}
106
107#[cfg(feature = "acl")]
108impl Type<Sqlite> for OIDMask {
109    fn type_info() -> sqlite::SqliteTypeInfo {
110        <str as Type<Sqlite>>::type_info()
111    }
112}
113
114#[cfg(feature = "acl")]
115impl Type<Postgres> for OIDMask {
116    fn type_info() -> postgres::PgTypeInfo {
117        <str as Type<Postgres>>::type_info()
118    }
119    fn compatible(ty: &postgres::PgTypeInfo) -> bool {
120        *ty == postgres::PgTypeInfo::with_name("VARCHAR")
121            || *ty == postgres::PgTypeInfo::with_name("TEXT")
122    }
123}
124
125#[cfg(feature = "acl")]
126impl postgres::PgHasArrayType for OIDMask {
127    fn array_type_info() -> postgres::PgTypeInfo {
128        postgres::PgTypeInfo::with_name("_TEXT")
129    }
130
131    fn array_compatible(ty: &postgres::PgTypeInfo) -> bool {
132        *ty == postgres::PgTypeInfo::with_name("_TEXT")
133            || *ty == postgres::PgTypeInfo::with_name("_VARCHAR")
134    }
135}
136
137#[cfg(feature = "acl")]
138impl<'r> Decode<'r, Sqlite> for OIDMask {
139    fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
140        let value = <&str as Decode<Sqlite>>::decode(value)?;
141        value.parse().map_err(Into::into)
142    }
143}
144
145#[cfg(feature = "acl")]
146impl<'r> Decode<'r, Postgres> for OIDMask {
147    fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
148        let value = <&str as Decode<Postgres>>::decode(value)?;
149        value.parse().map_err(Into::into)
150    }
151}
152
153#[cfg(feature = "acl")]
154impl<'q> Encode<'q, Sqlite> for OIDMask {
155    fn encode(self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> ResultIsNull {
156        args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
157            self.to_string(),
158        )));
159
160        Ok(IsNull::No)
161    }
162    fn encode_by_ref(&self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> ResultIsNull {
163        args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
164            self.to_string(),
165        )));
166        Ok(IsNull::No)
167    }
168}
169
170#[cfg(feature = "acl")]
171impl Encode<'_, Postgres> for OIDMask {
172    #[allow(clippy::needless_borrows_for_generic_args)]
173    fn encode_by_ref(&self, buf: &mut postgres::PgArgumentBuffer) -> ResultIsNull {
174        <&str as Encode<Postgres>>::encode(&self.to_string(), buf)
175    }
176}
177
178impl Type<Sqlite> for Value {
179    fn type_info() -> sqlite::SqliteTypeInfo {
180        <str as Type<Sqlite>>::type_info()
181    }
182
183    fn compatible(ty: &sqlite::SqliteTypeInfo) -> bool {
184        <&str as Type<Sqlite>>::compatible(ty)
185    }
186}
187
188impl Type<Postgres> for Value {
189    fn type_info() -> postgres::PgTypeInfo {
190        postgres::PgTypeInfo::with_name("JSONB")
191    }
192}
193
194impl Encode<'_, Sqlite> for Value {
195    fn encode_by_ref(&self, buf: &mut Vec<sqlite::SqliteArgumentValue<'_>>) -> ResultIsNull {
196        let json_string_value =
197            serde_json::to_string(self).expect("serde_json failed to convert to string");
198        Encode::<Sqlite>::encode(json_string_value, buf)
199    }
200}
201
202impl<'r> Decode<'r, Sqlite> for Value {
203    fn decode(value: sqlite::SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
204        let string_value = <&str as Decode<Sqlite>>::decode(value)?;
205
206        serde_json::from_str(string_value).map_err(Into::into)
207    }
208}
209
210impl Encode<'_, Postgres> for Value {
211    fn encode_by_ref(&self, buf: &mut postgres::PgArgumentBuffer) -> ResultIsNull {
212        buf.push(1);
213        serde_json::to_writer(&mut **buf, &self)
214            .expect("failed to serialize to JSON for encoding on transmission to the database");
215        Ok(IsNull::No)
216    }
217}
218
219impl<'r> Decode<'r, Postgres> for Value {
220    fn decode(value: postgres::PgValueRef<'r>) -> Result<Self, BoxDynError> {
221        let buf = value.as_bytes()?;
222        assert_eq!(buf[0], 1, "unsupported JSONB format version {}", buf[0]);
223        serde_json::from_slice(&buf[1..]).map_err(Into::into)
224    }
225}
226
227#[cfg(feature = "time")]
228mod time_impl {
229    use super::ResultIsNull;
230    use crate::time::Time;
231    use sqlx::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef};
232    use sqlx::sqlite::{SqliteArgumentValue, SqliteTypeInfo, SqliteValueRef};
233    use sqlx::{Decode, Encode, Postgres, Sqlite, Type, encode::IsNull, error::BoxDynError};
234
235    const J2000_EPOCH_US: i64 = 946_684_800_000_000;
236
237    impl Type<Sqlite> for Time {
238        fn type_info() -> SqliteTypeInfo {
239            <i64 as Type<Sqlite>>::type_info()
240        }
241
242        fn compatible(ty: &SqliteTypeInfo) -> bool {
243            *ty == <i64 as Type<Sqlite>>::type_info()
244                || *ty == <i32 as Type<Sqlite>>::type_info()
245                || *ty == <i16 as Type<Sqlite>>::type_info()
246                || *ty == <i8 as Type<Sqlite>>::type_info()
247        }
248    }
249
250    impl<'q> Encode<'q, Sqlite> for Time {
251        fn encode_by_ref(&self, args: &mut Vec<SqliteArgumentValue<'q>>) -> ResultIsNull {
252            args.push(SqliteArgumentValue::Int64(
253                i64::try_from(self.timestamp_ns()).expect("timestamp too large"),
254            ));
255
256            Ok(IsNull::No)
257        }
258    }
259
260    impl<'r> Decode<'r, Sqlite> for Time {
261        fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
262            let value = <i64 as Decode<Sqlite>>::decode(value)?;
263            Ok(Time::from_timestamp_ns(
264                value.try_into().unwrap_or_default(),
265            ))
266        }
267    }
268
269    impl Type<Postgres> for Time {
270        fn type_info() -> PgTypeInfo {
271            PgTypeInfo::with_name("TIMESTAMPTZ")
272        }
273        fn compatible(ty: &PgTypeInfo) -> bool {
274            *ty == PgTypeInfo::with_name("TIMESTAMPTZ") || *ty == PgTypeInfo::with_name("TIMESTAMP")
275        }
276    }
277
278    impl Encode<'_, Postgres> for Time {
279        fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> ResultIsNull {
280            let us =
281                i64::try_from(self.timestamp_us()).expect("timestamp too large") - J2000_EPOCH_US;
282            Encode::<Postgres>::encode(us, buf)
283        }
284
285        fn size_hint(&self) -> usize {
286            std::mem::size_of::<i64>()
287        }
288    }
289
290    impl<'r> Decode<'r, Postgres> for Time {
291        fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
292            let us: i64 = Decode::<Postgres>::decode(value)?;
293            Ok(Time::from_timestamp_us(
294                (us + J2000_EPOCH_US).try_into().unwrap_or_default(),
295            ))
296        }
297    }
298}
299
300/// # Panics
301///
302/// Will panic if not initialized
303#[allow(clippy::module_name_repetitions)]
304#[inline]
305pub fn db_pool() -> &'static DbPool {
306    DB_POOL.get().unwrap()
307}
308
309#[allow(clippy::module_name_repetitions)]
310pub enum DbPool {
311    Sqlite(SqlitePool),
312    Postgres(PgPool),
313}
314
315#[allow(clippy::module_name_repetitions)]
316#[derive(Debug, Copy, Clone, Eq, PartialEq)]
317pub enum DbKind {
318    Sqlite,
319    Postgres,
320}
321
322impl DbPool {
323    pub async fn begin(&self) -> Result<Transaction<'_>, sqlx::Error> {
324        match self {
325            DbPool::Sqlite(p) => Ok(Transaction::Sqlite(p.begin().await?)),
326            DbPool::Postgres(p) => Ok(Transaction::Postgres(p.begin().await?)),
327        }
328    }
329    pub fn kind(&self) -> DbKind {
330        match self {
331            DbPool::Sqlite(_) => DbKind::Sqlite,
332            DbPool::Postgres(_) => DbKind::Postgres,
333        }
334    }
335    pub async fn execute(&self, q: &str) -> EResult<()> {
336        match self {
337            DbPool::Sqlite(p) => {
338                sqlx::query(q).execute(p).await?;
339            }
340            DbPool::Postgres(p) => {
341                sqlx::query(q).execute(p).await?;
342            }
343        }
344        Ok(())
345    }
346}
347
348pub enum Transaction<'c> {
349    Sqlite(sqlx::Transaction<'c, sqlx::sqlite::Sqlite>),
350    Postgres(sqlx::Transaction<'c, sqlx::postgres::Postgres>),
351}
352
353impl Transaction<'_> {
354    pub async fn commit(self) -> Result<(), sqlx::Error> {
355        match self {
356            Transaction::Sqlite(tx) => tx.commit().await,
357            Transaction::Postgres(tx) => tx.commit().await,
358        }
359    }
360    pub fn kind(&self) -> DbKind {
361        match self {
362            Transaction::Sqlite(_) => DbKind::Sqlite,
363            Transaction::Postgres(_) => DbKind::Postgres,
364        }
365    }
366    pub async fn execute(&mut self, q: &str) -> EResult<()> {
367        match self {
368            Transaction::Sqlite(p) => {
369                sqlx::query(q).execute(&mut **p).await?;
370            }
371            Transaction::Postgres(p) => {
372                sqlx::query(q).execute(&mut **p).await?;
373            }
374        }
375        Ok(())
376    }
377}
378
379/// Initialize database, must be called first and only once,
380/// enables module-wide pool
381#[allow(clippy::module_name_repetitions)]
382pub async fn db_init(conn: &str, pool_size: u32, timeout: Duration) -> EResult<()> {
383    DB_POOL
384        .set(create_pool(conn, pool_size, timeout).await?)
385        .map_err(|_| Error::core("unable to set DB_POOL"))?;
386    Ok(())
387}
388
389/// Creates a pool to use it without the module
390pub async fn create_pool(conn: &str, pool_size: u32, timeout: Duration) -> EResult<DbPool> {
391    if conn.starts_with("sqlite://") {
392        let opts = SqliteConnectOptions::from_str(conn)?
393            .create_if_missing(true)
394            .synchronous(sqlx::sqlite::SqliteSynchronous::Extra)
395            .busy_timeout(timeout)
396            .log_statements(log::LevelFilter::Trace)
397            .log_slow_statements(log::LevelFilter::Warn, timeout);
398        Ok(DbPool::Sqlite(
399            SqlitePoolOptions::new()
400                .max_connections(pool_size)
401                .acquire_timeout(timeout)
402                .connect_with(opts)
403                .await?,
404        ))
405    } else if conn.starts_with("postgres://") {
406        let opts = PgConnectOptions::from_str(conn)?
407            .log_statements(log::LevelFilter::Trace)
408            .log_slow_statements(log::LevelFilter::Warn, timeout);
409        Ok(DbPool::Postgres(
410            PgPoolOptions::new()
411                .max_connections(pool_size)
412                .acquire_timeout(timeout)
413                .connect_with(opts)
414                .await?,
415        ))
416    } else {
417        Err(Error::unsupported("Unsupported database kind"))
418    }
419}