use super::client::ProxyClient;
use crate::{Error, Result};
use axum::http::{HeaderMap, Method, Uri};
use std::collections::HashMap;
pub struct ProxyHandler {
pub config: super::config::ProxyConfig,
}
impl ProxyHandler {
pub fn new(config: super::config::ProxyConfig) -> Self {
Self { config }
}
pub async fn handle_request(
&self,
method: &str,
url: &str,
headers: &HashMap<String, String>,
body: Option<&[u8]>,
) -> Result<ProxyResponse> {
if !self.config.enabled {
return Err(Error::internal("Proxy is not enabled"));
}
let reqwest_method = match method.to_uppercase().as_str() {
"GET" => Method::GET,
"POST" => Method::POST,
"PUT" => Method::PUT,
"DELETE" => Method::DELETE,
"HEAD" => Method::HEAD,
"OPTIONS" => Method::OPTIONS,
"PATCH" => Method::PATCH,
_ => return Err(Error::internal(format!("Unsupported HTTP method: {}", method))),
};
let mut request_headers = headers.clone();
for (key, value) in &self.config.headers {
request_headers.insert(key.clone(), value.clone());
}
let client = ProxyClient::new();
let response = client.send_request(reqwest_method, url, &request_headers, body).await?;
let mut response_headers = HeaderMap::new();
for (key, value) in response.headers() {
if let Ok(header_name) = axum::http::HeaderName::try_from(key.as_str()) {
response_headers.insert(header_name, value.clone());
}
}
let status_code = response.status().as_u16();
let body_bytes = response
.bytes()
.await
.map_err(|e| Error::io_with_context("response body", e.to_string()))?;
Ok(ProxyResponse {
status_code,
headers: response_headers,
body: Some(body_bytes.to_vec()),
})
}
pub async fn proxy_request(
&self,
method: &Method,
uri: &Uri,
headers: &HeaderMap,
body: Option<&[u8]>,
) -> Result<ProxyResponse> {
if !self.config.enabled {
return Err(Error::internal("Proxy is not enabled"));
}
if !self.config.should_proxy_with_condition(method, uri, headers, body) {
return Err(Error::internal("Request should not be proxied"));
}
let upstream_url = self.config.get_upstream_url(uri.path());
let mut header_map = HashMap::new();
for (key, value) in headers {
if let Ok(value_str) = value.to_str() {
header_map.insert(key.to_string(), value_str.to_string());
}
}
for (key, value) in &self.config.headers {
header_map.insert(key.clone(), value.clone());
}
let reqwest_method = match *method {
Method::GET => Method::GET,
Method::POST => Method::POST,
Method::PUT => Method::PUT,
Method::DELETE => Method::DELETE,
Method::HEAD => Method::HEAD,
Method::OPTIONS => Method::OPTIONS,
Method::PATCH => Method::PATCH,
_ => return Err(Error::internal(format!("Unsupported HTTP method: {}", method))),
};
let client = ProxyClient::new();
let response =
client.send_request(reqwest_method, &upstream_url, &header_map, body).await?;
let mut response_headers = HeaderMap::new();
for (key, value) in response.headers() {
if let Ok(header_name) = axum::http::HeaderName::try_from(key.as_str()) {
response_headers.insert(header_name, value.clone());
}
}
let status_code = response.status().as_u16();
let body_bytes = response
.bytes()
.await
.map_err(|e| Error::io_with_context("response body", e.to_string()))?;
Ok(ProxyResponse {
status_code,
headers: response_headers,
body: Some(body_bytes.to_vec()),
})
}
}
pub struct ProxyResponse {
pub status_code: u16,
pub headers: HeaderMap,
pub body: Option<Vec<u8>>,
}