use core::str::FromStr;
use url::Url;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Request(Url);
impl Request {
pub fn from_uri_string<S: AsRef<str>>(absolute_uri: S) -> Result<Self, RequestConstructError> {
match Url::parse(absolute_uri.as_ref()) {
Err(url::ParseError::RelativeUrlWithoutBase) => {
match Url::parse(&[GEMINI_SCHEME, "://", absolute_uri.as_ref()].concat()) {
Err(err) => Err(RequestConstructError::UrlParse(err)),
Ok(uri) => Self::new(uri),
}
}
Err(err) => Err(RequestConstructError::UrlParse(err)),
Ok(uri) => Self::new(uri),
}
}
pub fn new(absolute_uri: Url) -> Result<Self, RequestConstructError> {
let mut uri = absolute_uri;
let scheme = uri.scheme();
if scheme != GEMINI_SCHEME {
return Err(RequestConstructError::UnsupportedProtocol(scheme.into()));
}
if !uri.username().is_empty() || uri.password().is_some() {
return Err(RequestConstructError::Userinfo);
}
if !uri.has_host() {
return Err(RequestConstructError::MissingAuthority);
}
uri.set_fragment(None);
if uri.path().is_empty() {
uri.set_path("/");
}
if uri.port() == Some(GEMINI_PORT) {
uri.set_port(None).expect("setting None always succeeds");
}
#[allow(clippy::needless_as_bytes)]
let len = uri.as_str().as_bytes().len();
if len > URI_LIMIT {
return Err(RequestConstructError::RequestTooLongError(len));
}
Ok(Self(uri))
}
pub(crate) fn with_new_path(&self, path: &str) -> Result<Self, RequestConstructError> {
use normalize_path::NormalizePath;
use std::path::PathBuf;
if path.is_empty() {
return Ok(self.clone());
}
let sans_fragment = path.split('#').next().unwrap(); let has_trailing = sans_fragment.ends_with('/');
let path = PathBuf::from(sans_fragment);
let new_path = PathBuf::from(self.0.path()).join(path).normalize();
let mut new_uri = self
.0
.join(new_path.to_str().expect("path is utf-8"))
.map_err(RequestConstructError::UrlParse)?;
if has_trailing && !new_uri.path().ends_with('/') {
new_uri.set_path(&[new_uri.path(), "/"].concat());
}
Self::new(new_uri)
}
#[inline]
pub(crate) fn host(&self) -> url::Host<&str> {
self.0.host().expect("constructor made sure host exists")
}
#[inline]
pub(crate) fn port(&self) -> u16 {
self.0.port().unwrap_or(GEMINI_PORT)
}
pub fn as_bytes(&self) -> Vec<u8> {
[self.0.as_str().as_bytes(), &CRLF].concat()
}
}
#[cfg(not(tarpaulin_include))]
impl core::fmt::Display for Request {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
pub(crate) const CR: u8 = b'\r';
pub(crate) const LF: u8 = b'\n';
const CRLF: [u8; 2] = [CR, LF];
const GEMINI_PORT: u16 = 1965;
static GEMINI_SCHEME: &str = "gemini";
const URI_LIMIT: usize = 1024;
impl FromStr for Request {
type Err = RequestConstructError;
#[inline]
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::from_uri_string(s)
}
}
impl TryFrom<&str> for Request {
type Error = RequestConstructError;
#[inline]
fn try_from(value: &str) -> Result<Self, Self::Error> {
Self::from_str(value)
}
}
impl TryFrom<String> for Request {
type Error = RequestConstructError;
#[inline]
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_from(value.as_str())
}
}
impl TryFrom<Url> for Request {
type Error = RequestConstructError;
#[inline]
fn try_from(value: Url) -> Result<Self, Self::Error> {
Self::new(value)
}
}
#[derive(Debug)]
pub enum RequestConstructError {
MissingAuthority,
RequestTooLongError(usize),
UnsupportedProtocol(String),
UrlParse(url::ParseError),
Userinfo,
}
#[cfg(not(tarpaulin_include))]
impl core::fmt::Display for RequestConstructError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingAuthority => write!(f, "the request URI did not have a Host portion"),
Self::RequestTooLongError(len) => write!(
f,
"the request URI was too long: expected {URI_LIMIT} bytes or fewer, but got {len}"
),
Self::UnsupportedProtocol(scheme) => {
write!(f, "protocol '{scheme}' is not supported")
}
Self::UrlParse(err) => write!(f, "could not parse string as URI: {err}"),
Self::Userinfo => write!(f, "the request URI contained userinfo data"),
}
}
}
impl core::error::Error for RequestConstructError {}
#[cfg(test)]
mod tests {
use super::*;
use url::Url;
use url_static::url;
#[test]
fn test_constructs_request() {
#[rustfmt::skip]
let cases = [
(url!("gemini://localhost"), b"gemini://localhost/\r\n".as_ref(), 1965), (url!("gemini://localhost:1965"), b"gemini://localhost/\r\n", 1965), (url!("gemini://localhost:1964"), b"gemini://localhost:1964/\r\n", 1964), (url!("gemini://localhost:80"), b"gemini://localhost:80/\r\n", 80), (url!("gemini://:@localhost"), b"gemini://localhost/\r\n", 1965), ];
for (absolute_uri, expected, port) in cases {
let req = Request::new(absolute_uri).unwrap();
assert_eq!(req.port(), port);
assert_eq!(req.host(), url::Host::Domain("localhost"));
let result = req.as_bytes();
assert_eq!(result, expected);
}
}
#[test]
fn test_fails_for_missing_host() {
let url_str = String::from("gemini://"); let result = Request::try_from(url_str);
assert!(
matches!(result, Err(RequestConstructError::MissingAuthority)),
"{result:?}"
);
}
#[test]
fn test_fails_for_userinfo_url() {
let cases = [
url!("gemini://foo:bar@localhost"),
url!("gemini://foo:@localhost"),
url!("gemini://:bar@localhost"),
];
for absolute_uri in cases {
let result = Request::new(absolute_uri);
assert!(
matches!(result, Err(RequestConstructError::Userinfo)),
"{result:?}"
);
}
}
#[test]
fn test_fails_for_non_gemini_protocol() {
let cases = [
(url!("foo://bar"), "foo"),
(url!("Foo://bar"), "foo"),
(url!("bar://"), "bar"),
(url!("nope:"), "nope"),
(url!("ReallyNo:"), "reallyno"),
];
for (absolute_uri, proto) in cases {
let result = Request::new(absolute_uri);
assert!(
matches!(
&result,
Err(RequestConstructError::UnsupportedProtocol(p)) if p == proto
),
"{result:?}"
);
}
}
#[test]
fn test_fails_for_too_long_uri() {
let cases = [
Url::parse(&format!("{:A<1025}", "gemini://localhost/A")).unwrap(),
Url::parse(&format!("{:A<1025}", "gemini://localhost:1964/A")).unwrap(),
Url::parse(&format!("{:A<1027}", "gemini://:@localhost/A")).unwrap(),
];
for absolute_uri in cases {
let result = Request::try_from(absolute_uri).err().unwrap();
assert!(
matches!(result, RequestConstructError::RequestTooLongError(1025)),
"{result:?}"
);
}
}
#[test]
fn test_adds_gemini_protocol_if_none_given() {
let result = Request::from_uri_string("localhost").unwrap();
assert_eq!(result.as_bytes(), b"gemini://localhost/\r\n");
}
#[test]
fn test_fails_to_parse_bad_string_even_with_assumed_protocol() {
let result = Request::from_uri_string("this is not a url").err().unwrap();
assert!(
matches!(
result,
RequestConstructError::UrlParse(url::ParseError::InvalidDomainCharacter)
),
"{result:?}"
);
}
#[test]
fn test_for_exact_length_uri() {
let cases = [
format!("{:A<1024}", "gemini://localhost/A"),
format!("{:A<1024}", "gemini://localhost:1964/A"),
format!("{:A<1026}", "gemini://:@localhost/A"),
format!("{:A<1015}", "localhost/A"),
format!("{:A<1017}", ":@localhost/A"),
];
for absolute_uri in cases {
let payload = Request::from_str(&absolute_uri).unwrap().as_bytes();
assert_eq!(payload.len(), URI_LIMIT + CRLF.len());
}
}
#[test]
fn test_for_right_length_uri() {
let cases = [
format!("{:A<1023}", "gemini://localhost/A"),
format!("{:A<1023}", "gemini://localhost:1964/A"),
format!("{:A<1025}", "gemini://:@localhost/A"),
format!("{:A<1014}", "localhost/A"),
format!("{:A<1016}", ":@localhost/A"),
];
for absolute_uri in cases {
let payload = Request::from_str(&absolute_uri).unwrap().as_bytes();
assert_eq!(payload.len(), URI_LIMIT + CRLF.len() - 1);
}
}
}