1#[cfg(feature = "acl")]
12use crate::acl::OIDMask;
13use crate::{EResult, Error};
14use crate::{OID, value::Value};
15use sqlx::ConnectOptions;
16use sqlx::encode::IsNull;
17use sqlx::error::BoxDynError;
18use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgValueRef};
19use sqlx::sqlite::SqliteValueRef;
20use sqlx::sqlite::{self, SqliteConnectOptions, SqlitePoolOptions};
21use sqlx::{Decode, Encode};
22use sqlx::{PgPool, SqlitePool, postgres};
23use sqlx::{Postgres, Sqlite, Type};
24use std::borrow::Cow;
25use std::str::FromStr as _;
26use std::sync::OnceLock;
27use std::time::Duration;
28
29pub mod prelude {
30 pub use super::{DbKind, DbPool, Transaction, db_init, db_pool};
31}
32
33type ResultIsNull = Result<IsNull, Box<dyn std::error::Error + Send + Sync + 'static>>;
34
35static DB_POOL: OnceLock<DbPool> = OnceLock::new();
36
37impl Type<Sqlite> for OID {
38 fn type_info() -> sqlite::SqliteTypeInfo {
39 <str as Type<Sqlite>>::type_info()
40 }
41}
42
43impl Type<Postgres> for OID {
44 fn type_info() -> postgres::PgTypeInfo {
45 <str as Type<Postgres>>::type_info()
46 }
47 fn compatible(ty: &postgres::PgTypeInfo) -> bool {
48 *ty == postgres::PgTypeInfo::with_name("VARCHAR")
49 || *ty == postgres::PgTypeInfo::with_name("TEXT")
50 }
51}
52
53impl postgres::PgHasArrayType for OID {
54 fn array_type_info() -> postgres::PgTypeInfo {
55 postgres::PgTypeInfo::with_name("_TEXT")
56 }
57
58 fn array_compatible(ty: &postgres::PgTypeInfo) -> bool {
59 *ty == postgres::PgTypeInfo::with_name("_TEXT")
60 || *ty == postgres::PgTypeInfo::with_name("_VARCHAR")
61 }
62}
63
64impl<'r> Decode<'r, Sqlite> for OID {
65 fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
66 let value = <&str as Decode<Sqlite>>::decode(value)?;
67 value.parse().map_err(Into::into)
68 }
69}
70
71impl<'r> Decode<'r, Postgres> for OID {
72 fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
73 let value = <&str as Decode<Postgres>>::decode(value)?;
74 value.parse().map_err(Into::into)
75 }
76}
77
78impl<'q> Encode<'q, Sqlite> for OID {
79 fn encode(self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> ResultIsNull {
80 args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
81 self.to_string(),
82 )));
83
84 Ok(IsNull::No)
85 }
86 fn encode_by_ref(&self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> ResultIsNull {
87 args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
88 self.to_string(),
89 )));
90 Ok(IsNull::No)
91 }
92
93 fn size_hint(&self) -> usize {
94 self.as_str().len()
95 }
96}
97
98impl Encode<'_, Postgres> for OID {
99 fn encode_by_ref(&self, buf: &mut postgres::PgArgumentBuffer) -> ResultIsNull {
100 <&str as Encode<Postgres>>::encode(self.as_str(), buf)
101 }
102 fn size_hint(&self) -> usize {
103 self.as_str().len()
104 }
105}
106
107#[cfg(feature = "acl")]
108impl Type<Sqlite> for OIDMask {
109 fn type_info() -> sqlite::SqliteTypeInfo {
110 <str as Type<Sqlite>>::type_info()
111 }
112}
113
114#[cfg(feature = "acl")]
115impl Type<Postgres> for OIDMask {
116 fn type_info() -> postgres::PgTypeInfo {
117 <str as Type<Postgres>>::type_info()
118 }
119 fn compatible(ty: &postgres::PgTypeInfo) -> bool {
120 *ty == postgres::PgTypeInfo::with_name("VARCHAR")
121 || *ty == postgres::PgTypeInfo::with_name("TEXT")
122 }
123}
124
125#[cfg(feature = "acl")]
126impl postgres::PgHasArrayType for OIDMask {
127 fn array_type_info() -> postgres::PgTypeInfo {
128 postgres::PgTypeInfo::with_name("_TEXT")
129 }
130
131 fn array_compatible(ty: &postgres::PgTypeInfo) -> bool {
132 *ty == postgres::PgTypeInfo::with_name("_TEXT")
133 || *ty == postgres::PgTypeInfo::with_name("_VARCHAR")
134 }
135}
136
137#[cfg(feature = "acl")]
138impl<'r> Decode<'r, Sqlite> for OIDMask {
139 fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
140 let value = <&str as Decode<Sqlite>>::decode(value)?;
141 value.parse().map_err(Into::into)
142 }
143}
144
145#[cfg(feature = "acl")]
146impl<'r> Decode<'r, Postgres> for OIDMask {
147 fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
148 let value = <&str as Decode<Postgres>>::decode(value)?;
149 value.parse().map_err(Into::into)
150 }
151}
152
153#[cfg(feature = "acl")]
154impl<'q> Encode<'q, Sqlite> for OIDMask {
155 fn encode(self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> ResultIsNull {
156 args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
157 self.to_string(),
158 )));
159
160 Ok(IsNull::No)
161 }
162 fn encode_by_ref(&self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> ResultIsNull {
163 args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
164 self.to_string(),
165 )));
166 Ok(IsNull::No)
167 }
168}
169
170#[cfg(feature = "acl")]
171impl Encode<'_, Postgres> for OIDMask {
172 #[allow(clippy::needless_borrows_for_generic_args)]
173 fn encode_by_ref(&self, buf: &mut postgres::PgArgumentBuffer) -> ResultIsNull {
174 <&str as Encode<Postgres>>::encode(&self.to_string(), buf)
175 }
176}
177
178impl Type<Sqlite> for Value {
179 fn type_info() -> sqlite::SqliteTypeInfo {
180 <str as Type<Sqlite>>::type_info()
181 }
182
183 fn compatible(ty: &sqlite::SqliteTypeInfo) -> bool {
184 <&str as Type<Sqlite>>::compatible(ty)
185 }
186}
187
188impl Type<Postgres> for Value {
189 fn type_info() -> postgres::PgTypeInfo {
190 postgres::PgTypeInfo::with_name("JSONB")
191 }
192}
193
194impl Encode<'_, Sqlite> for Value {
195 fn encode_by_ref(&self, buf: &mut Vec<sqlite::SqliteArgumentValue<'_>>) -> ResultIsNull {
196 let json_string_value =
197 serde_json::to_string(self).expect("serde_json failed to convert to string");
198 Encode::<Sqlite>::encode(json_string_value, buf)
199 }
200}
201
202impl<'r> Decode<'r, Sqlite> for Value {
203 fn decode(value: sqlite::SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
204 let string_value = <&str as Decode<Sqlite>>::decode(value)?;
205
206 serde_json::from_str(string_value).map_err(Into::into)
207 }
208}
209
210impl Encode<'_, Postgres> for Value {
211 fn encode_by_ref(&self, buf: &mut postgres::PgArgumentBuffer) -> ResultIsNull {
212 buf.push(1);
213 serde_json::to_writer(&mut **buf, &self)
214 .expect("failed to serialize to JSON for encoding on transmission to the database");
215 Ok(IsNull::No)
216 }
217}
218
219impl<'r> Decode<'r, Postgres> for Value {
220 fn decode(value: postgres::PgValueRef<'r>) -> Result<Self, BoxDynError> {
221 let buf = value.as_bytes()?;
222 assert_eq!(buf[0], 1, "unsupported JSONB format version {}", buf[0]);
223 serde_json::from_slice(&buf[1..]).map_err(Into::into)
224 }
225}
226
227#[cfg(feature = "time")]
228mod time_impl {
229 use super::ResultIsNull;
230 use crate::time::Time;
231 use sqlx::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef};
232 use sqlx::sqlite::{SqliteArgumentValue, SqliteTypeInfo, SqliteValueRef};
233 use sqlx::{Decode, Encode, Postgres, Sqlite, Type, encode::IsNull, error::BoxDynError};
234
235 const J2000_EPOCH_US: i64 = 946_684_800_000_000;
236
237 impl Type<Sqlite> for Time {
238 fn type_info() -> SqliteTypeInfo {
239 <i64 as Type<Sqlite>>::type_info()
240 }
241
242 fn compatible(ty: &SqliteTypeInfo) -> bool {
243 *ty == <i64 as Type<Sqlite>>::type_info()
244 || *ty == <i32 as Type<Sqlite>>::type_info()
245 || *ty == <i16 as Type<Sqlite>>::type_info()
246 || *ty == <i8 as Type<Sqlite>>::type_info()
247 }
248 }
249
250 impl<'q> Encode<'q, Sqlite> for Time {
251 fn encode_by_ref(&self, args: &mut Vec<SqliteArgumentValue<'q>>) -> ResultIsNull {
252 args.push(SqliteArgumentValue::Int64(
253 i64::try_from(self.timestamp_ns()).expect("timestamp too large"),
254 ));
255
256 Ok(IsNull::No)
257 }
258 }
259
260 impl<'r> Decode<'r, Sqlite> for Time {
261 fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
262 let value = <i64 as Decode<Sqlite>>::decode(value)?;
263 Ok(Time::from_timestamp_ns(
264 value.try_into().unwrap_or_default(),
265 ))
266 }
267 }
268
269 impl Type<Postgres> for Time {
270 fn type_info() -> PgTypeInfo {
271 PgTypeInfo::with_name("TIMESTAMPTZ")
272 }
273 fn compatible(ty: &PgTypeInfo) -> bool {
274 *ty == PgTypeInfo::with_name("TIMESTAMPTZ") || *ty == PgTypeInfo::with_name("TIMESTAMP")
275 }
276 }
277
278 impl Encode<'_, Postgres> for Time {
279 fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> ResultIsNull {
280 let us =
281 i64::try_from(self.timestamp_us()).expect("timestamp too large") - J2000_EPOCH_US;
282 Encode::<Postgres>::encode(us, buf)
283 }
284
285 fn size_hint(&self) -> usize {
286 std::mem::size_of::<i64>()
287 }
288 }
289
290 impl<'r> Decode<'r, Postgres> for Time {
291 fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
292 let us: i64 = Decode::<Postgres>::decode(value)?;
293 Ok(Time::from_timestamp_us(
294 (us + J2000_EPOCH_US).try_into().unwrap_or_default(),
295 ))
296 }
297 }
298}
299
300#[allow(clippy::module_name_repetitions)]
304#[inline]
305pub fn db_pool() -> &'static DbPool {
306 DB_POOL.get().unwrap()
307}
308
309#[allow(clippy::module_name_repetitions)]
310pub enum DbPool {
311 Sqlite(SqlitePool),
312 Postgres(PgPool),
313}
314
315#[allow(clippy::module_name_repetitions)]
316#[derive(Debug, Copy, Clone, Eq, PartialEq)]
317pub enum DbKind {
318 Sqlite,
319 Postgres,
320}
321
322impl DbPool {
323 pub async fn begin(&self) -> Result<Transaction<'_>, sqlx::Error> {
324 match self {
325 DbPool::Sqlite(p) => Ok(Transaction::Sqlite(p.begin().await?)),
326 DbPool::Postgres(p) => Ok(Transaction::Postgres(p.begin().await?)),
327 }
328 }
329 pub fn kind(&self) -> DbKind {
330 match self {
331 DbPool::Sqlite(_) => DbKind::Sqlite,
332 DbPool::Postgres(_) => DbKind::Postgres,
333 }
334 }
335 pub async fn execute(&self, q: &str) -> EResult<()> {
336 match self {
337 DbPool::Sqlite(p) => {
338 sqlx::query(q).execute(p).await?;
339 }
340 DbPool::Postgres(p) => {
341 sqlx::query(q).execute(p).await?;
342 }
343 }
344 Ok(())
345 }
346}
347
348pub enum Transaction<'c> {
349 Sqlite(sqlx::Transaction<'c, sqlx::sqlite::Sqlite>),
350 Postgres(sqlx::Transaction<'c, sqlx::postgres::Postgres>),
351}
352
353impl Transaction<'_> {
354 pub async fn commit(self) -> Result<(), sqlx::Error> {
355 match self {
356 Transaction::Sqlite(tx) => tx.commit().await,
357 Transaction::Postgres(tx) => tx.commit().await,
358 }
359 }
360 pub fn kind(&self) -> DbKind {
361 match self {
362 Transaction::Sqlite(_) => DbKind::Sqlite,
363 Transaction::Postgres(_) => DbKind::Postgres,
364 }
365 }
366 pub async fn execute(&mut self, q: &str) -> EResult<()> {
367 match self {
368 Transaction::Sqlite(p) => {
369 sqlx::query(q).execute(&mut **p).await?;
370 }
371 Transaction::Postgres(p) => {
372 sqlx::query(q).execute(&mut **p).await?;
373 }
374 }
375 Ok(())
376 }
377}
378
379#[allow(clippy::module_name_repetitions)]
382pub async fn db_init(conn: &str, pool_size: u32, timeout: Duration) -> EResult<()> {
383 DB_POOL
384 .set(create_pool(conn, pool_size, timeout).await?)
385 .map_err(|_| Error::core("unable to set DB_POOL"))?;
386 Ok(())
387}
388
389pub async fn create_pool(conn: &str, pool_size: u32, timeout: Duration) -> EResult<DbPool> {
391 if conn.starts_with("sqlite://") {
392 let opts = SqliteConnectOptions::from_str(conn)?
393 .create_if_missing(true)
394 .synchronous(sqlx::sqlite::SqliteSynchronous::Extra)
395 .busy_timeout(timeout)
396 .log_statements(log::LevelFilter::Trace)
397 .log_slow_statements(log::LevelFilter::Warn, timeout);
398 Ok(DbPool::Sqlite(
399 SqlitePoolOptions::new()
400 .max_connections(pool_size)
401 .acquire_timeout(timeout)
402 .connect_with(opts)
403 .await?,
404 ))
405 } else if conn.starts_with("postgres://") {
406 let opts = PgConnectOptions::from_str(conn)?
407 .log_statements(log::LevelFilter::Trace)
408 .log_slow_statements(log::LevelFilter::Warn, timeout);
409 Ok(DbPool::Postgres(
410 PgPoolOptions::new()
411 .max_connections(pool_size)
412 .acquire_timeout(timeout)
413 .connect_with(opts)
414 .await?,
415 ))
416 } else {
417 Err(Error::unsupported("Unsupported database kind"))
418 }
419}