1use std::any::Any;
2use std::collections::HashMap;
3use std::fmt;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use chrono::{DateTime, Local, Utc};
9use futures::Stream;
10use serde_json::Value;
11use thiserror::Error;
12
13pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
14pub type NotificationStream = Pin<Box<dyn Stream<Item = Result<Notification, DbError>> + Send>>;
15
16#[derive(Clone, Debug)]
17pub enum DbValue {
18 Bool(bool),
19 BoolOpt(Option<bool>),
20 I16(i16),
21 I16Opt(Option<i16>),
22 I32(i32),
23 I32Opt(Option<i32>),
24 I64(i64),
25 I64Opt(Option<i64>),
26 Json(Value),
27 JsonOpt(Option<Value>),
28 Text(String),
29 TextOpt(Option<String>),
30 TextArray(Vec<String>),
31 TextArrayOpt(Option<Vec<String>>),
32 I32Array(Vec<i32>),
33 I64Array(Vec<i64>),
34 TimestampTz(DateTime<Utc>),
35 TimestampTzOpt(Option<DateTime<Utc>>),
36}
37
38#[derive(Clone, Debug, Default)]
39pub struct DbParams(Vec<DbValue>);
40
41impl DbParams {
42 pub fn new() -> Self {
43 Self::default()
44 }
45
46 pub fn push(&mut self, value: DbValue) {
47 self.0.push(value);
48 }
49
50 pub fn values(&self) -> &[DbValue] {
51 &self.0
52 }
53}
54
55impl From<Vec<DbValue>> for DbParams {
56 fn from(value: Vec<DbValue>) -> Self {
57 Self(value)
58 }
59}
60
61#[derive(Clone, Debug)]
62pub enum DbCell {
63 Null,
64 Bool(bool),
65 I16(i16),
66 I32(i32),
67 I64(i64),
68 Json(Value),
69 Text(String),
70 TimestampTz(DateTime<Utc>),
71}
72
73#[derive(Clone, Debug, Default)]
74pub struct DbRow {
75 cells: HashMap<String, DbCell>,
76}
77
78impl DbRow {
79 pub fn new(cells: HashMap<String, DbCell>) -> Self {
80 Self { cells }
81 }
82
83 pub fn try_get<T: FromDbCell>(&self, name: &str) -> Result<T, DbError> {
84 let cell = self.cells.get(name).ok_or_else(|| {
85 DbError::new(format!("column `{name}` was not present in query result"))
86 })?;
87 T::from_cell(name, cell)
88 }
89}
90
91pub trait FromDbCell: Sized {
92 fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError>;
93}
94
95fn type_error(name: &str, expected: &str, cell: &DbCell) -> DbError {
96 DbError::new(format!(
97 "column `{name}` could not be decoded as {expected}; actual value was {cell:?}"
98 ))
99}
100
101impl FromDbCell for bool {
102 fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
103 match cell {
104 DbCell::Bool(value) => Ok(*value),
105 _ => Err(type_error(name, "bool", cell)),
106 }
107 }
108}
109
110impl FromDbCell for i16 {
111 fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
112 match cell {
113 DbCell::I16(value) => Ok(*value),
114 _ => Err(type_error(name, "i16", cell)),
115 }
116 }
117}
118
119impl FromDbCell for i32 {
120 fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
121 match cell {
122 DbCell::I32(value) => Ok(*value),
123 _ => Err(type_error(name, "i32", cell)),
124 }
125 }
126}
127
128impl FromDbCell for i64 {
129 fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
130 match cell {
131 DbCell::I64(value) => Ok(*value),
132 _ => Err(type_error(name, "i64", cell)),
133 }
134 }
135}
136
137impl FromDbCell for String {
138 fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
139 match cell {
140 DbCell::Text(value) => Ok(value.clone()),
141 _ => Err(type_error(name, "String", cell)),
142 }
143 }
144}
145
146impl FromDbCell for Value {
147 fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
148 match cell {
149 DbCell::Json(value) => Ok(value.clone()),
150 _ => Err(type_error(name, "serde_json::Value", cell)),
151 }
152 }
153}
154
155impl FromDbCell for DateTime<Utc> {
156 fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
157 match cell {
158 DbCell::TimestampTz(value) => Ok(*value),
159 _ => Err(type_error(name, "DateTime<Utc>", cell)),
160 }
161 }
162}
163
164impl FromDbCell for DateTime<Local> {
165 fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
166 let value = DateTime::<Utc>::from_cell(name, cell)?;
167 Ok(value.with_timezone(&Local))
168 }
169}
170
171impl<T: FromDbCell> FromDbCell for Option<T> {
172 fn from_cell(name: &str, cell: &DbCell) -> Result<Self, DbError> {
173 if matches!(cell, DbCell::Null) {
174 return Ok(None);
175 }
176
177 T::from_cell(name, cell).map(Some)
178 }
179}
180
181#[derive(Clone, Debug)]
182pub struct Notification {
183 pub channel: String,
184 pub payload: String,
185}
186
187#[derive(Debug, Error, Clone)]
188#[error("{message}")]
189pub struct DbError {
190 message: String,
191 code: Option<String>,
192}
193
194impl DbError {
195 pub fn new(message: impl Into<String>) -> Self {
196 Self {
197 message: message.into(),
198 code: None,
199 }
200 }
201
202 pub fn with_code(message: impl Into<String>, code: impl Into<String>) -> Self {
203 Self {
204 message: message.into(),
205 code: Some(code.into()),
206 }
207 }
208
209 pub fn code(&self) -> Option<&str> {
210 self.code.as_deref()
211 }
212}
213
214pub trait DbExecutor: Send + Sync {
215 fn execute<'a>(&'a self, sql: &'a str, params: DbParams)
216 -> BoxFuture<'a, Result<u64, DbError>>;
217
218 fn fetch_all<'a>(
219 &'a self,
220 sql: &'a str,
221 params: DbParams,
222 ) -> BoxFuture<'a, Result<Vec<DbRow>, DbError>>;
223
224 fn fetch_optional<'a>(
225 &'a self,
226 sql: &'a str,
227 params: DbParams,
228 ) -> BoxFuture<'a, Result<Option<DbRow>, DbError>> {
229 Box::pin(async move {
230 let rows = self.fetch_all(sql, params).await?;
231 Ok(rows.into_iter().next())
232 })
233 }
234
235 fn fetch_one<'a>(
236 &'a self,
237 sql: &'a str,
238 params: DbParams,
239 ) -> BoxFuture<'a, Result<DbRow, DbError>> {
240 Box::pin(async move {
241 self.fetch_optional(sql, params).await?.ok_or_else(|| {
242 DbError::new("query returned no rows when exactly one row was expected")
243 })
244 })
245 }
246}
247
248pub trait DbExecutorArg: Send {
249 fn execute<'a>(
250 &'a mut self,
251 sql: &'a str,
252 params: DbParams,
253 ) -> BoxFuture<'a, Result<u64, DbError>>;
254
255 fn fetch_all<'a>(
256 &'a mut self,
257 sql: &'a str,
258 params: DbParams,
259 ) -> BoxFuture<'a, Result<Vec<DbRow>, DbError>>;
260
261 fn fetch_optional<'a>(
262 &'a mut self,
263 sql: &'a str,
264 params: DbParams,
265 ) -> BoxFuture<'a, Result<Option<DbRow>, DbError>>
266 where
267 Self: Send + 'a,
268 {
269 Box::pin(async move {
270 let rows = self.fetch_all(sql, params).await?;
271 Ok(rows.into_iter().next())
272 })
273 }
274
275 fn fetch_one<'a>(
276 &'a mut self,
277 sql: &'a str,
278 params: DbParams,
279 ) -> BoxFuture<'a, Result<DbRow, DbError>>
280 where
281 Self: Send + 'a,
282 {
283 Box::pin(async move {
284 self.fetch_optional(sql, params).await?.ok_or_else(|| {
285 DbError::new("query returned no rows when exactly one row was expected")
286 })
287 })
288 }
289}
290
291impl<T: DbExecutor + ?Sized> DbExecutorArg for &T {
292 fn execute<'a>(
293 &'a mut self,
294 sql: &'a str,
295 params: DbParams,
296 ) -> BoxFuture<'a, Result<u64, DbError>> {
297 DbExecutor::execute(*self, sql, params)
298 }
299
300 fn fetch_all<'a>(
301 &'a mut self,
302 sql: &'a str,
303 params: DbParams,
304 ) -> BoxFuture<'a, Result<Vec<DbRow>, DbError>> {
305 DbExecutor::fetch_all(*self, sql, params)
306 }
307}
308
309impl<T: DbExecutorArg + ?Sized> DbExecutorArg for &mut T {
310 fn execute<'a>(
311 &'a mut self,
312 sql: &'a str,
313 params: DbParams,
314 ) -> BoxFuture<'a, Result<u64, DbError>> {
315 (**self).execute(sql, params)
316 }
317
318 fn fetch_all<'a>(
319 &'a mut self,
320 sql: &'a str,
321 params: DbParams,
322 ) -> BoxFuture<'a, Result<Vec<DbRow>, DbError>> {
323 (**self).fetch_all(sql, params)
324 }
325}
326
327pub trait DatabaseDriver: DbExecutor + fmt::Debug + Any {
328 fn as_any(&self) -> &dyn Any;
329
330 fn begin<'a>(&'a self) -> BoxFuture<'a, Result<DbTransaction, DbError>>;
331
332 fn listen<'a>(
333 &'a self,
334 channel: &'a str,
335 ) -> BoxFuture<'a, Result<Option<NotificationStream>, DbError>>;
336}
337
338pub trait TransactionDriver: DbExecutor {
339 fn commit(self: Box<Self>) -> BoxFuture<'static, Result<(), DbError>>;
340}
341
342#[derive(Clone)]
343pub struct Database {
344 inner: Arc<dyn DatabaseDriver>,
345}
346
347impl Database {
348 pub fn new(driver: impl DatabaseDriver + 'static) -> Self {
349 Self {
350 inner: Arc::new(driver),
351 }
352 }
353
354 pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
355 self.inner.as_any().downcast_ref()
356 }
357
358 pub async fn begin(&self) -> Result<DbTransaction, DbError> {
359 self.inner.begin().await
360 }
361
362 pub async fn listen(&self, channel: &str) -> Result<Option<NotificationStream>, DbError> {
363 self.inner.listen(channel).await
364 }
365}
366
367impl From<&Database> for Database {
368 fn from(database: &Database) -> Self {
369 database.clone()
370 }
371}
372
373impl fmt::Debug for Database {
374 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
375 f.debug_struct("Database").finish_non_exhaustive()
376 }
377}
378
379impl DbExecutor for Database {
380 fn execute<'a>(
381 &'a self,
382 sql: &'a str,
383 params: DbParams,
384 ) -> BoxFuture<'a, Result<u64, DbError>> {
385 self.inner.execute(sql, params)
386 }
387
388 fn fetch_all<'a>(
389 &'a self,
390 sql: &'a str,
391 params: DbParams,
392 ) -> BoxFuture<'a, Result<Vec<DbRow>, DbError>> {
393 self.inner.fetch_all(sql, params)
394 }
395}
396
397pub struct DbTransaction {
398 inner: Box<dyn TransactionDriver>,
399}
400
401impl DbTransaction {
402 pub fn new(inner: Box<dyn TransactionDriver>) -> Self {
403 Self { inner }
404 }
405
406 pub async fn commit(self) -> Result<(), DbError> {
407 self.inner.commit().await
408 }
409}
410
411impl DbExecutor for DbTransaction {
412 fn execute<'a>(
413 &'a self,
414 sql: &'a str,
415 params: DbParams,
416 ) -> BoxFuture<'a, Result<u64, DbError>> {
417 self.inner.execute(sql, params)
418 }
419
420 fn fetch_all<'a>(
421 &'a self,
422 sql: &'a str,
423 params: DbParams,
424 ) -> BoxFuture<'a, Result<Vec<DbRow>, DbError>> {
425 self.inner.fetch_all(sql, params)
426 }
427}
428
429pub mod row_mapping {
430 use super::*;
431
432 pub fn cells(values: impl IntoIterator<Item = (impl Into<String>, DbCell)>) -> DbRow {
433 DbRow::new(
434 values
435 .into_iter()
436 .map(|(name, value)| (name.into(), value))
437 .collect(),
438 )
439 }
440}
441
442#[cfg(feature = "driver-sqlx")]
443pub mod sqlx;
444
445#[cfg(feature = "driver-tokio-postgres")]
446pub mod tokio_postgres;