use std::sync::{Arc, OnceLock};
use async_trait::async_trait;
use regex::RegexSet;
pub use ureq::http::request::Builder;
pub use ureq::http::{self, Method, Request, Response, StatusCode, Uri};
use ureq::middleware::Middleware;
use ureq::tls::TlsConfig;
pub use ureq::{Agent, AsSendBody, Body, RequestBuilder, SendBody};
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
pub enum RedirectPolicy {
#[default]
NoFollow,
FollowLimit(u32),
FollowAll,
}
pub struct FollowRedirects(pub bool);
pub trait HttpRequestExt {
fn follow_redirects(self, follow: RedirectPolicy) -> Self;
}
impl HttpRequestExt for Builder {
fn follow_redirects(self, follow: RedirectPolicy) -> Self {
self.extension(follow)
}
}
#[async_trait]
pub trait HttpClient: Send + Sync + 'static {
fn type_name(&self) -> &'static str;
async fn send(
&self,
req_builder: Builder,
body: Option<Vec<u8>>,
) -> Result<Response<Body>, anyhow::Error>;
async fn get(
&self,
uri: &str,
follow_redirects: bool,
) -> Result<Response<Body>, anyhow::Error> {
let request_builder = Builder::new()
.uri(uri)
.method(Method::GET)
.follow_redirects(if follow_redirects {
RedirectPolicy::FollowAll
} else {
RedirectPolicy::NoFollow
});
self.send(request_builder, None).await
}
async fn post_json(
&self,
uri: &str,
body: Option<Vec<u8>>,
) -> Result<Response<Body>, anyhow::Error> {
let request_builder = Builder::new()
.uri(uri)
.method(Method::POST)
.header("Content-Type", "application/json");
self.send(request_builder, body).await
}
}
#[warn(dead_code)]
#[derive(Clone, Debug)]
pub struct DefaultHttpClient {
agent: Agent,
}
impl DefaultHttpClient {
pub fn new(proxy: Option<String>, allow_hosts: Option<Vec<String>>) -> Self {
let req_proxy = match &proxy {
Some(proxy_url) => Some(ureq::Proxy::new(&proxy_url).unwrap()),
None => None,
};
let user_agent = format!("wasm-http/{}", env!("CARGO_PKG_VERSION"));
let config = ureq::config::Config::builder()
.user_agent(ureq::config::AutoHeaderValue::Provided(Arc::new(
user_agent.to_string(),
)))
.tls_config(TlsConfig::builder().disable_verification(false).build())
.proxy(req_proxy)
.middleware(AllowHostMiddleware::new(allow_hosts.clone()))
.build();
let agent = Agent::new_with_config(config);
Self { agent }
}
}
#[async_trait]
impl HttpClient for DefaultHttpClient {
fn type_name(&self) -> &'static str {
std::any::type_name::<Self>()
}
async fn send(
&self,
req_builder: Builder,
body: Option<Vec<u8>>,
) -> Result<Response<Body>, anyhow::Error> {
match body {
Some(bytes) => self
.agent
.run(req_builder.body(bytes)?)
.map_err(anyhow::Error::from),
None => self
.agent
.run(req_builder.body(())?)
.map_err(anyhow::Error::from),
}
}
}
struct AllowHostMiddleware {
allows: Vec<String>,
}
impl AllowHostMiddleware {
pub fn new(allows: Option<Vec<String>>) -> Self {
Self {
allows: allows.unwrap_or_default(),
}
}
}
impl Middleware for AllowHostMiddleware {
fn handle(
&self,
request: http::Request<SendBody>,
next: ureq::middleware::MiddlewareNext,
) -> Result<http::Response<Body>, ureq::Error> {
static ALLOW_REGEX: OnceLock<RegexSet> = OnceLock::new();
let allow_regex =
ALLOW_REGEX.get_or_init(|| RegexSet::new(self.allows.as_slice()).expect("regex error"));
if !allow_regex
.matches(&request.uri().to_string())
.matched_any()
{
let body = Body::builder()
.mime_type("application/json")
.charset("utf-8")
.data(r#"{"code":"-1","msg":"uri is not allowed"}"#);
return Response::builder()
.status(403)
.header("Content-Type", "application/json")
.body(body)
.map_err(|e| ureq::Error::from(e));
}
next.handle(request)
}
}
#[cfg(test)]
mod tests {
use crate::{DefaultHttpClient, HttpClient};
#[tokio::test]
async fn test_get() -> anyhow::Result<()> {
let http_client = DefaultHttpClient::new(
None,
Some(vec![r"https?://httpbin.org/**".to_string()]),
);
for _i in 0..100 {
let mut resp = http_client.get("https://httpbin.org/get", true).await?;
let body = resp.body_mut().read_to_string()?;
println!("body:{body}");
}
Ok(())
}
#[tokio::test]
async fn test_reqwest_get() -> anyhow::Result<()> {
let client = reqwest::Client::builder().gzip(true).build()?;
for _i in 0..100 {
let resp = client.get("https://httpbin.org/get").send().await?;
let body = resp.text().await?;
println!("body:{body}");
}
Ok(())
}
#[tokio::test]
async fn test_post() -> anyhow::Result<()> {
let http_client =
DefaultHttpClient::new(None, Some(vec!["https://httpbin.org/**".to_string()]));
let mut resp = http_client
.post_json(
"https://httpbin.org/post",
Some(b"{\"id\":1,\"name\":\"zhangsan\"}".to_vec()),
)
.await?;
println!("resp:{:#?}", &resp);
let body = resp.body_mut().read_to_string()?;
println!("body:{body}");
Ok(())
}
}