trotter 1.0.2

Trotter 🎠 is an experimental crate that aims to make writing Gemini clients fun and easy.
Documentation
use crate::{error::ActorError, Response, Titan, UserAgent};
use openssl::{
	nid::Nid,
	ssl::{Ssl, SslConnector, SslFiletype, SslMethod, SslVerifyMode},
};
use std::{collections::HashMap, path::PathBuf, time::Duration};
use tokio::{
	io::{AsyncReadExt, AsyncWriteExt},
	net::TcpStream,
};
use tokio_openssl::SslStream;
use url::Url;
use wildmatch::WildMatch;

/// 🎠 An ergonomic way to call [`Actor::get`] with the default actor.
///
/// ```
/// # use trotter::Actor;
/// # async fn f() -> Result<(), trotter::error::ActorError> {
/// let response = trotter::trot("localhost").await?;
/// # Ok(())
/// # }
/// ```
pub async fn trot(url: impl Into<String>) -> Result<Response> {
	let url = url.into();
	Actor::default().get(url).await
}

/// 🎠 An ergonomic way to call [`Actor::input`] with the default actor.
///
/// ```
/// # use trotter::Actor;
/// # async fn f() -> Result<(), trotter::error::ActorError> {
/// let response = trotter::trot_in("localhost/input", "notice me!").await?;
/// # Ok(())
/// # }
/// ```
pub async fn trot_in(url: impl Into<String>, input: impl Into<String>) -> Result<Response> {
	Actor::default().input(url.into(), input.into()).await
}

/// Make a gemini request.
pub struct Actor {
	cert: Option<PathBuf>,
	key: Option<PathBuf>,
	user_agent: Option<UserAgent>,
	timeout: Duration,
	proxy: Option<(String, u16)>,
}

type Result<T> = std::result::Result<T, ActorError>;

impl Default for Actor {
	fn default() -> Self {
		Self {
			user_agent: None,
			cert: None,
			key: None,
			timeout: Duration::from_secs(5),
			proxy: None,
		}
	}
}

impl Actor {
	/// Set your client certificate file path
	pub fn cert_file(mut self, cert: impl Into<PathBuf>) -> Self {
		self.cert = Some(cert.into());
		self
	}

	/// Set your client key file path
	pub fn key_file(mut self, key: impl Into<PathBuf>) -> Self {
		self.key = Some(key.into());
		self
	}

	/// *Please* include a user-agent if you're making any
	/// kind of service that indescriminately uses other
	/// peoples' content on gemini.
	///
	/// This allows people to block types of services they
	/// don't want to access their content.
	///
	/// More info: [robots.txt for Gemini](https://geminiprotocol.net/docs/companion/robots.gmi)
	pub fn user_agent(mut self, useragent: UserAgent) -> Self {
		self.user_agent = Some(useragent);
		self
	}

	/// Set timeout for the initial connection.
	///
	/// The default is 5 seconds.
	pub fn timeout(mut self, timeout: impl Into<Duration>) -> Self {
		self.timeout = timeout.into();
		self
	}

	/// Send gemini request to url.
	///
	/// Url can elide the `gemini://` prefix. It's up to you.
	pub async fn get(&self, url: impl Into<String>) -> Result<Response> {
		let url = self.build_url(url.into(), None, "gemini://")?;

		self.obey_robots(&url).await?;
		self.send_request(url, None).await
	}

	/// Send gemini request to url with input.
	///
	/// Input is automatically percent-encoded.
	pub async fn input(
		&self,
		url: impl Into<String>,
		input: impl Into<String>,
	) -> Result<Response> {
		let input = input.into();
		let input = urlencoding::encode(&input);
		let url = self.build_url(url.into(), Some(&input), "gemini://")?;

		self.obey_robots(&url).await?;
		self.send_request(url, None).await
	}

	/// Upload with titan
	pub async fn upload(&self, url: impl Into<String>, titan: Titan) -> Result<Response> {
		let url = self.build_url(url.into(), None, "titan://")?;
		self.obey_robots(&url).await?;
		self.send_request(url, Some(titan)).await
	}

	/// Designate a proxy server to send all future requests through.
	pub fn proxy(mut self, host: String, port: u16) -> Self {
		self.proxy = Some((host, port));
		self
	}

