use std::{
convert::Infallible,
fmt::Debug,
pin::Pin,
str::FromStr,
task::{Context, Poll},
};
use axum::{
Json, RequestExt,
body::Body,
http::{HeaderValue, Uri, header::CONTENT_TYPE, uri::PathAndQuery},
response::IntoResponse,
};
use axum_reverse_proxy::ReverseProxy;
use http::header::CONTENT_LENGTH;
use http_body_util::BodyExt;
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::connect::{Connect, HttpConnector};
use ogcapi_types::common::{
Link,
media_type::{GEO_JSON, JSON},
};
use serde_json::{Value, json};
use tower::Service;
use crate::{extractors::RemoteUrl, proxied_linked::ProxiedLinked};
#[derive(Clone)]
pub struct CollectionProxy<
C: Connect + Clone + Send + Sync + 'static = HttpsConnector<HttpConnector>,
> {
collection_id: String,
proxy: ReverseProxy<C>,
}
impl<C: Connect + Clone + Send + Sync + 'static> Debug for CollectionProxy<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CollectionProxy")
.field("collection_id", &self.collection_id)
.field("remote_collection_url", &self.proxy.target())
.finish()
}
}
impl CollectionProxy<HttpsConnector<HttpConnector>> {
pub fn new(path: String, remote_collection_url: String) -> Self {
Self {
collection_id: path.clone(),
proxy: ReverseProxy::new(
format!("/{}", path.trim_start_matches("/")),
remote_collection_url,
),
}
}
}
impl<C: Connect + Clone + Send + Sync + 'static> CollectionProxy<C> {
pub fn new_with_client(
path: String,
remote_collection_url: String,
client: hyper_util::client::legacy::Client<C, Body>,
) -> Self {
Self {
collection_id: path.clone(),
proxy: ReverseProxy::new_with_client(
format!("/{}", path.trim_start_matches("/")),
remote_collection_url,
client,
),
}
}
pub async fn handle_request(
&self,
mut req: axum::http::Request<Body>,
) -> Result<axum::http::Response<axum::body::Body>, Infallible> {
let proxy_uri = req.extract_parts::<RemoteUrl>().await.unwrap().0;
let mut parts = proxy_uri.into_parts();
parts.path_and_query = parts
.path_and_query
.map(|path_and_query| PathAndQuery::from_str(path_and_query.path()).unwrap());
let request_uri_without_query: Uri = parts.try_into().unwrap();
rewrite_req_to_accept_json(&mut req);
let response = self.proxy.proxy_request(req).await.unwrap();
if response.status().is_success()
&& response
.headers()
.get(CONTENT_TYPE)
.is_some_and(|ct| ct == JSON || ct == GEO_JSON)
{
let (mut parts, body) = response.into_parts();
let bytes = body.collect().await.ok().map(|b| b.to_bytes());
let value: Option<Value> =
bytes.and_then(|bytes| serde_json::from_slice(bytes.as_ref()).ok());
if let Some(mut value) = value {
if let Some(object) = value.as_object_mut() {
if let Some((key, links_value)) = object.remove_entry("links") {
let mut links: Vec<Link> = serde_json::from_value(links_value).unwrap();
links.rewrite_links(self.target(), &request_uri_without_query.to_string());
object.insert(key, json!(links));
};
};
parts.headers.remove(CONTENT_LENGTH);
Ok((parts, Json(value)).into_response())
} else {
parts.headers.insert(CONTENT_LENGTH, HeaderValue::from(0));
Ok((parts, ()).into_response())
}
} else {
Ok(response)
}
}
pub fn target(&self) -> &str {
self.proxy.target()
}
pub fn path(&self) -> &str {
self.proxy.path()
}
pub fn collection_id(&self) -> &str {
self.collection_id.trim_start_matches('/')
}
}
fn rewrite_req_to_accept_json(req: &mut axum::http::Request<Body>) {
req.headers_mut().insert(
"accept",
HeaderValue::from_str("application/json, application/geo+json").unwrap(),
);
}
impl<C> Service<axum::http::Request<Body>> for CollectionProxy<C>
where
C: Connect + Clone + Send + Sync + 'static,
{
type Response = axum::http::Response<axum::body::Body>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
let this = self.clone();
Box::pin(async move { this.handle_request(req).await })
}
}