use std::collections::HashMap;
use std::io::{Read, Write};
use std::net::TcpStream;
use crate::zhttp::builder::RequestBuilder;
use crate::zhttp::types::{ClientConfig, Error, Method, Response};
#[derive(Debug, Clone)]
pub struct Client {
pub(crate) config: ClientConfig,
}
impl Client {
pub(crate) fn default() -> Self {
Self {
config: ClientConfig::default(),
}
}
pub(crate) fn new(config: ClientConfig) -> Self {
Self { config }
}
pub fn get(&self, url: impl Into<String>) -> RequestBuilder {
self.create_request(Method::GET, url)
}
pub fn post(&self, url: impl Into<String>) -> RequestBuilder {
self.create_request(Method::POST, url)
}
pub fn put(&self, url: impl Into<String>) -> RequestBuilder {
self.create_request(Method::PUT, url)
}
pub fn delete(&self, url: impl Into<String>) -> RequestBuilder {
self.create_request(Method::DELETE, url)
}
pub fn head(&self, url: impl Into<String>) -> RequestBuilder {
self.create_request(Method::HEAD, url)
}
pub fn options(&self, url: impl Into<String>) -> RequestBuilder {
self.create_request(Method::OPTIONS, url)
}
pub fn patch(&self, url: impl Into<String>) -> RequestBuilder {
self.create_request(Method::PATCH, url)
}
fn create_request(&self, method: Method, url: impl Into<String>) -> RequestBuilder {
RequestBuilder {
method,
url: url.into(),
headers: Default::default(),
body: None,
body_format: None,
client: self.clone(),
timeout: None,
retries: None,
follow_redirects: None,
verify_ssl: None,
}
}
pub(crate) fn send_request(
&self,
builder: RequestBuilder,
) -> std::result::Result<Response, Error> {
let mut redirects = 0;
let mut current_url = url::Url::parse(&builder.url).map_err(Error::UrlParse)?;
let follow_redirects = builder
.follow_redirects
.unwrap_or(self.config.allow_redirects);
let max_redirects = self.config.max_redirects;
loop {
let response = self.try_request(&builder, current_url.as_str())?;
if follow_redirects && (300..400).contains(&response.status) {
if let Some(location) = response.header("Location") {
redirects += 1;
if redirects > max_redirects {
return Err(Error::TooManyRedirects);
}
current_url = match url::Url::parse(location) {
Ok(absolute_url) => absolute_url,
Err(url::ParseError::RelativeUrlWithoutBase) => {
current_url.join(location).map_err(Error::UrlParse)?
}
Err(e) => return Err(Error::UrlParse(e)),
};
continue;
}
}
return Ok(response);
}
}
fn try_request(
&self,
builder: &RequestBuilder,
url: &str,
) -> std::result::Result<Response, Error> {
let retries = builder.retries.unwrap_or(self.config.max_retries);
let mut last_error = None;
for attempt in 0..=retries {
if attempt > 0 {
if self.config.debug {
eprintln!(">>> Retry attempt {} of {}", attempt, retries);
}
std::thread::sleep(self.config.retry_interval);
}
match self.execute_request(builder, url) {
Ok(response) => {
if response.status == 429 || response.status >= 500 {
if attempt < retries {
if self.config.debug {
eprintln!(">>> Got status {}, will retry", response.status);
}
if let Some(retry_after) = response.header("Retry-After") {
if let Ok(secs) = retry_after.parse::<u64>() {
std::thread::sleep(std::time::Duration::from_secs(secs));
}
}
continue;
}
}
return Ok(response);
}
Err(e) => {
last_error = Some(e);
if attempt < retries {
continue;
}
}
}
}
Err(last_error.unwrap_or_else(|| Error::Custom("Unknown error".into())))
}
fn execute_request(
&self,
builder: &RequestBuilder,
url: &str,
) -> std::result::Result<Response, Error> {
let url = url::Url::parse(url).map_err(Error::UrlParse)?;
let host = url
.host_str()
.ok_or_else(|| Error::Custom("Invalid host".into()))?;
let port = url
.port()
.unwrap_or(if url.scheme() == "https" { 443 } else { 80 });
if self.config.debug {
eprintln!(">>> Request URL: {}", url);
eprintln!(">>> Method: {}", builder.method.as_str());
}
let path = self.build_request_path(&url);
let addr = format!("{}:{}", host, port);
if self.config.debug {
eprintln!(">>> Connecting to: {}", addr);
}
let mut stream = self.create_connection(&addr, builder)?;
let request = self.build_request(builder, &path, host)?;
self.send_request_data(&mut stream, &request, builder)?;
let response = self.read_response(&mut stream)?;
self.parse_response(&response)
}
fn build_request_path(&self, url: &url::Url) -> String {
let mut path = if url.path().is_empty() {
String::from("/")
} else {
String::from(url.path())
};
if let Some(query) = url.query() {
path = format!("{}?{}", path, query);
}
path
}
fn create_connection(
&self,
addr: &str,
builder: &RequestBuilder,
) -> std::result::Result<TcpStream, Error> {
let stream = TcpStream::connect(addr).map_err(Error::Network)?;
let timeout = builder.timeout.unwrap_or(self.config.timeout);
stream
.set_read_timeout(Some(timeout))
.map_err(Error::Network)?;
stream
.set_write_timeout(Some(timeout))
.map_err(Error::Network)?;
Ok(stream)
}
fn build_request(
&self,
builder: &RequestBuilder,
path: &str,
host: &str,
) -> std::result::Result<String, Error> {
let mut request = format!(
"{} {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: {}\r\n",
builder.method.as_str(),
path,
host,
self.config.user_agent
);
for (key, value) in &builder.headers {
request.push_str(&format!("{}: {}\r\n", key, value));
}
if let Some(body) = &builder.body {
if !builder.headers.contains_key("Content-Length") {
request.push_str(&format!("Content-Length: {}\r\n", body.len()));
}
if !builder.headers.contains_key("Content-Type") {
let content_type = builder
.body_format
.as_ref()
.map(|f| f.content_type())
.unwrap_or_else(|| "application/x-www-form-urlencoded".into());
request.push_str(&format!("Content-Type: {}\r\n", content_type));
}
}
if !builder.headers.contains_key("Accept") {
request.push_str("Accept: */*\r\n");
}
request.push_str("Connection: close\r\n\r\n");
Ok(request)
}
fn send_request_data(
&self,
stream: &mut TcpStream,
request: &str,
builder: &RequestBuilder,
) -> std::result::Result<(), Error> {
if self.config.debug {
eprintln!(">>> Request Headers:");
for line in request.lines() {
eprintln!(">>> {}", line);
}
}
stream
.write_all(request.as_bytes())
.map_err(Error::Network)?;
if let Some(body) = &builder.body {
if self.config.debug {
eprintln!(">>> Request Body: {} bytes", body.len());
if let Ok(body_str) = String::from_utf8(body.clone()) {
eprintln!(">>> {}", body_str);
}
}
stream.write_all(body).map_err(Error::Network)?;
}
Ok(())
}
fn read_response(&self, stream: &mut TcpStream) -> std::result::Result<Vec<u8>, Error> {
if self.config.debug {
eprintln!(">>> Reading response...");
}
let mut response = Vec::new();
let mut buffer = [0; 4096];
loop {
match stream.read(&mut buffer) {
Ok(0) => break,
Ok(n) => response.extend_from_slice(&buffer[..n]),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
return Err(Error::Timeout);
}
Err(e) => return Err(Error::Network(e)),
}
}
if response.is_empty() {
return Err(Error::ResponseParse("Empty response from server".into()));
}
if self.config.debug {
eprintln!(">>> Response received: {} bytes", response.len());
}
Ok(response)
}
fn parse_response(&self, response: &[u8]) -> std::result::Result<Response, Error> {
let response_str = String::from_utf8_lossy(response);
let mut headers = HashMap::new();
let separator_pos = response_str.find("\r\n\r\n").ok_or_else(|| {
Error::ResponseParse("Invalid response format: missing header separator".into())
})?;
let headers_part = &response_str[..separator_pos];
let mut lines = headers_part.lines();
let status = lines
.next()
.ok_or_else(|| Error::ResponseParse("Empty response headers".into()))
.and_then(|status_line| {
let parts: Vec<&str> = status_line.split_whitespace().collect();
if parts.len() >= 2 {
parts[1].parse().map_err(|_| Error::InvalidStatus)
} else {
Err(Error::ResponseParse("Invalid status line".into()))
}
})?;
for line in lines {
if let Some((key, value)) = line.split_once(':') {
headers.insert(key.trim().to_string(), value.trim().to_string());
}
}
let body = response[separator_pos + 4..].to_vec();
Ok(Response {
status,
headers,
body,
})
}
}