Skip to main content

scalar_axum/
lib.rs

1use axum::{
2    extract::{FromRef, Path, Request, State},
3    http::{self, StatusCode},
4    middleware::Next,
5    response::{IntoResponse, Response},
6    Extension, Json, Router,
7};
8use scalar_cms::{
9    db::{Credentials, DatabaseFactory, User},
10    validations::{Valid, ValidationError},
11    DatabaseConnection, Document, Item, Schema,
12};
13use serde::{de::DeserializeOwned, Serialize};
14
15#[cfg(feature = "img")]
16pub mod img;
17
18pub struct ValidationFailiure(pub ValidationError);
19
20impl IntoResponse for ValidationFailiure {
21    fn into_response(self) -> axum::response::Response {
22        let mut response = Json(self.0).into_response();
23        *response.status_mut() = StatusCode::NOT_ACCEPTABLE;
24        response
25    }
26}
27
28#[cfg(feature = "img")]
29#[doc(hidden)]
30pub fn add_image_routes__<S: Clone + Send + Sync + 'static>(router: Router<S>) -> Router<S>
31where
32    scalar_img::WrappedBucket: FromRef<S>,
33{
34    use axum::extract::DefaultBodyLimit;
35    use img::{list, upload_file, upload_image};
36
37    let merge = Router::new()
38        .route(
39            "/images/upload",
40            axum::routing::put(upload_image).layer(DefaultBodyLimit::max(25_000_000)),
41        )
42        .route(
43            "/files/upload",
44            axum::routing::put(upload_file).layer(DefaultBodyLimit::disable()),
45        )
46        .route("/images/list", axum::routing::get(list));
47
48    router.merge(merge)
49}
50
51#[cfg(not(feature = "img"))]
52#[doc(hidden)]
53pub fn add_image_routes__<S: Clone + Send + Sync + 'static>(router: Router<S>) -> Router<S> {
54    router
55}
56
57#[macro_export]
58#[doc(hidden)]
59macro_rules! crud_routes__ {
60    ($router:ident, $db:ty, $doc:ty) => {
61        let path = format!("/docs/{}", <$doc>::identifier());
62        let drafts_path = format!("{path}/drafts/{{id}}");
63        $router = $router
64            .route(&path, ::axum::routing::get(::scalar_axum::get_all_docs::<$doc, $db>))
65            .route(&format!("{path}/{{id}}"), ::axum::routing::get(::scalar_axum::get_doc_by_id::<$doc, $db>))
66            .route(&drafts_path, ::axum::routing::put(::scalar_axum::update_draft::<$doc, $db>))
67            .route(&format!("{path}/schema"), ::axum::routing::get(::scalar_axum::get_schema::<$doc>));
68    };
69
70    ($router:ident, $db:ty, $($doc:ty),+) => {
71        $(::scalar_axum::crud_routes__!($router, $db, $doc);)*
72    };
73}
74
75#[macro_export]
76#[doc(hidden)]
77macro_rules! publish_routes__ {
78    ($router:ident, $db:ty, $doc:ty) => {
79        let path = format!("/docs/{}", <$doc>::identifier());
80        $router = $router
81            .route(&format!("{path}/{{id}}/publish"), ::axum::routing::post(::scalar_axum::publish_doc::<$doc, $db>));
82    };
83
84    ($router:ident, $db:ty, $($doc:ty),+) => {
85        $(::scalar_axum::publish_routes__!($router, $db, $doc);)*
86    };
87}
88
89#[macro_export]
90#[doc(hidden)]
91macro_rules! validate_routes__ {
92    ($router:ident, $doc:ty) => {
93        $router = $router
94            .route(&format!("/docs/{}/validate", <$doc>::identifier()), ::axum::routing::post(::scalar_axum::validate::<$doc>));
95    };
96
97    ($router:ident, $($doc:ty),+) => {
98        $(::scalar_axum::validate_routes__!($router, $doc);)*
99    };
100}
101
102#[macro_export]
103macro_rules! generate_routes {
104    ({$app_state:ty}, $db_instance:ident: $db:ty, [$($doc:ty),+]) => {
105        {
106            let mut router = ::axum::Router::<$app_state>::new();
107            ::scalar_axum::crud_routes__!(router, $db, $($doc),+);
108            ::scalar_axum::publish_routes__!(router, $db, $($doc),+);
109            async fn get_docs() -> ::axum::Json<Vec<::scalar_cms::DocInfo>> {
110                ::axum::Json(vec![
111                    $(::scalar_cms::DocInfo {
112                        identifier: <$doc>::identifier(),
113                        title: <$doc>::title()
114                    }),+
115                ])
116            }
117            router = router.route("/docs", ::axum::routing::get(get_docs));
118
119            router = router.route("/me", ::axum::routing::get(::scalar_axum::me::<$db>));
120            router = ::scalar_axum::add_image_routes__(router);
121            router = router.layer(::axum::middleware::from_fn_with_state($db_instance.clone(), ::scalar_axum::authenticated_connection_middleware::<$db>));
122            router = router.route("/signin", ::axum::routing::post(::scalar_axum::signin::<$db>));
123
124            ::scalar_axum::validate_routes__!(router, $($doc),+);
125
126            router
127        }
128    };
129}
130
131pub async fn authenticated_connection_middleware<F: DatabaseFactory + Clone>(
132    State(db_factory): State<F>,
133    mut req: Request,
134    next: Next,
135) -> Result<Response, StatusCode>
136where
137    <F as DatabaseFactory>::Connection: 'static,
138{
139    let auth_header = req
140        .headers()
141        .get(http::header::AUTHORIZATION)
142        .map(|header| {
143            header
144                .to_str()
145                .map(str::trim)
146                .map_err(|_| StatusCode::BAD_REQUEST)
147        })
148        .ok_or(StatusCode::UNAUTHORIZED)??;
149
150    let connection = db_factory.init().await.map_err(|e| {
151        println!("{e}");
152        StatusCode::INTERNAL_SERVER_ERROR
153    })?;
154
155    let (_, token) = auth_header
156        .starts_with("Bearer ")
157        .then(|| {
158            auth_header
159                .split_at_checked(7)
160                .ok_or(StatusCode::UNAUTHORIZED)
161        })
162        .ok_or(StatusCode::UNAUTHORIZED)??;
163
164    connection.authenticate(token).await.map_err(|e| match e {
165        scalar_cms::db::AuthenticationError::BadToken => StatusCode::UNAUTHORIZED,
166        scalar_cms::db::AuthenticationError::BadCredentials => StatusCode::UNAUTHORIZED,
167        scalar_cms::db::AuthenticationError::DatabaseError(_) => StatusCode::INTERNAL_SERVER_ERROR,
168    })?;
169
170    req.extensions_mut().insert(connection);
171
172    Ok(next.run(req).await)
173}
174
175//#[axum_macros::debug_handler]
176pub async fn signin<F: DatabaseFactory + Clone>(
177    State(factory): State<F>,
178    Json(credentials): Json<Credentials>,
179) -> Result<String, StatusCode> {
180    let connection = factory.init().await.map_err(|e| {
181        println!("{e}");
182        StatusCode::INTERNAL_SERVER_ERROR
183    })?;
184
185    println!("connection");
186
187    let token = connection.signin(credentials).await.map_err(|e| match e {
188        scalar_cms::db::AuthenticationError::BadToken => StatusCode::UNAUTHORIZED,
189        scalar_cms::db::AuthenticationError::BadCredentials => StatusCode::UNAUTHORIZED,
190        scalar_cms::db::AuthenticationError::DatabaseError(_) => StatusCode::INTERNAL_SERVER_ERROR,
191    })?;
192
193    Ok(token)
194}
195
196pub async fn get_schema<T: Document>() -> Json<Schema> {
197    Json(T::schema())
198}
199
200pub async fn validate<D: Document>(
201    Json(doc): Json<D>,
202) -> Result<(), (StatusCode, Json<ValidationError>)> {
203    doc.validate()
204        .map_err(|e| (StatusCode::UNPROCESSABLE_ENTITY, Json(e)))
205}
206
207pub async fn me<F: DatabaseFactory>(
208    state: Extension<<F as DatabaseFactory>::Connection>,
209) -> Result<Json<User>, StatusCode> {
210    Ok(Json(state.me().await.map_err(|e| {
211        println!("{e}");
212        StatusCode::INTERNAL_SERVER_ERROR
213    })?))
214}
215
216pub async fn update_draft<T: Document + Serialize + DeserializeOwned + Send, F: DatabaseFactory>(
217    state: Extension<<F as DatabaseFactory>::Connection>,
218    Path(id): Path<String>,
219    Json(data): Json<serde_json::Value>,
220) -> Result<Json<Item<serde_json::Value>>, StatusCode> {
221    Ok(Json(
222        state
223            .draft::<T>(&id, data)
224            .await
225            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
226    ))
227}
228
229pub async fn publish_doc<
230    D: Document + Serialize + DeserializeOwned + Send + 'static,
231    F: DatabaseFactory,
232>(
233    Path(id): Path<String>,
234    state: Extension<<F as DatabaseFactory>::Connection>,
235    doc: Json<D>,
236) -> Result<(), StatusCode> {
237    state
238        .publish(
239            &id,
240            None,
241            Valid::new(doc.0).map_err(|_| StatusCode::UNPROCESSABLE_ENTITY)?,
242        )
243        .await
244        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
245
246    Ok(())
247}
248
249pub async fn get_all_docs<T: Document + Serialize + DeserializeOwned + Send, F: DatabaseFactory>(
250    state: Extension<<F as DatabaseFactory>::Connection>,
251) -> Result<Json<Vec<Item<serde_json::Value>>>, StatusCode> {
252    let items = state
253        .get_all::<T>()
254        .await
255        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
256
257    Ok(Json(items))
258}
259
260pub async fn get_doc_by_id<
261    T: Document + Serialize + DeserializeOwned + Send,
262    F: DatabaseFactory,
263>(
264    state: Extension<<F as DatabaseFactory>::Connection>,
265    id: Path<String>,
266) -> Result<Json<Item<serde_json::Value>>, StatusCode> {
267    state
268        .get_by_id::<T>(id.as_str())
269        .await
270        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
271        .map(Json)
272        .ok_or(StatusCode::NOT_FOUND)
273}