use anyhow::{anyhow, Result};
use core::fmt;
use curl::easy::{Easy2, Handler, List, ReadError, WriteError};
use curl::multi::{Easy2Handle, Multi};
use serde::{Deserialize, Serialize};
use std::future::Future;
use std::io::Read;
use std::pin::Pin;
use std::str::{self, FromStr};
use std::task::{Context, Poll};
use std::time::Duration;
macro_rules! set_handle_optional {
($field:expr, $handle:ident, $fn:ident) => {
if let Some(f) = $field {
$handle.$fn(f)?;
}
};
}
#[derive(Debug, Clone)]
pub enum RequestMethod {
Delete,
Get,
Head,
Options,
Patch,
Post,
Put,
Trace,
Custom(String),
}
impl<'a> From<&'a RequestMethod> for &'a str {
fn from(request_method: &'a RequestMethod) -> &'a str {
match request_method {
RequestMethod::Delete => "DELETE",
RequestMethod::Get => "GET",
RequestMethod::Head => "HEAD",
RequestMethod::Options => "OPTIONS",
RequestMethod::Patch => "PATCH",
RequestMethod::Post => "POST",
RequestMethod::Put => "PUT",
RequestMethod::Trace => "TRACE",
RequestMethod::Custom(request_method) => request_method,
}
}
}
impl From<String> for RequestMethod {
fn from(request_method: String) -> Self {
let request_method = request_method.to_uppercase();
match request_method.as_str() {
"DELETE" => RequestMethod::Delete,
"GET" => RequestMethod::Get,
"HEAD" => RequestMethod::Head,
"OPTIONS" => RequestMethod::Options,
"PATCH" => RequestMethod::Patch,
"POST" => RequestMethod::Post,
"PUT" => RequestMethod::Put,
"TRACE" => RequestMethod::Trace,
_ => Self::Custom(request_method),
}
}
}
#[derive(Debug, Clone)]
pub struct Config {
pub location: bool,
pub connect_timeout: Option<Duration>,
pub request_method: RequestMethod,
pub data: Option<String>,
pub headers: Vec<Header>,
pub insecure: bool,
pub client_cert: Option<String>,
pub client_key: Option<String>,
pub ca_cert: Option<String>,
pub url: String,
pub verbose: bool,
pub max_response_size: Option<usize>,
}
impl Default for Config {
fn default() -> Self {
Self {
location: Default::default(),
connect_timeout: Default::default(),
request_method: RequestMethod::Get,
data: Default::default(),
headers: Default::default(),
insecure: Default::default(),
client_cert: Default::default(),
client_key: Default::default(),
ca_cert: Default::default(),
url: Default::default(),
verbose: Default::default(),
max_response_size: Default::default(),
}
}
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct Header {
pub name: String,
pub value: String,
}
impl fmt::Display for Header {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.name, self.value)
}
}
impl FromStr for Header {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.split_once(':') {
Some(header_tuple) => Ok(Self {
name: header_tuple.0.into(),
value: header_tuple.1.trim().into(),
}),
None => Err(anyhow!("Invalid header \"{}\"", s)),
}
}
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct HttpResponseHeader {
pub http_version: String,
pub response_code: i32,
pub response_message: Option<String>,
}
impl From<String> for HttpResponseHeader {
fn from(line: String) -> Self {
let cleaned = line.trim().replace("\r", "").replace("\n", "");
let header_tuple: (&str, &str) = cleaned.split_once('/').unwrap();
let response_arr: Vec<&str> = header_tuple.1.split(' ').collect();
let http_version: String = response_arr.get(0).unwrap().to_string();
let response_code: i32 = response_arr.get(1).unwrap().parse().unwrap();
let response_message: Option<String> = response_arr.get(2).map(|msg| msg.to_string());
Self {
http_version,
response_code,
response_message,
}
}
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct Timing {
pub namelookup: Duration,
pub connect: Duration,
pub pretransfer: Duration,
pub starttransfer: Duration,
pub total: Duration,
pub dns_resolution: Duration,
pub tcp_connection: Duration,
pub tls_connection: Duration,
pub server_processing: Duration,
pub content_transfer: Duration,
}
impl Timing {
pub fn new(handle: &mut Easy2Handle<Collector>) -> Self {
let namelookup = handle.namelookup_time().unwrap();
let connect = handle.connect_time().unwrap();
let pretransfer = handle.pretransfer_time().unwrap();
let starttransfer = handle.starttransfer_time().unwrap();
let total = handle.total_time().unwrap();
let dns_resolution = namelookup;
let tcp_connection = connect - namelookup;
let tls_connection = pretransfer - connect;
let server_processing = starttransfer - pretransfer;
let content_transfer = total - starttransfer;
Self {
namelookup,
connect,
pretransfer,
starttransfer,
total,
dns_resolution,
tcp_connection,
tls_connection,
server_processing,
content_transfer,
}
}
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct StatResult {
pub http_version: String,
pub response_code: i32,
pub response_message: Option<String>,
pub headers: Vec<Header>,
pub timing: Timing,
pub body: Vec<u8>,
}
pub struct Collector<'a> {
config: &'a Config,
headers: &'a mut Vec<u8>,
data: &'a mut Vec<u8>,
}
impl<'a> Collector<'a> {
pub fn new(config: &'a Config, data: &'a mut Vec<u8>, headers: &'a mut Vec<u8>) -> Self {
Self {
config,
data,
headers,
}
}
}
impl<'a> Handler for Collector<'a> {
fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
self.data.extend_from_slice(data);
if let Some(ref max_response_size) = self.config.max_response_size {
if self.data.len() > *max_response_size {
return Ok(0);
}
}
Ok(data.len())
}
fn read(&mut self, into: &mut [u8]) -> Result<usize, ReadError> {
match &self.config.data {
Some(data) => Ok(data.as_bytes().read(into).unwrap()),
None => Ok(0),
}
}
fn header(&mut self, data: &[u8]) -> bool {
self.headers.extend_from_slice(data);
true
}
}
pub struct HttpstatFuture<'a>(&'a Multi);
impl<'a> Future for HttpstatFuture<'a> {
type Output = Result<()>;
fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.perform() {
Ok(running) => {
if running > 0 {
context.waker().wake_by_ref();
Poll::Pending
} else {
Poll::Ready(Ok(()))
}
}
Err(error) => Poll::Ready(Err(error.into())),
}
}
}
pub async fn httpstat(config: &Config) -> Result<StatResult> {
let mut body = Vec::new();
let mut headers = Vec::new();
let mut handle = Easy2::new(Collector::new(config, &mut body, &mut headers));
handle.url(&config.url)?;
handle.show_header(false)?;
handle.progress(true)?;
handle.verbose(config.verbose)?;
if config.insecure {
handle.ssl_verify_host(false)?;
handle.ssl_verify_peer(false)?;
}
set_handle_optional!(&config.client_cert, handle, ssl_cert);
set_handle_optional!(&config.client_key, handle, ssl_key);
set_handle_optional!(&config.ca_cert, handle, cainfo);
set_handle_optional!(config.connect_timeout, handle, connect_timeout);
if config.location {
handle.follow_location(true)?;
}
let data_len = config.data.as_ref().map(|data| data.len() as u64);
let request_method = &config.request_method;
match request_method {
RequestMethod::Put => {
handle.upload(true)?;
set_handle_optional!(data_len, handle, in_filesize);
}
RequestMethod::Get => handle.get(true)?,
RequestMethod::Head => handle.nobody(true)?,
RequestMethod::Post => handle.post(true)?,
_ => handle.custom_request(request_method.into())?,
}
if data_len.is_some() && !matches!(request_method, RequestMethod::Put) {
handle.post_field_size(data_len.unwrap())?;
}
if !&config.headers.is_empty() {
let mut headers = List::new();
for header in &config.headers {
headers.append(&header.to_string())?;
}
handle.http_headers(headers)?;
}
let multi = Multi::new();
let mut handle = multi.add2(handle)?;
HttpstatFuture(&multi).await?;
let mut transfer_result: Result<()> = Ok(());
multi.messages(|m| {
if let Ok(()) = transfer_result {
if let Some(Err(error)) = m.result_for2(&handle) {
if error.is_write_error() {
transfer_result = Err(anyhow!("Maximum response size reached"));
} else {
transfer_result = Err(error.into());
}
}
}
});
transfer_result?;
let timing = Timing::new(&mut handle);
drop(handle);
let header_lines = str::from_utf8(&headers[..])?.lines();
let mut http_response_header: Option<HttpResponseHeader> = None;
let mut headers: Vec<Header> = Vec::new();
let header_iter = header_lines
.map(|line| line.replace("\r", "").replace("\n", ""))
.filter(|line| !line.is_empty());
for line in header_iter {
if line.to_uppercase().starts_with("HTTP/") {
http_response_header = Some(HttpResponseHeader::from(line.to_string()));
} else if let Ok(header) = Header::from_str(&line) {
headers.push(header);
}
}
Ok(StatResult {
http_version: http_response_header
.as_ref()
.map_or_else(|| "Unknown".into(), |h| h.http_version.clone()),
response_code: http_response_header
.as_ref()
.map_or(-1, |h| h.response_code),
response_message: http_response_header
.as_ref()
.and_then(|h| h.response_message.clone()),
headers,
body,
timing,
})
}