use std::borrow::Cow;
use std::io::Read;
use std::time::Duration;
use crate::color::Color;
use crate::{make_color, print_error};
#[derive(Default)]
pub struct Config<'a> {
pub url: Cow<'a, str>,
pub method: Method,
pub color: Color,
pub request_headers: Vec<Header>,
pub request_body: Option<Cow<'a, str>>,
pub output: Option<Cow<'a, str>>,
pub display_response_body: bool,
pub display_response_headers: bool,
pub follow_redirects: bool,
pub verbose: bool,
}
pub struct Decorator<'a> {
config: &'a Config<'a>,
pub response_headers: &'a mut Vec<u8>,
pub response_body: &'a mut Vec<u8>,
}
impl<'a> Decorator<'a> {
pub fn new(
config: &'a Config<'a>,
response_headers: &'a mut Vec<u8>,
response_body: &'a mut Vec<u8>,
) -> Self {
Self {
config,
response_headers,
response_body,
}
}
}
impl<'a> curl::easy::Handler for Decorator<'a> {
fn header(&mut self, data: &[u8]) -> bool {
self.response_headers.extend_from_slice(data);
true
}
fn read(&mut self, data: &mut [u8]) -> Result<usize, curl::easy::ReadError> {
match &self.config.request_body {
Some(d) => match d.as_bytes().read(data) {
Ok(len) => Ok(len),
Err(e) => {
print_error!("Error reading data: {}", e);
Err(curl::easy::ReadError::Abort)
}
},
None => Ok(0),
}
}
fn write(&mut self, data: &[u8]) -> Result<usize, curl::easy::WriteError> {
self.response_body.extend_from_slice(data);
Ok(data.len())
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct Header {
pub key: String,
pub value: String,
}
impl Header {
pub fn header_key(&self) -> String {
self.key
.split('-')
.map(|word| {
word.chars()
.enumerate()
.map(|(i, c)| {
if i == 0 {
c.to_uppercase().to_string()
} else {
c.to_lowercase().to_string()
}
})
.collect::<String>()
})
.collect::<Vec<_>>()
.join("-")
}
}
impl core::fmt::Display for Header {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}: {}", self.key, self.value)
}
}
impl std::str::FromStr for Header {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.split_once(':') {
Some((key, value)) => Ok(Self {
key: key.trim().to_string(),
value: value.trim().to_string(),
}),
None => anyhow::bail!("Invalid header format, please use key: value"),
}
}
}
#[derive(Default)]
pub struct Stat {
pub ip_address: Option<String>,
pub http_version: Option<String>,
pub name_lookup: Duration,
pub connect: Duration,
pub app_connect: Duration,
pub pre_transfer: Duration,
pub start_transfer: Duration,
pub total: Duration,
pub response_status_code: Option<i32>,
pub response_headers: Vec<Header>,
pub response_body: Vec<u8>,
}
impl Stat {
pub fn dns_lookup(&self) -> Option<Duration> {
Some(self.name_lookup)
}
pub fn tcp_handshake(&self) -> Option<Duration> {
if self.connect > self.name_lookup {
Some(self.connect - self.name_lookup)
} else {
None
}
}
pub fn tls_handshake(&self) -> Option<Duration> {
if self.app_connect > self.connect {
Some(self.app_connect - self.connect)
} else {
None
}
}
pub fn server_processing(&self) -> Option<Duration> {
if self.start_transfer > self.pre_transfer {
Some(self.start_transfer - self.pre_transfer)
} else {
None
}
}
pub fn content_transfer(&self) -> Option<Duration> {
if self.total > self.start_transfer {
Some(self.total - self.start_transfer)
} else {
None
}
}
pub fn utf8_response_body(&self) -> Option<String> {
if self.response_body.is_empty() {
return None;
}
let raw = String::from_utf8_lossy(&self.response_body);
let index = raw.find("\r\n\r\n").map(|i| i + 4).unwrap_or_default();
let body = &raw[index..];
Some(body.to_string())
}
}
impl<'a> TryFrom<&mut curl::easy::Easy2<Decorator<'a>>> for Stat {
type Error = anyhow::Error;
fn try_from(handle: &mut curl::easy::Easy2<Decorator<'a>>) -> Result<Self, Self::Error> {
let raw_headers = std::str::from_utf8(handle.get_ref().response_headers)?
.lines()
.map(|line| line.replace(['\r', '\n'], ""))
.filter(|line| !line.is_empty());
let mut headers: Vec<Header> = vec![];
let mut http_version = None;
let mut response_code = None;
for header in raw_headers {
if header.to_uppercase().starts_with("HTTP/") {
if let Some((_, h)) = header.split_once('/') {
let tail = h.split(' ').collect::<Vec<&str>>();
response_code = tail.get(1).and_then(|code| code.parse().ok());
http_version = tail.first().map(|v| v.to_string())
}
} else if let Some((name, value)) = header.split_once(':') {
headers.push(Header {
key: name.trim().to_string(),
value: value.trim().to_string(),
});
}
}
let ip_address = handle.primary_ip()?.map(|ip| ip.to_string());
Ok(Stat {
ip_address,
http_version,
response_status_code: response_code,
response_headers: headers,
name_lookup: handle.namelookup_time()?,
connect: handle.connect_time()?,
app_connect: handle.appconnect_time()?,
pre_transfer: handle.pretransfer_time()?,
start_transfer: handle.starttransfer_time()?,
total: handle.total_time()?,
response_body: handle.get_ref().response_body.to_owned(),
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum Method {
Get,
Head,
Post,
Put,
Delete,
Connect,
Options,
Trace,
Patch,
}
impl Default for Method {
fn default() -> Self {
Self::Get
}
}
impl<'a> From<&'a Method> for &'a str {
fn from(method: &'a Method) -> &'a str {
match method {
Method::Get => "GET",
Method::Head => "HEAD",
Method::Post => "POST",
Method::Put => "PUT",
Method::Delete => "DELETE",
Method::Connect => "CONNECT",
Method::Options => "OPTIONS",
Method::Trace => "TRACE",
Method::Patch => "PATCH",
}
}
}
impl TryFrom<&str> for Method {
type Error = anyhow::Error;
fn try_from(value: &str) -> Result<Self, Self::Error> {
match value.to_uppercase().as_str() {
"GET" => Ok(Self::Get),
"HEAD" => Ok(Self::Head),
"POST" => Ok(Self::Post),
"PUT" => Ok(Self::Put),
"DELETE" => Ok(Self::Delete),
"CONNECT" => Ok(Self::Connect),
"OPTIONS" => Ok(Self::Options),
"TRACE" => Ok(Self::Trace),
"PATCH" => Ok(Self::Patch),
_ => Err(anyhow::anyhow!("Invalid method, please use GET, HEAD, POST, PUT, DELETE, CONNECT, OPTIONS, TRACE, PATCH")),
}
}
}
pub fn send_request(conf: &Config) -> anyhow::Result<Stat> {
let mut headers = vec![];
let mut response = vec![];
let mut easy = curl::easy::Easy2::new(Decorator::new(conf, &mut headers, &mut response));
easy.url(&conf.url)?;
easy.show_header(true)?;
easy.follow_location(conf.follow_redirects)?;
easy.verbose(conf.verbose)?;
if !conf.request_headers.is_empty() {
let mut headers = curl::easy::List::new();
for header in &conf.request_headers {
headers.append(&header.to_string())?;
}
easy.http_headers(headers)?;
}
let data_size = conf.request_body.as_ref().map(|d| d.len() as u64);
match &conf.method {
Method::Get => easy.get(true)?,
Method::Head => easy.nobody(true)?,
Method::Post => {
easy.post(true)?;
if let Some(ds) = data_size {
easy.post_field_size(ds)?;
}
}
Method::Patch => {
easy.custom_request("PATCH")?;
if let Some(ds) = data_size {
easy.post_field_size(ds)?;
}
}
Method::Put => {
easy.put(true)?;
if let Some(ds) = data_size {
easy.in_filesize(ds)?;
}
}
_ => easy.custom_request((&conf.method).into())?,
}
easy.perform()?;
Stat::try_from(&mut easy)
}
#[cfg(test)]
mod test {
use super::*;
use httpmock::prelude::*;
use std::str::FromStr;
#[test]
fn test_header_key_value() {
let header = Header::from_str("content-type: application/json").unwrap();
assert_eq!(header.key, "content-type");
assert_eq!(header.header_key(), "Content-Type");
assert_eq!(header.value, "application/json");
}
#[test]
fn test_method_try_from_str() {
let table = vec![
("GET", Method::Get),
("POST", Method::Post),
("PUT", Method::Put),
("DELETE", Method::Delete),
("CONNECT", Method::Connect),
("OPTIONS", Method::Options),
("TRACE", Method::Trace),
("PATCH", Method::Patch),
("HEAD", Method::Head),
];
for (method, expected) in table {
let result = Method::try_from(method).unwrap();
assert_eq!(result, expected);
}
}
#[test]
fn test_method_try_from_str_invalid() {
let result = Method::try_from("INVALID");
assert!(result.is_err());
}
#[test]
fn test_method_from_str() {
let table = vec![
(Method::Get, "GET"),
(Method::Post, "POST"),
(Method::Put, "PUT"),
(Method::Delete, "DELETE"),
(Method::Connect, "CONNECT"),
(Method::Options, "OPTIONS"),
(Method::Trace, "TRACE"),
(Method::Patch, "PATCH"),
(Method::Head, "HEAD"),
];
for (method, expected) in table {
let result: &str = (&method).into();
assert_eq!(result, expected);
}
}
#[test]
fn test_send_request_body() {
let methods = vec![Method::Post, Method::Put, Method::Patch];
for method in methods {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.path("/");
then.status(200)
.header("content-type", "text/html")
.body("ohi");
});
let conf = Config {
url: server.url("/").into(),
method,
verbose: true,
follow_redirects: true,
request_body: Some("oh".into()),
request_headers: vec![Header::from_str("content-type: text/plain").unwrap()],
..Default::default()
};
let stat = send_request(&conf).unwrap();
mock.assert();
assert_eq!(stat.response_status_code.unwrap(), 200);
assert_eq!(stat.utf8_response_body().unwrap(), "ohi");
}
}
#[test]
fn test_send_request_nobody() {
let methods = vec![
Method::Head,
Method::Options,
Method::Trace,
Method::Connect,
Method::Delete,
Method::Get,
Method::Post,
Method::Put,
];
for method in methods {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.path("/");
then.status(200).header("content-type", "text/html");
});
let conf = Config {
url: server.url("/").into(),
method,
verbose: true,
follow_redirects: true,
..Default::default()
};
let stat = send_request(&conf).unwrap();
mock.assert();
assert_eq!(stat.response_status_code.unwrap(), 200);
}
}
#[test]
fn test_timing_stat() {
let stat = Stat {
name_lookup: Duration::from_secs(1),
connect: Duration::from_secs(2),
app_connect: Duration::from_secs(3),
pre_transfer: Duration::from_secs(4),
start_transfer: Duration::from_secs(5),
total: Duration::from_secs(6),
..Default::default()
};
let one_sec = Duration::from_secs(1);
assert_eq!(stat.dns_lookup().unwrap(), one_sec);
assert_eq!(stat.tcp_handshake().unwrap(), one_sec);
assert_eq!(stat.tls_handshake().unwrap(), one_sec);
assert_eq!(stat.server_processing().unwrap(), one_sec);
assert_eq!(stat.content_transfer().unwrap(), one_sec);
}
#[test]
fn test_no_tls() {
let stat = Stat {
name_lookup: Duration::from_secs(1),
connect: Duration::from_secs(2),
app_connect: Duration::from_secs(1),
pre_transfer: Duration::from_secs(4),
start_transfer: Duration::from_secs(5),
total: Duration::from_secs(6),
..Default::default()
};
assert!(stat.tls_handshake().is_none());
}
}