Skip to main content

oxide_framework_core/
extract.rs

1use axum::extract::FromRequestParts;
2use axum::http::StatusCode;
3use axum::http::request::Parts;
4use axum::response::{IntoResponse, Response};
5use std::sync::Arc;
6
7use crate::config::AppConfig;
8use crate::state::AppState;
9
10/// Extractor for the application configuration.
11///
12/// # Example
13///
14/// ```rust,ignore
15/// use oxide_framework_core::{ApiResponse, Config};
16///
17/// async fn handler(Config(cfg): Config) -> ApiResponse<String> {
18///     ApiResponse::ok(format!("Welcome to {}", cfg.app_name))
19/// }
20/// ```
21pub struct Config(pub Arc<AppConfig>);
22
23impl<S: Send + Sync> FromRequestParts<S> for Config {
24    type Rejection = StateNotFound;
25
26    async fn from_request_parts(
27        parts: &mut Parts,
28        _state: &S,
29    ) -> Result<Self, Self::Rejection> {
30        parts
31            .extensions
32            .get::<AppState>()
33            .map(|s| Config(s.config.clone()))
34            .ok_or(StateNotFound("AppConfig"))
35    }
36}
37
38/// Extractor for user-provided state registered via [`App::state()`](crate::App::state).
39///
40/// Returns `Arc<T>` so the data can be shared cheaply across handlers.
41///
42/// # Example
43///
44/// ```rust,ignore
45/// use oxide_framework_core::{ApiResponse, Data};
46/// use std::sync::Arc;
47///
48/// struct DbPool { /* ... */ }
49///
50/// async fn handler(Data(pool): Data<DbPool>) -> ApiResponse<String> {
51///     ApiResponse::ok("connected".into())
52/// }
53/// ```
54pub struct Data<T: Send + Sync + 'static>(pub Arc<T>);
55
56impl<S: Send + Sync, T: Send + Sync + 'static> FromRequestParts<S> for Data<T> {
57    type Rejection = StateNotFound;
58
59    async fn from_request_parts(
60        parts: &mut Parts,
61        _state: &S,
62    ) -> Result<Self, Self::Rejection> {
63        let app_state = parts
64            .extensions
65            .get::<AppState>()
66            .ok_or(StateNotFound("AppState"))?;
67
68        app_state
69            .get::<T>()
70            .map(Data)
71            .ok_or(StateNotFound(std::any::type_name::<T>()))
72    }
73}
74
75/// Ergonomic alias for [`Data<T>`] — intended for use inside controllers.
76///
77/// Semantically identical to `Data`, but reads more naturally in constructor
78/// injection code:
79///
80/// ```rust,ignore
81/// fn new(state: &AppState) -> Self {
82///     Self {
83///         pool: state.get::<DbPool>().expect("DbPool missing").as_ref().clone(),
84///     }
85/// }
86///
87/// #[get("/")]
88/// async fn index(&self, Inject(cache): Inject<Cache>) -> ApiResponse<String> {
89///     // ...
90/// }
91/// ```
92pub struct Inject<T: Send + Sync + 'static>(pub Arc<T>);
93
94impl<S: Send + Sync, T: Send + Sync + 'static> FromRequestParts<S> for Inject<T> {
95    type Rejection = StateNotFound;
96
97    async fn from_request_parts(
98        parts: &mut Parts,
99        _state: &S,
100    ) -> Result<Self, Self::Rejection> {
101        let app_state = parts
102            .extensions
103            .get::<AppState>()
104            .ok_or(StateNotFound("AppState"))?;
105
106        app_state
107            .get::<T>()
108            .map(Inject)
109            .ok_or(StateNotFound(std::any::type_name::<T>()))
110    }
111}
112
113/// Rejection returned when requested state is missing.
114#[derive(Debug)]
115pub struct StateNotFound(pub &'static str);
116
117impl std::fmt::Display for StateNotFound {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        write!(f, "state not found: {}", self.0)
120    }
121}
122
123impl std::error::Error for StateNotFound {}
124
125impl IntoResponse for StateNotFound {
126    fn into_response(self) -> Response {
127        (
128            StatusCode::INTERNAL_SERVER_ERROR,
129            format!("internal error: missing state ({})", self.0),
130        )
131            .into_response()
132    }
133}
134
135/// Extractor for request-scoped dependencies.
136///
137/// If a dependency `T` was injected into the current request (e.g. via `App::scoped_state`),
138/// this extractor will retrieve it. Otherwise, it fails with a 500 Internal Server Error.
139pub struct Scoped<T>(pub T);
140
141impl<S, T> axum::extract::FromRequestParts<S> for Scoped<T>
142where
143    S: Send + Sync,
144    T: Clone + Send + Sync + 'static,
145{
146    type Rejection = axum::response::Response;
147
148    async fn from_request_parts(
149        parts: &mut axum::http::request::Parts,
150        _state: &S,
151    ) -> Result<Self, Self::Rejection> {
152        parts
153            .extensions
154            .get::<T>()
155            .cloned()
156            .map(Scoped)
157            .ok_or_else(|| {
158                crate::ApiResponse::<()>::error(
159                    axum::http::StatusCode::INTERNAL_SERVER_ERROR,
160                    format!(
161                        "Missing scoped dependency: {}",
162                        std::any::type_name::<T>()
163                    )
164                ).into_response()
165            })
166    }
167}
168
169