use std::thread;
use async_net::TcpListener;
use futures_lite::future::block_on;
use futures_lite::io::{AsyncReadExt, AsyncWriteExt};
use ugi::{Client, RedirectPolicy};
fn run<T>(value: T) -> T::Output
where
T: std::future::IntoFuture,
{
block_on(async move { value.await })
}
async fn serve_sequence(responses: Vec<Vec<u8>>) -> String {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
let addr = listener.local_addr().unwrap();
thread::spawn(move || {
block_on(async move {
for response in responses {
let (mut stream, _) = listener.accept().await.unwrap();
let mut scratch = vec![0_u8; 8192];
let _ = stream.read(&mut scratch).await.unwrap();
stream.write_all(&response).await.unwrap();
stream.flush().await.unwrap();
}
});
});
format!("http://{}", addr)
}
async fn serve_once_with_inspect(
inspect: impl Fn(String) + Send + 'static,
response: Vec<u8>,
) -> String {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
let addr = listener.local_addr().unwrap();
thread::spawn(move || {
block_on(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut scratch = vec![0_u8; 8192];
let read = stream.read(&mut scratch).await.unwrap();
let request = String::from_utf8_lossy(&scratch[..read]).to_string();
inspect(request);
stream.write_all(&response).await.unwrap();
stream.flush().await.unwrap();
});
});
format!("http://{}", addr)
}
#[test]
fn redirect_301_changes_post_to_get() {
use httpmock::prelude::*;
let server = MockServer::start();
let origin = server.mock(|when, then| {
when.method(POST).path("/origin");
then.status(301).header("location", "/final");
});
let final_mock = server.mock(|when, then| {
when.method(GET).path("/final");
then.status(200).body("ok-get");
});
let client = Client::builder().build().unwrap();
let response = run(client.post(server.url("/origin")).text("body").unwrap()).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok-get");
origin.assert();
final_mock.assert();
}
#[test]
fn redirect_302_changes_post_to_get() {
use httpmock::prelude::*;
let server = MockServer::start();
let origin = server.mock(|when, then| {
when.method(POST).path("/origin");
then.status(302).header("location", "/final");
});
let final_mock = server.mock(|when, then| {
when.method(GET).path("/final");
then.status(200).body("ok-get");
});
let client = Client::builder().build().unwrap();
let response = run(client.post(server.url("/origin")).text("body").unwrap()).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok-get");
origin.assert();
final_mock.assert();
}
#[test]
fn redirect_303_always_becomes_get() {
use httpmock::prelude::*;
let server = MockServer::start();
let origin = server.mock(|when, then| {
when.method(POST).path("/origin");
then.status(303).header("location", "/final");
});
let final_mock = server.mock(|when, then| {
when.method(GET).path("/final");
then.status(200).body("ok-303");
});
let client = Client::builder().build().unwrap();
let response = run(client.post(server.url("/origin")).text("payload").unwrap()).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok-303");
origin.assert();
final_mock.assert();
}
#[test]
fn redirect_307_preserves_post_method_and_body() {
use httpmock::prelude::*;
let server = MockServer::start();
let origin = server.mock(|when, then| {
when.method(POST).path("/origin");
then.status(307).header("location", "/final");
});
let final_mock = server.mock(|when, then| {
when.method(POST).path("/final").body("payload");
then.status(200).body("ok-307");
});
let client = Client::builder().build().unwrap();
let response = run(client.post(server.url("/origin")).text("payload").unwrap()).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok-307");
origin.assert();
final_mock.assert();
}
#[test]
fn redirect_308_preserves_post_method_and_body() {
use httpmock::prelude::*;
let server = MockServer::start();
let origin = server.mock(|when, then| {
when.method(POST).path("/origin");
then.status(308).header("location", "/final");
});
let final_mock = server.mock(|when, then| {
when.method(POST).path("/final").body("payload");
then.status(200).body("ok-308");
});
let client = Client::builder().build().unwrap();
let response = run(client.post(server.url("/origin")).text("payload").unwrap()).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok-308");
origin.assert();
final_mock.assert();
}
#[test]
fn cross_origin_redirect_strips_cookie_header() {
let target_base = block_on(serve_once_with_inspect(
|request| {
assert!(
!request
.lines()
.any(|line| line.to_lowercase().starts_with("cookie:")),
"cookie header must be stripped on cross-origin redirect; got request:\n{request}"
);
},
b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok".to_vec(),
));
let redirect_response =
format!("HTTP/1.1 302 Found\r\nLocation: {target_base}/final\r\nContent-Length: 0\r\n\r\n");
let redirect_base = block_on(serve_sequence(vec![redirect_response.into_bytes()]));
let client = Client::builder()
.cookie("session", "secret-value")
.build()
.unwrap();
let response = run(client.get(format!("{redirect_base}/start"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn cross_origin_redirect_strips_authorization_and_cookie() {
let target_base = block_on(serve_once_with_inspect(
|request| {
let lower = request.to_lowercase();
assert!(
!lower.contains("\r\nauthorization:"),
"authorization must be stripped; got:\n{request}"
);
assert!(
!lower.contains("\r\ncookie:"),
"cookie must be stripped; got:\n{request}"
);
},
b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok".to_vec(),
));
let redirect_response =
format!("HTTP/1.1 302 Found\r\nLocation: {target_base}/final\r\nContent-Length: 0\r\n\r\n");
let redirect_base = block_on(serve_sequence(vec![redirect_response.into_bytes()]));
let client = Client::builder()
.bearer_auth("top-secret-token")
.unwrap()
.cookie("sid", "abc123")
.build()
.unwrap();
let response = run(client.get(format!("{redirect_base}/start"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn same_origin_redirect_preserves_cookie_header() {
use httpmock::prelude::*;
let server = MockServer::start();
let start = server.mock(|when, then| {
when.method(GET).path("/start").header("cookie", "sid=abc");
then.status(302).header("location", "/final");
});
let final_mock = server.mock(|when, then| {
when.method(GET).path("/final").header("cookie", "sid=abc");
then.status(200).body("ok-same-origin");
});
let client = Client::builder().cookie("sid", "abc").build().unwrap();
let response = run(client.get(server.url("/start"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok-same-origin");
start.assert();
final_mock.assert();
}
#[test]
fn cookie_store_does_not_leak_across_cross_origin_redirect() {
let target_base = block_on(serve_once_with_inspect(
|request| {
assert!(
!request
.lines()
.any(|line| line.to_lowercase().starts_with("cookie:")),
"cookie store must not send server-A cookies to cross-origin server B; got:\n{request}"
);
},
b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\ndone".to_vec(),
));
let redirect_response = format!(
"HTTP/1.1 302 Found\r\nSet-Cookie: server_a_cookie=leaked; Path=/\r\nLocation: {target_base}/final\r\nContent-Length: 0\r\n\r\n"
);
let redirect_base = block_on(serve_sequence(vec![redirect_response.into_bytes()]));
let client = Client::builder().cookie_store().build().unwrap();
let response = run(client.get(format!("{redirect_base}/start"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "done");
}
#[test]
fn redirect_none_policy_does_not_follow_redirects() {
use httpmock::prelude::*;
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(GET).path("/redirect");
then.status(302).header("location", "/never-reached");
});
let client = Client::builder()
.redirect(RedirectPolicy::None)
.build()
.unwrap();
let response = run(client.get(server.url("/redirect"))).unwrap();
assert_eq!(response.status().as_u16(), 302);
mock.assert();
}
#[test]
fn redirect_limit_policy_stops_at_given_hop_count() {
let third = block_on(serve_sequence(vec![
b"HTTP/1.1 302 Found\r\nLocation: /x\r\nContent-Length: 0\r\n\r\n".to_vec(),
]));
let second_response =
format!("HTTP/1.1 302 Found\r\nLocation: {third}/c\r\nContent-Length: 0\r\n\r\n");
let second = block_on(serve_sequence(vec![second_response.into_bytes()]));
let first_response =
format!("HTTP/1.1 302 Found\r\nLocation: {second}/b\r\nContent-Length: 0\r\n\r\n");
let first = block_on(serve_sequence(vec![first_response.into_bytes()]));
let client = Client::builder()
.redirect(RedirectPolicy::Limit(2))
.build()
.unwrap();
let result = run(client.get(format!("{first}/a")));
assert!(
result.is_err(),
"expected redirect limit error, got: {:?}",
result
);
assert_eq!(result.unwrap_err().kind(), &ugi::ErrorKind::Transport);
}