ogcapi-proxy 0.2.0

OGC API proxy service
Documentation
use std::{
    convert::Infallible,
    fmt::Debug,
    pin::Pin,
    str::FromStr,
    task::{Context, Poll},
};

use axum::{
    Json, RequestExt,
    body::Body,
    http::{HeaderValue, Uri, header::CONTENT_TYPE, uri::PathAndQuery},
    response::IntoResponse,
};
use axum_reverse_proxy::ReverseProxy;
use http::header::CONTENT_LENGTH;
use http_body_util::BodyExt;
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::connect::{Connect, HttpConnector};
use ogcapi_types::common::{
    Link,
    media_type::{GEO_JSON, JSON},
};
use serde_json::{Value, json};
use tower::Service;

use crate::{extractors::RemoteUrl, proxied_linked::ProxiedLinked};

/// Proxy for one individual OGC API Collection and its sub-endpoints.
///
/// Links in JSON responses are rewriten to match the proxy URL.
/// ```rust
/// # use ogcapi_proxy::CollectionProxy;
/// # use ogcapi_types::common::Collection;
/// # use ogcapi_types::features::FeatureCollection;
/// # use ogcapi_types::common::Linked;
/// # use http::request::Request;
/// # use axum::body::Body;
/// # use axum::body::to_bytes;
/// # use tower::Service;
/// #
/// # tokio_test::block_on(async {
/// #
/// let mut collection_proxy = CollectionProxy::new(
///     "/collections/proxied-lakes".to_string(),
///     "https://demo.pygeoapi.io/stable/collections/lakes".to_string(),
/// );
///
/// let req = Request::builder()
///     .uri("/collections/proxied-lakes")
///     .header("Host", "test-host.example.org")
///     .body(Body::empty())
///     .unwrap();
///
/// let res = collection_proxy.call(req).await
///     .expect("Proxied request to demo.pygeoapi.io might fail if that api is not available");
///
/// assert!(res.status().is_success());
///
/// let mut collection: Collection = serde_json::from_slice(
///     &to_bytes(res.into_body(), 200_000).await.unwrap()
/// ).unwrap();
///
/// assert!(
///     collection.links.get_base_url().as_ref().is_some_and(|base_url| {
///         base_url.as_str().starts_with("http://test-host.example.org/collections/proxied-lakes")
///     })
/// );
/// let items_url = collection.links.iter().find(|link| link.rel == "items")
///     .expect("Collection is expected to have an items link");
///
/// assert!(
///     items_url.href.starts_with("http://test-host.example.org/collections/proxied-lakes/items")
/// );
/// # })
/// ```
#[derive(Clone)]
pub struct CollectionProxy<
    C: Connect + Clone + Send + Sync + 'static = HttpsConnector<HttpConnector>,
> {
    collection_id: String,
    proxy: ReverseProxy<C>,
}

impl<C: Connect + Clone + Send + Sync + 'static> Debug for CollectionProxy<C> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("CollectionProxy")
            .field("collection_id", &self.collection_id)
            .field("remote_collection_url", &self.proxy.target())
            .finish()
    }
}

impl CollectionProxy<HttpsConnector<HttpConnector>> {
    pub fn new(path: String, remote_collection_url: String) -> Self {
        Self {
            // FIXME this should distinguish between collection_id and path
            // although this is the same most of the time to enable rewriting
            // of collection ids in json responses
            collection_id: path.clone(),
            proxy: ReverseProxy::new(
                format!("/{}", path.trim_start_matches("/")),
                remote_collection_url,
            ),
        }
    }
}

impl<C: Connect + Clone + Send + Sync + 'static> CollectionProxy<C> {
    pub fn new_with_client(
        path: String,
        remote_collection_url: String,
        client: hyper_util::client::legacy::Client<C, Body>,
    ) -> Self {
        Self {
            collection_id: path.clone(),
            proxy: ReverseProxy::new_with_client(
                format!("/{}", path.trim_start_matches("/")),
                remote_collection_url,
                client,
            ),
        }
    }

    pub async fn handle_request(
        &self,
        mut req: axum::http::Request<Body>,
    ) -> Result<axum::http::Response<axum::body::Body>, Infallible> {
        let proxy_uri = req.extract_parts::<RemoteUrl>().await.unwrap().0;
        let mut parts = proxy_uri.into_parts();
        parts.path_and_query = parts
            .path_and_query
            .map(|path_and_query| PathAndQuery::from_str(path_and_query.path()).unwrap());
        let request_uri_without_query: Uri = parts.try_into().unwrap();

        // modify request to accept json, as we currently can only rewrite links in json
        rewrite_req_to_accept_json(&mut req);

        // Unwrap is safe as proxy_request is Infallible
        let response = self.proxy.proxy_request(req).await.unwrap();

        if response.status().is_success()
            // && proxy_uri.path() == self.proxy.path()
            && response
                .headers()
                .get(CONTENT_TYPE)
                .is_some_and(|ct| ct == JSON || ct == GEO_JSON)
        {
            let (mut parts, body) = response.into_parts();
            // FIXME this silently drops the response body if it can't be rewritten.
            // FIXME this drops response headers etc.
            let bytes = body.collect().await.ok().map(|b| b.to_bytes());
            let value: Option<Value> =
                bytes.and_then(|bytes| serde_json::from_slice(bytes.as_ref()).ok());

            if let Some(mut value) = value {
                if let Some(object) = value.as_object_mut() {
                    if let Some((key, links_value)) = object.remove_entry("links") {
                        let mut links: Vec<Link> = serde_json::from_value(links_value).unwrap();
                        links.rewrite_links(self.target(), &request_uri_without_query.to_string());
                        object.insert(key, json!(links));
                    };
                };

                parts.headers.remove(CONTENT_LENGTH);
                Ok((parts, Json(value)).into_response())
            } else {
                parts.headers.insert(CONTENT_LENGTH, HeaderValue::from(0));
                Ok((parts, ()).into_response())
            }
        } else {
            Ok(response)
        }
    }

    /// Get the original URL of the proxied collection.
    pub fn target(&self) -> &str {
        self.proxy.target()
    }

    /// Get the relative URI at which this proxy is accessed.
    pub fn path(&self) -> &str {
        self.proxy.path()
    }

    /// Get the new id of the proxied collection
    pub fn collection_id(&self) -> &str {
        self.collection_id.trim_start_matches('/')
    }
}

fn rewrite_req_to_accept_json(req: &mut axum::http::Request<Body>) {
    req.headers_mut().insert(
        "accept",
        HeaderValue::from_str("application/json, application/geo+json").unwrap(),
    );
}

impl<C> Service<axum::http::Request<Body>> for CollectionProxy<C>
where
    C: Connect + Clone + Send + Sync + 'static,
{
    type Response = axum::http::Response<axum::body::Body>;
    type Error = Infallible;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
        let this = self.clone();
        Box::pin(async move { this.handle_request(req).await })
    }
}