cinema 0.1.0

HTTP record-replay proxy for Rust tests
Documentation
use crate::prelude::*;
use http_body_util::Full;
use std::collections::HashSet;
use std::sync::{Arc, RwLock};
use tokio::sync::oneshot::{Sender, channel};
use tokio::task::JoinHandle;

mod address;
pub use address::*;

mod route;
pub use route::*;

mod forward;
pub use forward::*;

mod server;

pub struct Proxy {
    pub address: ProxyAddress,
    shutdown_tx: Option<Sender<()>>,
    server_task: Option<JoinHandle<()>>,
    routes: ProxyRoutes,
}

pub(crate) type ProxyRoutes = Arc<RwLock<HashSet<ProxyRoute>>>;

pub type ProxyHandler = Arc<
    dyn Fn(ProxyHandlerPayload) -> Result<Option<Response<Full<Bytes>>>, CinemaError> + Send + Sync,
>;

pub struct ProxyHandlerPayload<'a> {
    pub request: &'a ProxyForwardRequestPayload<'a>,
    pub origin: &'a str,
    pub path: &'a str,
}

impl Proxy {
    pub async fn start() -> Result<Self, CinemaError> {
        Self::start_with(Arc::new(|_| Ok(None))).await
    }

    pub async fn start_with(handler: ProxyHandler) -> Result<Self, CinemaError> {
        let address = ProxyAddress::pick_unused()?;
        let listener = address.bind_tcp().await?;
        let (shutdown_tx, shutdown_rx) = channel::<()>();
        let routes = Arc::new(RwLock::new(HashSet::new()));

        let server_task = {
            let routes = Arc::clone(&routes);
            let handler = handler.clone();
            tokio::spawn(async move {
                Self::serve(routes, handler, listener, shutdown_rx).await;
            })
        };

        Ok(Self {
            address,
            shutdown_tx: Some(shutdown_tx),
            server_task: Some(server_task),
            routes,
            // handler,
        })
    }

    pub async fn stop(&mut self) {
        if let Some(tx) = self.shutdown_tx.take() {
            let _ = tx.send(());
        }
        if let Some(handle) = self.server_task.take() {
            let _ = handle.await;
        }
    }

    pub fn forward<Str: Into<String>>(&mut self, origin: Str) -> String {
        self.forward_route(ProxyRoute::new(origin.into(), None))
    }

    pub fn forward_with<Str: Into<String>>(
        &mut self,
        origin: Str,
        callback: ProxyRouteCallback,
    ) -> String {
        self.forward_route(ProxyRoute::new(origin.into(), Some(callback)))
    }

    pub fn forward_route(&mut self, route: ProxyRoute) -> String {
        let proxy_url = route.proxy_url(&self.address);

        let mut routes = self.routes.write().unwrap();
        routes.insert(route);

        proxy_url
    }

    pub fn unforward(&mut self, url: String) {
        if let Some(found_route) = self
            .routes
            .read()
            .unwrap()
            .iter()
            .find(|route| route.proxy_url(&self.address) == url)
        {
            let mut routes = self.routes.write().unwrap();
            routes.remove(&found_route);
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use hyper::body::Bytes;
    use hyper::header::HeaderValue;
    use hyper::{HeaderMap, StatusCode};
    use pretty_assertions::assert_eq;
    use reqwest::Client;

    #[tokio::test]
    async fn test_start() {
        let proxy = new_proxy().await;
        assert_eq!(proxy.address.ip(), [127, 0, 0, 1]);
    }

    #[tokio::test]
    async fn test_start_and_stop() {
        let mut proxy = new_proxy().await;
        proxy.stop().await;
    }

    #[tokio::test]
    async fn test_forward_route() {
        let mut proxy = new_proxy().await;
        let forward_url = proxy.forward("https://example.com");
        assert_eq!(
            forward_url,
            format!(
                "http://127.0.0.1:{}/https%3A%2F%2Fexample.com",
                proxy.address.port()
            )
        );
    }

    #[tokio::test]
    async fn test_forward_request() {
        let mut proxy = new_proxy().await;
        let forward_url = proxy.forward("https://example.com");

        let client = Client::new();
        let response = client
            .get(forward_url)
            .send()
            .await
            .expect("Failed to send request");
        let text = response.text().await.expect("Failed to read response");
        assert!(
            text.contains("<title>Example Domain</title>"),
            "The response was not as expected:\n\n```\n{}\n```",
            text
        );
    }

    #[tokio::test]
    async fn test_forward_callback() {
        let mut proxy = new_proxy().await;

        let captures: Arc<RwLock<Option<(StatusCode, Option<HeaderMap<HeaderValue>>, Bytes)>>> =
            Arc::new(RwLock::new(None));
        let callback_captures = Arc::clone(&captures);
        let forward_url = proxy.forward_with(
            "https://example.com",
            Arc::new(move |payload| {
                let mut lock = callback_captures.write().unwrap();
                *lock = Some((
                    payload.status,
                    payload.headers.cloned(),
                    payload.body.clone(),
                ));
                Ok(())
            }),
        );

        let client = Client::new();
        let _ = client
            .get(forward_url)
            .send()
            .await
            .expect("Failed to send request");

        let lock = captures.read().unwrap();
        let (status, headers, body) = lock.clone().unwrap();

        assert_eq!(status, 200);

        let headers = headers.expect("Headers were not captured");
        assert_eq!(
            headers.get("Content-Type"),
            Some(&HeaderValue::from_static("text/html"))
        );

        let body =
            String::from_utf8(body.to_vec()).expect("Failed to convert body bytes to string");
        assert!(
            body.contains("<title>Example Domain</title>"),
            "The response was not captured as expected:\n\n```\n{}\n```",
            body
        );
    }

    async fn new_proxy() -> Proxy {
        Proxy::start().await.expect("Failed to start proxy")
    }
}