#![cfg(feature = "logging")]
mod error;
mod logger;
#[cfg(test)]
mod test_server;
use crate::error::ExitReason;
use ::clap::{Parser, ValueHint};
use ::gload::{
net::{
CancellationToken, FetchOptions, IpResolutionLimit, REDIR_LIMIT, RedirectPolicy,
ResponseHandler,
},
request::Request,
response::{StatusCategory, StatusCode},
};
use ::log::{debug, info, warn};
use std::{io::Write, path::PathBuf};
#[cfg(not(tarpaulin_include))]
#[tokio::main(flavor = "current_thread")]
async fn main() -> ExitReason {
let args = Args::parse();
logger::init(args.verbose);
if args.info.changelog {
let changelog = include_str!("../CHANGELOG.gmi");
info!("{changelog}");
return ExitReason::Success;
}
if args.info.license {
let license = include_str!("../LICENSE");
info!("{license}");
return ExitReason::Success;
}
run(args).await
}
async fn run(args: Args) -> ExitReason {
let absolute_uri = args.url.as_ref().expect("clap ensures url is given");
let req = match Request::from_uri_string(absolute_uri) {
Err(err) => return err.into(),
Ok(req) => req,
};
let redirect_policy = if args.location {
match args.max_redirs {
Some(max) => RedirectPolicy::follow_only(max),
None => RedirectPolicy::follow(),
}
} else {
RedirectPolicy::no_follow()
};
let address_resolution_limit = args
.ip_resolution_limit
.as_ref()
.map(IpResolutionLimit::from);
let cancellation_token = CancellationToken::new();
let token = cancellation_token.clone();
tokio::spawn(async move {
use tokio::signal;
#[cfg(unix)]
{
let mut sigint = signal::unix::signal(signal::unix::SignalKind::interrupt()).unwrap();
let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate()).unwrap();
tokio::select! {
_ = sigint.recv() => {},
_ = sigterm.recv() => {},
}
}
#[cfg(windows)]
{
let mut ctrl_break = signal::windows::ctrl_break().unwrap();
let mut ctrl_c = signal::windows::ctrl_c().unwrap();
let mut ctrl_close = signal::windows::ctrl_close().unwrap();
let mut ctrl_logoff = signal::windows::ctrl_logoff().unwrap();
let mut ctrl_shutdown = signal::windows::ctrl_shutdown().unwrap();
tokio::select! {
_ = ctrl_break.recv() => {},
_ = ctrl_c.recv() => {},
_ = ctrl_close.recv() => {},
_ = ctrl_logoff.recv() => {},
_ = ctrl_shutdown.recv() => {},
}
}
#[cfg(any(unix, windows))]
token.cancel();
});
let options = FetchOptions {
address_resolution_limit,
allow_truncation: args.tls_allow_truncation,
redirect_policy,
response_handler: PrintingHandler(&args),
cancellation_token,
};
let response = match gload::fetch(req, options).await {
Err(err) => return err.into(),
Ok(r) => r,
};
if args.fail && response.status().is_failure() {
return ExitReason::ServerErrorResponse(response.status());
}
if args.head {
return ExitReason::Success;
}
match response.status().category() {
StatusCategory::InputExpected => ExitReason::InputExpected,
StatusCategory::Success
| StatusCategory::TemporaryFailure
| StatusCategory::PermanentFailure
| StatusCategory::Redirection => ExitReason::Success,
StatusCategory::ClientCertificates => {
if response.status() == StatusCode::CERTIFICATE_NOT_VALID {
ExitReason::ClientCertificateNotValid
} else {
ExitReason::ClientCertificateRequired
}
}
}
}
struct PrintingHandler<'a>(&'a Args);
impl core::fmt::Debug for PrintingHandler<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "PrintingHandler")
}
}
impl ResponseHandler for PrintingHandler<'_> {
type Error = PrintingError;
fn handle(&self, res: &gload::Response) -> Result<(), Self::Error> {
let args = self.0;
let print_header = args.head || args.show_header;
let print_body = !args.head;
let status = res.status().as_u8();
match res.meta() {
None => debug!("< {status}"),
Some(meta) => debug!("< {status} {meta}"),
}
debug!("<");
if args.out_null {
return Ok(());
}
if print_header {
match res.meta() {
None => info!("{status}"),
Some(meta) => info!("{status} {meta}"),
}
}
if print_body && let Some(body_bytes) = res.body() {
if let Some(output) = &args.output {
if output.to_string_lossy() == "-" {
if std::io::stdout().write_all(body_bytes).is_err() {
return Err(PrintingError::Stdout(body_bytes.len()));
}
} else {
if let Err(err) = std::fs::write(output, body_bytes) {
warn!(
"Warning: Failed to open the file {}: {err}",
output.display()
);
return Err(PrintingError::File(body_bytes.len()));
}
}
} else {
match str::from_utf8(body_bytes) {
Err(_) => {
warn!(
r#"Warning: Binary output can mess up your terminal. Use "--output <FILE>" to save to a file or consider "--output -" to tell {} to output to your terminal anyway."#,
clap::crate_name!()
);
return Err(PrintingError::BinaryOutput);
}
Ok(body) => print!("{body}"),
}
}
}
Ok(())
}
}
#[derive(Debug)]
enum PrintingError {
BinaryOutput,
File(usize),
Stdout(usize),
}
impl core::fmt::Display for PrintingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::BinaryOutput => {
write!(f, "could not parse data as a string for printing to stdout")
}
Self::File(size) => write!(f, "failed to write {size} bytes to file"),
Self::Stdout(size) => write!(f, "failed to write {size} bytes to stdout"),
}
}
}
impl core::error::Error for PrintingError {}
static FULL_VERSION: &str = include_str!(concat!(env!("OUT_DIR"), "/full_version.txt"));
#[derive(Parser)]
#[cfg_attr(test, derive(Default))]
#[command(version = FULL_VERSION, about)]
struct Args {
#[arg(short, long)]
fail: bool,
#[arg(short = 'I', long)]
head: bool,
#[clap(flatten)]
ip_resolution_limit: Option<IpResolutionLimitArgs>,
#[arg(short = 'L', long)]
location: bool,
#[arg(long, value_parser = parse_max_redirs)]
max_redirs: Option<u8>,
#[arg(short = 'o', long, conflicts_with = "out_null")]
output: Option<PathBuf>,
#[arg(long, conflicts_with = "output")]
out_null: bool,
#[arg(short = 'i', long)]
show_header: bool,
#[arg(long)]
tls_allow_truncation: bool,
#[arg(short, long)]
verbose: bool,
#[clap(required_unless_present_any(["changelog", "license"]), value_hint = ValueHint::Url)]
url: Option<String>,
#[clap(flatten)]
info: InfoArgs,
}
fn parse_max_redirs(value: &str) -> Result<u8, String> {
let max_redirs: u8 = value.parse().map_err(|e| format!("{e}"))?;
if max_redirs > REDIR_LIMIT {
Err(format!(
"{max_redirs} is greater than the limit of {REDIR_LIMIT}"
))
} else {
Ok(max_redirs)
}
}
#[derive(clap::Args)]
#[group(multiple = false)]
struct IpResolutionLimitArgs {
#[arg(short = '4', long)]
ipv4: bool,
#[arg(short = '6', long)]
ipv6: bool,
}
impl From<&IpResolutionLimitArgs> for IpResolutionLimit {
fn from(value: &IpResolutionLimitArgs) -> Self {
if value.ipv4 {
Self::Ipv4
} else if value.ipv6 {
Self::Ipv6
} else {
unreachable!("clap ensures exactly one of these is true")
}
}
}
#[derive(clap::Args)]
#[cfg_attr(test, derive(Default))]
#[group(multiple = false)]
struct InfoArgs {
#[arg(long)]
changelog: bool,
#[arg(long)]
license: bool,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_server::{
new_simple_runtime, start_friendly_server, start_redir_server,
start_unfriendly_server_no_close_notify, start_unfriendly_server_slow,
};
use assert_cmd::cargo::CommandCargoExt;
use clap::crate_name;
use core::time::Duration;
use signal_child::Signalable;
use std::process::Command;
#[tokio::test]
async fn test_quits_for_malformed_url() {
let cases = ["ge:// f d"];
for url in cases {
let args = Args::parse_from([crate_name!(), url]);
let res = run(args).await;
assert!(matches!(res, ExitReason::UrlMalformed));
}
}
#[tokio::test]
async fn test_quits_for_non_gemini_protocol() {
let cases = [
("foo://bar", "foo"),
("Foo://bar", "foo"),
("bar://", "bar"),
("nope:", "nope"),
("ReallyNo:", "reallyno"),
];
for (url, proto) in cases {
let args = Args::parse_from([crate_name!(), url]);
let res = run(args).await;
assert!(matches!(res, ExitReason::UnsupportedProtocol(p) if p == proto));
}
}
#[tokio::test]
async fn test_quits_for_userinfo_url() {
let cases = [
"gemini://foo:bar@localhost",
"gemini://foo:@localhost",
"gemini://:bar@localhost",
];
for url in cases {
let args = Args::parse_from([crate_name!(), url]);
let res = run(args).await;
assert!(matches!(res, ExitReason::UrlMalformed));
}
}
#[test]
fn test_default_behavior() {
let key = rcgen::generate_simple_self_signed(["localhost".into()]).unwrap();
let server = start_friendly_server(key);
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let uri = format!("gemini://localhost:{}/", addr.port());
let args = Args::parse_from([crate_name!(), &uri]);
let res = runtime.block_on(run(args)); assert!(matches!(res, ExitReason::Success));
assert_eq!(server.request_count(), 1);
}
#[test]
fn test_default_redir_behavior() {
let server = start_redir_server("/");
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let uri = format!("gemini://localhost:{}/", addr.port());
let args = Args::parse_from([crate_name!(), &uri, "-I"]);
let res = runtime.block_on(run(args));
assert!(matches!(res, ExitReason::Success), "{uri}: {res:?}"); assert_eq!(server.request_count(), 1);
}
#[test]
fn test_default_behavior_on_redir() {
let server = start_redir_server("/now-you-see-me");
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let uri = format!("gemini://localhost:{}/", addr.port());
let args = Args::parse_from([crate_name!(), &uri, "-IL"]);
let res = runtime.block_on(run(args));
assert!(matches!(res, ExitReason::Success), "{uri}: {res:?}"); assert_eq!(server.request_count(), 2);
}
#[test]
fn test_redirects_at_default_limit() {
let server = start_redir_server("/");
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let uri = format!("gemini://localhost:{}/", addr.port());
let args = Args::parse_from([crate_name!(), &uri, "-L"]);
let res = runtime.block_on(run(args));
assert!(
matches!(res, ExitReason::TooManyRedirects { max: 5 }), "{uri}: {res:?}"
);
assert_eq!(server.request_count(), 6);
}
#[test]
fn test_redirects_at_limit() {
for limit in 0..=5 {
let server = start_redir_server("/");
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let uri = format!("gemini://localhost:{}/", addr.port());
let max = &limit.to_string();
let args = Args::parse_from([crate_name!(), &uri, "-L", "--max-redirs", max]);
let res = runtime.block_on(run(args));
assert!(
matches!(res, ExitReason::TooManyRedirects { max } if max == limit),
"{uri}: {res:?}"
);
assert_eq!(server.request_count(), limit + 1);
}
}
#[test]
fn test_redirects_to_not_found_page() {
let server = start_redir_server("/foo");
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let uri = format!("gemini://localhost:{}/", addr.port());
let args = Args::parse_from([crate_name!(), &uri, "-fL"]);
let res = runtime.block_on(run(args));
assert!(
matches!(res, ExitReason::ServerErrorResponse(StatusCode::NOT_FOUND)),
"{uri}: {res:?}"
);
assert_eq!(server.request_count(), 2);
}
#[test]
fn test_redirects_only_once() {
let server = start_redir_server("/hello");
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let uri = format!("gemini://localhost:{}/", addr.port());
let args = Args::parse_from([crate_name!(), &uri, "-fL", "--max-redirs", "1"]);
let res = runtime.block_on(run(args));
assert!(matches!(res, ExitReason::Success), "{uri}: {res:?}");
assert_eq!(server.request_count(), 2);
}
#[test]
fn test_fails_to_redirects_when_none_permitted() {
let server = start_redir_server("/hello");
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let uri = format!("gemini://localhost:{}/", addr.port());
let args = Args::parse_from([crate_name!(), &uri, "-fL", "--max-redirs", "0"]);
let res = runtime.block_on(run(args));
assert!(
matches!(res, ExitReason::TooManyRedirects { max: 0 }),
"{uri}: {res:?}"
);
assert_eq!(server.request_count(), 1);
}
#[test]
fn test_redirects_only_once_with_more_permitted() {
for limit in 2..=5 {
let server = start_redir_server("/hello");
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let uri = format!("gemini://localhost:{}/", addr.port());
let max = &limit.to_string();
let args = Args::parse_from([crate_name!(), &uri, "-fL", "--max-redirs", max]);
let res = runtime.block_on(run(args));
assert!(matches!(res, ExitReason::Success), "{uri}: {res:?}");
assert_eq!(server.request_count(), 2); }
}
#[test]
fn test_fails_at_no_close_notify() {
let server = start_unfriendly_server_no_close_notify();
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let uri = format!("gemini://localhost:{}/", addr.port());
let args = Args::parse_from([crate_name!(), &uri]);
let res = runtime.block_on(run(args));
assert!(matches!(res, ExitReason::AbruptClosure), "{uri}: {res:?}");
assert_eq!(server.request_count(), 1);
}
#[test]
fn test_allows_missing_close_notify_with_flag() {
let server = start_unfriendly_server_no_close_notify();
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let uri = format!("gemini://localhost:{}/", addr.port());
let args = Args::parse_from([crate_name!(), &uri, "--tls-allow-truncation"]);
let res = runtime.block_on(run(args));
assert!(matches!(res, ExitReason::Success), "{uri}: {res:?}");
assert_eq!(server.request_count(), 1);
}
#[test]
fn test_cancels_on_sigint() {
let server = start_unfriendly_server_slow();
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let status = runtime
.block_on(async {
let mut child = Command::cargo_bin(crate_name!())
.unwrap()
.arg(format!("gemini://localhost:{}/", addr.port()))
.spawn()
.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
child.interrupt().unwrap();
child.wait()
})
.unwrap();
assert_eq!(status.code(), None); assert_eq!(server.request_count(), 1); }
#[test]
fn test_cancels_on_sigterm() {
let server = start_unfriendly_server_slow();
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let status = runtime
.block_on(async {
let mut child = Command::cargo_bin(crate_name!())
.unwrap()
.arg(format!("gemini://localhost:{}/", addr.port()))
.spawn()
.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
child.term().unwrap();
child.wait()
})
.unwrap();
assert_eq!(status.code(), None); assert_eq!(server.request_count(), 1); }
}