poem_ext/db.rs
1//! Contains a middleware that automatically creates and manages a
2//! [`sea_orm::DatabaseTransaction`](sea_orm::DatabaseTransaction) for each
3//! incoming request. The transaction is automatically
4//! [`commit()`](sea_orm::DatabaseTransaction::commit)ed if the endpoint returns
5//! a successful response or
6//! [`rollback()`](sea_orm::DatabaseTransaction::rollback)ed in case of an
7//! error.
8//!
9//! #### Example
10//! ```no_run
11//! use poem::{web::Data, EndpointExt, Route};
12//! use poem_ext::db::{DbTransactionMiddleware, DbTxn};
13//! use poem_openapi::{payload::PlainText, OpenApi, OpenApiService};
14//! use sea_orm::DatabaseTransaction;
15//!
16//! struct Api;
17//!
18//! #[OpenApi]
19//! impl Api {
20//! #[oai(path = "/test", method = "get")]
21//! async fn test(&self, txn: Data<&DbTxn>) -> PlainText<&'static str> {
22//! let txn: &DatabaseTransaction = &txn;
23//! todo!()
24//! }
25//! }
26//!
27//! # let db_connection = todo!();
28//! let api_service = OpenApiService::new(Api, "test", "0.1.0");
29//! let app = Route::new()
30//! .nest("/", api_service)
31//! .with(DbTransactionMiddleware::new(db_connection));
32//! ```
33
34use std::{fmt::Debug, sync::Arc};
35
36use poem::{Endpoint, IntoResponse, Middleware, Response};
37use sea_orm::{DatabaseConnection, DatabaseTransaction, TransactionTrait};
38
39use crate::responses::internal_server_error;
40
41/// Param type to use in endpoints that need a database transaction.
42pub type DbTxn = Arc<DatabaseTransaction>;
43
44/// A function that checks if a response is successful.
45pub type CheckFn = Arc<dyn Fn(&Response) -> bool + Send + Sync>;
46
47/// A middleware for automatically creating and managing
48/// [`sea_orm::DatabaseTransaction`](sea_orm::DatabaseTransaction)s for incoming
49/// requests.
50pub struct DbTransactionMiddleware {
51 db: DatabaseConnection,
52 check_fn: Option<CheckFn>,
53}
54
55impl Debug for DbTransactionMiddleware {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("DbTransactionMiddleware")
58 .field("db", &self.db)
59 .finish_non_exhaustive()
60 }
61}
62
63impl DbTransactionMiddleware {
64 /// Create a new DbTransactionMiddleware.
65 pub fn new(db: DatabaseConnection) -> Self {
66 Self { db, check_fn: None }
67 }
68
69 /// Use a custom function to check if a response is successful.
70 ///
71 /// By default a response is considered successful iff it is neither a
72 /// client error (400-499) nor a server error (500-599).
73 ///
74 /// #### Example
75 /// ```no_run
76 /// use poem::{EndpointExt, Route};
77 /// use poem_ext::db::DbTransactionMiddleware;
78 ///
79 /// # let api_service: poem_openapi::OpenApiService<(), ()> = todo!();
80 /// # let db_connection = todo!();
81 /// let app = Route::new().nest("/", api_service).with(
82 /// // commit only if the response status is "200 OK", otherwise rollback the transaction
83 /// DbTransactionMiddleware::new(db_connection).with_check_fn(|response| response.is_ok()),
84 /// );
85 /// ```
86 pub fn with_check_fn<F>(self, check_fn: F) -> Self
87 where
88 F: Fn(&Response) -> bool + Send + Sync + 'static,
89 {
90 Self {
91 db: self.db,
92 check_fn: Some(Arc::new(check_fn)),
93 }
94 }
95}
96
97impl<E: Endpoint> Middleware<E> for DbTransactionMiddleware {
98 type Output = DbTransactionMwEndpoint<E>;
99
100 fn transform(&self, ep: E) -> Self::Output {
101 DbTransactionMwEndpoint {
102 inner: ep,
103 db: self.db.clone(),
104 check_fn: self.check_fn.clone(),
105 }
106 }
107}
108
109#[doc(hidden)]
110pub struct DbTransactionMwEndpoint<E> {
111 inner: E,
112 db: DatabaseConnection,
113 check_fn: Option<CheckFn>,
114}
115
116impl<E: Debug> Debug for DbTransactionMwEndpoint<E> {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 f.debug_struct("DbTransactionMwEndpoint")
119 .field("inner", &self.inner)
120 .field("db", &self.db)
121 .finish_non_exhaustive()
122 }
123}
124
125impl<E: Endpoint> Endpoint for DbTransactionMwEndpoint<E> {
126 type Output = Response;
127
128 async fn call(&self, mut req: poem::Request) -> Result<Self::Output, poem::Error> {
129 let txn = Arc::new(self.db.begin().await.map_err(internal_server_error)?);
130 req.extensions_mut().insert(txn.clone());
131 let result = self.inner.call(req).await;
132 let txn = Arc::try_unwrap(txn).map_err(|_| {
133 internal_server_error("db transaction has not been dropped in endpoint")
134 })?;
135 match result {
136 Ok(resp) => {
137 let resp = resp.into_response();
138 if self.check_fn.as_ref().map_or_else(
139 || !resp.status().is_server_error() && !resp.status().is_client_error(),
140 |check_fn| check_fn(&resp),
141 ) {
142 txn.commit().await.map_err(internal_server_error)?;
143 } else {
144 txn.rollback().await.map_err(internal_server_error)?;
145 }
146 Ok(resp)
147 }
148 Err(err) => {
149 txn.rollback().await.map_err(internal_server_error)?;
150 Err(err)
151 }
152 }
153 }
154}