gload 0.5.1

A command line client for the Gemini protocol.
Documentation
//! This software is licensed as described in the file LICENSE, which
//! you should have received as part of this distribution.
//!
//! You may opt to use, copy, modify, merge, publish, distribute and/or sell
//! copies of the Software, and permit persons to whom the Software is
//! furnished to do so, under the terms of the LICENSE file.
//!
//! This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
//! KIND, either express or implied.
//!
//! SPDX-License-Identifier: BSD-3-Clause

#![cfg(test)]

use core::{net::SocketAddr, time::Duration};
use mktemp::Temp;
use openssl::ssl::{Ssl, SslAcceptor, SslFiletype, SslMethod, SslVerifyMode};
use rcgen::{CertifiedKey, KeyPair};
use std::{collections::HashMap, io::ErrorKind, sync::Arc};
use tokio::{
	io::{AsyncReadExt, AsyncWriteExt},
	runtime::Runtime,
	sync::RwLock,
};
use tokio_openssl::SslStream;

pub(crate) static FRIENDLY: &[u8] = b"20 text/gemini; charset=utf-8; lang=en\r\n:3\n";

/// Constructs a new multi-threaded Tokio runtime.
pub(crate) fn new_runtime(thread_name: &'static str) -> Runtime {
	tokio::runtime::Builder::new_multi_thread()
		.worker_threads(2)
		.thread_name(thread_name)
		.enable_all()
		.build()
		.unwrap()
}

/// Constructs a new single-threaded Tokio runtime on which to run the given future.
pub(crate) fn new_simple_runtime(thread_name: &'static str) -> Runtime {
	tokio::runtime::Builder::new_current_thread()
		.thread_name(thread_name)
		.enable_all()
		.build()
		.unwrap()
}

/// Starts a server whose root path answers with a successful response,
/// but does not send `close_notify` when finished.
pub(crate) fn start_unfriendly_server_no_close_notify() -> ServerHandle {
	let mut handlers = Responses::new();
	handlers.insert("/", Response::Immediate(FRIENDLY.to_vec()));
	start_gemini_server_with_certs(false, handlers)
}

/// Starts a server whose root path waits five seconds to answer with
/// a successful response.
pub(crate) fn start_unfriendly_server_slow() -> ServerHandle {
	let mut handlers = Responses::new();
	let timeout = Duration::from_secs(5);
	handlers.insert("/", Response::Wait(timeout, FRIENDLY.to_vec()));
	start_gemini_server_with_certs(false, handlers)
}

/// Starts a server whose root path answers with a successful response.
pub(crate) fn start_friendly_server(key: CertifiedKey<KeyPair>) -> ServerHandle {
	let mut handlers = Responses::new();
	handlers.insert("/", Response::Immediate(FRIENDLY.to_vec()));
	start_gemini_server_with_key(key, true, handlers)
}

/// Starts a server whose root path redirects to the given path.
pub(crate) fn start_redir_server(target: &'static str) -> ServerHandle {
	let mut handlers = Responses::new();
	handlers.insert(
		"/",
		Response::Immediate([b"30 ", target.as_bytes(), b"\r\n"].concat()),
	);
	handlers.insert("/hello", Response::Immediate(FRIENDLY.to_vec()));
	start_gemini_server_with_certs(true, handlers)
}

fn start_gemini_server_with_certs(send_close_notify: bool, handlers: Responses) -> ServerHandle {
	let key = rcgen::generate_simple_self_signed(["localhost".into()]).unwrap();
	start_gemini_server_with_key(key, send_close_notify, handlers)
}

fn start_gemini_server_with_key(
	key: CertifiedKey<KeyPair>,
	send_close_notify: bool,
	handlers: Responses,
) -> ServerHandle {
	let certs_dir = mktemp::Temp::new_dir().unwrap();
	let key_pem = certs_dir.join("key.pem");
	let cert_pem = certs_dir.join("cert.pem");

	std::fs::write(cert_pem, key.cert.pem()).unwrap();
	std::fs::write(key_pem, key.signing_key.serialize_pem()).unwrap();

	start_gemini_server(certs_dir, send_close_notify, handlers)
}

