ogcapi_services/
service.rs

1use std::{any::Any, net::SocketAddr};
2
3use axum::{
4    Router,
5    body::Body,
6    http::{
7        Response, StatusCode,
8        header::{AUTHORIZATION, CONTENT_TYPE, COOKIE, PROXY_AUTHORIZATION, SET_COOKIE},
9    },
10    response::IntoResponse,
11    routing::get,
12};
13use tokio::net::TcpListener;
14use tower::ServiceBuilder;
15use tower_http::{
16    ServiceBuilderExt,
17    catch_panic::CatchPanicLayer,
18    compression::CompressionLayer,
19    cors::CorsLayer,
20    request_id::MakeRequestUuid,
21    sensitive_headers::SetSensitiveRequestHeadersLayer,
22    trace::{DefaultMakeSpan, TraceLayer},
23};
24
25use ogcapi_types::common::Exception;
26
27use crate::{AppState, Config, ConfigParser, Error, routes};
28
29/// OGC API Services
30pub struct Service {
31    pub state: AppState,
32    pub router: Router<AppState>,
33    listener: TcpListener,
34}
35
36impl Service {
37    pub async fn new() -> Self {
38        // config
39        let config = Config::parse();
40
41        // state
42        let state = AppState::new_from(&config).await;
43
44        Service::new_with(&config, state).await
45    }
46
47    pub async fn new_with(config: &Config, state: AppState) -> Self {
48        // router
49        let router = Router::new()
50            .route("/", get(routes::root))
51            .route("/api", get(routes::api::api))
52            .route("/redoc", get(routes::api::redoc))
53            .route("/swagger", get(routes::api::swagger))
54            .route("/conformance", get(routes::conformance));
55
56        let router = router.merge(routes::collections::router(&state));
57
58        #[cfg(feature = "stac")]
59        let router = router.route(
60            "/search",
61            get(routes::stac::search_get).post(routes::stac::search_post),
62        );
63
64        #[cfg(feature = "features")]
65        let router = router.merge(routes::features::router(&state));
66
67        #[cfg(feature = "edr")]
68        let router = router.merge(routes::edr::router(&state));
69
70        #[cfg(feature = "styles")]
71        let router = router.merge(routes::styles::router(&state));
72
73        #[cfg(feature = "tiles")]
74        let router = router.merge(routes::tiles::router(&state));
75
76        #[cfg(feature = "processes")]
77        let router = router.merge(routes::processes::router(&state));
78
79        // add a fallback service for handling routes to unknown paths
80        let router = router.fallback(handler_404);
81
82        // middleware stack
83        let router = router.layer(
84            ServiceBuilder::new()
85                .set_x_request_id(MakeRequestUuid)
86                .layer(SetSensitiveRequestHeadersLayer::new([
87                    AUTHORIZATION,
88                    PROXY_AUTHORIZATION,
89                    COOKIE,
90                    SET_COOKIE,
91                ]))
92                .layer(TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::new()))
93                .layer(CompressionLayer::new())
94                .layer(CorsLayer::permissive())
95                .layer(CatchPanicLayer::custom(handle_panic))
96                .propagate_x_request_id(),
97        );
98
99        let listener = TcpListener::bind((config.host.as_str(), config.port))
100            .await
101            .expect("create listener");
102
103        Service {
104            state,
105            router,
106            listener,
107        }
108    }
109
110    /// Serve application
111    pub async fn serve(self) {
112        // add state
113        let router = self.router.with_state(self.state);
114
115        // serve
116        tracing::info!(
117            "listening on http://{}",
118            self.listener.local_addr().unwrap()
119        );
120
121        axum::serve::serve(self.listener, router)
122            .with_graceful_shutdown(shutdown_signal())
123            .await
124            .unwrap()
125    }
126
127    // helper function to get randomized port
128    #[doc(hidden)]
129    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
130        self.listener.local_addr()
131    }
132}
133
134/// Custom 404 handler
135async fn handler_404() -> impl IntoResponse {
136    Error::NotFound
137}
138
139/// Custom panic handler
140fn handle_panic(err: Box<dyn Any + Send + 'static>) -> Response<Body> {
141    let details = if let Some(s) = err.downcast_ref::<String>() {
142        s.clone()
143    } else if let Some(s) = err.downcast_ref::<&str>() {
144        s.to_string()
145    } else {
146        "Unknown panic message".to_string()
147    };
148
149    let body =
150        Exception::new_from_status(StatusCode::INTERNAL_SERVER_ERROR.as_u16()).detail(details);
151
152    let body = serde_json::to_string(&body).unwrap();
153
154    Response::builder()
155        .status(StatusCode::INTERNAL_SERVER_ERROR)
156        .header(CONTENT_TYPE, "application/json")
157        .body(Body::from(body))
158        .unwrap()
159}
160
161/// Handle shutdown signals
162async fn shutdown_signal() {
163    let ctrl_c = async {
164        tokio::signal::ctrl_c()
165            .await
166            .expect("failed to install Ctrl+C handler");
167    };
168
169    #[cfg(unix)]
170    let terminate = async {
171        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
172            .expect("failed to install signal handler")
173            .recv()
174            .await;
175    };
176
177    #[cfg(not(unix))]
178    let terminate = std::future::pending::<()>();
179
180    tokio::select! {
181        _ = ctrl_c => {},
182        _ = terminate => {},
183    }
184
185    tracing::debug!("signal received, starting graceful shutdown");
186}