lunatic_control_axum/
api.rs

1use std::{fmt::Display, sync::Arc};
2
3use axum::{
4    async_trait,
5    extract::{FromRequest, FromRequestParts, Host, Path},
6    http::{self, request::Parts, Request},
7    response::{IntoResponse, Response},
8    Extension, Json,
9};
10use http::header;
11use serde::{de::DeserializeOwned, Serialize};
12use serde_json::json;
13
14use crate::server::ControlServer;
15
16pub type ApiResponse<D> = Result<Json<D>, ApiError>;
17
18pub fn ok<D: Serialize>(data: D) -> ApiResponse<D> {
19    Ok(Json(data))
20}
21
22#[derive(Debug)]
23pub enum ApiError {
24    Internal,
25    NotAuthenticated,
26    NotAuthorized,
27    InvalidData(String),
28    InvalidPathArg(String),
29    InvalidQueryArg(String),
30    Custom {
31        code: &'static str,
32        message: Option<String>,
33    },
34}
35
36impl ApiError {
37    pub fn code(&self) -> &str {
38        match self {
39            ApiError::Internal => "internal",
40            ApiError::NotAuthenticated => "unauthenticated",
41            ApiError::NotAuthorized => "unauthorized",
42            ApiError::InvalidData(_) => "invalid_data",
43            ApiError::InvalidPathArg(_) => "invalid_path_arg",
44            ApiError::InvalidQueryArg(_) => "invalid_query_arg",
45            ApiError::Custom { code, .. } => code,
46        }
47    }
48
49    pub fn message(&self) -> String {
50        match self {
51            ApiError::Internal => "".into(),
52            ApiError::NotAuthenticated => "Not authenticated".into(),
53            ApiError::NotAuthorized => "Not authorized".into(),
54            ApiError::InvalidData(msg) => msg.clone(),
55            ApiError::InvalidPathArg(msg) => msg.clone(),
56            ApiError::InvalidQueryArg(msg) => msg.clone(),
57            ApiError::Custom { message, .. } => message.clone().unwrap_or_else(|| "".into()),
58        }
59    }
60
61    #[allow(dead_code)]
62    pub fn log_internal(msg: &str, e: impl std::fmt::Debug) -> Self {
63        log::error!("{}: {:?}", msg, e);
64        Self::Internal
65    }
66
67    pub fn custom(code: &'static str, message: String) -> Self {
68        Self::Custom {
69            code,
70            message: Some(message),
71        }
72    }
73
74    pub fn custom_code(code: &'static str) -> Self {
75        Self::Custom {
76            code,
77            message: None,
78        }
79    }
80}
81
82impl Display for ApiError {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        f.write_str("Error ")?;
85        f.write_str(self.code())?;
86        let msg = self.message();
87        if !msg.is_empty() {
88            f.write_str(": ")?;
89            f.write_str(&msg)?;
90        }
91        Ok(())
92    }
93}
94
95impl IntoResponse for ApiError {
96    fn into_response(self) -> Response {
97        use http::StatusCode as S;
98        use ApiError::*;
99
100        let body = Json(json!({
101            "message": self.message(),
102            "code": self.code(),
103        }));
104
105        let status = match self {
106            Self::Internal => S::INTERNAL_SERVER_ERROR,
107            Self::NotAuthenticated => S::UNAUTHORIZED,
108            Self::NotAuthorized => S::FORBIDDEN,
109            InvalidData(_) | InvalidPathArg(_) | InvalidQueryArg(_) | Custom { .. } => {
110                S::BAD_REQUEST
111            }
112        };
113
114        (status, body).into_response()
115    }
116}
117
118pub struct JsonExtractor<T>(pub T);
119
120#[async_trait]
121impl<S, B, T> FromRequest<S, B> for JsonExtractor<T>
122where
123    axum::Json<T>: FromRequest<S, B, Rejection = axum::extract::rejection::JsonRejection>,
124    S: Send + Sync,
125    B: Send + 'static,
126{
127    type Rejection = ApiError;
128
129    async fn from_request(req: Request<B>, state: &S) -> Result<JsonExtractor<T>, Self::Rejection> {
130        match Json::from_request(req, state).await {
131            Ok(Json(value)) => Ok(JsonExtractor(value)),
132            Err(e) => Err(ApiError::InvalidData(e.to_string())),
133        }
134    }
135}
136
137pub struct PathExtractor<T>(pub T);
138
139#[async_trait]
140impl<S, T> FromRequestParts<S> for PathExtractor<T>
141where
142    S: Send + Sync,
143    T: DeserializeOwned + Send,
144{
145    type Rejection = ApiError;
146
147    async fn from_request_parts(
148        req: &mut Parts,
149        state: &S,
150    ) -> Result<PathExtractor<T>, Self::Rejection> {
151        match Path::from_request_parts(req, state).await {
152            Ok(Path(value)) => Ok(PathExtractor(value)),
153            Err(e) => Err(ApiError::InvalidPathArg(e.to_string())),
154        }
155    }
156}
157
158pub struct HostExtractor(pub String);
159
160#[async_trait]
161impl<S> FromRequestParts<S> for HostExtractor
162where
163    S: Send + Sync,
164{
165    type Rejection = ApiError;
166
167    async fn from_request_parts(
168        req: &mut Parts,
169        state: &S,
170    ) -> Result<HostExtractor, Self::Rejection> {
171        match Host::from_request_parts(req, state).await {
172            Ok(Host(host)) => Ok(HostExtractor(host)),
173            Err(e) => Err(ApiError::Custom {
174                code: "no_host",
175                message: Some(e.to_string()),
176            }),
177        }
178    }
179}
180
181#[derive(Debug)]
182pub struct NodeAuth {
183    pub registration_id: i64,
184    pub node_name: uuid::Uuid,
185}
186
187#[async_trait]
188impl<S> FromRequestParts<S> for NodeAuth
189where
190    S: Send + Sync,
191{
192    type Rejection = ApiError;
193
194    async fn from_request_parts(req: &mut Parts, state: &S) -> Result<NodeAuth, Self::Rejection> {
195        let headers = req.headers.clone();
196        let auth_header = headers
197            .get(header::AUTHORIZATION)
198            .ok_or_else(|| {
199                ApiError::custom("no_auth_header", "Missing node authorization header".into())
200            })?
201            .to_str()
202            .map_err(|_| {
203                ApiError::custom(
204                    "invalid_auth_header",
205                    "Invalid authorization header value".into(),
206                )
207            })?;
208
209        let token = auth_header
210            .strip_prefix("Bearer ")
211            .to_owned()
212            .ok_or_else(|| {
213                ApiError::custom(
214                    "invalid_auth_token",
215                    "Header value doesn't start with Bearer".into(),
216                )
217            })?;
218
219        let node_name = headers
220            .get("x-lunatic-node-name")
221            .ok_or_else(|| {
222                ApiError::custom(
223                    "no_lunatic_node_name_header",
224                    "Missing x-lunatic-node-name header".into(),
225                )
226            })?
227            .to_str()
228            .map_err(|_| {
229                ApiError::custom(
230                    "invalid_lunatic_node_name_header",
231                    "Invalid x-lunatic-node-name header value".into(),
232                )
233            })?;
234
235        let node_name: uuid::Uuid = node_name.parse().map_err(|_| {
236            ApiError::custom(
237                "invalid_lunatic_node_name_header",
238                format!("Invalid x-lunatic-node-name header: {node_name} not a valid UUID"),
239            )
240        })?;
241
242        let cs: Extension<Arc<ControlServer>> = Extension::from_request_parts(req, state)
243            .await
244            .map_err(|e| ApiError::log_internal("Error getting cs in registration auth", e))?;
245
246        let (registration_id, reg) = cs
247            .registrations
248            .iter()
249            .find(|r| r.node_name == node_name && r.authentication_token == token)
250            .map(|r| (*r.key(), r.value().clone()))
251            .ok_or(ApiError::NotAuthenticated)?;
252        let node_auth = NodeAuth {
253            registration_id: registration_id as i64,
254            node_name: reg.node_name,
255        };
256
257        Ok(node_auth)
258    }
259}