1use std::{fmt::Display, sync::Arc};
2
3use axum::{
4 async_trait,
5 extract::{FromRequest, FromRequestParts, Host, Path},
6 http::{self, request::Parts, Request},
7 response::{IntoResponse, Response},
8 Extension, Json,
9};
10use http::header;
11use serde::{de::DeserializeOwned, Serialize};
12use serde_json::json;
13
14use crate::server::ControlServer;
15
16pub type ApiResponse<D> = Result<Json<D>, ApiError>;
17
18pub fn ok<D: Serialize>(data: D) -> ApiResponse<D> {
19 Ok(Json(data))
20}
21
22#[derive(Debug)]
23pub enum ApiError {
24 Internal,
25 NotAuthenticated,
26 NotAuthorized,
27 InvalidData(String),
28 InvalidPathArg(String),
29 InvalidQueryArg(String),
30 Custom {
31 code: &'static str,
32 message: Option<String>,
33 },
34}
35
36impl ApiError {
37 pub fn code(&self) -> &str {
38 match self {
39 ApiError::Internal => "internal",
40 ApiError::NotAuthenticated => "unauthenticated",
41 ApiError::NotAuthorized => "unauthorized",
42 ApiError::InvalidData(_) => "invalid_data",
43 ApiError::InvalidPathArg(_) => "invalid_path_arg",
44 ApiError::InvalidQueryArg(_) => "invalid_query_arg",
45 ApiError::Custom { code, .. } => code,
46 }
47 }
48
49 pub fn message(&self) -> String {
50 match self {
51 ApiError::Internal => "".into(),
52 ApiError::NotAuthenticated => "Not authenticated".into(),
53 ApiError::NotAuthorized => "Not authorized".into(),
54 ApiError::InvalidData(msg) => msg.clone(),
55 ApiError::InvalidPathArg(msg) => msg.clone(),
56 ApiError::InvalidQueryArg(msg) => msg.clone(),
57 ApiError::Custom { message, .. } => message.clone().unwrap_or_else(|| "".into()),
58 }
59 }
60
61 #[allow(dead_code)]
62 pub fn log_internal(msg: &str, e: impl std::fmt::Debug) -> Self {
63 log::error!("{}: {:?}", msg, e);
64 Self::Internal
65 }
66
67 pub fn custom(code: &'static str, message: String) -> Self {
68 Self::Custom {
69 code,
70 message: Some(message),
71 }
72 }
73
74 pub fn custom_code(code: &'static str) -> Self {
75 Self::Custom {
76 code,
77 message: None,
78 }
79 }
80}
81
82impl Display for ApiError {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 f.write_str("Error ")?;
85 f.write_str(self.code())?;
86 let msg = self.message();
87 if !msg.is_empty() {
88 f.write_str(": ")?;
89 f.write_str(&msg)?;
90 }
91 Ok(())
92 }
93}
94
95impl IntoResponse for ApiError {
96 fn into_response(self) -> Response {
97 use http::StatusCode as S;
98 use ApiError::*;
99
100 let body = Json(json!({
101 "message": self.message(),
102 "code": self.code(),
103 }));
104
105 let status = match self {
106 Self::Internal => S::INTERNAL_SERVER_ERROR,
107 Self::NotAuthenticated => S::UNAUTHORIZED,
108 Self::NotAuthorized => S::FORBIDDEN,
109 InvalidData(_) | InvalidPathArg(_) | InvalidQueryArg(_) | Custom { .. } => {
110 S::BAD_REQUEST
111 }
112 };
113
114 (status, body).into_response()
115 }
116}
117
118pub struct JsonExtractor<T>(pub T);
119
120#[async_trait]
121impl<S, B, T> FromRequest<S, B> for JsonExtractor<T>
122where
123 axum::Json<T>: FromRequest<S, B, Rejection = axum::extract::rejection::JsonRejection>,
124 S: Send + Sync,
125 B: Send + 'static,
126{
127 type Rejection = ApiError;
128
129 async fn from_request(req: Request<B>, state: &S) -> Result<JsonExtractor<T>, Self::Rejection> {
130 match Json::from_request(req, state).await {
131 Ok(Json(value)) => Ok(JsonExtractor(value)),
132 Err(e) => Err(ApiError::InvalidData(e.to_string())),
133 }
134 }
135}
136
137pub struct PathExtractor<T>(pub T);
138
139#[async_trait]
140impl<S, T> FromRequestParts<S> for PathExtractor<T>
141where
142 S: Send + Sync,
143 T: DeserializeOwned + Send,
144{
145 type Rejection = ApiError;
146
147 async fn from_request_parts(
148 req: &mut Parts,
149 state: &S,
150 ) -> Result<PathExtractor<T>, Self::Rejection> {
151 match Path::from_request_parts(req, state).await {
152 Ok(Path(value)) => Ok(PathExtractor(value)),
153 Err(e) => Err(ApiError::InvalidPathArg(e.to_string())),
154 }
155 }
156}
157
158pub struct HostExtractor(pub String);
159
160#[async_trait]
161impl<S> FromRequestParts<S> for HostExtractor
162where
163 S: Send + Sync,
164{
165 type Rejection = ApiError;
166
167 async fn from_request_parts(
168 req: &mut Parts,
169 state: &S,
170 ) -> Result<HostExtractor, Self::Rejection> {
171 match Host::from_request_parts(req, state).await {
172 Ok(Host(host)) => Ok(HostExtractor(host)),
173 Err(e) => Err(ApiError::Custom {
174 code: "no_host",
175 message: Some(e.to_string()),
176 }),
177 }
178 }
179}
180
181#[derive(Debug)]
182pub struct NodeAuth {
183 pub registration_id: i64,
184 pub node_name: uuid::Uuid,
185}
186
187#[async_trait]
188impl<S> FromRequestParts<S> for NodeAuth
189where
190 S: Send + Sync,
191{
192 type Rejection = ApiError;
193
194 async fn from_request_parts(req: &mut Parts, state: &S) -> Result<NodeAuth, Self::Rejection> {
195 let headers = req.headers.clone();
196 let auth_header = headers
197 .get(header::AUTHORIZATION)
198 .ok_or_else(|| {
199 ApiError::custom("no_auth_header", "Missing node authorization header".into())
200 })?
201 .to_str()
202 .map_err(|_| {
203 ApiError::custom(
204 "invalid_auth_header",
205 "Invalid authorization header value".into(),
206 )
207 })?;
208
209 let token = auth_header
210 .strip_prefix("Bearer ")
211 .to_owned()
212 .ok_or_else(|| {
213 ApiError::custom(
214 "invalid_auth_token",
215 "Header value doesn't start with Bearer".into(),
216 )
217 })?;
218
219 let node_name = headers
220 .get("x-lunatic-node-name")
221 .ok_or_else(|| {
222 ApiError::custom(
223 "no_lunatic_node_name_header",
224 "Missing x-lunatic-node-name header".into(),
225 )
226 })?
227 .to_str()
228 .map_err(|_| {
229 ApiError::custom(
230 "invalid_lunatic_node_name_header",
231 "Invalid x-lunatic-node-name header value".into(),
232 )
233 })?;
234
235 let node_name: uuid::Uuid = node_name.parse().map_err(|_| {
236 ApiError::custom(
237 "invalid_lunatic_node_name_header",
238 format!("Invalid x-lunatic-node-name header: {node_name} not a valid UUID"),
239 )
240 })?;
241
242 let cs: Extension<Arc<ControlServer>> = Extension::from_request_parts(req, state)
243 .await
244 .map_err(|e| ApiError::log_internal("Error getting cs in registration auth", e))?;
245
246 let (registration_id, reg) = cs
247 .registrations
248 .iter()
249 .find(|r| r.node_name == node_name && r.authentication_token == token)
250 .map(|r| (*r.key(), r.value().clone()))
251 .ok_or(ApiError::NotAuthenticated)?;
252 let node_auth = NodeAuth {
253 registration_id: registration_id as i64,
254 node_name: reg.node_name,
255 };
256
257 Ok(node_auth)
258 }
259}