use crate::{
request::{Request, RequestConstructError},
response::{Response, ResponseParseError, StatusCategory},
tls::{self, Stream, TlsError},
};
use core::{cmp::Ordering, net::SocketAddr};
use rustls::pki_types::ServerName;
use std::net::ToSocketAddrs;
pub use tokio_util::sync::CancellationToken; use url::Url;
macro_rules! debug {
($($arg:tt)+) => {{
#[cfg(feature = "logging")]
::log::debug!(target: "gload::net", $($arg)+)
}}
}
#[derive(Debug, Clone)]
pub struct FetchOptions<H: ResponseHandler> {
pub address_resolution_limit: Option<IpResolutionLimit>,
pub allow_truncation: bool,
pub cancellation_token: CancellationToken,
pub redirect_policy: RedirectPolicy,
pub response_handler: H,
}
impl Default for FetchOptions<DefaultHandler> {
fn default() -> Self {
Self {
address_resolution_limit: None,
allow_truncation: false,
cancellation_token: CancellationToken::new(),
redirect_policy: RedirectPolicy::no_follow(),
response_handler: DefaultHandler,
}
}
}
pub fn fetch_sync<H: ResponseHandler>(
req: Request,
options: FetchOptions<H>,
) -> Result<Response, SendError<H>> {
let runtime = tokio::runtime::Builder::new_current_thread()
.thread_name("gload-worker-thread")
.enable_all()
.build()
.map_err(SendError::TokioRuntimeError)?;
runtime.block_on(fetch(req, options))
}
pub async fn fetch<H: ResponseHandler>(
req: Request,
options: FetchOptions<H>,
) -> Result<Response, SendError<H>> {
if options.cancellation_token.is_cancelled() {
return Err(SendError::Cancelled);
}
let resolution_limit = options.address_resolution_limit;
let (authority, addrs, payload) = prepare_req(&req, resolution_limit)?;
let mut res = send_req_to(
&authority,
&addrs,
payload,
options.allow_truncation,
&options.cancellation_token,
)
.await?;
options
.response_handler
.handle(&res)
.map_err(SendError::Handler)?;
debug!("* shutting down connection #0");
let max_redirs = match options.redirect_policy.max_redirs() {
None => return Ok(res),
Some(max) => max,
};
for conn_idx in 1..=(max_redirs + 1) {
if res.status().category() != StatusCategory::Redirection {
return Ok(res);
} else if let Some(target) = res.meta() {
if conn_idx > max_redirs {
break;
}
let (authority, addrs, payload) = match Url::parse(target) {
Err(_) => {
handle_redir_to_path(&req, &authority, &addrs, target)?
}
Ok(absolute_uri) if absolute_uri.host() == Some(req.host()) => {
let new_req = Request::new(absolute_uri)?;
if new_req.port() == req.port() {
handle_redir_to_path(&req, &authority, &addrs, target)?
} else {
debug!(
"* Redirects to port from {} to {}",
req.port(),
new_req.port()
);
debug!("* Issue another request to this URL: '{new_req}'",);
prepare_req(&new_req, resolution_limit)?
}
}
Ok(absolute_uri) => {
debug!("* Issue another request to this URL: '{absolute_uri}'");
let new_req = Request::new(absolute_uri)?;
prepare_req(&new_req, resolution_limit)?
}
};
res = send_req_to(
&authority,
&addrs,
payload,
options.allow_truncation,
&options.cancellation_token,
)
.await?;
options
.response_handler
.handle(&res)
.map_err(SendError::Handler)?;
debug!("* shutting down connection #{conn_idx}");
} else {
return Ok(res);
}
}
Err(SendError::TooManyRedirects { max: max_redirs })
}
pub const REDIR_LIMIT: u8 = 5;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RedirectPolicy(u8);
impl RedirectPolicy {
pub const fn follow() -> Self {
Self(REDIR_LIMIT)
}
pub const fn no_follow() -> Self {
Self(u8::MAX) }
pub const fn follow_only(max: u8) -> Self {
assert!(max <= REDIR_LIMIT, "cannot follow more than 5 redirections");
Self(max)
}
const fn max_redirs(&self) -> Option<u8> {
if self.0 <= REDIR_LIMIT {
Some(self.0)
} else {
None }
}
}
pub trait ResponseHandler: core::fmt::Debug {
type Error: core::error::Error;
fn handle(&self, res: &Response) -> Result<(), Self::Error>;
}
#[derive(Debug)]
pub struct DefaultHandler;
impl ResponseHandler for DefaultHandler {
type Error = core::convert::Infallible;
fn handle(&self, _: &Response) -> Result<(), Self::Error> {
Ok(())
}
}
fn prepare_req<H: ResponseHandler>(
req: &Request,
resolution_limit: Option<IpResolutionLimit>,
) -> Result<(ServerName<'static>, Vec<SocketAddr>, Vec<u8>), SendError<H>> {
let host = req.host();
let port = req.port();
let addrs = resolve(&host, port, resolution_limit)?;
let authority = match ServerName::try_from(host.to_string()) {
Err(_) => {
return Err(SendError::CouldNotResolveHost(to_owned_host(&host)));
}
Ok(a) => a,
};
Ok((authority, addrs, req.as_bytes()))
}
fn handle_redir_to_path<H: ResponseHandler>(
old_req: &Request,
authority: &ServerName<'static>,
addrs: &[SocketAddr],
path: &str,
) -> Result<(ServerName<'static>, Vec<SocketAddr>, Vec<u8>), SendError<H>> {
let req = old_req.with_new_path(path)?;
debug!("* Issue another request to this URL: '{req}'");
Ok((authority.clone(), addrs.to_owned(), req.as_bytes()))
}
async fn send_req_to<H: ResponseHandler>(
authority: &ServerName<'static>,
addrs: &[SocketAddr],
payload: Vec<u8>,
allow_truncation: bool,
token: &CancellationToken,
) -> Result<Response, SendError<H>> {
let (tcp_stream, addr) = {
let mut last: Option<Result<(Stream, SocketAddr), TlsError>> = None;
for addr in addrs {
debug!("* Trying {addr}...");
let stream_res = tokio::select! {
_ = token.cancelled() => { return Err(SendError::Cancelled) }
r = tls::open_stream(addr, authority.clone()) => r
};
match stream_res {
Err(err) => {
last = Some(Err(err));
continue;
}
Ok(s) => {
last = Some(Ok((s, *addr)));
break;
}
}
}
match last {
None => unreachable!("addrs is nonempty"),
Some(Err(err)) => return Err(err.into()),
Some(Ok(s)) => s,
}
};
let response_bytes = tokio::select! {
_ = token.cancelled() => { return Err(SendError::Cancelled) }
r = tls::send(&addr, authority, tcp_stream, payload, allow_truncation) => r
}?;
let res = Response::new(response_bytes)?;
Ok(res)
}
#[derive(Debug)]
#[non_exhaustive]
pub enum SendError<H: ResponseHandler> {
Cancelled,
CouldNotResolveHost(url::Host<String>),
Handler(H::Error),
RequestConstruct(RequestConstructError),
ResponseParse(ResponseParseError),
Tls(TlsError),
TokioRuntimeError(std::io::Error),
TooManyRedirects { max: u8 },
}
impl<H: ResponseHandler> From<RequestConstructError> for SendError<H> {
fn from(value: RequestConstructError) -> Self {
Self::RequestConstruct(value)
}
}
impl<H: ResponseHandler> From<ResponseParseError> for SendError<H> {
fn from(value: ResponseParseError) -> Self {
Self::ResponseParse(value)
}
}
impl<H: ResponseHandler> From<TlsError> for SendError<H> {
fn from(value: TlsError) -> Self {
Self::Tls(value)
}
}
impl<H: ResponseHandler> core::fmt::Display for SendError<H> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Cancelled => write!(f, "request cancelled"),
Self::CouldNotResolveHost(host) => {
write!(f, "could not resolve IP address from host: {host}")
}
Self::Handler(err) => write!(f, "{err}"),
Self::RequestConstruct(err) => write!(
f,
"could not create a request from a redirect target: {err}"
),
Self::ResponseParse(err) => write!(f, "could not parse the server's response: {err}"),
Self::Tls(err) => write!(f, "something went wrong while sending the request: {err}"),
Self::TokioRuntimeError(err) => write!(f, "could not sawn a worker thread: {err}"),
Self::TooManyRedirects { max } => write!(
f,
"received a redirect response, but we've already reached the a maximum number of redirects ({max})"
),
}
}
}
impl<H: ResponseHandler> core::error::Error for SendError<H> {}
fn resolve<H: ResponseHandler>(
host: &url::Host<&str>,
port: u16,
resolution_limit: Option<IpResolutionLimit>,
) -> Result<Vec<SocketAddr>, SendError<H>> {
let mut addrs = match host {
url::Host::Domain(domain) => [domain, ":", &port.to_string()]
.concat()
.to_socket_addrs() .unwrap_or_else(|_| Vec::new().into_iter())
.collect(),
url::Host::Ipv4(ip) => vec![SocketAddr::new((*ip).into(), port)],
url::Host::Ipv6(ip) => vec![SocketAddr::new((*ip).into(), port)],
};
if let Some(limit) = resolution_limit {
match limit {
IpResolutionLimit::Ipv4 => addrs.retain(|ip| ip.is_ipv4()),
IpResolutionLimit::Ipv6 => addrs.retain(|ip| ip.is_ipv6()),
}
}
if addrs.is_empty() {
return Err(SendError::CouldNotResolveHost(to_owned_host(host)));
}
addrs.sort_by(|a, b| {
if a.is_ipv6() && b.is_ipv4() {
Ordering::Less
} else if a.is_ipv4() && b.is_ipv6() {
Ordering::Greater
} else {
Ordering::Equal
}
});
debug!("* Host {host}:{port} was resolved.");
let mut res = Vec::with_capacity(2);
if let Some(ipv6) = addrs.iter().find(|ip| ip.is_ipv6()) {
debug!("* IPv6: {}", ipv6.ip());
res.push(*ipv6);
} else {
debug!("* IPv6: (none)");
}
if let Some(ipv4) = addrs.iter().find(|ip| ip.is_ipv4()) {
debug!("* IPv4: {}", ipv4.ip());
res.push(*ipv4);
} else {
debug!("* IPv4: (none)");
}
Ok(res)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IpResolutionLimit {
Ipv4,
Ipv6,
}
fn to_owned_host(host: &url::Host<&str>) -> url::Host<String> {
match host {
url::Host::Domain(domain) => url::Host::Domain(domain.to_string()),
url::Host::Ipv4(ip) => url::Host::Ipv4(*ip),
url::Host::Ipv6(ip) => url::Host::Ipv6(*ip),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_server::{
new_runtime, start_friendly_server, start_redir_server, start_unfriendly_server_slow,
};
use core::time::Duration;
use url_static::url;
#[test]
fn test_redir_by_uri_target() {
let cases = [
url!("gemini://localhost"),
url!("gemini://localhost/"),
url!("gemini://localhost/foo"),
url!("gemini://localhost/foo/bar"),
url!("gemini://localhost/foo/bar"),
];
let localhost_v6 = "[::1]:1965".to_socket_addrs().unwrap().next().unwrap();
let localhost_v4 = "127.0.0.1:1965".to_socket_addrs().unwrap().next().unwrap();
for uri in cases {
let expected = Request::new(uri.clone()).unwrap();
let req = Request::new(uri).unwrap();
let (authority, addrs, payload) = prepare_req::<DefaultHandler>(&req, None).unwrap();
assert_eq!(payload, expected.as_bytes());
assert_eq!(authority.to_str(), "localhost");
assert!(
addrs.contains(&localhost_v6) || addrs.contains(&localhost_v4),
"{addrs:?}"
);
}
}
#[test]
fn test_resolves_all_address_families() {
let localhost_v4 = "127.0.0.1:1965".to_socket_addrs().unwrap().next().unwrap();
let localhost_v6 = "[::1]:1965".to_socket_addrs().unwrap().next().unwrap();
let uri = url!("gemini://localhost:1965");
let host = uri.host().unwrap();
let port = uri.port().unwrap();
let res = resolve::<DefaultHandler>(&host, port, None).unwrap();
assert_eq!(res.len(), 2);
assert!(res.contains(&localhost_v4));
assert!(res.contains(&localhost_v6));
}
#[test]
fn test_resolves_only_ipv4() {
let localhost_v4 = "127.0.0.1:1965".to_socket_addrs().unwrap().next().unwrap();
let uri = url!("gemini://localhost:1965");
let host = uri.host().unwrap();
let port = uri.port().unwrap();
let res = resolve::<DefaultHandler>(&host, port, Some(IpResolutionLimit::Ipv4)).unwrap();
assert_eq!(res.len(), 1);
assert!(res.contains(&localhost_v4));
}
#[test]
fn test_resolves_only_ipv6() {
let localhost_v6 = "[::1]:1965".to_socket_addrs().unwrap().next().unwrap();
let uri = url!("gemini://localhost:1965");
let host = uri.host().unwrap();
let port = uri.port().unwrap();
let res = resolve::<DefaultHandler>(&host, port, Some(IpResolutionLimit::Ipv6)).unwrap();
assert_eq!(res.len(), 1);
assert!(res.contains(&localhost_v6));
}
#[test]
fn test_redir_by_path_target() {
#[rustfmt::skip]
let cases = [
("gemini://localhost/", "..", "gemini://localhost/"),
("gemini://localhost/", "/..", "gemini://localhost/"),
("gemini://localhost/", "../..", "gemini://localhost/"),
("gemini://localhost/", "/../..", "gemini://localhost/"),
("gemini://localhost/", "../../", "gemini://localhost/"),
("gemini://localhost/", "/", "gemini://localhost/"),
("gemini://localhost", "/", "gemini://localhost/"),
("gemini://localhost", "", "gemini://localhost/"), ("gemini://localhost", ".", "gemini://localhost/"),
("gemini://localhost", "..", "gemini://localhost/"),
("gemini://localhost", "../..", "gemini://localhost/"),
("gemini://localhost/target", "", "gemini://localhost/target"),
("gemini://localhost/target/", "", "gemini://localhost/target/"),
("gemini://localhost/target", "/..", "gemini://localhost/"),
("gemini://localhost/target/", "/..", "gemini://localhost/"),
("gemini://localhost/target", "/../..", "gemini://localhost/"),
("gemini://localhost/target/target", "/..", "gemini://localhost/"),
("gemini://localhost/target/target", "/../", "gemini://localhost/"),
("gemini://localhost/target/target", "..", "gemini://localhost/target"),
("gemini://localhost/target/target", "../", "gemini://localhost/target/"),
("gemini://localhost/target/target", "/../..", "gemini://localhost/"),
("gemini://localhost/target/target/", "/../..", "gemini://localhost/"),
("gemini://localhost/target/target/", "", "gemini://localhost/target/target/"),
("gemini://localhost/target/target/", ".", "gemini://localhost/target/target"),
("gemini://localhost", "/target", "gemini://localhost/target"),
("gemini://localhost/", "/target", "gemini://localhost/target"),
("gemini://localhost/", "/target/", "gemini://localhost/target/"),
("gemini://localhost/target", "/target", "gemini://localhost/target"),
("gemini://localhost/target", "/target/", "gemini://localhost/target/"),
("gemini://localhost/target/", "/target", "gemini://localhost/target"),
("gemini://localhost/target/", "/target/", "gemini://localhost/target/"),
("gemini://localhost/target", "target", "gemini://localhost/target/target"),
("gemini://localhost/target/", "target", "gemini://localhost/target/target"),
("gemini://localhost/target/", "target/", "gemini://localhost/target/target/"),
("gemini://localhost/target/", "target/foo", "gemini://localhost/target/target/foo"),
("gemini://localhost/target", "..", "gemini://localhost/"),
("gemini://localhost/target", "#foo", "gemini://localhost/target"),
("gemini://localhost/target", "foo#foo", "gemini://localhost/target/foo"),
("gemini://localhost/target", "foo/#foo", "gemini://localhost/target/foo/"),
("gemini://localhost", "foo and bar", "gemini://localhost/foo%20and%20bar"),
("gemini://localhost/", "foo and bar", "gemini://localhost/foo%20and%20bar"),
("gemini://localhost/", "foo and bar/", "gemini://localhost/foo%20and%20bar/"),
("gemini://localhost/", "foo and bar/baz", "gemini://localhost/foo%20and%20bar/baz"),
("gemini://localhost/", "/gemini://", "gemini://localhost/gemini:/"), ("gemini://localhost/", "/gemini://example.com", "gemini://localhost/gemini:/example.com"),
("gemini://localhost/", "/可愛い", "gemini://localhost/%E5%8F%AF%E6%84%9B%E3%81%84"),
("gemini://localhost/", "/可愛い/", "gemini://localhost/%E5%8F%AF%E6%84%9B%E3%81%84/"),
];
let authority = ServerName::try_from("localhost").unwrap();
let addrs = "[::]:0".to_socket_addrs().unwrap().collect::<Vec<_>>();
for (start, path, expected) in cases {
let expected = Request::from_uri_string(expected).unwrap();
let exp_payload = expected.as_bytes();
let old_req = Request::from_uri_string(start).unwrap();
let new_req =
handle_redir_to_path::<DefaultHandler>(&old_req, &authority, &addrs, path).unwrap();
assert_eq!(
new_req.2,
exp_payload,
"{} != {}",
String::from_utf8_lossy(&new_req.2).trim(),
String::from_utf8_lossy(&exp_payload).trim()
);
}
}
#[test]
fn test_fails_to_redir_when_limit_is_zero() {
let server = start_redir_server("/nowhere"); let port = server.addr().port();
let req = Request::from_uri_string(format!("gemini://localhost:{port}")).unwrap();
let options = FetchOptions {
redirect_policy: RedirectPolicy::follow_only(0),
..Default::default()
};
let err = fetch_sync(req, options).err().unwrap();
assert!(
matches!(err, SendError::TooManyRedirects { max: 0 }),
"{err:?}"
);
assert_eq!(server.request_count(), 1); }
#[test]
fn test_fails_to_redir_when_limit_is_one() {
let server = start_redir_server("/");
let port = server.addr().port();
let req = Request::from_uri_string(format!("gemini://localhost:{port}")).unwrap();
let options = FetchOptions {
redirect_policy: RedirectPolicy::follow_only(1),
..Default::default()
};
let err = fetch_sync(req, options).err().unwrap();
assert!(
matches!(err, SendError::TooManyRedirects { max: 1 }),
"{err:?}"
);
assert_eq!(server.request_count(), 2); }
#[test]
fn test_succeeds_when_redir_limit_is_one_and_only_one_redir() {
let server = start_redir_server("/hello"); let port = server.addr().port();
let req = Request::from_uri_string(format!("gemini://localhost:{port}")).unwrap();
let options = FetchOptions {
redirect_policy: RedirectPolicy::follow_only(1),
..Default::default()
};
let res = fetch_sync(req, options).unwrap();
assert!(res.is_success());
assert_eq!(server.request_count(), 2); }
#[test]
fn test_fails_to_redir_when_limit_is_two() {
let server = start_redir_server("/");
let port = server.addr().port();
let req = Request::from_uri_string(format!("gemini://localhost:{port}")).unwrap();
let options = FetchOptions {
redirect_policy: RedirectPolicy::follow_only(2),
..Default::default()
};
let err = fetch_sync(req, options).err().unwrap();
assert!(
matches!(err, SendError::TooManyRedirects { max: 2 }),
"{err:?}"
);
assert_eq!(server.request_count(), 3); }
#[test]
fn test_fails_to_redir_when_limit_is_default() {
let server = start_redir_server("/");
let port = server.addr().port();
let req = Request::from_uri_string(format!("gemini://localhost:{port}")).unwrap();
let options = FetchOptions {
redirect_policy: RedirectPolicy::follow(),
..Default::default()
};
let err = fetch_sync(req, options).err().unwrap();
assert!(
matches!(err, SendError::TooManyRedirects { max: 5 }),
"{err:?}"
);
assert_eq!(server.request_count(), 6); }
#[test]
fn test_sync_api() {
let port = std::net::TcpListener::bind("[::]:0")
.unwrap()
.local_addr()
.unwrap()
.port();
let req = Request::from_uri_string(format!("gemini://localhost:{port}")).unwrap();
let err = fetch_sync(req, Default::default()).err().unwrap();
assert!(
matches!(err, SendError::Tls(TlsError::Open(ref err)) if err.kind() == std::io::ErrorKind::ConnectionRefused),
"{err:?}"
);
}
#[test]
fn test_cancelling_before_flight() {
let key = rcgen::generate_simple_self_signed(["localhost".into()]).unwrap();
let server = start_friendly_server(key);
let port = server.addr().port();
let req = Request::from_uri_string(format!("gemini://localhost:{port}")).unwrap();
let token = CancellationToken::new();
token.cancel();
let options = FetchOptions {
cancellation_token: token,
..Default::default()
};
let err = fetch_sync(req, options).err().unwrap();
assert!(matches!(err, SendError::Cancelled), "{err:?}");
assert_eq!(server.request_count(), 0); }
#[test]
fn test_cancelling_during_flight() {
let server = start_unfriendly_server_slow();
let port = server.addr().port();
let req = Request::from_uri_string(format!("gemini://localhost:{port}")).unwrap();
let token = CancellationToken::new();
let options = FetchOptions {
cancellation_token: token.clone(),
..Default::default()
};
let runtime = new_runtime("test-client-runtime-worker");
let err = runtime
.block_on(async move {
let handle = tokio::spawn(fetch(req, options));
tokio::time::sleep(Duration::from_millis(100)).await;
token.cancel();
handle.await.unwrap()
})
.err()
.unwrap();
assert!(matches!(err, SendError::Cancelled), "{err:?}");
assert_eq!(server.request_count(), 1); }
}