ogcapi_proxy/
collection_proxy.rs

1use std::{
2    convert::Infallible,
3    fmt::Debug,
4    pin::Pin,
5    str::FromStr,
6    task::{Context, Poll},
7};
8
9use axum::{
10    Json, RequestExt,
11    body::Body,
12    http::{HeaderValue, StatusCode, Uri, header::CONTENT_TYPE, uri::PathAndQuery},
13    response::IntoResponse,
14};
15use axum_reverse_proxy::ReverseProxy;
16use http_body_util::BodyExt;
17use hyper_rustls::HttpsConnector;
18use hyper_util::client::legacy::connect::{Connect, HttpConnector};
19use ogcapi_types::common::{
20    Link,
21    media_type::{GEO_JSON, JSON},
22};
23use serde_json::{Value, json};
24use tower::Service;
25
26use crate::{extractors::RemoteUrl, proxied_linked::ProxiedLinked};
27
28/// Proxy for one individual OGC API Collection and its sub-endpoints.
29///
30/// Links in JSON responses are rewriten to match the proxy URL.
31/// ```rust
32/// # use ogcapi_proxy::CollectionProxy;
33/// # use ogcapi_types::common::Collection;
34/// # use ogcapi_types::features::FeatureCollection;
35/// # use ogcapi_types::common::Linked;
36/// # use http::request::Request;
37/// # use axum::body::Body;
38/// # use axum::body::to_bytes;
39/// # use tower::Service;
40/// #
41/// # tokio_test::block_on(async {
42/// #
43/// let mut collection_proxy = CollectionProxy::new(
44///     "/collections/proxied-lakes".to_string(),
45///     "https://demo.pygeoapi.io/stable/collections/lakes".to_string(),
46/// );
47///
48/// let req = Request::builder()
49///     .uri("/collections/proxied-lakes")
50///     .header("Host", "test-host.example.org")
51///     .body(Body::empty())
52///     .unwrap();
53///
54/// let res = collection_proxy.call(req).await
55///     .expect("Proxied request to demo.pygeoapi.io might fail if that api is not available");
56///
57/// assert!(res.status().is_success());
58///
59/// let mut collection: Collection = serde_json::from_slice(
60///     &to_bytes(res.into_body(), 200_000).await.unwrap()
61/// ).unwrap();
62///
63/// assert!(
64///     collection.links.get_base_url().as_ref().is_some_and(|base_url| {
65///         base_url.as_str().starts_with("http://test-host.example.org/collections/proxied-lakes")
66///     })
67/// );
68/// let items_url = collection.links.iter().find(|link| link.rel == "items")
69///     .expect("Collection is expected to have an items link");
70///
71/// assert!(
72///     items_url.href.starts_with("http://test-host.example.org/collections/proxied-lakes/items")
73/// );
74/// # })
75/// ```
76#[derive(Clone)]
77pub struct CollectionProxy<
78    C: Connect + Clone + Send + Sync + 'static = HttpsConnector<HttpConnector>,
79> {
80    collection_id: String,
81    proxy: ReverseProxy<C>,
82}
83
84impl<C: Connect + Clone + Send + Sync + 'static> Debug for CollectionProxy<C> {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        f.debug_struct("CollectionProxy")
87            .field("collection_id", &self.collection_id)
88            .field("remote_collection_url", &self.proxy.target())
89            .finish()
90    }
91}
92
93impl CollectionProxy<HttpsConnector<HttpConnector>> {
94    pub fn new(path: String, remote_collection_url: String) -> Self {
95        Self {
96            // FIXME this should distinguish between collection_id and path
97            // although this is the same most of the time to enable rewriting
98            // of collection ids in json responses
99            collection_id: path.clone(),
100            proxy: ReverseProxy::new(
101                format!("/{}", path.trim_start_matches("/")),
102                remote_collection_url,
103            ),
104        }
105    }
106}
107
108impl<C: Connect + Clone + Send + Sync + 'static> CollectionProxy<C> {
109    pub fn new_with_client(
110        path: String,
111        remote_collection_url: String,
112        client: hyper_util::client::legacy::Client<C, Body>,
113    ) -> Self {
114        Self {
115            collection_id: path.clone(),
116            proxy: ReverseProxy::new_with_client(
117                format!("/{}", path.trim_start_matches("/")),
118                remote_collection_url,
119                client,
120            ),
121        }
122    }
123
124    pub async fn handle_request(
125        &self,
126        mut req: axum::http::Request<Body>,
127    ) -> Result<axum::http::Response<axum::body::Body>, Infallible> {
128        let proxy_uri = req.extract_parts::<RemoteUrl>().await.unwrap().0;
129        let mut parts = proxy_uri.into_parts();
130        parts.path_and_query = parts
131            .path_and_query
132            .map(|path_and_query| PathAndQuery::from_str(path_and_query.path()).unwrap());
133        let request_uri_without_query: Uri = parts.try_into().unwrap();
134
135        // modify request to accept json, as we currently can only rewrite links in json
136        rewrite_req_to_accept_json(&mut req);
137
138        // Unwrap is safe as proxy_request is Infallible
139        let response = self.proxy.proxy_request(req).await.unwrap();
140
141        if response.status().is_success()
142            // && proxy_uri.path() == self.proxy.path()
143            && response
144                .headers()
145                .get(CONTENT_TYPE)
146                .is_some_and(|ct| ct == JSON || ct == GEO_JSON)
147        {
148            let body = response.into_body();
149            // FIXME this silently drops the response body if it can't be rewritten.
150            // FIXME this drops response headers etc.
151            let bytes = body.collect().await.ok().map(|b| b.to_bytes());
152            let value: Option<Value> =
153                bytes.and_then(|bytes| serde_json::from_slice(bytes.as_ref()).ok());
154
155            if let Some(mut value) = value {
156                if let Some(object) = value.as_object_mut() {
157                    if let Some((key, links_value)) = object.remove_entry("links") {
158                        let mut links: Vec<Link> = serde_json::from_value(links_value).unwrap();
159                        links.rewrite_links(self.target(), &request_uri_without_query.to_string());
160                        object.insert(key, json!(links));
161                    };
162                };
163
164                Ok((StatusCode::OK, Json(value)).into_response())
165            } else {
166                Ok((StatusCode::OK, Json(())).into_response())
167            }
168        } else {
169            Ok(response)
170        }
171    }
172
173    /// Get the original URL of the proxied collection.
174    pub fn target(&self) -> &str {
175        self.proxy.target()
176    }
177
178    /// Get the relative URI at which this proxy is accessed.
179    pub fn path(&self) -> &str {
180        self.proxy.path()
181    }
182
183    /// Get the new id of the proxied collection
184    pub fn collection_id(&self) -> &str {
185        self.collection_id.trim_start_matches('/')
186    }
187}
188
189fn rewrite_req_to_accept_json(req: &mut axum::http::Request<Body>) {
190    req.headers_mut().insert(
191        "accept",
192        HeaderValue::from_str("application/json, application/geo+json").unwrap(),
193    );
194}
195
196impl<C> Service<axum::http::Request<Body>> for CollectionProxy<C>
197where
198    C: Connect + Clone + Send + Sync + 'static,
199{
200    type Response = axum::http::Response<axum::body::Body>;
201    type Error = Infallible;
202    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
203
204    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
205        Poll::Ready(Ok(()))
206    }
207
208    fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
209        let this = self.clone();
210        Box::pin(async move { this.handle_request(req).await })
211    }
212}