1#[cfg(feature = "acl")]
12use crate::acl::OIDMask;
13use crate::{value::Value, EResult, Error, OID};
14use once_cell::sync::OnceCell;
15use sqlx::encode::IsNull;
16use sqlx::error::BoxDynError;
17use sqlx::postgres::{self, PgConnectOptions, PgPool, PgPoolOptions};
18use sqlx::sqlite::{self, SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
19use sqlx::{database, ConnectOptions, Database, Decode, Encode};
20use sqlx::{Postgres, Sqlite, Type};
21use std::borrow::Cow;
22use std::str::FromStr;
23use std::time::Duration;
24
25pub mod prelude {
26 pub use super::{db_init, db_pool, DbKind, DbPool, Transaction};
27}
28
29static DB_POOL: OnceCell<DbPool> = OnceCell::new();
30
31impl Type<Sqlite> for OID {
32 fn type_info() -> sqlite::SqliteTypeInfo {
33 <str as Type<Sqlite>>::type_info()
34 }
35}
36
37impl Type<Postgres> for OID {
38 fn type_info() -> postgres::PgTypeInfo {
39 <str as Type<Postgres>>::type_info()
40 }
41 fn compatible(ty: &postgres::PgTypeInfo) -> bool {
42 *ty == postgres::PgTypeInfo::with_name("VARCHAR")
43 || *ty == postgres::PgTypeInfo::with_name("TEXT")
44 }
45}
46
47impl postgres::PgHasArrayType for OID {
48 fn array_type_info() -> postgres::PgTypeInfo {
49 postgres::PgTypeInfo::with_name("_TEXT")
50 }
51
52 fn array_compatible(ty: &postgres::PgTypeInfo) -> bool {
53 *ty == postgres::PgTypeInfo::with_name("_TEXT")
54 || *ty == postgres::PgTypeInfo::with_name("_VARCHAR")
55 }
56}
57
58impl<'r, DB: Database> Decode<'r, DB> for OID
59where
60 &'r str: Decode<'r, DB>,
61{
62 fn decode(value: <DB as database::HasValueRef<'r>>::ValueRef) -> Result<Self, BoxDynError> {
63 let value = <&str as Decode<DB>>::decode(value)?;
64 value.parse().map_err(Into::into)
65 }
66}
67
68impl<'q> Encode<'q, Sqlite> for OID {
69 fn encode(self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> IsNull {
70 args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
71 self.to_string(),
72 )));
73
74 IsNull::No
75 }
76 fn encode_by_ref(&self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> IsNull {
77 args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
78 self.to_string(),
79 )));
80 IsNull::No
81 }
82
83 fn size_hint(&self) -> usize {
84 self.as_str().len()
85 }
86}
87
88impl Encode<'_, Postgres> for OID {
89 fn encode_by_ref(&self, buf: &mut postgres::PgArgumentBuffer) -> IsNull {
90 <&str as Encode<Postgres>>::encode(self.as_str(), buf)
91 }
92 fn size_hint(&self) -> usize {
93 self.as_str().len()
94 }
95}
96
97#[cfg(feature = "acl")]
98impl Type<Sqlite> for OIDMask {
99 fn type_info() -> sqlite::SqliteTypeInfo {
100 <str as Type<Sqlite>>::type_info()
101 }
102}
103
104#[cfg(feature = "acl")]
105impl Type<Postgres> for OIDMask {
106 fn type_info() -> postgres::PgTypeInfo {
107 <str as Type<Postgres>>::type_info()
108 }
109 fn compatible(ty: &postgres::PgTypeInfo) -> bool {
110 *ty == postgres::PgTypeInfo::with_name("VARCHAR")
111 || *ty == postgres::PgTypeInfo::with_name("TEXT")
112 }
113}
114
115#[cfg(feature = "acl")]
116impl postgres::PgHasArrayType for OIDMask {
117 fn array_type_info() -> postgres::PgTypeInfo {
118 postgres::PgTypeInfo::with_name("_TEXT")
119 }
120
121 fn array_compatible(ty: &postgres::PgTypeInfo) -> bool {
122 *ty == postgres::PgTypeInfo::with_name("_TEXT")
123 || *ty == postgres::PgTypeInfo::with_name("_VARCHAR")
124 }
125}
126
127#[cfg(feature = "acl")]
128impl<'r, DB: Database> Decode<'r, DB> for OIDMask
129where
130 &'r str: Decode<'r, DB>,
131{
132 fn decode(value: <DB as database::HasValueRef<'r>>::ValueRef) -> Result<Self, BoxDynError> {
133 let value = <&str as Decode<DB>>::decode(value)?;
134 value.parse().map_err(Into::into)
135 }
136}
137
138#[cfg(feature = "acl")]
139impl<'q> Encode<'q, Sqlite> for OIDMask {
140 fn encode(self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> IsNull {
141 args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
142 self.to_string(),
143 )));
144
145 IsNull::No
146 }
147 fn encode_by_ref(&self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> IsNull {
148 args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
149 self.to_string(),
150 )));
151 IsNull::No
152 }
153}
154
155#[cfg(feature = "acl")]
156impl Encode<'_, Postgres> for OIDMask {
157 #[allow(clippy::needless_borrows_for_generic_args)]
158 fn encode_by_ref(&self, buf: &mut postgres::PgArgumentBuffer) -> IsNull {
159 <&str as Encode<Postgres>>::encode(&self.to_string(), buf)
160 }
161}
162
163impl Type<Sqlite> for Value {
164 fn type_info() -> sqlite::SqliteTypeInfo {
165 <str as Type<Sqlite>>::type_info()
166 }
167
168 fn compatible(ty: &sqlite::SqliteTypeInfo) -> bool {
169 <&str as Type<Sqlite>>::compatible(ty)
170 }
171}
172
173impl Type<Postgres> for Value {
174 fn type_info() -> postgres::PgTypeInfo {
175 postgres::PgTypeInfo::with_name("JSONB")
176 }
177}
178
179impl Encode<'_, Sqlite> for Value {
180 fn encode_by_ref(&self, buf: &mut Vec<sqlite::SqliteArgumentValue<'_>>) -> IsNull {
181 let json_string_value =
182 serde_json::to_string(self).expect("serde_json failed to convert to string");
183 Encode::<Sqlite>::encode(json_string_value, buf)
184 }
185}
186
187impl<'r> Decode<'r, Sqlite> for Value {
188 fn decode(value: sqlite::SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
189 let string_value = <&str as Decode<Sqlite>>::decode(value)?;
190
191 serde_json::from_str(string_value).map_err(Into::into)
192 }
193}
194
195impl Encode<'_, Postgres> for Value {
196 fn encode_by_ref(&self, buf: &mut postgres::PgArgumentBuffer) -> IsNull {
197 buf.push(1);
198 serde_json::to_writer(&mut **buf, &self)
199 .expect("failed to serialize to JSON for encoding on transmission to the database");
200 IsNull::No
201 }
202}
203
204impl<'r> Decode<'r, Postgres> for Value {
205 fn decode(value: postgres::PgValueRef<'r>) -> Result<Self, BoxDynError> {
206 let buf = value.as_bytes()?;
207 assert_eq!(buf[0], 1, "unsupported JSONB format version {}", buf[0]);
208 serde_json::from_slice(&buf[1..]).map_err(Into::into)
209 }
210}
211
212#[cfg(feature = "time")]
213mod time_impl {
214 use crate::time::Time;
215 use sqlx::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef};
216 use sqlx::sqlite::{SqliteArgumentValue, SqliteTypeInfo, SqliteValueRef};
217 use sqlx::{encode::IsNull, error::BoxDynError, Decode, Encode, Postgres, Sqlite, Type};
218
219 const J2000_EPOCH_US: i64 = 946_684_800_000_000;
220
221 impl Type<Sqlite> for Time {
222 fn type_info() -> SqliteTypeInfo {
223 <i64 as Type<Sqlite>>::type_info()
224 }
225
226 fn compatible(ty: &SqliteTypeInfo) -> bool {
227 *ty == <i64 as Type<Sqlite>>::type_info()
228 || *ty == <i32 as Type<Sqlite>>::type_info()
229 || *ty == <i16 as Type<Sqlite>>::type_info()
230 || *ty == <i8 as Type<Sqlite>>::type_info()
231 }
232 }
233
234 impl<'q> Encode<'q, Sqlite> for Time {
235 fn encode_by_ref(&self, args: &mut Vec<SqliteArgumentValue<'q>>) -> IsNull {
236 args.push(SqliteArgumentValue::Int64(
237 i64::try_from(self.timestamp_ns()).expect("timestamp too large"),
238 ));
239
240 IsNull::No
241 }
242 }
243
244 impl<'r> Decode<'r, Sqlite> for Time {
245 fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
246 let value = <i64 as Decode<Sqlite>>::decode(value)?;
247 Ok(Time::from_timestamp_ns(
248 value.try_into().unwrap_or_default(),
249 ))
250 }
251 }
252
253 impl Type<Postgres> for Time {
254 fn type_info() -> PgTypeInfo {
255 PgTypeInfo::with_name("TIMESTAMPTZ")
256 }
257 fn compatible(ty: &PgTypeInfo) -> bool {
258 *ty == PgTypeInfo::with_name("TIMESTAMPTZ") || *ty == PgTypeInfo::with_name("TIMESTAMP")
259 }
260 }
261
262 impl Encode<'_, Postgres> for Time {
263 fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
264 let us =
265 i64::try_from(self.timestamp_us()).expect("timestamp too large") - J2000_EPOCH_US;
266 Encode::<Postgres>::encode(us, buf)
267 }
268
269 fn size_hint(&self) -> usize {
270 std::mem::size_of::<i64>()
271 }
272 }
273
274 impl<'r> Decode<'r, Postgres> for Time {
275 fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
276 let us: i64 = Decode::<Postgres>::decode(value)?;
277 Ok(Time::from_timestamp_us(
278 (us + J2000_EPOCH_US).try_into().unwrap_or_default(),
279 ))
280 }
281 }
282}
283
284#[allow(clippy::module_name_repetitions)]
288#[inline]
289pub fn db_pool() -> &'static DbPool {
290 DB_POOL.get().unwrap()
291}
292
293#[allow(clippy::module_name_repetitions)]
294pub enum DbPool {
295 Sqlite(SqlitePool),
296 Postgres(PgPool),
297}
298
299#[allow(clippy::module_name_repetitions)]
300#[derive(Debug, Copy, Clone, Eq, PartialEq)]
301pub enum DbKind {
302 Sqlite,
303 Postgres,
304}
305
306impl DbPool {
307 pub async fn begin(&self) -> Result<Transaction<'_>, sqlx::Error> {
308 match self {
309 DbPool::Sqlite(p) => Ok(Transaction::Sqlite(p.begin().await?)),
310 DbPool::Postgres(p) => Ok(Transaction::Postgres(p.begin().await?)),
311 }
312 }
313 pub fn kind(&self) -> DbKind {
314 match self {
315 DbPool::Sqlite(_) => DbKind::Sqlite,
316 DbPool::Postgres(_) => DbKind::Postgres,
317 }
318 }
319 pub async fn execute(&self, q: &str) -> EResult<()> {
320 match self {
321 DbPool::Sqlite(ref p) => {
322 sqlx::query(q).execute(p).await?;
323 }
324 DbPool::Postgres(ref p) => {
325 sqlx::query(q).execute(p).await?;
326 }
327 }
328 Ok(())
329 }
330}
331
332pub enum Transaction<'c> {
333 Sqlite(sqlx::Transaction<'c, sqlx::sqlite::Sqlite>),
334 Postgres(sqlx::Transaction<'c, sqlx::postgres::Postgres>),
335}
336
337impl<'c> Transaction<'c> {
338 pub async fn commit(self) -> Result<(), sqlx::Error> {
339 match self {
340 Transaction::Sqlite(tx) => tx.commit().await,
341 Transaction::Postgres(tx) => tx.commit().await,
342 }
343 }
344 pub fn kind(&self) -> DbKind {
345 match self {
346 Transaction::Sqlite(_) => DbKind::Sqlite,
347 Transaction::Postgres(_) => DbKind::Postgres,
348 }
349 }
350 pub async fn execute(&mut self, q: &str) -> EResult<()> {
351 match self {
352 Transaction::Sqlite(ref mut p) => {
353 sqlx::query(q).execute(p).await?;
354 }
355 Transaction::Postgres(ref mut p) => {
356 sqlx::query(q).execute(p).await?;
357 }
358 }
359 Ok(())
360 }
361}
362
363#[allow(clippy::module_name_repetitions)]
366pub async fn db_init(conn: &str, pool_size: u32, timeout: Duration) -> EResult<()> {
367 DB_POOL
368 .set(create_pool(conn, pool_size, timeout).await?)
369 .map_err(|_| Error::core("unable to set DB_POOL"))?;
370 Ok(())
371}
372
373pub async fn create_pool(conn: &str, pool_size: u32, timeout: Duration) -> EResult<DbPool> {
375 if conn.starts_with("sqlite://") {
376 let mut opts = SqliteConnectOptions::from_str(conn)?
377 .create_if_missing(true)
378 .synchronous(sqlx::sqlite::SqliteSynchronous::Extra)
379 .busy_timeout(timeout);
380 opts.log_statements(log::LevelFilter::Trace)
381 .log_slow_statements(log::LevelFilter::Warn, timeout);
382 Ok(DbPool::Sqlite(
383 SqlitePoolOptions::new()
384 .max_connections(pool_size)
385 .acquire_timeout(timeout)
386 .connect_with(opts)
387 .await?,
388 ))
389 } else if conn.starts_with("postgres://") {
390 let mut opts = PgConnectOptions::from_str(conn)?;
391 opts.log_statements(log::LevelFilter::Trace)
392 .log_slow_statements(log::LevelFilter::Warn, timeout);
393 Ok(DbPool::Postgres(
394 PgPoolOptions::new()
395 .max_connections(pool_size)
396 .acquire_timeout(timeout)
397 .connect_with(opts)
398 .await?,
399 ))
400 } else {
401 Err(Error::unsupported("Unsupported database kind"))
402 }
403}