use crate::prelude::*;
use http_body_util::Full;
use std::collections::HashSet;
use std::sync::{Arc, RwLock};
use tokio::sync::oneshot::{Sender, channel};
use tokio::task::JoinHandle;
mod address;
pub use address::*;
mod route;
pub use route::*;
mod forward;
pub use forward::*;
mod server;
pub struct Proxy {
pub address: ProxyAddress,
shutdown_tx: Option<Sender<()>>,
server_task: Option<JoinHandle<()>>,
routes: ProxyRoutes,
}
pub(crate) type ProxyRoutes = Arc<RwLock<HashSet<ProxyRoute>>>;
pub type ProxyHandler = Arc<
dyn Fn(ProxyHandlerPayload) -> Result<Option<Response<Full<Bytes>>>, CinemaError> + Send + Sync,
>;
pub struct ProxyHandlerPayload<'a> {
pub request: &'a ProxyForwardRequestPayload<'a>,
pub origin: &'a str,
pub path: &'a str,
}
impl Proxy {
pub async fn start() -> Result<Self, CinemaError> {
Self::start_with(Arc::new(|_| Ok(None))).await
}
pub async fn start_with(handler: ProxyHandler) -> Result<Self, CinemaError> {
let address = ProxyAddress::pick_unused()?;
let listener = address.bind_tcp().await?;
let (shutdown_tx, shutdown_rx) = channel::<()>();
let routes = Arc::new(RwLock::new(HashSet::new()));
let server_task = {
let routes = Arc::clone(&routes);
let handler = handler.clone();
tokio::spawn(async move {
Self::serve(routes, handler, listener, shutdown_rx).await;
})
};
Ok(Self {
address,
shutdown_tx: Some(shutdown_tx),
server_task: Some(server_task),
routes,
})
}
pub async fn stop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(handle) = self.server_task.take() {
let _ = handle.await;
}
}
pub fn forward<Str: Into<String>>(&mut self, origin: Str) -> String {
self.forward_route(ProxyRoute::new(origin.into(), None))
}
pub fn forward_with<Str: Into<String>>(
&mut self,
origin: Str,
callback: ProxyRouteCallback,
) -> String {
self.forward_route(ProxyRoute::new(origin.into(), Some(callback)))
}
pub fn forward_route(&mut self, route: ProxyRoute) -> String {
let proxy_url = route.proxy_url(&self.address);
let mut routes = self.routes.write().unwrap();
routes.insert(route);
proxy_url
}
pub fn unforward(&mut self, url: String) {
if let Some(found_route) = self
.routes
.read()
.unwrap()
.iter()
.find(|route| route.proxy_url(&self.address) == url)
{
let mut routes = self.routes.write().unwrap();
routes.remove(&found_route);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use hyper::body::Bytes;
use hyper::header::HeaderValue;
use hyper::{HeaderMap, StatusCode};
use pretty_assertions::assert_eq;
use reqwest::Client;
#[tokio::test]
async fn test_start() {
let proxy = new_proxy().await;
assert_eq!(proxy.address.ip(), [127, 0, 0, 1]);
}
#[tokio::test]
async fn test_start_and_stop() {
let mut proxy = new_proxy().await;
proxy.stop().await;
}
#[tokio::test]
async fn test_forward_route() {
let mut proxy = new_proxy().await;
let forward_url = proxy.forward("https://example.com");
assert_eq!(
forward_url,
format!(
"http://127.0.0.1:{}/https%3A%2F%2Fexample.com",
proxy.address.port()
)
);
}
#[tokio::test]
async fn test_forward_request() {
let mut proxy = new_proxy().await;
let forward_url = proxy.forward("https://example.com");
let client = Client::new();
let response = client
.get(forward_url)
.send()
.await
.expect("Failed to send request");
let text = response.text().await.expect("Failed to read response");
assert!(
text.contains("<title>Example Domain</title>"),
"The response was not as expected:\n\n```\n{}\n```",
text
);
}
#[tokio::test]
async fn test_forward_callback() {
let mut proxy = new_proxy().await;
let captures: Arc<RwLock<Option<(StatusCode, Option<HeaderMap<HeaderValue>>, Bytes)>>> =
Arc::new(RwLock::new(None));
let callback_captures = Arc::clone(&captures);
let forward_url = proxy.forward_with(
"https://example.com",
Arc::new(move |payload| {
let mut lock = callback_captures.write().unwrap();
*lock = Some((
payload.status,
payload.headers.cloned(),
payload.body.clone(),
));
Ok(())
}),
);
let client = Client::new();
let _ = client
.get(forward_url)
.send()
.await
.expect("Failed to send request");
let lock = captures.read().unwrap();
let (status, headers, body) = lock.clone().unwrap();
assert_eq!(status, 200);
let headers = headers.expect("Headers were not captured");
assert_eq!(
headers.get("Content-Type"),
Some(&HeaderValue::from_static("text/html"))
);
let body =
String::from_utf8(body.to_vec()).expect("Failed to convert body bytes to string");
assert!(
body.contains("<title>Example Domain</title>"),
"The response was not captured as expected:\n\n```\n{}\n```",
body
);
}
async fn new_proxy() -> Proxy {
Proxy::start().await.expect("Failed to start proxy")
}
}