1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
/// Possible errors when extracting [`Tx`] from a request.
///
/// Errors can occur at two points during the request lifecycle:
///
/// 1. The [`Tx`] extractor might fail to obtain a connection from the pool and `BEGIN` a
///    transaction. This could be due to:
///
///    - Forgetting to add the middleware: [`Error::MissingExtension`].
///    - Calling the extractor multiple times in the same request: [`Error::OverlappingExtractors`].
///    - A problem communicating with the database: [`Error::Database`].
///
/// 2. The middleware [`Layer`] might fail to commit the transaction. This could be due to a problem
///    communicating with the database, or else a logic error (e.g. unsatisfied deferred
///    constraint): [`Error::Database`].
///
/// `axum` requires that errors can be turned into responses. The [`Error`] type converts into a
/// HTTP 500 response with the error message as the response body. This may be suitable for
/// development or internal services but it's generally not advisable to return internal error
/// details to clients.
///
/// You can override the error types for both the [`Tx`] extractor and [`Layer`]:
///
/// - Override the [`Tx`]`<DB, E>` error type using the `E` generic type parameter. `E` must be
///   convertible from [`Error`] (e.g. [`Error`]`: Into<E>`).
///
/// - Override the [`Layer`] error type using [`Config::layer_error`](crate::Config::layer_error).
///   The layer error type must be convertible from `sqlx::Error` (e.g.
///   `sqlx::Error: Into<LayerError>`).
///
/// In both cases, the error type must implement `axum::response::IntoResponse`.
///
/// ```
/// use axum::{response::IntoResponse, routing::post};
///
/// enum MyError{
///     Extractor(axum_sqlx_tx::Error),
///     Layer(sqlx::Error),
/// }
///
/// impl From<axum_sqlx_tx::Error> for MyError {
///     fn from(error: axum_sqlx_tx::Error) -> Self {
///         Self::Extractor(error)
///     }
/// }
///
/// impl From<sqlx::Error> for MyError {
///     fn from(error: sqlx::Error) -> Self {
///         Self::Layer(error)
///     }
/// }
///
/// impl IntoResponse for MyError {
///     fn into_response(self) -> axum::response::Response {
///         // note that you would probably want to log the error as well
///         (http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
///     }
/// }
///
/// // Override the `Tx` error type using the second generic type parameter
/// type Tx = axum_sqlx_tx::Tx<sqlx::Sqlite, MyError>;
///
/// # async fn foo() {
/// let pool = sqlx::SqlitePool::connect("...").await.unwrap();
///
/// let (state, layer) = Tx::config(pool)
///     // Override the `Layer` error type using the `Config` API
///     .layer_error::<MyError>()
///     .setup();
/// # let app = axum::Router::new()
/// #    .route("/", post(create_user))
/// #    .layer(layer)
/// #    .with_state(state);
/// # axum::serve(todo!(), app);
/// # }
/// # async fn create_user(mut tx: Tx, /* ... */) {
/// #     /* ... */
/// # }
/// ```
///
/// [`Tx`]: crate::Tx
/// [`Layer`]: crate::Layer
#[derive(Debug, thiserror::Error)]
pub enum Error {
    /// Indicates that the [`Layer`](crate::Layer) middleware was not installed.
    #[error("required extension not registered; did you add the axum_sqlx_tx::Layer middleware?")]
    MissingExtension,

    /// Indicates that [`Tx`](crate::Tx) was extracted multiple times in a single
    /// handler/middleware.
    #[error("axum_sqlx_tx::Tx extractor used multiple times in the same handler/middleware")]
    OverlappingExtractors,

    /// A database error occurred when starting or committing the transaction.
    #[error(transparent)]
    Database {
        #[from]
        error: sqlx::Error,
    },
}

impl axum_core::response::IntoResponse for Error {
    fn into_response(self) -> axum_core::response::Response {
        (http::StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
    }
}