poem_ext/
db.rs

1//! Contains a middleware that automatically creates and manages a
2//! [`sea_orm::DatabaseTransaction`](sea_orm::DatabaseTransaction) for each
3//! incoming request. The transaction is automatically
4//! [`commit()`](sea_orm::DatabaseTransaction::commit)ed if the endpoint returns
5//! a successful response or
6//! [`rollback()`](sea_orm::DatabaseTransaction::rollback)ed in case of an
7//! error.
8//!
9//! #### Example
10//! ```no_run
11//! use poem::{web::Data, EndpointExt, Route};
12//! use poem_ext::db::{DbTransactionMiddleware, DbTxn};
13//! use poem_openapi::{payload::PlainText, OpenApi, OpenApiService};
14//! use sea_orm::DatabaseTransaction;
15//!
16//! struct Api;
17//!
18//! #[OpenApi]
19//! impl Api {
20//!     #[oai(path = "/test", method = "get")]
21//!     async fn test(&self, txn: Data<&DbTxn>) -> PlainText<&'static str> {
22//!         let txn: &DatabaseTransaction = &txn;
23//!         todo!()
24//!     }
25//! }
26//!
27//! # let db_connection = todo!();
28//! let api_service = OpenApiService::new(Api, "test", "0.1.0");
29//! let app = Route::new()
30//!     .nest("/", api_service)
31//!     .with(DbTransactionMiddleware::new(db_connection));
32//! ```
33
34use std::{fmt::Debug, sync::Arc};
35
36use poem::{Endpoint, IntoResponse, Middleware, Response};
37use sea_orm::{DatabaseConnection, DatabaseTransaction, TransactionTrait};
38
39use crate::responses::internal_server_error;
40
41/// Param type to use in endpoints that need a database transaction.
42pub type DbTxn = Arc<DatabaseTransaction>;
43
44/// A function that checks if a response is successful.
45pub type CheckFn = Arc<dyn Fn(&Response) -> bool + Send + Sync>;
46
47/// A middleware for automatically creating and managing
48/// [`sea_orm::DatabaseTransaction`](sea_orm::DatabaseTransaction)s for incoming
49/// requests.
50pub struct DbTransactionMiddleware {
51    db: DatabaseConnection,
52    check_fn: Option<CheckFn>,
53}
54
55impl Debug for DbTransactionMiddleware {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("DbTransactionMiddleware")
58            .field("db", &self.db)
59            .finish_non_exhaustive()
60    }
61}
62
63impl DbTransactionMiddleware {
64    /// Create a new DbTransactionMiddleware.
65    pub fn new(db: DatabaseConnection) -> Self {
66        Self { db, check_fn: None }
67    }
68
69    /// Use a custom function to check if a response is successful.
70    ///
71    /// By default a response is considered successful iff it is neither a
72    /// client error (400-499) nor a server error (500-599).
73    ///
74    /// #### Example
75    /// ```no_run
76    /// use poem::{EndpointExt, Route};
77    /// use poem_ext::db::DbTransactionMiddleware;
78    ///
79    /// # let api_service: poem_openapi::OpenApiService<(), ()> = todo!();
80    /// # let db_connection = todo!();
81    /// let app = Route::new().nest("/", api_service).with(
82    ///     // commit only if the response status is "200 OK", otherwise rollback the transaction
83    ///     DbTransactionMiddleware::new(db_connection).with_check_fn(|response| response.is_ok()),
84    /// );
85    /// ```
86    pub fn with_check_fn<F>(self, check_fn: F) -> Self
87    where
88        F: Fn(&Response) -> bool + Send + Sync + 'static,
89    {
90        Self {
91            db: self.db,
92            check_fn: Some(Arc::new(check_fn)),
93        }
94    }
95}
96
97impl<E: Endpoint> Middleware<E> for DbTransactionMiddleware {
98    type Output = DbTransactionMwEndpoint<E>;
99
100    fn transform(&self, ep: E) -> Self::Output {
101        DbTransactionMwEndpoint {
102            inner: ep,
103            db: self.db.clone(),
104            check_fn: self.check_fn.clone(),
105        }
106    }
107}
108
109#[doc(hidden)]
110pub struct DbTransactionMwEndpoint<E> {
111    inner: E,
112    db: DatabaseConnection,
113    check_fn: Option<CheckFn>,
114}
115
116impl<E: Debug> Debug for DbTransactionMwEndpoint<E> {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        f.debug_struct("DbTransactionMwEndpoint")
119            .field("inner", &self.inner)
120            .field("db", &self.db)
121            .finish_non_exhaustive()
122    }
123}
124
125impl<E: Endpoint> Endpoint for DbTransactionMwEndpoint<E> {
126    type Output = Response;
127
128    async fn call(&self, mut req: poem::Request) -> Result<Self::Output, poem::Error> {
129        let txn = Arc::new(self.db.begin().await.map_err(internal_server_error)?);
130        req.extensions_mut().insert(txn.clone());
131        let result = self.inner.call(req).await;
132        let txn = Arc::try_unwrap(txn).map_err(|_| {
133            internal_server_error("db transaction has not been dropped in endpoint")
134        })?;
135        match result {
136            Ok(resp) => {
137                let resp = resp.into_response();
138                if self.check_fn.as_ref().map_or_else(
139                    || !resp.status().is_server_error() && !resp.status().is_client_error(),
140                    |check_fn| check_fn(&resp),
141                ) {
142                    txn.commit().await.map_err(internal_server_error)?;
143                } else {
144                    txn.rollback().await.map_err(internal_server_error)?;
145                }
146                Ok(resp)
147            }
148            Err(err) => {
149                txn.rollback().await.map_err(internal_server_error)?;
150                Err(err)
151            }
152        }
153    }
154}