#![cfg(test)]
use core::{net::SocketAddr, time::Duration};
use mktemp::Temp;
use openssl::ssl::{Ssl, SslAcceptor, SslFiletype, SslMethod, SslVerifyMode};
use rcgen::{CertifiedKey, KeyPair};
use std::{collections::HashMap, io::ErrorKind, sync::Arc};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
runtime::Runtime,
sync::RwLock,
};
use tokio_openssl::SslStream;
pub(crate) static FRIENDLY: &[u8] = b"20 text/gemini; charset=utf-8; lang=en\r\n:3\n";
pub(crate) fn new_runtime(thread_name: &'static str) -> Runtime {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.thread_name(thread_name)
.enable_all()
.build()
.unwrap()
}
pub(crate) fn new_simple_runtime(thread_name: &'static str) -> Runtime {
tokio::runtime::Builder::new_current_thread()
.thread_name(thread_name)
.enable_all()
.build()
.unwrap()
}
pub(crate) fn start_unfriendly_server_no_close_notify() -> ServerHandle {
let mut handlers = Responses::new();
handlers.insert("/", Response::Immediate(FRIENDLY.to_vec()));
start_gemini_server_with_certs(false, handlers)
}
pub(crate) fn start_unfriendly_server_slow() -> ServerHandle {
let mut handlers = Responses::new();
let timeout = Duration::from_secs(5);
handlers.insert("/", Response::Wait(timeout, FRIENDLY.to_vec()));
start_gemini_server_with_certs(false, handlers)
}
pub(crate) fn start_friendly_server(key: CertifiedKey<KeyPair>) -> ServerHandle {
let mut handlers = Responses::new();
handlers.insert("/", Response::Immediate(FRIENDLY.to_vec()));
start_gemini_server_with_key(key, true, handlers)
}
pub(crate) fn start_redir_server(target: &'static str) -> ServerHandle {
let mut handlers = Responses::new();
handlers.insert(
"/",
Response::Immediate([b"30 ", target.as_bytes(), b"\r\n"].concat()),
);
handlers.insert("/hello", Response::Immediate(FRIENDLY.to_vec()));
start_gemini_server_with_certs(true, handlers)
}
fn start_gemini_server_with_certs(send_close_notify: bool, handlers: Responses) -> ServerHandle {
let key = rcgen::generate_simple_self_signed(["localhost".into()]).unwrap();
start_gemini_server_with_key(key, send_close_notify, handlers)
}
fn start_gemini_server_with_key(
key: CertifiedKey<KeyPair>,
send_close_notify: bool,
handlers: Responses,
) -> ServerHandle {
let certs_dir = mktemp::Temp::new_dir().unwrap();
let key_pem = certs_dir.join("key.pem");
let cert_pem = certs_dir.join("cert.pem");
std::fs::write(cert_pem, key.cert.pem()).unwrap();
std::fs::write(key_pem, key.signing_key.serialize_pem()).unwrap();
start_gemini_server(certs_dir, send_close_notify, handlers)
}
fn start_gemini_server(
certs_dir: Temp,
send_close_notify: bool,
responses: Responses,
) -> ServerHandle {
let listener = std::net::TcpListener::bind("[::]:0").unwrap();
listener.set_nonblocking(true).unwrap();
let addr = listener.local_addr().unwrap();
let key_pem = certs_dir.join("key.pem");
let cert_pem = certs_dir.join("cert.pem");
let runtime = new_runtime("test-server-runtime-worker");
let requests = Arc::new(RwLock::new(0));
let requests_c = requests.clone();
runtime.spawn(async move {
let listener = tokio::net::TcpListener::from_std(listener).unwrap();
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls_server()).unwrap();
builder
.set_private_key_file(key_pem, SslFiletype::PEM)
.unwrap();
builder
.set_certificate_file(cert_pem, SslFiletype::PEM)
.unwrap();
builder.check_private_key().unwrap();
builder.set_verify_callback(SslVerifyMode::PEER, |_, _| true);
builder
.set_session_id_context(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
.to_string()
.as_bytes(),
)
.unwrap();
let acceptor = builder.build();
let requests = requests_c.clone();
loop {
let Ok((stream, _)) = listener.accept().await else {
continue;
};
let Ok(ssl) = Ssl::new(acceptor.context()) else {
continue;
};
let Ok(mut stream) = SslStream::new(ssl, stream) else {
continue;
};
let handlers = responses.to_owned();
let requests_c = requests.clone();
tokio::spawn(async move {
std::pin::Pin::new(&mut stream).accept().await.unwrap();
{
let mut r = requests_c.write().await;
*r += 1
}
let mut input = Vec::with_capacity(1024); let mut previous = b' ';
for _ in 0..1020 {
let current = match stream.read_u8().await {
Err(e) if e.kind() == ErrorKind::UnexpectedEof => {
let res = b"59 not utf-8\r\n";
stream.write_all(res).await.unwrap();
if send_close_notify {
stream.shutdown().await.unwrap();
}
return;
}
Err(e) => {
eprintln!("failed to read buffer: {e}");
let res = b"50 internal error\r\n";
if let Err(e) = stream.write_all(res).await {
eprintln!("failed to send 59: {e}");
}
if send_close_notify {
stream.shutdown().await.unwrap();
}
return;
}
Ok(b) => b,
};
if previous == b'\r' && current == b'\n' {
input.pop(); break;
}
input.push(current);
previous = current;
}
let input = match str::from_utf8(&input) {
Err(_) => {
let res = b"59 not utf-8\r\n";
stream.write_all(res).await.unwrap();
if send_close_notify {
stream.shutdown().await.unwrap();
}
return;
}
Ok(url) => url,
};
eprintln!("got request: {input}");
let url = match url::Url::parse(input) {
Err(_) => {
let res = b"59 not a URL\r\n";
stream.write_all(res).await.unwrap();
if send_close_notify {
stream.shutdown().await.unwrap();
}
return;
}
Ok(url) => url,
};
let res = match handlers.get(url.path()) {
None => {
let res = b"51 not found\r\n";
stream.write_all(res).await.unwrap();
if send_close_notify {
stream.shutdown().await.unwrap();
}
return;
}
Some(h) => h,
};
let res = match res {
Response::Immediate(bytes) => bytes.as_slice(),
Response::Wait(timeout, bytes) => {
tokio::time::sleep(*timeout).await;
bytes.as_slice()
}
};
stream.write_all(res).await.unwrap();
if send_close_notify {
stream.shutdown().await.unwrap();
}
});
}
});
std::thread::sleep(std::time::Duration::from_millis(50));
ServerHandle::new(requests, addr, certs_dir, runtime)
}
pub(crate) struct ServerHandle {
requests: Arc<RwLock<u8>>,
addr: SocketAddr,
_certs: Temp,
_runtime: Runtime,
}
impl ServerHandle {
const fn new(
requests: Arc<RwLock<u8>>,
addr: SocketAddr,
_certs: Temp,
_runtime: Runtime,
) -> Self {
Self {
requests,
addr,
_certs,
_runtime,
}
}
pub(crate) const fn addr(&self) -> SocketAddr {
self.addr
}
pub(crate) fn request_count(&self) -> u8 {
*self.requests.blocking_read()
}
}
#[derive(Clone)]
enum Response {
Immediate(Vec<u8>),
Wait(Duration, Vec<u8>),
}
#[derive(Clone)]
struct Responses(HashMap<&'static str, Response>);
impl Responses {
fn new() -> Self {
Self(HashMap::with_capacity(2))
}
fn insert(&mut self, path: &'static str, response: Response) {
self.0.insert(path, response);
}
fn get(&self, path: &str) -> Option<&Response> {
match self.0.get(path) {
None => None,
Some(res) => Some(res),
}
}
}