Skip to main content

graphile_worker_database/
lib.rs

1use std::any::Any;
2use std::collections::HashMap;
3use std::fmt;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use chrono::{DateTime, Local, Utc};
9use futures::Stream;
10use serde_json::Value;
11use thiserror::Error;
12
13pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
14pub type NotificationStream = Pin<Box<dyn Stream<Item = Result<Notification, DbError>> + Send>>;
15
16#[derive(Clone, Debug)]
17pub enum DbValue {
18    Bool(bool),
19    BoolOpt(Option<bool>),
20    I16(i16),
21    I16Opt(Option<i16>),
22    I32(i32),
23    I32Opt(Option<i32>),
24    I64(i64),
25    I64Opt(Option<i64>),
26    Json(Value),
27    JsonOpt(Option<Value>),
28    Text(String),
29    TextOpt(Option<String>),
30    TextArray(Vec<String>),
31    TextArrayOpt(Option<Vec<String>>),
32    I32Array(Vec<i32>),
33    I64Array(Vec<i64>),
34    TimestampTz(DateTime<Utc>),
35    TimestampTzOpt(Option<DateTime<Utc>>),
36}
37
38#[derive(Clone, Debug, Default)]
39pub struct DbParams(Vec<DbValue>);
40
41impl DbParams {
42    pub fn new() -> Self {
43        Self::default()
44    }
45
46    pub fn push(&mut self, value: DbValue) {
47        self.0.push(value);
48    }
49
50    pub fn values(&self) -> &[DbValue] {
51        &self.0
52    }
53}
54
55impl From<Vec<DbValue>> for DbParams {
56    fn from(value: Vec<DbValue>) -> Self {
57        Self(value)
58    }
59}
60
61#[derive(Clone, Debug)]
62pub enum DbCell {
63    Null,
64    Bool(bool),
65    I16(i16),
66    I32(i32),
67    I64(i64),
68    Json(Value),
69    Text(String),
70    TimestampTz(DateTime<Utc>),
71}
72
73#[derive(Clone, Debug, Default)]
74pub struct DbRow {
75    cells: HashMap<String, DbCell>,
76}
77
78impl DbRow {
79    pub fn new(cells: HashMap<String, DbCell>) -> Self {
80        Self { cells }
81    }
82
83    pub fn try_get<T: FromDbCell>(&self, name: &str) -> Result<T, DbError> {
84        let cell = self.cells.get(name).ok_or_else(|| {
85            DbError::new(format!("column `{name}` was not present in query result"))
86        })?;
87        T::from_cell(name, cell)
88    }
89}
90
91pub trait FromDbCell: Sized {
92    fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError>;
93}
94
95fn type_error(name: &str, expected: &str, cell: &DbCell) -> DbError {
96    DbError::new(format!(
97        "column `{name}` could not be decoded as {expected}; actual value was {cell:?}"
98    ))
99}
100
101impl FromDbCell for bool {
102    fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
103        match cell {
104            DbCell::Bool(value) => Ok(*value),
105            _ => Err(type_error(name, "bool", cell)),
106        }
107    }
108}
109
110impl FromDbCell for i16 {
111    fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
112        match cell {
113            DbCell::I16(value) => Ok(*value),
114            _ => Err(type_error(name, "i16", cell)),
115        }
116    }
117}
118
119impl FromDbCell for i32 {
120    fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
121        match cell {
122            DbCell::I32(value) => Ok(*value),
123            _ => Err(type_error(name, "i32", cell)),
124        }
125    }
126}
127
128impl FromDbCell for i64 {
129    fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
130        match cell {
131            DbCell::I64(value) => Ok(*value),
132            _ => Err(type_error(name, "i64", cell)),
133        }
134    }
135}
136
137impl FromDbCell for String {
138    fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
139        match cell {
140            DbCell::Text(value) => Ok(value.clone()),
141            _ => Err(type_error(name, "String", cell)),
142        }
143    }
144}
145
146impl FromDbCell for Value {
147    fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
148        match cell {
149            DbCell::Json(value) => Ok(value.clone()),
150            _ => Err(type_error(name, "serde_json::Value", cell)),
151        }
152    }
153}
154
155impl FromDbCell for DateTime<Utc> {
156    fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
157        match cell {
158            DbCell::TimestampTz(value) => Ok(*value),
159            _ => Err(type_error(name, "DateTime<Utc>", cell)),
160        }
161    }
162}
163
164impl FromDbCell for DateTime<Local> {
165    fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
166        let value = DateTime::<Utc>::from_cell(name, cell)?;
167        Ok(value.with_timezone(&Local))
168    }
169}
170
171impl<T: FromDbCell> FromDbCell for Option<T> {
172    fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
173        if matches!(cell, DbCell::Null) {
174            return Ok(None);
175        }
176
177        T::from_cell(name, cell).map(Some)
178    }
179}
180
181#[derive(Clone, Debug)]
182pub struct Notification {
183    pub channel: String,
184    pub payload: String,
185}
186
187#[derive(Debug, Error, Clone)]
188#[error("{message}")]
189pub struct DbError {
190    message: String,
191    code: Option<String>,
192}
193
194impl DbError {
195    pub fn new(message: impl Into<String>) -> Self {
196        Self {
197            message: message.into(),
198            code: None,
199        }
200    }
201
202    pub fn with_code(message: impl Into<String>, code: impl Into<String>) -> Self {
203        Self {
204            message: message.into(),
205            code: Some(code.into()),
206        }
207    }
208
209    pub fn code(&self) -> Option<&str> {
210        self.code.as_deref()
211    }
212}
213
214pub trait DbExecutor: Send + Sync {
215    fn execute<'a>(&'a self, sql: &'a str, params: DbParams)
216        -> BoxFuture<'a, Result<u64, DbError>>;
217
218    fn fetch_all<'a>(
219        &'a self,
220        sql: &'a str,
221        params: DbParams,
222    ) -> BoxFuture<'a, Result<Vec<DbRow>, DbError>>;
223
224    fn fetch_optional<'a>(
225        &'a self,
226        sql: &'a str,
227        params: DbParams,
228    ) -> BoxFuture<'a, Result<Option<DbRow>, DbError>> {
229        Box::pin(async move {
230            let rows = self.fetch_all(sql, params).await?;
231            Ok(rows.into_iter().next())
232        })
233    }
234
235    fn fetch_one<'a>(
236        &'a self,
237        sql: &'a str,
238        params: DbParams,
239    ) -> BoxFuture<'a, Result<DbRow, DbError>> {
240        Box::pin(async move {
241            self.fetch_optional(sql, params).await?.ok_or_else(|| {
242                DbError::new("query returned no rows when exactly one row was expected")
243            })
244        })
245    }
246}
247
248pub trait DbExecutorArg: Send {
249    fn execute<'a>(
250        &'a mut self,
251        sql: &'a str,
252        params: DbParams,
253    ) -> BoxFuture<'a, Result<u64, DbError>>;
254
255    fn fetch_all<'a>(
256        &'a mut self,
257        sql: &'a str,
258        params: DbParams,
259    ) -> BoxFuture<'a, Result<Vec<DbRow>, DbError>>;
260
261    fn fetch_optional<'a>(
262        &'a mut self,
263        sql: &'a str,
264        params: DbParams,
265    ) -> BoxFuture<'a, Result<Option<DbRow>, DbError>>
266    where
267        Self: Send + 'a,
268    {
269        Box::pin(async move {
270            let rows = self.fetch_all(sql, params).await?;
271            Ok(rows.into_iter().next())
272        })
273    }
274
275    fn fetch_one<'a>(
276        &'a mut self,
277        sql: &'a str,
278        params: DbParams,
279    ) -> BoxFuture<'a, Result<DbRow, DbError>>
280    where
281        Self: Send + 'a,
282    {
283        Box::pin(async move {
284            self.fetch_optional(sql, params).await?.ok_or_else(|| {
285                DbError::new("query returned no rows when exactly one row was expected")
286            })
287        })
288    }
289}
290
291impl<T: DbExecutor + ?Sized> DbExecutorArg for &T {
292    fn execute<'a>(
293        &'a mut self,
294        sql: &'a str,
295        params: DbParams,
296    ) -> BoxFuture<'a, Result<u64, DbError>> {
297        DbExecutor::execute(*self, sql, params)
298    }
299
300    fn fetch_all<'a>(
301        &'a mut self,
302        sql: &'a str,
303        params: DbParams,
304    ) -> BoxFuture<'a, Result<Vec<DbRow>, DbError>> {
305        DbExecutor::fetch_all(*self, sql, params)
306    }
307}
308
309impl<T: DbExecutorArg + ?Sized> DbExecutorArg for &mut T {
310    fn execute<'a>(
311        &'a mut self,
312        sql: &'a str,
313        params: DbParams,
314    ) -> BoxFuture<'a, Result<u64, DbError>> {
315        (**self).execute(sql, params)
316    }
317
318    fn fetch_all<'a>(
319        &'a mut self,
320        sql: &'a str,
321        params: DbParams,
322    ) -> BoxFuture<'a, Result<Vec<DbRow>, DbError>> {
323        (**self).fetch_all(sql, params)
324    }
325}
326
327pub trait DatabaseDriver: DbExecutor + fmt::Debug + Any {
328    fn as_any(&self) -> &dyn Any;
329
330    fn begin<'a>(&'a self) -> BoxFuture<'a, Result<DbTransaction, DbError>>;
331
332    fn listen<'a>(
333        &'a self,
334        channel: &'a str,
335    ) -> BoxFuture<'a, Result<Option<NotificationStream>, DbError>>;
336}
337
338pub trait TransactionDriver: DbExecutor {
339    fn commit(self: Box<Self>) -> BoxFuture<'static, Result<(), DbError>>;
340}
341
342#[derive(Clone)]
343pub struct Database {
344    inner: Arc<dyn DatabaseDriver>,
345}
346
347impl Database {
348    pub fn new(driver: impl DatabaseDriver + 'static) -> Self {
349        Self {
350            inner: Arc::new(driver),
351        }
352    }
353
354    pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
355        self.inner.as_any().downcast_ref()
356    }
357
358    pub async fn begin(&self) -> Result<DbTransaction, DbError> {
359        self.inner.begin().await
360    }
361
362    pub async fn listen(&self, channel: &str) -> Result<Option<NotificationStream>, DbError> {
363        self.inner.listen(channel).await
364    }
365}
366
367impl From<&Database> for Database {
368    fn from(database: &Database) -> Self {
369        database.clone()
370    }
371}
372
373impl fmt::Debug for Database {
374    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
375        f.debug_struct("Database").finish_non_exhaustive()
376    }
377}
378
379impl DbExecutor for Database {
380    fn execute<'a>(
381        &'a self,
382        sql: &'a str,
383        params: DbParams,
384    ) -> BoxFuture<'a, Result<u64, DbError>> {
385        self.inner.execute(sql, params)
386    }
387
388    fn fetch_all<'a>(
389        &'a self,
390        sql: &'a str,
391        params: DbParams,
392    ) -> BoxFuture<'a, Result<Vec<DbRow>, DbError>> {
393        self.inner.fetch_all(sql, params)
394    }
395}
396
397pub struct DbTransaction {
398    inner: Box<dyn TransactionDriver>,
399}
400
401impl DbTransaction {
402    pub fn new(inner: Box<dyn TransactionDriver>) -> Self {
403        Self { inner }
404    }
405
406    pub async fn commit(self) -> Result<(), DbError> {
407        self.inner.commit().await
408    }
409}
410
411impl DbExecutor for DbTransaction {
412    fn execute<'a>(
413        &'a self,
414        sql: &'a str,
415        params: DbParams,
416    ) -> BoxFuture<'a, Result<u64, DbError>> {
417        self.inner.execute(sql, params)
418    }
419
420    fn fetch_all<'a>(
421        &'a self,
422        sql: &'a str,
423        params: DbParams,
424    ) -> BoxFuture<'a, Result<Vec<DbRow>, DbError>> {
425        self.inner.fetch_all(sql, params)
426    }
427}
428
429pub mod row_mapping {
430    use super::*;
431
432    pub fn cells(values: impl IntoIterator<Item = (impl Into<String>, DbCell)>) -> DbRow {
433        DbRow::new(
434            values
435                .into_iter()
436                .map(|(name, value)| (name.into(), value))
437                .collect(),
438        )
439    }
440}
441
442#[cfg(feature = "driver-sqlx")]
443pub mod sqlx;
444
445#[cfg(feature = "driver-tokio-postgres")]
446pub mod tokio_postgres;