ogcapi_proxy/
collection_proxy.rs1use std::{
2 convert::Infallible,
3 fmt::Debug,
4 pin::Pin,
5 str::FromStr,
6 task::{Context, Poll},
7};
8
9use axum::{
10 Json, RequestExt,
11 body::Body,
12 http::{HeaderValue, Uri, header::CONTENT_TYPE, uri::PathAndQuery},
13 response::IntoResponse,
14};
15use axum_reverse_proxy::ReverseProxy;
16use http::header::CONTENT_LENGTH;
17use http_body_util::BodyExt;
18use hyper_rustls::HttpsConnector;
19use hyper_util::client::legacy::connect::{Connect, HttpConnector};
20use ogcapi_types::common::{
21 Link,
22 media_type::{GEO_JSON, JSON},
23};
24use serde_json::{Value, json};
25use tower::Service;
26
27use crate::{extractors::RemoteUrl, proxied_linked::ProxiedLinked};
28
29#[derive(Clone)]
78pub struct CollectionProxy<
79 C: Connect + Clone + Send + Sync + 'static = HttpsConnector<HttpConnector>,
80> {
81 collection_id: String,
82 proxy: ReverseProxy<C>,
83}
84
85impl<C: Connect + Clone + Send + Sync + 'static> Debug for CollectionProxy<C> {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("CollectionProxy")
88 .field("collection_id", &self.collection_id)
89 .field("remote_collection_url", &self.proxy.target())
90 .finish()
91 }
92}
93
94impl CollectionProxy<HttpsConnector<HttpConnector>> {
95 pub fn new(path: String, remote_collection_url: String) -> Self {
96 Self {
97 collection_id: path.clone(),
101 proxy: ReverseProxy::new(
102 format!("/{}", path.trim_start_matches("/")),
103 remote_collection_url,
104 ),
105 }
106 }
107}
108
109impl<C: Connect + Clone + Send + Sync + 'static> CollectionProxy<C> {
110 pub fn new_with_client(
111 path: String,
112 remote_collection_url: String,
113 client: hyper_util::client::legacy::Client<C, Body>,
114 ) -> Self {
115 Self {
116 collection_id: path.clone(),
117 proxy: ReverseProxy::new_with_client(
118 format!("/{}", path.trim_start_matches("/")),
119 remote_collection_url,
120 client,
121 ),
122 }
123 }
124
125 pub async fn handle_request(
126 &self,
127 mut req: axum::http::Request<Body>,
128 ) -> Result<axum::http::Response<axum::body::Body>, Infallible> {
129 let proxy_uri = req.extract_parts::<RemoteUrl>().await.unwrap().0;
130 let mut parts = proxy_uri.into_parts();
131 parts.path_and_query = parts
132 .path_and_query
133 .map(|path_and_query| PathAndQuery::from_str(path_and_query.path()).unwrap());
134 let request_uri_without_query: Uri = parts.try_into().unwrap();
135
136 rewrite_req_to_accept_json(&mut req);
138
139 let response = self.proxy.proxy_request(req).await.unwrap();
141
142 if response.status().is_success()
143 && response
145 .headers()
146 .get(CONTENT_TYPE)
147 .is_some_and(|ct| ct == JSON || ct == GEO_JSON)
148 {
149 let (mut parts, body) = response.into_parts();
150 let bytes = body.collect().await.ok().map(|b| b.to_bytes());
153 let value: Option<Value> =
154 bytes.and_then(|bytes| serde_json::from_slice(bytes.as_ref()).ok());
155
156 if let Some(mut value) = value {
157 if let Some(object) = value.as_object_mut() {
158 if let Some((key, links_value)) = object.remove_entry("links") {
159 let mut links: Vec<Link> = serde_json::from_value(links_value).unwrap();
160 links.rewrite_links(self.target(), &request_uri_without_query.to_string());
161 object.insert(key, json!(links));
162 };
163 };
164
165 parts.headers.remove(CONTENT_LENGTH);
166 Ok((parts, Json(value)).into_response())
167 } else {
168 parts.headers.insert(CONTENT_LENGTH, HeaderValue::from(0));
169 Ok((parts, ()).into_response())
170 }
171 } else {
172 Ok(response)
173 }
174 }
175
176 pub fn target(&self) -> &str {
178 self.proxy.target()
179 }
180
181 pub fn path(&self) -> &str {
183 self.proxy.path()
184 }
185
186 pub fn collection_id(&self) -> &str {
188 self.collection_id.trim_start_matches('/')
189 }
190}
191
192fn rewrite_req_to_accept_json(req: &mut axum::http::Request<Body>) {
193 req.headers_mut().insert(
194 "accept",
195 HeaderValue::from_str("application/json, application/geo+json").unwrap(),
196 );
197}
198
199impl<C> Service<axum::http::Request<Body>> for CollectionProxy<C>
200where
201 C: Connect + Clone + Send + Sync + 'static,
202{
203 type Response = axum::http::Response<axum::body::Body>;
204 type Error = Infallible;
205 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
206
207 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
208 Poll::Ready(Ok(()))
209 }
210
211 fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
212 let this = self.clone();
213 Box::pin(async move { this.handle_request(req).await })
214 }
215}