use std::{fs::File, io, path::Path, str::FromStr};
use anyhow::{anyhow, Result};
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::{blocking::multipart, Method};
use structopt::clap;
use crate::cli::RequestType;
pub const FORM_CONTENT_TYPE: &str = "application/x-www-form-urlencoded";
pub const JSON_CONTENT_TYPE: &str = "application/json";
pub const JSON_ACCEPT: &str = "application/json, */*;q=0.5";
#[derive(Debug, Clone, PartialEq)]
pub enum RequestItem {
HttpHeader(String, String),
HttpHeaderToUnset(String),
UrlParam(String, String),
DataField(String, String),
JSONField(String, serde_json::Value),
FormFile(String, String, Option<String>),
}
impl FromStr for RequestItem {
type Err = clap::Error;
fn from_str(request_item: &str) -> clap::Result<RequestItem> {
const SPECIAL_CHARS: &str = "=@:;\\";
const SEPS: &[&str] = &["==", ":=", "=", "@", ":"];
fn unescape(text: &str) -> String {
let mut out = String::new();
let mut chars = text.chars();
while let Some(ch) = chars.next() {
if ch == '\\' {
match chars.next() {
Some(next) if SPECIAL_CHARS.contains(next) => {
out.push(next);
}
Some(next) => {
out.push(ch);
out.push(next);
}
None => {
out.push(ch);
}
}
} else {
out.push(ch);
}
}
out
}
fn split(request_item: &str) -> Option<(String, &'static str, String)> {
let mut char_inds = request_item.char_indices();
while let Some((ind, ch)) = char_inds.next() {
if ch == '\\' {
char_inds.next();
continue;
}
for sep in SEPS {
if let Some(value) = request_item[ind..].strip_prefix(sep) {
let key = &request_item[..ind];
return Some((unescape(key), sep, unescape(value)));
}
}
}
None
}
if let Some((key, sep, value)) = split(request_item) {
match sep {
"==" => Ok(RequestItem::UrlParam(key, value)),
"=" => Ok(RequestItem::DataField(key, value)),
":=" => Ok(RequestItem::JSONField(
key,
serde_json::from_str(&value).map_err(|err| {
clap::Error::with_description(
&format!("{:?}: {}", request_item, err),
clap::ErrorKind::InvalidValue,
)
})?,
)),
"@" => {
let with_type: Vec<&str> = value.rsplitn(2, ";type=").collect();
if let Some(&typed_filename) = with_type.get(1) {
Ok(RequestItem::FormFile(
key,
typed_filename.to_owned(),
Some(with_type[0].to_owned()),
))
} else {
Ok(RequestItem::FormFile(key, value, None))
}
}
":" if value.is_empty() => Ok(RequestItem::HttpHeaderToUnset(key)),
":" => Ok(RequestItem::HttpHeader(key, value)),
_ => unreachable!(),
}
} else if let Some(header) = request_item.strip_suffix(';') {
Ok(RequestItem::HttpHeader(header.to_owned(), "".to_owned()))
} else {
Err(clap::Error::with_description(
&format!("{:?} is not a valid request item", request_item),
clap::ErrorKind::InvalidValue,
))
}
}
}
pub struct RequestItems(pub Vec<RequestItem>);
pub enum Body {
Json(serde_json::Map<String, serde_json::Value>),
Form(Vec<(String, String)>),
Multipart(multipart::Form),
Raw(Vec<u8>),
}
impl Body {
pub fn is_empty(&self) -> bool {
match self {
Body::Json(map) => map.is_empty(),
Body::Form(items) => items.is_empty(),
Body::Raw(data) => data.is_empty(),
Body::Multipart(_) => false,
}
}
pub fn pick_method(&self) -> Method {
if self.is_empty() {
Method::GET
} else {
Method::POST
}
}
pub fn is_multipart(&self) -> bool {
matches!(self, Body::Multipart(..))
}
}
impl RequestItems {
pub fn new(request_items: Vec<RequestItem>) -> RequestItems {
RequestItems(request_items)
}
pub fn has_form_files(&self) -> bool {
self.0
.iter()
.any(|item| matches!(item, RequestItem::FormFile(..)))
}
pub fn headers(&self) -> Result<(HeaderMap<HeaderValue>, Vec<HeaderName>)> {
let mut headers = HeaderMap::new();
let mut headers_to_unset = vec![];
for item in &self.0 {
match item {
RequestItem::HttpHeader(key, value) => {
let key = HeaderName::from_bytes(&key.as_bytes())?;
let value = HeaderValue::from_str(&value)?;
headers.insert(key, value);
}
RequestItem::HttpHeaderToUnset(key) => {
let key = HeaderName::from_bytes(&key.as_bytes())?;
headers_to_unset.push(key);
}
_ => {}
}
}
Ok((headers, headers_to_unset))
}
pub fn query(&self) -> Vec<(&str, &str)> {
let mut query = vec![];
for item in &self.0 {
if let RequestItem::UrlParam(key, value) = item {
query.push((key.as_str(), value.as_str()));
}
}
query
}
fn body_as_json(self) -> Result<Body> {
let mut body = serde_json::Map::new();
for item in self.0 {
match item {
RequestItem::JSONField(key, value) => {
body.insert(key, value);
}
RequestItem::DataField(key, value) => {
body.insert(key, serde_json::Value::String(value));
}
RequestItem::FormFile(_, _, _) => {
return Err(anyhow!(
"Sending Files is not supported when the request body is in JSON format"
));
}
_ => {}
}
}
Ok(Body::Json(body))
}
fn body_as_form(self) -> Result<Body> {
let mut text_fields = Vec::<(String, String)>::new();
for item in self.0 {
match item {
RequestItem::JSONField(_, _) => {
return Err(anyhow!("JSON values are not supported in Form fields"));
}
RequestItem::DataField(key, value) => text_fields.push((key, value)),
RequestItem::FormFile(..) => unreachable!(),
_ => {}
}
}
Ok(Body::Form(text_fields))
}
fn body_as_multipart(self) -> Result<Body> {
let mut form = multipart::Form::new();
for item in self.0 {
match item {
RequestItem::JSONField(_, _) => {
return Err(anyhow!("JSON values are not supported in multipart fields"));
}
RequestItem::DataField(key, value) => {
form = form.text(key, value);
}
RequestItem::FormFile(key, value, file_type) => {
let mut part = file_to_part(&value)?;
if let Some(file_type) = file_type {
part = part.mime_str(&file_type)?;
}
form = form.part(key, part);
}
_ => {}
}
}
Ok(Body::Multipart(form))
}
pub fn body(self, request_type: RequestType) -> Result<Body> {
match request_type {
RequestType::Multipart => self.body_as_multipart(),
RequestType::Form if self.has_form_files() => self.body_as_multipart(),
RequestType::Form => self.body_as_form(),
RequestType::Json => self.body_as_json(),
}
}
pub fn pick_method(&self, request_type: RequestType) -> Method {
if request_type == RequestType::Multipart {
return Method::POST;
}
for item in &self.0 {
match item {
RequestItem::HttpHeader(..)
| RequestItem::HttpHeaderToUnset(..)
| RequestItem::UrlParam(..) => continue,
RequestItem::DataField(..)
| RequestItem::JSONField(..)
| RequestItem::FormFile(..) => return Method::POST,
}
}
Method::GET
}
}
pub fn file_to_part(path: impl AsRef<Path>) -> io::Result<multipart::Part> {
let path = path.as_ref();
let file_name = path
.file_name()
.map(|file_name| file_name.to_string_lossy().to_string());
let file = File::open(path)?;
let file_length = file.metadata()?.len();
let mut part = multipart::Part::reader_with_length(file, file_length);
if let Some(file_name) = file_name {
part = part.file_name(file_name);
}
Ok(part)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn request_item_parsing() {
use serde_json::json;
use RequestItem::*;
fn parse(text: &str) -> RequestItem {
text.parse().unwrap()
}
assert_eq!(parse("foo=bar"), DataField("foo".into(), "bar".into()));
assert_eq!(parse("foo==bar"), UrlParam("foo".into(), "bar".into()));
assert_eq!(parse(r"foo\==bar"), DataField("foo=".into(), "bar".into()));
assert_eq!(parse("foo:bar"), HttpHeader("foo".into(), "bar".into()));
assert_eq!(parse("foo:=[1,2]"), JSONField("foo".into(), json!([1, 2])));
"foo:=bar".parse::<RequestItem>().unwrap_err();
assert_eq!(
parse(r"f\o\o=\ba\r"),
DataField(r"f\o\o".into(), r"\ba\r".into()),
);
assert_eq!(
parse(r"f\=\:\@\;oo=b\:\:\:ar"),
DataField("f=:@;oo".into(), "b:::ar".into()),
);
assert_eq!(parse("foobar:"), HttpHeaderToUnset("foobar".into()));
assert_eq!(parse("foobar;"), HttpHeader("foobar".into(), "".into()));
assert_eq!(parse("foo@bar"), FormFile("foo".into(), "bar".into(), None));
assert_eq!(
parse("foo@bar;type=qux"),
FormFile("foo".into(), "bar".into(), Some("qux".into())),
);
assert_eq!(
parse("foo@bar;type=qux;type=qux"),
FormFile("foo".into(), "bar;type=qux".into(), Some("qux".into())),
);
assert_eq!(parse("foo@"), FormFile("foo".into(), "".into(), None));
"foobar".parse::<RequestItem>().unwrap_err();
"".parse::<RequestItem>().unwrap_err();
assert_eq!(parse(r"foo=bar\"), DataField("foo".into(), r"bar\".into()));
assert_eq!(parse(r"foo\\=bar"), DataField(r"foo\".into(), "bar".into()),);
assert_eq!(
parse("\u{00B5}=\u{00B5}"),
DataField("\u{00B5}".into(), "\u{00B5}".into()),
);
assert_eq!(parse("="), DataField("".into(), "".into()));
}
}