opentalk_database/
lib.rs

1// SPDX-FileCopyrightText: OpenTalk GmbH <mail@opentalk.eu>
2//
3// SPDX-License-Identifier: EUPL-1.2
4
5//! OpenTalk Database connector, interface and connection handling
6
7use diesel::{
8    pg::Pg,
9    query_builder::{AstPass, Query, QueryFragment, QueryId},
10    sql_types::BigInt,
11    QueryResult,
12};
13use diesel_async::{
14    methods::LoadQuery,
15    pooled_connection::deadpool::{BuildError, Object, PoolError},
16    AsyncConnection, AsyncPgConnection,
17};
18use snafu::Snafu;
19
20mod db;
21mod metrics;
22pub mod query_helper;
23
24pub use db::Db;
25pub use metrics::DatabaseMetrics;
26
27/// Pooled connection alias
28pub type DbConnection = metrics::MetricsConnection<Object<AsyncPgConnection>>;
29
30/// Result type using [`DatabaseError`] as a default Error
31pub type Result<T, E = DatabaseError> = std::result::Result<T, E>;
32
33/// Error types for the database abstraction
34#[derive(Debug, Snafu)]
35pub enum DatabaseError {
36    #[snafu(display("Database Error: `{message}`",))]
37    Custom { message: String },
38
39    #[snafu(display("Diesel Error: `{source}`",))]
40    DieselError { source: diesel::result::Error },
41
42    #[snafu(display("A requested resource could not be found"))]
43    NotFound,
44
45    #[snafu(display("Deadpool build error: `{source}`",), context(false))]
46    DeadpoolBuildError { source: BuildError },
47
48    #[snafu(display("Deadpool error: `{source}`",))]
49    DeadpoolError { source: PoolError },
50
51    #[snafu(context(false))]
52    UrlParseError { source: url::ParseError },
53}
54
55impl DatabaseError {
56    /// Returns `true` if the database error is [`NotFound`].
57    ///
58    /// [`NotFound`]: DatabaseError::NotFound
59    #[must_use]
60    pub fn is_not_found(&self) -> bool {
61        matches!(self, Self::NotFound)
62    }
63}
64
65impl From<diesel::result::Error> for DatabaseError {
66    fn from(err: diesel::result::Error) -> Self {
67        match err {
68            diesel::result::Error::NotFound => Self::NotFound,
69            source => Self::DieselError { source },
70        }
71    }
72}
73
74pub trait OptionalExt<T, E> {
75    fn optional(self) -> Result<Option<T>, E>;
76}
77
78impl<T> OptionalExt<T, DatabaseError> for Result<T, DatabaseError> {
79    fn optional(self) -> Result<Option<T>, DatabaseError> {
80        match self {
81            Ok(t) => Ok(Some(t)),
82            Err(DatabaseError::NotFound) => Ok(None),
83            Err(e) => Err(e),
84        }
85    }
86}
87
88/// Pagination trait for diesel
89pub trait Paginate: Sized {
90    fn paginate(self, page: i64) -> Paginated<Self>;
91    fn paginate_by(self, per_page: i64, page: i64) -> Paginated<Self>;
92}
93
94impl<T> Paginate for T {
95    fn paginate(self, page: i64) -> Paginated<Self> {
96        Paginated {
97            query: self,
98            per_page: DEFAULT_PER_PAGE,
99            offset: (page - 1) * DEFAULT_PER_PAGE,
100        }
101    }
102    fn paginate_by(self, per_page: i64, page: i64) -> Paginated<Self> {
103        Paginated {
104            query: self,
105            per_page,
106            offset: (page - 1) * per_page,
107        }
108    }
109}
110
111const DEFAULT_PER_PAGE: i64 = 10;
112
113/// Paginated diesel database response
114#[derive(Debug, Clone, Copy, QueryId)]
115pub struct Paginated<T> {
116    query: T,
117    per_page: i64,
118    // We need to store the offset instead of the page due to
119    // lifetime requirements in `QueryFragment::walk_ast(...)`.
120    offset: i64,
121}
122
123impl<T> Paginated<T> {
124    pub fn per_page(self, per_page: i64) -> Self {
125        Paginated { per_page, ..self }
126    }
127
128    pub async fn load_and_count<'query, U, Conn>(
129        self,
130        conn: &mut Conn,
131    ) -> QueryResult<(Vec<U>, i64)>
132    where
133        Self: LoadQuery<'query, Conn, (U, i64)>,
134        Conn: AsyncConnection,
135        U: Send + 'static,
136        T: 'query,
137    {
138        let results: Vec<(U, i64)> = {
139            // When `diesel_async::RunQueryDsl` is imported globally, the call
140            // to `results.first()` below will cause compiler errors because the
141            // compiler mistakes it for `diesel_async::RunQueryDsl::first(…)`
142            // and fails finding a trait implementation of `results` that
143            // matches, so we restrict the import scope.
144            use diesel_async::RunQueryDsl;
145            self.load::<(U, i64)>(conn).await?
146        };
147        let total = results.first().map(|x: &(U, i64)| x.1).unwrap_or(0);
148        let records = results.into_iter().map(|x| x.0).collect();
149        Ok((records, total))
150    }
151}
152
153impl<T: Query> Query for Paginated<T> {
154    type SqlType = (T::SqlType, BigInt);
155}
156
157impl<T> QueryFragment<Pg> for Paginated<T>
158where
159    T: QueryFragment<Pg>,
160{
161    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
162        out.push_sql("SELECT *, COUNT(*) OVER () FROM (");
163        self.query.walk_ast(out.reborrow())?;
164        out.push_sql(") t LIMIT ");
165        out.push_bind_param::<BigInt, _>(&self.per_page)?;
166        out.push_sql(" OFFSET ");
167        out.push_bind_param::<BigInt, _>(&self.offset)?;
168        Ok(())
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use crate::DatabaseError;
175
176    #[test]
177    fn test_database_error_from_implementation() {
178        // The `diesel::result::Error::NotFound` should also be a
179        // `DatabaseError::NotFound` and never be a `DatabaseError::DieselError`
180        // All other cases should be `DatabaseError::DieselError`
181        assert!(matches!(
182            Into::<DatabaseError>::into(diesel::result::Error::NotFound),
183            DatabaseError::NotFound,
184        ));
185        assert!(!matches!(
186            Into::<DatabaseError>::into(diesel::result::Error::NotFound),
187            DatabaseError::DieselError {
188                source: diesel::result::Error::NotFound
189            },
190        ));
191        assert!(matches!(
192            Into::<DatabaseError>::into(diesel::result::Error::NotInTransaction),
193            DatabaseError::DieselError {
194                source: diesel::result::Error::NotInTransaction
195            },
196        ));
197    }
198}