planetary_server/
lib.rs

1//! Implements a common server for Planetary microservices.
2
3use std::any::Any;
4use std::future::Future;
5
6use anyhow::Context as _;
7use axum::Router;
8use axum::body::Body;
9use axum::extract::FromRequest;
10use axum::extract::FromRequestParts;
11use axum::extract::rejection::JsonRejection;
12use axum::extract::rejection::PathRejection;
13use axum::http;
14use axum::http::HeaderName;
15use axum::http::StatusCode;
16use axum::http::header;
17use axum::response::IntoResponse;
18use axum::response::Response;
19use axum::routing::get;
20use bon::Builder;
21use serde::Serialize;
22use serde::Serializer;
23use tokio::net::TcpListener;
24use tower::ServiceBuilder;
25use tower_http::LatencyUnit;
26use tower_http::compression::CompressionLayer;
27use tower_http::sensitive_headers::SetSensitiveRequestHeadersLayer;
28use tower_http::sensitive_headers::SetSensitiveResponseHeadersLayer;
29use tower_http::trace::DefaultMakeSpan;
30use tower_http::trace::DefaultOnResponse;
31use tower_http::trace::TraceLayer;
32use tracing::Span;
33use tracing::debug;
34use tracing::error;
35use tracing::info;
36
37/// The default address to bind the server to.
38pub const DEFAULT_ADDRESS: &str = "0.0.0.0";
39
40/// The default port to bind the server to.
41pub const DEFAULT_PORT: u16 = 8080;
42
43/// Header values to be blocked from logging.
44const SENSITIVE_HEADERS: [HeaderName; 2] = [header::AUTHORIZATION, header::COOKIE];
45
46/// A panic handler for returning 500.
47fn handle_panic(err: Box<dyn Any + Send + 'static>) -> Response {
48    if let Some(s) = err.downcast_ref::<String>() {
49        error!("server panicked: {s}");
50    } else if let Some(s) = err.downcast_ref::<&str>() {
51        error!("server panicked: {s}");
52    } else {
53        error!("server panicked: unknown panic message");
54    };
55
56    Error::internal().into_response()
57}
58
59/// An extractor that wraps the JSON extractor of Axum.
60///
61/// This extractor returns an error object on rejection.
62#[derive(FromRequest)]
63#[from_request(via(axum::Json), rejection(Error))]
64pub struct Json<T>(pub T);
65
66impl<T> IntoResponse for Json<T>
67where
68    T: Serialize,
69{
70    fn into_response(self) -> Response {
71        axum::Json(self.0).into_response()
72    }
73}
74
75/// Helper for serializing a HTTP status code.
76fn serialize_status<S>(status: &StatusCode, serializer: S) -> Result<S::Ok, S::Error>
77where
78    S: Serializer,
79{
80    serializer.serialize_u16(status.as_u16())
81}
82
83/// Represents a generic error from the server.
84#[derive(Serialize, Debug)]
85pub struct Error {
86    /// The status code being returned in the response.
87    #[serde(serialize_with = "serialize_status")]
88    pub status: StatusCode,
89    /// The error message.
90    pub message: String,
91}
92
93impl Error {
94    /// Returns a "not found" JSON error response.
95    pub fn not_found() -> Error {
96        Error {
97            status: StatusCode::NOT_FOUND,
98            message: "the requested resource was not found".to_string(),
99        }
100    }
101
102    /// Returns a "bad request" JSON error response.
103    pub fn bad_request(message: impl Into<String>) -> Error {
104        Error {
105            status: StatusCode::BAD_REQUEST,
106            message: message.into(),
107        }
108    }
109
110    /// Returns a "forbidden" JSON error response.
111    pub fn forbidden() -> Error {
112        Self {
113            status: StatusCode::FORBIDDEN,
114            message: StatusCode::FORBIDDEN.to_string(),
115        }
116    }
117
118    /// Returns an "internal server error" JSON error response.
119    pub fn internal() -> Error {
120        Self {
121            status: StatusCode::INTERNAL_SERVER_ERROR,
122            message: StatusCode::INTERNAL_SERVER_ERROR.to_string(),
123        }
124    }
125}
126
127impl From<anyhow::Error> for Error {
128    fn from(e: anyhow::Error) -> Self {
129        tracing::error!("{e:#}");
130        Self::internal()
131    }
132}
133
134impl From<reqwest::Error> for Error {
135    fn from(e: reqwest::Error) -> Self {
136        tracing::error!("{e:#}");
137        Self::internal()
138    }
139}
140
141#[cfg(feature = "postgres")]
142impl From<planetary_db::postgres::Error> for Error {
143    fn from(e: planetary_db::postgres::Error) -> Self {
144        use planetary_db::postgres::Error::*;
145
146        let (status, message) = match &e {
147            TaskNotFound(_) => (StatusCode::NOT_FOUND, e.to_string()),
148            Pool(e) => {
149                // Log the error but do not return it to the client
150                tracing::error!("database connection error: {e:#}");
151                return Self::internal();
152            }
153            Diesel(e) => {
154                // Log the error but do not return it to the client
155                tracing::error!("database error: {e:#}");
156                return Self::internal();
157            }
158        };
159
160        Self { status, message }
161    }
162}
163
164impl From<planetary_db::Error> for Error {
165    fn from(e: planetary_db::Error) -> Self {
166        let (status, message) = match e {
167            planetary_db::Error::InvalidPageToken(_) => (StatusCode::BAD_REQUEST, e.to_string()),
168            #[cfg(feature = "postgres")]
169            planetary_db::Error::Postgres(e) => return e.into(),
170            planetary_db::Error::Other(e) => return e.into(),
171        };
172
173        Self { status, message }
174    }
175}
176
177impl From<JsonRejection> for Error {
178    fn from(rejection: JsonRejection) -> Self {
179        Self {
180            status: rejection.status(),
181            message: rejection.body_text(),
182        }
183    }
184}
185
186impl IntoResponse for Error {
187    fn into_response(self) -> Response {
188        (self.status, axum::Json(self)).into_response()
189    }
190}
191
192/// Represents the response type for most endpoints.
193pub type ServerResponse<T> = Result<T, Error>;
194
195/// An extractor that wraps the path extractor of Axum.
196///
197/// This extractor returns an error on rejection.
198#[derive(FromRequestParts)]
199#[from_request(via(axum::extract::Path), rejection(Error))]
200pub struct Path<T>(pub T);
201
202impl From<PathRejection> for Error {
203    fn from(rejection: PathRejection) -> Self {
204        Self {
205            status: rejection.status(),
206            message: rejection.body_text(),
207        }
208    }
209}
210
211/// An extractor that wraps the query extractor of Axum (extra).
212///
213/// This extractor returns an error on rejection.
214#[derive(FromRequestParts)]
215#[from_request(via(axum_extra::extract::Query), rejection(Error))]
216pub struct Query<T>(pub T);
217
218impl From<axum_extra::extract::QueryRejection> for Error {
219    fn from(rejection: axum_extra::extract::QueryRejection) -> Self {
220        Self {
221            status: rejection.status(),
222            message: rejection.body_text(),
223        }
224    }
225}
226
227/// The state for a task execution service (TES) server.
228#[derive(Clone, Builder)]
229pub struct Server<S> {
230    /// The address to bind the server to.
231    #[builder(into, default = DEFAULT_ADDRESS)]
232    address: String,
233
234    /// The port to bind the server to.
235    #[builder(into, default = DEFAULT_PORT)]
236    port: u16,
237
238    /// The routers for the server.
239    #[builder(into, default)]
240    routers: Vec<Router<S>>,
241}
242
243impl<S> Server<S>
244where
245    S: Clone + Send + Sync + 'static,
246{
247    /// Runs the server with the given state and shutdown function.
248    pub async fn run<F>(self, state: S, shutdown: F) -> anyhow::Result<()>
249    where
250        F: Future<Output = ()> + Send + 'static,
251    {
252        // Hook up the axum middleware
253        let middleware = ServiceBuilder::new()
254            .layer(SetSensitiveRequestHeadersLayer::new(SENSITIVE_HEADERS))
255            .layer(
256                TraceLayer::new_for_http()
257                    .make_span_with(DefaultMakeSpan::new().include_headers(true))
258                    .on_request(|request: &http::Request<Body>, _span: &Span| {
259                        debug!(
260                            "{method} {path}",
261                            method = request.method(),
262                            path = request.uri().path()
263                        )
264                    })
265                    .on_response(
266                        DefaultOnResponse::new()
267                            .level(tracing::Level::DEBUG)
268                            .latency_unit(LatencyUnit::Micros),
269                    ),
270            )
271            .layer(CompressionLayer::new())
272            .layer(SetSensitiveResponseHeadersLayer::new(SENSITIVE_HEADERS));
273
274        // Construct the axum app
275        let mut router = Router::new().route("/ping", get(|| async {}));
276
277        for merge in self.routers {
278            router = router.merge(merge);
279        }
280
281        // Run the server
282        let addr = format!("{address}:{port}", address = self.address, port = self.port);
283        let listener = TcpListener::bind(&addr)
284            .await
285            .context("binding to the provided address")?;
286
287        info!("listening at {addr}");
288
289        axum::serve(
290            listener,
291            router
292                .fallback(async || Error::not_found())
293                .layer(middleware)
294                .layer(tower_http::catch_panic::CatchPanicLayer::custom(
295                    handle_panic,
296                ))
297                .with_state(state),
298        )
299        .with_graceful_shutdown(shutdown)
300        .await
301        .context("failed to run server")?;
302
303        Ok(())
304    }
305}