1use std::any::Any;
4use std::future::Future;
5
6use anyhow::Context as _;
7use axum::Router;
8use axum::body::Body;
9use axum::extract::FromRequest;
10use axum::extract::FromRequestParts;
11use axum::extract::rejection::JsonRejection;
12use axum::extract::rejection::PathRejection;
13use axum::http;
14use axum::http::HeaderName;
15use axum::http::StatusCode;
16use axum::http::header;
17use axum::response::IntoResponse;
18use axum::response::Response;
19use axum::routing::get;
20use bon::Builder;
21use serde::Serialize;
22use serde::Serializer;
23use tokio::net::TcpListener;
24use tower::ServiceBuilder;
25use tower_http::LatencyUnit;
26use tower_http::compression::CompressionLayer;
27use tower_http::sensitive_headers::SetSensitiveRequestHeadersLayer;
28use tower_http::sensitive_headers::SetSensitiveResponseHeadersLayer;
29use tower_http::trace::DefaultMakeSpan;
30use tower_http::trace::DefaultOnResponse;
31use tower_http::trace::TraceLayer;
32use tracing::Span;
33use tracing::debug;
34use tracing::error;
35use tracing::info;
36
37pub const DEFAULT_ADDRESS: &str = "0.0.0.0";
39
40pub const DEFAULT_PORT: u16 = 8080;
42
43const SENSITIVE_HEADERS: [HeaderName; 2] = [header::AUTHORIZATION, header::COOKIE];
45
46fn handle_panic(err: Box<dyn Any + Send + 'static>) -> Response {
48 if let Some(s) = err.downcast_ref::<String>() {
49 error!("server panicked: {s}");
50 } else if let Some(s) = err.downcast_ref::<&str>() {
51 error!("server panicked: {s}");
52 } else {
53 error!("server panicked: unknown panic message");
54 };
55
56 Error::internal().into_response()
57}
58
59#[derive(FromRequest)]
63#[from_request(via(axum::Json), rejection(Error))]
64pub struct Json<T>(pub T);
65
66impl<T> IntoResponse for Json<T>
67where
68 T: Serialize,
69{
70 fn into_response(self) -> Response {
71 axum::Json(self.0).into_response()
72 }
73}
74
75fn serialize_status<S>(status: &StatusCode, serializer: S) -> Result<S::Ok, S::Error>
77where
78 S: Serializer,
79{
80 serializer.serialize_u16(status.as_u16())
81}
82
83#[derive(Serialize, Debug)]
85pub struct Error {
86 #[serde(serialize_with = "serialize_status")]
88 pub status: StatusCode,
89 pub message: String,
91}
92
93impl Error {
94 pub fn not_found() -> Error {
96 Error {
97 status: StatusCode::NOT_FOUND,
98 message: "the requested resource was not found".to_string(),
99 }
100 }
101
102 pub fn bad_request(message: impl Into<String>) -> Error {
104 Error {
105 status: StatusCode::BAD_REQUEST,
106 message: message.into(),
107 }
108 }
109
110 pub fn forbidden() -> Error {
112 Self {
113 status: StatusCode::FORBIDDEN,
114 message: StatusCode::FORBIDDEN.to_string(),
115 }
116 }
117
118 pub fn internal() -> Error {
120 Self {
121 status: StatusCode::INTERNAL_SERVER_ERROR,
122 message: StatusCode::INTERNAL_SERVER_ERROR.to_string(),
123 }
124 }
125}
126
127impl From<anyhow::Error> for Error {
128 fn from(e: anyhow::Error) -> Self {
129 tracing::error!("{e:#}");
130 Self::internal()
131 }
132}
133
134impl From<reqwest::Error> for Error {
135 fn from(e: reqwest::Error) -> Self {
136 tracing::error!("{e:#}");
137 Self::internal()
138 }
139}
140
141#[cfg(feature = "postgres")]
142impl From<planetary_db::postgres::Error> for Error {
143 fn from(e: planetary_db::postgres::Error) -> Self {
144 use planetary_db::postgres::Error::*;
145
146 let (status, message) = match &e {
147 TaskNotFound(_) => (StatusCode::NOT_FOUND, e.to_string()),
148 Pool(e) => {
149 tracing::error!("database connection error: {e:#}");
151 return Self::internal();
152 }
153 Diesel(e) => {
154 tracing::error!("database error: {e:#}");
156 return Self::internal();
157 }
158 };
159
160 Self { status, message }
161 }
162}
163
164impl From<planetary_db::Error> for Error {
165 fn from(e: planetary_db::Error) -> Self {
166 let (status, message) = match e {
167 planetary_db::Error::InvalidPageToken(_) => (StatusCode::BAD_REQUEST, e.to_string()),
168 #[cfg(feature = "postgres")]
169 planetary_db::Error::Postgres(e) => return e.into(),
170 planetary_db::Error::Other(e) => return e.into(),
171 };
172
173 Self { status, message }
174 }
175}
176
177impl From<JsonRejection> for Error {
178 fn from(rejection: JsonRejection) -> Self {
179 Self {
180 status: rejection.status(),
181 message: rejection.body_text(),
182 }
183 }
184}
185
186impl IntoResponse for Error {
187 fn into_response(self) -> Response {
188 (self.status, axum::Json(self)).into_response()
189 }
190}
191
192pub type ServerResponse<T> = Result<T, Error>;
194
195#[derive(FromRequestParts)]
199#[from_request(via(axum::extract::Path), rejection(Error))]
200pub struct Path<T>(pub T);
201
202impl From<PathRejection> for Error {
203 fn from(rejection: PathRejection) -> Self {
204 Self {
205 status: rejection.status(),
206 message: rejection.body_text(),
207 }
208 }
209}
210
211#[derive(FromRequestParts)]
215#[from_request(via(axum_extra::extract::Query), rejection(Error))]
216pub struct Query<T>(pub T);
217
218impl From<axum_extra::extract::QueryRejection> for Error {
219 fn from(rejection: axum_extra::extract::QueryRejection) -> Self {
220 Self {
221 status: rejection.status(),
222 message: rejection.body_text(),
223 }
224 }
225}
226
227#[derive(Clone, Builder)]
229pub struct Server<S> {
230 #[builder(into, default = DEFAULT_ADDRESS)]
232 address: String,
233
234 #[builder(into, default = DEFAULT_PORT)]
236 port: u16,
237
238 #[builder(into, default)]
240 routers: Vec<Router<S>>,
241}
242
243impl<S> Server<S>
244where
245 S: Clone + Send + Sync + 'static,
246{
247 pub async fn run<F>(self, state: S, shutdown: F) -> anyhow::Result<()>
249 where
250 F: Future<Output = ()> + Send + 'static,
251 {
252 let middleware = ServiceBuilder::new()
254 .layer(SetSensitiveRequestHeadersLayer::new(SENSITIVE_HEADERS))
255 .layer(
256 TraceLayer::new_for_http()
257 .make_span_with(DefaultMakeSpan::new().include_headers(true))
258 .on_request(|request: &http::Request<Body>, _span: &Span| {
259 debug!(
260 "{method} {path}",
261 method = request.method(),
262 path = request.uri().path()
263 )
264 })
265 .on_response(
266 DefaultOnResponse::new()
267 .level(tracing::Level::DEBUG)
268 .latency_unit(LatencyUnit::Micros),
269 ),
270 )
271 .layer(CompressionLayer::new())
272 .layer(SetSensitiveResponseHeadersLayer::new(SENSITIVE_HEADERS));
273
274 let mut router = Router::new().route("/ping", get(|| async {}));
276
277 for merge in self.routers {
278 router = router.merge(merge);
279 }
280
281 let addr = format!("{address}:{port}", address = self.address, port = self.port);
283 let listener = TcpListener::bind(&addr)
284 .await
285 .context("binding to the provided address")?;
286
287 info!("listening at {addr}");
288
289 axum::serve(
290 listener,
291 router
292 .fallback(async || Error::not_found())
293 .layer(middleware)
294 .layer(tower_http::catch_panic::CatchPanicLayer::custom(
295 handle_panic,
296 ))
297 .with_state(state),
298 )
299 .with_graceful_shutdown(shutdown)
300 .await
301 .context("failed to run server")?;
302
303 Ok(())
304 }
305}