/// Starts a Gemini server for testing.
/// # Panics
/// Panics if anything goes wrong.
fn start_gemini_server(
	certs_dir: Temp,
	send_close_notify: bool,
	responses: Responses,
) -> ServerHandle {
	// Pick a random open port
	let listener = std::net::TcpListener::bind("[::]:0").unwrap();
	listener.set_nonblocking(true).unwrap();
	let addr = listener.local_addr().unwrap();

	// Spawn a thread for the server to use,
	let key_pem = certs_dir.join("key.pem");
	let cert_pem = certs_dir.join("cert.pem");
	let runtime = new_runtime("test-server-runtime-worker");
	let requests = Arc::new(RwLock::new(0));
	let requests_c = requests.clone();
	runtime.spawn(async move {
		// This part's adapted from fluffer: https://codeberg.org/catboomer/fluffer/src/commit/36a38239891ebc208f856f049c1daf47f4196a3a/src/app.rs#L121

		// Create a TCP listener
		let listener = tokio::net::TcpListener::from_std(listener).unwrap();

		// Configure TLS
		let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls_server()).unwrap();
		builder
			.set_private_key_file(key_pem, SslFiletype::PEM)
			.unwrap();
		builder
			.set_certificate_file(cert_pem, SslFiletype::PEM)
			.unwrap();
		builder.check_private_key().unwrap();
		builder.set_verify_callback(SslVerifyMode::PEER, |_, _| true);
		builder
			.set_session_id_context(
				std::time::SystemTime::now()
					.duration_since(std::time::UNIX_EPOCH)
					.unwrap()
					.as_secs()
					.to_string()
					.as_bytes(),
			)
			.unwrap();

		let acceptor = builder.build();
		let requests = requests_c.clone();

		loop {
			let Ok((stream, _)) = listener.accept().await else {
				continue;
			};
			let Ok(ssl) = Ssl::new(acceptor.context()) else {
				continue;
			};
			let Ok(mut stream) = SslStream::new(ssl, stream) else {
				continue;
			};

			let handlers = responses.to_owned();
			let requests_c = requests.clone();
			tokio::spawn(async move {
				std::pin::Pin::new(&mut stream).accept().await.unwrap();
				{
					let mut r = requests_c.write().await;
					*r += 1
				}

				// Read in the request
				let mut input = Vec::with_capacity(1024); // specful request is <= 1024 bytes
				let mut previous = b' ';

				for _ in 0..1020 {
					let current = match stream.read_u8().await {
						Err(e) if e.kind() == ErrorKind::UnexpectedEof => {
							let res = b"59 not utf-8\r\n";
							stream.write_all(res).await.unwrap();
							if send_close_notify {
								stream.shutdown().await.unwrap();
							}
							return;
						}
						Err(e) => {
							eprintln!("failed to read buffer: {e}");
							let res = b"50 internal error\r\n";
							if let Err(e) = stream.write_all(res).await {
								eprintln!("failed to send 59: {e}");
							}
							if send_close_notify {
								stream.shutdown().await.unwrap();
							}
							return;
						}
						Ok(b) => b,
					};

					// Stop on CRLF.
					if previous == b'\r' && current == b'\n' {
						input.pop(); // ditch the CR from last round
						break;
					}

					input.push(current);
					previous = current;
				}

				// Parse request
				let input = match str::from_utf8(&input) {
					Err(_) => {
						let res = b"59 not utf-8\r\n";
						stream.write_all(res).await.unwrap();
						if send_close_notify {
							stream.shutdown().await.unwrap();
						}
						return;
					}
					Ok(url) => url,
				};
				eprintln!("got request: {input}");
				let url = match url::Url::parse(input) {
					Err(_) => {
						let res = b"59 not a URL\r\n";
						stream.write_all(res).await.unwrap();
						if send_close_notify {
							stream.shutdown().await.unwrap();
						}
						return;
					}
					Ok(url) => url,
				};

				// Prepare response
				let res = match handlers.get(url.path()) {
					None => {
						let res = b"51 not found\r\n";
						stream.write_all(res).await.unwrap();
						if send_close_notify {
							stream.shutdown().await.unwrap();
						}
						return;
					}
					Some(h) => h,
				};

				let res = match res {
					Response::Immediate(bytes) => bytes.as_slice(),
					Response::Wait(timeout, bytes) => {
						tokio::time::sleep(*timeout).await;
						bytes.as_slice()
					}
				};

				stream.write_all(res).await.unwrap();
				if send_close_notify {
					stream.shutdown().await.unwrap();
				}
			});
		}
	});
	std::thread::sleep(std::time::Duration::from_millis(50)); // should be plenty for the server to spin up

	ServerHandle::new(requests, addr, certs_dir, runtime)
}

/// Tears down the server on drop.
pub(crate) struct ServerHandle {
	requests: Arc<RwLock<u8>>,
	addr: SocketAddr,
	_certs: Temp,
	_runtime: Runtime,
}

impl ServerHandle {
	const fn new(
		requests: Arc<RwLock<u8>>,
		addr: SocketAddr,
		_certs: Temp,
		_runtime: Runtime,
	) -> Self {
		Self {
			requests,
			addr,
			_certs,
			_runtime,
		}
	}

	pub(crate) const fn addr(&self) -> SocketAddr {
		self.addr
	}

	/// The number of times this server saw a request.
	pub(crate) fn request_count(&self) -> u8 {
		*self.requests.blocking_read()
	}
}

#[derive(Clone)]
enum Response {
	Immediate(Vec<u8>),
	Wait(Duration, Vec<u8>),
}

#[derive(Clone)]
struct Responses(HashMap<&'static str, Response>);

impl Responses {
	fn new() -> Self {
		Self(HashMap::with_capacity(2))
	}

	fn insert(&mut self, path: &'static str, response: Response) {
		self.0.insert(path, response);
	}

	fn get(&self, path: &str) -> Option<&Response> {
		match self.0.get(path) {
			None => None,
			Some(res) => Some(res),
		}
	}
}