use std::io::{Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, Instant};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use rustls::{ServerConfig, ServerConnection, StreamOwned};
const READ_TIMEOUT: Duration = Duration::from_secs(5);
const REQUEST_READ_DEADLINE: Duration = Duration::from_secs(5);
pub fn spawn_oneshot_https_responder(responses: Vec<&'static str>) -> (SocketAddr, Arc<AtomicU32>) {
install_default_provider_once();
let server_config = Arc::new(build_self_signed_server_config());
let listener = TcpListener::bind("127.0.0.1:0").expect("bind ephemeral port");
let addr = listener.local_addr().expect("local_addr");
let counter = Arc::new(AtomicU32::new(0));
let counter_inner = counter.clone();
std::thread::spawn(move || {
for resp in responses.iter() {
let (stream, _) = match listener.accept() {
Ok(pair) => pair,
Err(_) => return,
};
counter_inner.fetch_add(1, Ordering::SeqCst);
serve_one_tls(stream, server_config.clone(), resp);
}
let _ = listener.set_nonblocking(true);
let drain_deadline = Instant::now() + Duration::from_millis(250);
while Instant::now() < drain_deadline {
match listener.accept() {
Ok((stream, _)) => {
let _ = stream.set_nonblocking(false);
serve_one_tls(
stream,
server_config.clone(),
"HTTP/1.1 503 Service Unavailable\r\nContent-Length: 0\r\n\r\n",
);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::sleep(Duration::from_millis(10));
}
Err(_) => break,
}
}
});
(addr, counter)
}
pub fn https_test_client() -> reqwest::Client {
reqwest::ClientBuilder::new()
.danger_accept_invalid_certs(true)
.build()
.expect("build reqwest::Client with invalid-certs override")
}
fn build_self_signed_server_config() -> ServerConfig {
let subject_alt_names = vec!["127.0.0.1".to_string(), "localhost".to_string()];
let key_pair =
rcgen::generate_simple_self_signed(subject_alt_names).expect("generate self-signed cert");
let cert_der: CertificateDer<'static> = key_pair.cert.der().clone();
let key_der: PrivateKeyDer<'static> =
PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.key_pair.serialize_der()));
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.expect("build rustls::ServerConfig from self-signed cert")
}
fn install_default_provider_once() {
use std::sync::Once;
static ONCE: Once = Once::new();
ONCE.call_once(|| {
let _ = rustls::crypto::ring::default_provider().install_default();
});
}
fn serve_one_tls(stream: TcpStream, config: Arc<ServerConfig>, resp: &str) {
let _ = stream.set_read_timeout(Some(READ_TIMEOUT));
let _ = stream.set_write_timeout(Some(READ_TIMEOUT));
let conn = match ServerConnection::new(config) {
Ok(c) => c,
Err(_) => return,
};
let mut tls = StreamOwned::new(conn, stream);
consume_request(&mut tls);
let _ = tls.write_all(resp.as_bytes());
let _ = tls.flush();
tls.conn.send_close_notify();
let _ = tls.conn.write_tls(&mut tls.sock);
let _ = tls.sock.shutdown(std::net::Shutdown::Both);
}
fn consume_request<R: Read>(stream: &mut R) {
let deadline = Instant::now() + REQUEST_READ_DEADLINE;
let mut accum: Vec<u8> = Vec::with_capacity(8 * 1024);
let mut chunk = [0u8; 8 * 1024];
let header_end = loop {
if Instant::now() >= deadline {
return;
}
match stream.read(&mut chunk) {
Ok(0) => return,
Ok(n) => {
accum.extend_from_slice(&chunk[..n]);
if let Some(pos) = find_double_crlf(&accum) {
break pos + 4;
}
if accum.len() > 1 << 20 {
return;
}
}
Err(_) => return,
}
};
let content_length = parse_content_length(&accum[..header_end]);
let already_have = accum.len() - header_end;
let Some(total_body) = content_length else {
return;
};
if already_have >= total_body {
return;
}
let mut remaining = total_body - already_have;
while remaining > 0 {
if Instant::now() >= deadline {
return;
}
let want = remaining.min(chunk.len());
match stream.read(&mut chunk[..want]) {
Ok(0) => return,
Ok(n) => {
remaining -= n;
}
Err(_) => return,
}
}
}
fn find_double_crlf(buf: &[u8]) -> Option<usize> {
buf.windows(4).position(|w| w == b"\r\n\r\n")
}
fn parse_content_length(header_block: &[u8]) -> Option<usize> {
let as_str = std::str::from_utf8(header_block).ok()?;
for line in as_str.split("\r\n") {
let Some((name, value)) = line.split_once(':') else {
continue;
};
if name.trim().eq_ignore_ascii_case("content-length") {
return value.trim().parse::<usize>().ok();
}
}
None
}
#[cfg(test)]
mod self_tests {
use super::*;
#[tokio::test]
async fn https_responder_serves_canned_response() {
let canned = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 16\r\n\r\n{\"remaining\":99}";
let (addr, calls) = spawn_oneshot_https_responder(vec![canned]);
let client = https_test_client();
let url = format!("https://{}/rate_limit", addr);
let resp = client.get(&url).send().await.expect("send request");
assert_eq!(resp.status(), 200);
let body = resp.text().await.expect("read body");
assert_eq!(body, r#"{"remaining":99}"#);
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn https_responder_serves_multiple_responses_in_order() {
let canned_a = "HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nA";
let canned_b = "HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nB";
let (addr, calls) = spawn_oneshot_https_responder(vec![canned_a, canned_b]);
let client = https_test_client();
let url = format!("https://{}/x", addr);
let r1 = client.get(&url).send().await.expect("req 1");
let b1 = r1.text().await.expect("body 1");
let r2 = client.get(&url).send().await.expect("req 2");
let b2 = r2.text().await.expect("body 2");
assert_eq!(b1, "A");
assert_eq!(b2, "B");
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
}