eva_common/
db.rs

1/// Database functions (sqlx-based) and dynamic wrapper
2/// Currently supported sqlx 0.6 only
3///
4/// Supported databases: Sqlite, PostgresSQL
5///
6/// For Value type use JSONB only
7/// For OID use VARCHAR(1024)
8///
9/// For Time (feature "time" enabled) type use INTEGER for Sqlite and TIMESTAMP/TIMESTAMPTZ for
10/// Postgres
11#[cfg(feature = "acl")]
12use crate::acl::OIDMask;
13use crate::{value::Value, EResult, Error, OID};
14use once_cell::sync::OnceCell;
15use sqlx::encode::IsNull;
16use sqlx::error::BoxDynError;
17use sqlx::postgres::{self, PgConnectOptions, PgPool, PgPoolOptions};
18use sqlx::sqlite::{self, SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
19use sqlx::{database, ConnectOptions, Database, Decode, Encode};
20use sqlx::{Postgres, Sqlite, Type};
21use std::borrow::Cow;
22use std::str::FromStr;
23use std::time::Duration;
24
25pub mod prelude {
26    pub use super::{db_init, db_pool, DbKind, DbPool, Transaction};
27}
28
29static DB_POOL: OnceCell<DbPool> = OnceCell::new();
30
31impl Type<Sqlite> for OID {
32    fn type_info() -> sqlite::SqliteTypeInfo {
33        <str as Type<Sqlite>>::type_info()
34    }
35}
36
37impl Type<Postgres> for OID {
38    fn type_info() -> postgres::PgTypeInfo {
39        <str as Type<Postgres>>::type_info()
40    }
41    fn compatible(ty: &postgres::PgTypeInfo) -> bool {
42        *ty == postgres::PgTypeInfo::with_name("VARCHAR")
43            || *ty == postgres::PgTypeInfo::with_name("TEXT")
44    }
45}
46
47impl postgres::PgHasArrayType for OID {
48    fn array_type_info() -> postgres::PgTypeInfo {
49        postgres::PgTypeInfo::with_name("_TEXT")
50    }
51
52    fn array_compatible(ty: &postgres::PgTypeInfo) -> bool {
53        *ty == postgres::PgTypeInfo::with_name("_TEXT")
54            || *ty == postgres::PgTypeInfo::with_name("_VARCHAR")
55    }
56}
57
58impl<'r, DB: Database> Decode<'r, DB> for OID
59where
60    &'r str: Decode<'r, DB>,
61{
62    fn decode(value: <DB as database::HasValueRef<'r>>::ValueRef) -> Result<Self, BoxDynError> {
63        let value = <&str as Decode<DB>>::decode(value)?;
64        value.parse().map_err(Into::into)
65    }
66}
67
68impl<'q> Encode<'q, Sqlite> for OID {
69    fn encode(self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> IsNull {
70        args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
71            self.to_string(),
72        )));
73
74        IsNull::No
75    }
76    fn encode_by_ref(&self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> IsNull {
77        args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
78            self.to_string(),
79        )));
80        IsNull::No
81    }
82
83    fn size_hint(&self) -> usize {
84        self.as_str().len()
85    }
86}
87
88impl Encode<'_, Postgres> for OID {
89    fn encode_by_ref(&self, buf: &mut postgres::PgArgumentBuffer) -> IsNull {
90        <&str as Encode<Postgres>>::encode(self.as_str(), buf)
91    }
92    fn size_hint(&self) -> usize {
93        self.as_str().len()
94    }
95}
96
97#[cfg(feature = "acl")]
98impl Type<Sqlite> for OIDMask {
99    fn type_info() -> sqlite::SqliteTypeInfo {
100        <str as Type<Sqlite>>::type_info()
101    }
102}
103
104#[cfg(feature = "acl")]
105impl Type<Postgres> for OIDMask {
106    fn type_info() -> postgres::PgTypeInfo {
107        <str as Type<Postgres>>::type_info()
108    }
109    fn compatible(ty: &postgres::PgTypeInfo) -> bool {
110        *ty == postgres::PgTypeInfo::with_name("VARCHAR")
111            || *ty == postgres::PgTypeInfo::with_name("TEXT")
112    }
113}
114
115#[cfg(feature = "acl")]
116impl postgres::PgHasArrayType for OIDMask {
117    fn array_type_info() -> postgres::PgTypeInfo {
118        postgres::PgTypeInfo::with_name("_TEXT")
119    }
120
121    fn array_compatible(ty: &postgres::PgTypeInfo) -> bool {
122        *ty == postgres::PgTypeInfo::with_name("_TEXT")
123            || *ty == postgres::PgTypeInfo::with_name("_VARCHAR")
124    }
125}
126
127#[cfg(feature = "acl")]
128impl<'r, DB: Database> Decode<'r, DB> for OIDMask
129where
130    &'r str: Decode<'r, DB>,
131{
132    fn decode(value: <DB as database::HasValueRef<'r>>::ValueRef) -> Result<Self, BoxDynError> {
133        let value = <&str as Decode<DB>>::decode(value)?;
134        value.parse().map_err(Into::into)
135    }
136}
137
138#[cfg(feature = "acl")]
139impl<'q> Encode<'q, Sqlite> for OIDMask {
140    fn encode(self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> IsNull {
141        args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
142            self.to_string(),
143        )));
144
145        IsNull::No
146    }
147    fn encode_by_ref(&self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> IsNull {
148        args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
149            self.to_string(),
150        )));
151        IsNull::No
152    }
153}
154
155#[cfg(feature = "acl")]
156impl Encode<'_, Postgres> for OIDMask {
157    #[allow(clippy::needless_borrows_for_generic_args)]
158    fn encode_by_ref(&self, buf: &mut postgres::PgArgumentBuffer) -> IsNull {
159        <&str as Encode<Postgres>>::encode(&self.to_string(), buf)
160    }
161}
162
163impl Type<Sqlite> for Value {
164    fn type_info() -> sqlite::SqliteTypeInfo {
165        <str as Type<Sqlite>>::type_info()
166    }
167
168    fn compatible(ty: &sqlite::SqliteTypeInfo) -> bool {
169        <&str as Type<Sqlite>>::compatible(ty)
170    }
171}
172
173impl Type<Postgres> for Value {
174    fn type_info() -> postgres::PgTypeInfo {
175        postgres::PgTypeInfo::with_name("JSONB")
176    }
177}
178
179impl Encode<'_, Sqlite> for Value {
180    fn encode_by_ref(&self, buf: &mut Vec<sqlite::SqliteArgumentValue<'_>>) -> IsNull {
181        let json_string_value =
182            serde_json::to_string(self).expect("serde_json failed to convert to string");
183        Encode::<Sqlite>::encode(json_string_value, buf)
184    }
185}
186
187impl<'r> Decode<'r, Sqlite> for Value {
188    fn decode(value: sqlite::SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
189        let string_value = <&str as Decode<Sqlite>>::decode(value)?;
190
191        serde_json::from_str(string_value).map_err(Into::into)
192    }
193}
194
195impl Encode<'_, Postgres> for Value {
196    fn encode_by_ref(&self, buf: &mut postgres::PgArgumentBuffer) -> IsNull {
197        buf.push(1);
198        serde_json::to_writer(&mut **buf, &self)
199            .expect("failed to serialize to JSON for encoding on transmission to the database");
200        IsNull::No
201    }
202}
203
204impl<'r> Decode<'r, Postgres> for Value {
205    fn decode(value: postgres::PgValueRef<'r>) -> Result<Self, BoxDynError> {
206        let buf = value.as_bytes()?;
207        assert_eq!(buf[0], 1, "unsupported JSONB format version {}", buf[0]);
208        serde_json::from_slice(&buf[1..]).map_err(Into::into)
209    }
210}
211
212#[cfg(feature = "time")]
213mod time_impl {
214    use crate::time::Time;
215    use sqlx::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef};
216    use sqlx::sqlite::{SqliteArgumentValue, SqliteTypeInfo, SqliteValueRef};
217    use sqlx::{encode::IsNull, error::BoxDynError, Decode, Encode, Postgres, Sqlite, Type};
218
219    const J2000_EPOCH_US: i64 = 946_684_800_000_000;
220
221    impl Type<Sqlite> for Time {
222        fn type_info() -> SqliteTypeInfo {
223            <i64 as Type<Sqlite>>::type_info()
224        }
225
226        fn compatible(ty: &SqliteTypeInfo) -> bool {
227            *ty == <i64 as Type<Sqlite>>::type_info()
228                || *ty == <i32 as Type<Sqlite>>::type_info()
229                || *ty == <i16 as Type<Sqlite>>::type_info()
230                || *ty == <i8 as Type<Sqlite>>::type_info()
231        }
232    }
233
234    impl<'q> Encode<'q, Sqlite> for Time {
235        fn encode_by_ref(&self, args: &mut Vec<SqliteArgumentValue<'q>>) -> IsNull {
236            args.push(SqliteArgumentValue::Int64(
237                i64::try_from(self.timestamp_ns()).expect("timestamp too large"),
238            ));
239
240            IsNull::No
241        }
242    }
243
244    impl<'r> Decode<'r, Sqlite> for Time {
245        fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
246            let value = <i64 as Decode<Sqlite>>::decode(value)?;
247            Ok(Time::from_timestamp_ns(
248                value.try_into().unwrap_or_default(),
249            ))
250        }
251    }
252
253    impl Type<Postgres> for Time {
254        fn type_info() -> PgTypeInfo {
255            PgTypeInfo::with_name("TIMESTAMPTZ")
256        }
257        fn compatible(ty: &PgTypeInfo) -> bool {
258            *ty == PgTypeInfo::with_name("TIMESTAMPTZ") || *ty == PgTypeInfo::with_name("TIMESTAMP")
259        }
260    }
261
262    impl Encode<'_, Postgres> for Time {
263        fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
264            let us =
265                i64::try_from(self.timestamp_us()).expect("timestamp too large") - J2000_EPOCH_US;
266            Encode::<Postgres>::encode(us, buf)
267        }
268
269        fn size_hint(&self) -> usize {
270            std::mem::size_of::<i64>()
271        }
272    }
273
274    impl<'r> Decode<'r, Postgres> for Time {
275        fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
276            let us: i64 = Decode::<Postgres>::decode(value)?;
277            Ok(Time::from_timestamp_us(
278                (us + J2000_EPOCH_US).try_into().unwrap_or_default(),
279            ))
280        }
281    }
282}
283
284/// # Panics
285///
286/// Will panic if not initialized
287#[allow(clippy::module_name_repetitions)]
288#[inline]
289pub fn db_pool() -> &'static DbPool {
290    DB_POOL.get().unwrap()
291}
292
293#[allow(clippy::module_name_repetitions)]
294pub enum DbPool {
295    Sqlite(SqlitePool),
296    Postgres(PgPool),
297}
298
299#[allow(clippy::module_name_repetitions)]
300#[derive(Debug, Copy, Clone, Eq, PartialEq)]
301pub enum DbKind {
302    Sqlite,
303    Postgres,
304}
305
306impl DbPool {
307    pub async fn begin(&self) -> Result<Transaction<'_>, sqlx::Error> {
308        match self {
309            DbPool::Sqlite(p) => Ok(Transaction::Sqlite(p.begin().await?)),
310            DbPool::Postgres(p) => Ok(Transaction::Postgres(p.begin().await?)),
311        }
312    }
313    pub fn kind(&self) -> DbKind {
314        match self {
315            DbPool::Sqlite(_) => DbKind::Sqlite,
316            DbPool::Postgres(_) => DbKind::Postgres,
317        }
318    }
319    pub async fn execute(&self, q: &str) -> EResult<()> {
320        match self {
321            DbPool::Sqlite(ref p) => {
322                sqlx::query(q).execute(p).await?;
323            }
324            DbPool::Postgres(ref p) => {
325                sqlx::query(q).execute(p).await?;
326            }
327        }
328        Ok(())
329    }
330}
331
332pub enum Transaction<'c> {
333    Sqlite(sqlx::Transaction<'c, sqlx::sqlite::Sqlite>),
334    Postgres(sqlx::Transaction<'c, sqlx::postgres::Postgres>),
335}
336
337impl<'c> Transaction<'c> {
338    pub async fn commit(self) -> Result<(), sqlx::Error> {
339        match self {
340            Transaction::Sqlite(tx) => tx.commit().await,
341            Transaction::Postgres(tx) => tx.commit().await,
342        }
343    }
344    pub fn kind(&self) -> DbKind {
345        match self {
346            Transaction::Sqlite(_) => DbKind::Sqlite,
347            Transaction::Postgres(_) => DbKind::Postgres,
348        }
349    }
350    pub async fn execute(&mut self, q: &str) -> EResult<()> {
351        match self {
352            Transaction::Sqlite(ref mut p) => {
353                sqlx::query(q).execute(p).await?;
354            }
355            Transaction::Postgres(ref mut p) => {
356                sqlx::query(q).execute(p).await?;
357            }
358        }
359        Ok(())
360    }
361}
362
363/// Initialize database, must be called first and only once,
364/// enables module-wide pool
365#[allow(clippy::module_name_repetitions)]
366pub async fn db_init(conn: &str, pool_size: u32, timeout: Duration) -> EResult<()> {
367    DB_POOL
368        .set(create_pool(conn, pool_size, timeout).await?)
369        .map_err(|_| Error::core("unable to set DB_POOL"))?;
370    Ok(())
371}
372
373/// Creates a pool to use it without the module
374pub async fn create_pool(conn: &str, pool_size: u32, timeout: Duration) -> EResult<DbPool> {
375    if conn.starts_with("sqlite://") {
376        let mut opts = SqliteConnectOptions::from_str(conn)?
377            .create_if_missing(true)
378            .synchronous(sqlx::sqlite::SqliteSynchronous::Extra)
379            .busy_timeout(timeout);
380        opts.log_statements(log::LevelFilter::Trace)
381            .log_slow_statements(log::LevelFilter::Warn, timeout);
382        Ok(DbPool::Sqlite(
383            SqlitePoolOptions::new()
384                .max_connections(pool_size)
385                .acquire_timeout(timeout)
386                .connect_with(opts)
387                .await?,
388        ))
389    } else if conn.starts_with("postgres://") {
390        let mut opts = PgConnectOptions::from_str(conn)?;
391        opts.log_statements(log::LevelFilter::Trace)
392            .log_slow_statements(log::LevelFilter::Warn, timeout);
393        Ok(DbPool::Postgres(
394            PgPoolOptions::new()
395                .max_connections(pool_size)
396                .acquire_timeout(timeout)
397                .connect_with(opts)
398                .await?,
399        ))
400    } else {
401        Err(Error::unsupported("Unsupported database kind"))
402    }
403}