use super::config::RetryStrategy;
use super::error::HttpError;
use serde::Serialize;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Method {
GET,
POST,
PUT,
DELETE,
PATCH,
HEAD,
OPTIONS,
}
impl std::fmt::Display for Method {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Method::GET => write!(f, "GET"),
Method::POST => write!(f, "POST"),
Method::PUT => write!(f, "PUT"),
Method::DELETE => write!(f, "DELETE"),
Method::PATCH => write!(f, "PATCH"),
Method::HEAD => write!(f, "HEAD"),
Method::OPTIONS => write!(f, "OPTIONS"),
}
}
}
#[derive(Debug, Clone)]
pub(crate) enum RequestBody {
None,
Json(Vec<u8>),
Form(String),
Bytes(Vec<u8>),
Text(String),
}
#[derive(Debug)]
pub struct RequestBuilder {
pub(crate) method: Method,
pub(crate) url: String,
pub(crate) headers: HashMap<String, String>,
pub(crate) query: Vec<(String, String)>,
pub(crate) body: RequestBody,
pub(crate) timeout: Option<Duration>,
pub(crate) max_retries: Option<u32>,
pub(crate) retry_strategy: Option<RetryStrategy>,
}
impl RequestBuilder {
pub fn new(method: Method, url: impl Into<String>) -> Self {
Self {
method,
url: url.into(),
headers: HashMap::new(),
query: Vec::new(),
body: RequestBody::None,
timeout: None,
max_retries: None,
retry_strategy: None,
}
}
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(name.into(), value.into());
self
}
pub fn headers(mut self, headers: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>) -> Self {
for (k, v) in headers {
self.headers.insert(k.into(), v.into());
}
self
}
pub fn bearer_auth(self, token: impl AsRef<str>) -> Self {
self.header("Authorization", format!("Bearer {}", token.as_ref()))
}
pub fn basic_auth(self, username: impl AsRef<str>, password: Option<&str>) -> Self {
use base64::Engine;
let credentials = match password {
Some(p) => format!("{}:{}", username.as_ref(), p),
None => username.as_ref().to_string(),
};
let encoded = base64::engine::general_purpose::STANDARD.encode(credentials);
self.header("Authorization", format!("Basic {}", encoded))
}
pub fn query(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.query.push((key.into(), value.into()));
self
}
pub fn queries(mut self, params: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>) -> Self {
for (k, v) in params {
self.query.push((k.into(), v.into()));
}
self
}
pub fn json<T: Serialize>(mut self, body: &T) -> Result<Self, HttpError> {
let bytes = serde_json::to_vec(body)
.map_err(|e| HttpError::JsonSerialize(e.to_string()))?;
self.body = RequestBody::Json(bytes);
self.headers.insert("Content-Type".into(), "application/json".into());
Ok(self)
}
pub fn form<T: Serialize>(mut self, body: &T) -> Result<Self, HttpError> {
let encoded = serde_urlencoded::to_string(body)
.map_err(|e| HttpError::FormEncode(e.to_string()))?;
self.body = RequestBody::Form(encoded);
self.headers
.insert("Content-Type".into(), "application/x-www-form-urlencoded".into());
Ok(self)
}
pub fn body(mut self, bytes: impl Into<Vec<u8>>) -> Self {
self.body = RequestBody::Bytes(bytes.into());
self
}
pub fn text(mut self, text: impl Into<String>) -> Self {
self.body = RequestBody::Text(text.into());
self.headers.insert("Content-Type".into(), "text/plain; charset=utf-8".into());
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn max_retries(mut self, retries: u32) -> Self {
self.max_retries = Some(retries);
self
}
pub fn no_retry(mut self) -> Self {
self.max_retries = Some(0);
self.retry_strategy = Some(RetryStrategy::None);
self
}
pub fn with_cancellation(self, token: CancellationToken) -> CancellableRequest {
CancellableRequest::new(self, token)
}
pub(crate) fn build_url(&self) -> Result<String, HttpError> {
if self.query.is_empty() {
return Ok(self.url.clone());
}
let query_string = self
.query
.iter()
.map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
.collect::<Vec<_>>()
.join("&");
let separator = if self.url.contains('?') { "&" } else { "?" };
Ok(format!("{}{}{}", self.url, separator, query_string))
}
}
#[derive(Clone)]
pub struct CancellationToken {
cancelled: Arc<AtomicBool>,
}
impl CancellationToken {
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::Acquire)
}
pub async fn cancelled(&self) {
while !self.is_cancelled() {
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
}
pub struct CancellationTrigger {
cancelled: Arc<AtomicBool>,
}
impl CancellationTrigger {
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::Release);
}
}
pub fn cancellation_pair() -> (CancellationToken, CancellationTrigger) {
let cancelled = Arc::new(AtomicBool::new(false));
(
CancellationToken { cancelled: cancelled.clone() },
CancellationTrigger { cancelled },
)
}
pub struct CancellableRequest {
pub(crate) builder: RequestBuilder,
pub(crate) cancel_token: CancellationToken,
}
impl CancellableRequest {
pub fn new(builder: RequestBuilder, token: CancellationToken) -> Self {
Self {
builder,
cancel_token: token,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_builder_basic() {
let req = RequestBuilder::new(Method::GET, "https://example.com")
.header("Accept", "application/json")
.query("page", "1")
.query("limit", "10");
assert_eq!(req.method, Method::GET);
assert_eq!(req.headers.get("Accept"), Some(&"application/json".into()));
assert_eq!(req.query.len(), 2);
}
#[test]
fn test_build_url_with_query() {
let req = RequestBuilder::new(Method::GET, "https://example.com/api")
.query("name", "test")
.query("value", "hello world");
let url = req.build_url().unwrap();
assert!(url.contains("name=test"));
assert!(url.contains("value=hello%20world"));
}
#[test]
fn test_bearer_auth() {
let req = RequestBuilder::new(Method::GET, "https://example.com")
.bearer_auth("my_token");
assert_eq!(
req.headers.get("Authorization"),
Some(&"Bearer my_token".into())
);
}
#[test]
fn test_basic_auth() {
let req = RequestBuilder::new(Method::GET, "https://example.com")
.basic_auth("user", Some("pass"));
let auth = req.headers.get("Authorization").unwrap();
assert!(auth.starts_with("Basic "));
}
#[test]
fn test_json_body() {
#[derive(Serialize)]
struct Data {
name: String,
}
let req = RequestBuilder::new(Method::POST, "https://example.com")
.json(&Data { name: "test".into() })
.unwrap();
assert_eq!(
req.headers.get("Content-Type"),
Some(&"application/json".into())
);
matches!(req.body, RequestBody::Json(_));
}
#[test]
fn test_method_display() {
assert_eq!(Method::GET.to_string(), "GET");
assert_eq!(Method::POST.to_string(), "POST");
assert_eq!(Method::PUT.to_string(), "PUT");
assert_eq!(Method::DELETE.to_string(), "DELETE");
}
#[test]
fn test_cancellation_token() {
let (token, trigger) = cancellation_pair();
assert!(!token.is_cancelled());
trigger.cancel();
assert!(token.is_cancelled());
}
}