use std::{fmt, time::Duration};
use crate::{
http::{Error, HttpClient, Metrics, Request, Response},
io::{BufReader, ReadLimited},
};
use futures_lite::{AsyncBufReadExt as _, AsyncReadExt, AsyncWriteExt as _, io::Cursor};
use http::Uri;
use once_cell::sync::Lazy;
use zng_unit::{ByteLength, ByteUnits as _};
use zng_var::{Var, const_var, var};
use super::uri::Scheme;
#[derive(Default)]
pub struct CurlProcessClient {}
impl HttpClient for CurlProcessClient {
fn send(&'static self, request: Request) -> std::pin::Pin<Box<dyn Future<Output = Result<Response, Error>> + Send>> {
Box::pin(run(request))
}
fn is_cache_manager(&self) -> bool {
false
}
}
async fn run(request: Request) -> Result<Response, Error> {
let not_http = match request.uri.scheme() {
Some(s) => s != &Scheme::HTTP && s != &Scheme::HTTPS,
None => true,
};
if not_http {
return Err(Box::new(NotHttpUriError));
}
let mut curl = crate::process::Command::new(&*CURL);
curl.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());
curl.arg("--include");
curl.arg("--http1.1");
curl.arg("-X").arg(request.method.as_str());
#[cfg(feature = "http_compression")]
if request.auto_decompress && !request.headers.contains_key(http::header::ACCEPT_ENCODING) {
curl.arg("-H").arg("accept-encoding").arg("zstd, br, gzip");
}
for (name, value) in request.headers {
if let Some(name) = name
&& let Ok(value) = value.to_str()
{
curl.arg("-H").arg(format!("{name}: {value}"));
}
}
let connect_timeout = request.timeout.min(request.connect_timeout);
if connect_timeout < Duration::MAX {
curl.arg("--connect-timeout").arg(request.connect_timeout.as_secs().to_string());
}
if request.timeout < Duration::MAX {
curl.arg("--max-time").arg(request.timeout.as_secs().to_string());
}
if request.low_speed_timeout.0 < Duration::MAX && request.low_speed_timeout.1 > 0.bytes() {
curl.arg("-y")
.arg(request.low_speed_timeout.0.as_secs().to_string())
.arg("-Y")
.arg(request.low_speed_timeout.1.bytes().to_string());
}
if request.redirect_limit > 0 {
curl.arg("-L").arg("--max-redirs").arg(request.redirect_limit.to_string());
}
let rate_limit = request.max_upload_speed.min(request.max_download_speed);
if rate_limit < ByteLength::MAX {
curl.arg("--limit-rate").arg(format!("{}K", rate_limit.kibis()));
}
if !request.body.is_empty() {
curl.arg("--data-binary").arg("@-");
}
curl.arg(request.uri.to_string());
let mut curl = curl.spawn()?;
let mut stdin = curl.stdin.take().unwrap();
let mut stdout = BufReader::new(curl.stdout.take().unwrap());
let stderr = curl.stderr.take().unwrap();
if !request.body.is_empty() {
stdin.write_all(&request.body[..]).await?;
stdin.flush().await?;
}
stdin.close().await?;
drop(stdin);
let metrics = if request.metrics {
let m = var(Metrics::zero());
read_metrics(m.clone(), stderr);
m.read_only()
} else {
const_var(Metrics::zero())
};
let mut response_bytes = Vec::with_capacity(1024);
let mut buffer = [0u8; 1024];
let mut effective_uri = request.uri;
loop {
let bytes_read = stdout.read(&mut buffer).await?;
if bytes_read == 0 && response_bytes.is_empty() {
Err(Box::new(UnexpectedPartialError))?;
}
response_bytes.extend_from_slice(&buffer[..bytes_read]);
let mut response_headers = [httparse::EMPTY_HEADER; 64];
let mut response = httparse::Response::new(&mut response_headers);
match response.parse(&response_bytes)? {
httparse::Status::Complete(header_length) => {
let code = http::StatusCode::from_u16(response.code.unwrap_or(502))?;
if code.is_redirection()
&& let Some(l) = response.headers.iter().find(|h| h.name.eq_ignore_ascii_case("Location"))
&& let Ok(l) = std::str::from_utf8(l.value)
&& let Ok(l) = l.parse::<Uri>()
{
effective_uri = l;
let content_length = response
.headers
.iter()
.find(|h| h.name.eq_ignore_ascii_case("Content-Length"))
.and_then(|h| std::str::from_utf8(h.value).ok())
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(0);
let redirect_rsp_len = header_length + content_length;
if response_bytes.len() > redirect_rsp_len {
response_bytes.drain(..redirect_rsp_len);
}
continue;
}
let initial_body_chunk = &response_bytes[header_length..];
return run_response(
response,
effective_uri,
#[cfg(feature = "http_compression")]
request.auto_decompress,
request.require_length,
request.max_length,
metrics,
initial_body_chunk,
stdout,
);
}
httparse::Status::Partial => {
continue;
}
}
}
}
fn read_metrics(metrics: Var<Metrics>, stderr: crate::process::ChildStderr) {
let mut stderr = BufReader::new(stderr);
let mut progress_bytes = Vec::with_capacity(92);
let mut run = async move || -> std::io::Result<()> {
loop {
progress_bytes.clear();
let len = stderr.read_until(b'\r', &mut progress_bytes).await?;
if len == 0 {
break;
}
let progress = str::from_utf8(&progress_bytes).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
if !progress.trim_start().chars().next().unwrap_or('\0').is_ascii_digit() {
continue;
}
let mut iter = progress.split_whitespace();
let _pct = iter.next();
let _total = iter.next();
let pct_down: u8 = iter.next().unwrap_or("100").parse().unwrap_or(100);
let down = parse_curl_bytes(iter.next().unwrap_or("0"));
let response_total = (down.0 as f64 * 100.0 / pct_down as f64).bytes();
let pct_up: u8 = iter.next().unwrap_or("100").parse().unwrap_or(100);
let up = parse_curl_bytes(iter.next().unwrap_or("0"));
let request_total = (up.0 as f64 * 100.0 / pct_up as f64).bytes();
let down_speed = parse_curl_bytes(iter.next().unwrap_or("0"));
let up_speed = parse_curl_bytes(iter.next().unwrap_or("0"));
let _total_time = iter.next();
let time_current = parse_curl_duration(iter.next().unwrap_or("HH:MM:SS"));
metrics.set(Metrics {
read_progress: (down, response_total),
read_speed: down_speed,
write_progress: (up, request_total),
write_speed: up_speed,
total_time: time_current,
});
}
Ok(())
};
crate::spawn(async move {
let _ = run().await;
});
}
fn parse_curl_bytes(s: &str) -> ByteLength {
let (s, scale) = if let Some(s) = s.strip_suffix("K") {
(s, 2usize.pow(10))
} else if let Some(s) = s.strip_suffix("M") {
(s, 2usize.pow(20))
} else if let Some(s) = s.strip_prefix("G") {
(s, 2usize.pow(30))
} else if let Some(s) = s.strip_prefix("T") {
(s, 2usize.pow(40))
} else if let Some(s) = s.strip_prefix("P") {
(s, 2usize.pow(50))
} else {
(s, 1)
};
let l: usize = s.parse().unwrap_or(0);
ByteLength::from_byte(l * scale)
}
fn parse_curl_duration(s: &str) -> Duration {
let mut iter = s.split(':');
let h: usize = iter.next().unwrap_or("0").parse().unwrap_or(0);
let m: u8 = iter.next().unwrap_or("0").parse().unwrap_or(0);
let s: u8 = iter.next().unwrap_or("0").parse().unwrap_or(0);
Duration::from_hours(h as _) + Duration::from_mins(m as _) + Duration::from_secs(s as _)
}
fn run_response(
response: httparse::Response<'_, '_>,
effective_uri: Uri,
#[cfg(feature = "http_compression")] auto_decompress: bool,
require_length: bool,
max_length: ByteLength,
metrics: Var<Metrics>,
initial_body_chunk: &[u8],
reader: BufReader<crate::process::ChildStdout>,
) -> Result<Response, Error> {
let reader = Cursor::new(initial_body_chunk.to_owned()).chain(reader);
let code = http::StatusCode::from_u16(response.code.unwrap_or(502))?;
let mut header = http::header::HeaderMap::new();
for r in response.headers {
if r.name.is_empty() {
continue;
}
header.append(
http::HeaderName::from_bytes(r.name.as_bytes())?,
http::HeaderValue::from_bytes(r.value)?,
);
}
if require_length {
if let Some(l) = header.get(http::header::CONTENT_LENGTH)
&& let Ok(l) = l.to_str()
&& let Ok(l) = l.parse::<usize>()
{
if l < max_length.bytes() {
return Err(Box::new(ContentLengthExceedsMaxError));
}
} else {
return Err(Box::new(ContentLengthRequiredError));
}
}
let reader = ReadLimited::new_default_err(reader, max_length);
macro_rules! respond {
($read:expr) => {
return Ok(Response::from_read(code, header, effective_uri, metrics, Box::new($read)))
};
}
#[cfg(feature = "http_compression")]
if auto_decompress && let Some(enc) = header.get(http::header::CONTENT_ENCODING) {
if enc == "zstd" {
respond!(async_compression::futures::bufread::ZstdDecoder::new(reader))
} else if enc == "br" {
respond!(async_compression::futures::bufread::BrotliDecoder::new(reader))
} else if enc == "gzip" {
respond!(async_compression::futures::bufread::GzipDecoder::new(reader))
}
}
respond!(reader)
}
static CURL: Lazy<String> = Lazy::new(|| std::env::var("ZNG_CURL").unwrap_or_else(|_| "curl".to_owned()));
#[derive(Debug)]
struct NotHttpUriError;
impl fmt::Display for NotHttpUriError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "uri is not HTTP or HTTPS")
}
}
impl std::error::Error for NotHttpUriError {}
#[derive(Debug)]
struct ContentLengthRequiredError;
impl fmt::Display for ContentLengthRequiredError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "response content length is required")
}
}
impl std::error::Error for ContentLengthRequiredError {}
#[derive(Debug)]
struct ContentLengthExceedsMaxError;
impl fmt::Display for ContentLengthExceedsMaxError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "response content length is exceeds maximum")
}
}
impl std::error::Error for ContentLengthExceedsMaxError {}
#[derive(Debug)]
struct UnexpectedPartialError;
impl fmt::Display for UnexpectedPartialError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "unexpected partial response from curl")
}
}
impl std::error::Error for UnexpectedPartialError {}