ogcapi_proxy/
collection_proxy.rs

1use std::{
2    convert::Infallible,
3    fmt::Debug,
4    pin::Pin,
5    str::FromStr,
6    task::{Context, Poll},
7};
8
9use axum::{
10    Json, RequestExt,
11    body::Body,
12    http::{HeaderValue, Uri, header::CONTENT_TYPE, uri::PathAndQuery},
13    response::IntoResponse,
14};
15use axum_reverse_proxy::ReverseProxy;
16use http::header::CONTENT_LENGTH;
17use http_body_util::BodyExt;
18use hyper_rustls::HttpsConnector;
19use hyper_util::client::legacy::connect::{Connect, HttpConnector};
20use ogcapi_types::common::{
21    Link,
22    media_type::{GEO_JSON, JSON},
23};
24use serde_json::{Value, json};
25use tower::Service;
26
27use crate::{extractors::RemoteUrl, proxied_linked::ProxiedLinked};
28
29/// Proxy for one individual OGC API Collection and its sub-endpoints.
30///
31/// Links in JSON responses are rewriten to match the proxy URL.
32/// ```rust
33/// # use ogcapi_proxy::CollectionProxy;
34/// # use ogcapi_types::common::Collection;
35/// # use ogcapi_types::features::FeatureCollection;
36/// # use ogcapi_types::common::Linked;
37/// # use http::request::Request;
38/// # use axum::body::Body;
39/// # use axum::body::to_bytes;
40/// # use tower::Service;
41/// #
42/// # tokio_test::block_on(async {
43/// #
44/// let mut collection_proxy = CollectionProxy::new(
45///     "/collections/proxied-lakes".to_string(),
46///     "https://demo.pygeoapi.io/stable/collections/lakes".to_string(),
47/// );
48///
49/// let req = Request::builder()
50///     .uri("/collections/proxied-lakes")
51///     .header("Host", "test-host.example.org")
52///     .body(Body::empty())
53///     .unwrap();
54///
55/// let res = collection_proxy.call(req).await
56///     .expect("Proxied request to demo.pygeoapi.io might fail if that api is not available");
57///
58/// assert!(res.status().is_success());
59///
60/// let mut collection: Collection = serde_json::from_slice(
61///     &to_bytes(res.into_body(), 200_000).await.unwrap()
62/// ).unwrap();
63///
64/// assert!(
65///     collection.links.get_base_url().as_ref().is_some_and(|base_url| {
66///         base_url.as_str().starts_with("http://test-host.example.org/collections/proxied-lakes")
67///     })
68/// );
69/// let items_url = collection.links.iter().find(|link| link.rel == "items")
70///     .expect("Collection is expected to have an items link");
71///
72/// assert!(
73///     items_url.href.starts_with("http://test-host.example.org/collections/proxied-lakes/items")
74/// );
75/// # })
76/// ```
77#[derive(Clone)]
78pub struct CollectionProxy<
79    C: Connect + Clone + Send + Sync + 'static = HttpsConnector<HttpConnector>,
80> {
81    collection_id: String,
82    proxy: ReverseProxy<C>,
83}
84
85impl<C: Connect + Clone + Send + Sync + 'static> Debug for CollectionProxy<C> {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.debug_struct("CollectionProxy")
88            .field("collection_id", &self.collection_id)
89            .field("remote_collection_url", &self.proxy.target())
90            .finish()
91    }
92}
93
94impl CollectionProxy<HttpsConnector<HttpConnector>> {
95    pub fn new(path: String, remote_collection_url: String) -> Self {
96        Self {
97            // FIXME this should distinguish between collection_id and path
98            // although this is the same most of the time to enable rewriting
99            // of collection ids in json responses
100            collection_id: path.clone(),
101            proxy: ReverseProxy::new(
102                format!("/{}", path.trim_start_matches("/")),
103                remote_collection_url,
104            ),
105        }
106    }
107}
108
109impl<C: Connect + Clone + Send + Sync + 'static> CollectionProxy<C> {
110    pub fn new_with_client(
111        path: String,
112        remote_collection_url: String,
113        client: hyper_util::client::legacy::Client<C, Body>,
114    ) -> Self {
115        Self {
116            collection_id: path.clone(),
117            proxy: ReverseProxy::new_with_client(
118                format!("/{}", path.trim_start_matches("/")),
119                remote_collection_url,
120                client,
121            ),
122        }
123    }
124
125    pub async fn handle_request(
126        &self,
127        mut req: axum::http::Request<Body>,
128    ) -> Result<axum::http::Response<axum::body::Body>, Infallible> {
129        let proxy_uri = req.extract_parts::<RemoteUrl>().await.unwrap().0;
130        let mut parts = proxy_uri.into_parts();
131        parts.path_and_query = parts
132            .path_and_query
133            .map(|path_and_query| PathAndQuery::from_str(path_and_query.path()).unwrap());
134        let request_uri_without_query: Uri = parts.try_into().unwrap();
135
136        // modify request to accept json, as we currently can only rewrite links in json
137        rewrite_req_to_accept_json(&mut req);
138
139        // Unwrap is safe as proxy_request is Infallible
140        let response = self.proxy.proxy_request(req).await.unwrap();
141
142        if response.status().is_success()
143            // && proxy_uri.path() == self.proxy.path()
144            && response
145                .headers()
146                .get(CONTENT_TYPE)
147                .is_some_and(|ct| ct == JSON || ct == GEO_JSON)
148        {
149            let (mut parts, body) = response.into_parts();
150            // FIXME this silently drops the response body if it can't be rewritten.
151            // FIXME this drops response headers etc.
152            let bytes = body.collect().await.ok().map(|b| b.to_bytes());
153            let value: Option<Value> =
154                bytes.and_then(|bytes| serde_json::from_slice(bytes.as_ref()).ok());
155
156            if let Some(mut value) = value {
157                if let Some(object) = value.as_object_mut() {
158                    if let Some((key, links_value)) = object.remove_entry("links") {
159                        let mut links: Vec<Link> = serde_json::from_value(links_value).unwrap();
160                        links.rewrite_links(self.target(), &request_uri_without_query.to_string());
161                        object.insert(key, json!(links));
162                    };
163                };
164
165                parts.headers.remove(CONTENT_LENGTH);
166                Ok((parts, Json(value)).into_response())
167            } else {
168                parts.headers.insert(CONTENT_LENGTH, HeaderValue::from(0));
169                Ok((parts, ()).into_response())
170            }
171        } else {
172            Ok(response)
173        }
174    }
175
176    /// Get the original URL of the proxied collection.
177    pub fn target(&self) -> &str {
178        self.proxy.target()
179    }
180
181    /// Get the relative URI at which this proxy is accessed.
182    pub fn path(&self) -> &str {
183        self.proxy.path()
184    }
185
186    /// Get the new id of the proxied collection
187    pub fn collection_id(&self) -> &str {
188        self.collection_id.trim_start_matches('/')
189    }
190}
191
192fn rewrite_req_to_accept_json(req: &mut axum::http::Request<Body>) {
193    req.headers_mut().insert(
194        "accept",
195        HeaderValue::from_str("application/json, application/geo+json").unwrap(),
196    );
197}
198
199impl<C> Service<axum::http::Request<Body>> for CollectionProxy<C>
200where
201    C: Connect + Clone + Send + Sync + 'static,
202{
203    type Response = axum::http::Response<axum::body::Body>;
204    type Error = Infallible;
205    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
206
207    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
208        Poll::Ready(Ok(()))
209    }
210
211    fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
212        let this = self.clone();
213        Box::pin(async move { this.handle_request(req).await })
214    }
215}