	/// Internal function for sending a request.
	async fn send_request(&self, url: Url, titan: Option<Titan>) -> Result<Response> {
		// Build connector
		let mut connector = SslConnector::builder(SslMethod::tls_client())?;
		connector.set_verify_callback(SslVerifyMode::FAIL_IF_NO_PEER_CERT, |_, _| true);

		// Add client certificate
		if let Some(key) = &self.key {
			connector
				.set_private_key_file(key, SslFiletype::PEM)
				.map_err(ActorError::KeyCertFileError)?;
		}
		if let Some(cert) = &self.cert {
			connector
				.set_certificate_file(cert, SslFiletype::PEM)
				.map_err(ActorError::KeyCertFileError)?;
		}

		// Connect with tcp
		let (domain, port) = match self.proxy {
			Some((ref proxy_host, proxy_port)) => (proxy_host.as_str(), proxy_port),
			None => (
				url.domain().ok_or(ActorError::DomainErr)?,
				url.port().unwrap_or(1965),
			),
		};

		let tcp = tokio::time::timeout(
			self.timeout,
			TcpStream::connect(&format!("{domain}:{port}")),
		)
		.await
		.map_err(ActorError::Timeout)?
		.map_err(ActorError::TcpError)?;

		// Wrap connection in ssl stream
		let mut ssl = Ssl::new(connector.build().context())?;
		ssl.set_connect_state();
		ssl.set_hostname(domain)?; // <- SNI (Server name indication) and don't you forget it 💢

		let mut stream = SslStream::new(ssl, tcp)?;
		if let Some(titan) = titan {
			// Titan request
			stream
				.write_all(
					&format!(
						"{url};{}mime={};size={}\r\n",
						if let Some(token) = &titan.token {
							format!("token={token};")
						} else {
							String::new()
						},
						titan.mimetype,
						titan.content.len()
					)
					.into_bytes(),
				)
				.await?;

			stream.write_all(&titan.content).await?;
		} else {
			// Gemini request
			stream.write_all(&format!("{url}\r\n").into_bytes()).await?;
		}

		// Get certificate
		let certificate = stream
			.ssl()
			.peer_certificate()
			.ok_or(ActorError::NoCertificate)?;

		// Begin collecting a list of valid domains
		let mut valid_domains: Vec<String> = Vec::new();

		// Add subject's common name to list
		for x in certificate.subject_name().entries_by_nid(Nid::COMMONNAME) {
			valid_domains.push(
				x.data()
					.as_utf8()
					.map_err(ActorError::SubjectNameNotUtf8)?
					.to_string(),
			);
		}

		// Add list of alternative common names tagged under `DNS`
		if let Some(names) = certificate.subject_alt_names() {
			names.into_iter().for_each(|x| {
				if let Some(name) = x.dnsname() {
					valid_domains.push(name.to_string());
				}
			})
		}

		// Error if none of them match
		if valid_domains
			.iter()
			.filter(|x| WildMatch::new(x).matches(domain))
			.count() == 0
		{
			return Err(ActorError::DomainUncerified(
				format!("{valid_domains:?}"),
				domain.to_string(),
			))?;
		}

		// Get response header
		let mut header: Vec<u8> = Vec::with_capacity(1024);
		let mut p = b' ';
		for _ in 0..=1026 {
			let c = stream.read_u8().await?;

			// Break if \r\n
			if p == b'\r' && c == b'\n' {
				let _ = header.pop();
				break;
			}

			header.push(c);
			p = c;
		}

		let header = std::str::from_utf8(&header).map_err(ActorError::HeaderNotUtf8)?;

		// Strip status and meta from the header
		let (status, meta) = header.split_once(' ').ok_or(ActorError::MalformedHeader)?;
		let status = status.parse::<u8>().map_err(ActorError::MalformedStatus)?;
		let meta = meta.to_string();

		// Get remaining response content
		let mut content: Vec<u8> = Vec::new();
		stream.read_to_end(&mut content).await?;

		Ok(Response {
			content,
			status,
			meta,
			certificate,
		})
	}

	/// Internal function for obeying robots.txt
	async fn obey_robots(&self, url: &Url) -> Result<()> {
		let Some(user_agent) = &self.user_agent else {
			return Ok(());
		};

		if let Ok(response) = self
			.send_request(
				Url::parse(&format!(
					"gemini://{}/robots.txt",
					url.domain().ok_or(ActorError::DomainErr)?
				))?,
				None,
			)
			.await
		{
			if let Ok(txt) = response.text() {
				// Remove comments
				let txt = txt.lines().filter_map(|x| {
					if !x.trim_start().starts_with('#') {
						if let Some((x, _)) = x.split_once('#') {
							Some(x)
						} else {
							Some(x)
						}
					} else {
						None
					}
				});

				// Parse robots
				let mut robots_map: HashMap<&str, Vec<&str>> = HashMap::new();
				let mut active_agents: Vec<&str> = Vec::new();
				let mut was_user = false; // True if the last line was a user agent
				for line in txt {
					if let Some((_, agent)) = line.trim().split_once("User-agent:") {
						if !was_user {
							// Clear active agents if we're in a new user-agent block
							active_agents.clear();
						}
						// Add active agents
						active_agents.push(agent.trim());
						was_user = true;
					} else if let Some((_, disallow)) = line.trim().split_once("Disallow:") {
						for a in &active_agents {
							// Add disallow entry to all active agents
							if let Some(entry) = robots_map.get_mut(a) {
								entry.push(disallow.trim());
							} else {
								robots_map.insert(a, vec![disallow.trim()]);
							}
						}
						was_user = false;
					}
				}

				// Track the disallows that affect us
				let mut disallow_list: Vec<&str> = Vec::new();

				// Add our useragent's entries to disallow list
				if let Some(for_me) = robots_map.get_mut(user_agent.to_string().as_str()) {
					disallow_list.append(for_me);
				}

				// Add * entries to disallow list
				if let Some(for_everyone) = robots_map.get_mut("*") {
					disallow_list.append(for_everyone);
				}

				for path in disallow_list {
					if path == "/" || url.path().starts_with(path) {
						return Err(ActorError::RobotDenied(
							path.to_string(),
							user_agent.clone(),
						));
					}
				}
			}
		}
		Ok(())
	}

	fn build_url(&self, mut url: String, input: Option<&str>, scheme: &str) -> Result<Url> {
		//  Add scheme if it's not in the url
		if let Some(pos) = url.find(scheme) {
			if pos != 0 {
				url = format!("{scheme}{url}");
			}
		} else {
			url = format!("{scheme}{url}");
		}

		let mut url = Url::parse(&url)?;

		// Add slash to path
		if url.path() == "" {
			url.set_path("/");
		}

		if let Some(input) = input {
			url.set_query(Some(input));
		}

		Ok(url)
	}
}