use bytes::Bytes;
use http::{HeaderMap, HeaderValue, Method, Request};
use http_body_util::Full;
use serde::Serialize;
use serde_json::Value;
use std::collections::HashMap;
use crate::client::ClientError;
pub struct APIRequestFactory {
default_format: String,
default_headers: HeaderMap,
}
impl APIRequestFactory {
pub fn new() -> Self {
Self {
default_format: "json".to_string(),
default_headers: HeaderMap::new(),
}
}
pub fn with_format(mut self, format: impl Into<String>) -> Self {
self.default_format = format.into();
self
}
pub fn with_header(
mut self,
name: impl AsRef<str>,
value: impl AsRef<str>,
) -> Result<Self, ClientError> {
let header_name: http::header::HeaderName = name.as_ref().parse().map_err(|_| {
ClientError::RequestFailed(format!("Invalid header name: {}", name.as_ref()))
})?;
self.default_headers
.insert(header_name, HeaderValue::from_str(value.as_ref())?);
Ok(self)
}
pub fn get(&self, path: &str) -> RequestBuilder {
RequestBuilder::new(Method::GET, path, &self.default_headers)
}
pub fn post(&self, path: &str) -> RequestBuilder {
RequestBuilder::new(Method::POST, path, &self.default_headers)
.with_format(&self.default_format)
}
pub fn put(&self, path: &str) -> RequestBuilder {
RequestBuilder::new(Method::PUT, path, &self.default_headers)
.with_format(&self.default_format)
}
pub fn patch(&self, path: &str) -> RequestBuilder {
RequestBuilder::new(Method::PATCH, path, &self.default_headers)
.with_format(&self.default_format)
}
pub fn delete(&self, path: &str) -> RequestBuilder {
RequestBuilder::new(Method::DELETE, path, &self.default_headers)
}
pub fn head(&self, path: &str) -> RequestBuilder {
RequestBuilder::new(Method::HEAD, path, &self.default_headers)
}
pub fn options(&self, path: &str) -> RequestBuilder {
RequestBuilder::new(Method::OPTIONS, path, &self.default_headers)
}
pub fn request(&self, method: Method, path: &str) -> RequestBuilder {
RequestBuilder::new(method, path, &self.default_headers)
}
}
impl Default for APIRequestFactory {
fn default() -> Self {
Self::new()
}
}
pub struct RequestBuilder {
method: Method,
path: String,
headers: HeaderMap,
query_params: HashMap<String, String>,
body: Option<Bytes>,
format: String,
}
impl RequestBuilder {
pub fn new(method: Method, path: &str, default_headers: &HeaderMap) -> Self {
Self {
method,
path: path.to_string(),
headers: default_headers.clone(),
query_params: HashMap::new(),
body: None,
format: "json".to_string(),
}
}
pub fn method(&self) -> Method {
self.method.clone()
}
pub fn path(&self) -> &str {
&self.path
}
pub fn with_format(mut self, format: &str) -> Self {
self.format = format.to_string();
self
}
pub fn header(mut self, name: &str, value: &str) -> Result<Self, ClientError> {
let header_name: http::header::HeaderName = name
.parse()
.map_err(|_| ClientError::RequestFailed(format!("Invalid header name: {}", name)))?;
self.headers
.insert(header_name, HeaderValue::from_str(value)?);
Ok(self)
}
pub fn query(mut self, key: &str, value: &str) -> Self {
self.query_params.insert(key.to_string(), value.to_string());
self
}
pub fn query_param(self, key: &str, value: &str) -> Self {
self.query(key, value)
}
pub fn json<T: Serialize>(mut self, data: &T) -> Result<Self, ClientError> {
let json = serde_json::to_vec(data)?;
self.body = Some(Bytes::from(json));
self.format = "json".to_string();
Ok(self)
}
pub fn form<T: Serialize>(mut self, data: &T) -> Result<Self, ClientError> {
let json_value = serde_json::to_value(data)?;
if let Value::Object(map) = json_value {
let form_data = map
.iter()
.map(|(k, v)| {
let value_str = match v {
Value::String(s) => s.clone(),
_ => v.to_string(),
};
format!(
"{}={}",
url::form_urlencoded::byte_serialize(k.as_bytes()).collect::<String>(),
url::form_urlencoded::byte_serialize(value_str.as_bytes())
.collect::<String>()
)
})
.collect::<Vec<_>>()
.join("&");
self.body = Some(Bytes::from(form_data));
self.format = "form".to_string();
Ok(self)
} else {
Err(ClientError::RequestFailed(
"Expected object for form data".to_string(),
))
}
}
pub fn body(mut self, body: impl Into<Bytes>) -> Self {
self.body = Some(body.into());
self
}
pub fn build(self) -> Result<Request<Full<Bytes>>, ClientError> {
let mut url = self.path.clone();
if !self.query_params.is_empty() {
let query_string = self
.query_params
.iter()
.map(|(k, v)| {
format!(
"{}={}",
url::form_urlencoded::byte_serialize(k.as_bytes()).collect::<String>(),
url::form_urlencoded::byte_serialize(v.as_bytes()).collect::<String>()
)
})
.collect::<Vec<_>>()
.join("&");
url = format!("{}?{}", url, query_string);
}
let mut request = Request::builder().method(self.method).uri(url);
for (name, value) in self.headers.iter() {
request = request.header(name, value);
}
if self.body.is_some() {
let content_type = match self.format.as_str() {
"json" => "application/json",
"form" => "application/x-www-form-urlencoded",
_ => "application/octet-stream",
};
request = request.header("Content-Type", content_type);
}
let body = self.body.unwrap_or_default();
let req = request.body(Full::new(body))?;
Ok(req)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use serde_json::json;
#[rstest]
fn test_factory_new() {
let factory = APIRequestFactory::new();
let request = factory.get("/api/users/").build().unwrap();
assert_eq!(request.method(), Method::GET);
}
#[rstest]
fn test_factory_default() {
let factory_new = APIRequestFactory::new();
let factory_default = APIRequestFactory::default();
let req_new = factory_new.get("/test").build().unwrap();
let req_default = factory_default.get("/test").build().unwrap();
assert_eq!(req_new.method(), req_default.method());
assert_eq!(req_new.uri(), req_default.uri());
}
#[rstest]
fn test_factory_with_format() {
let factory = APIRequestFactory::new().with_format("xml");
let request = factory.post("/api/data/").body("payload").build().unwrap();
assert_eq!(
request.headers().get("Content-Type").unwrap(),
"application/octet-stream"
);
}
#[rstest]
fn test_factory_with_header() {
let factory = APIRequestFactory::new()
.with_header("X-Custom", "value123")
.unwrap();
let request = factory.get("/api/items/").build().unwrap();
assert_eq!(request.headers().get("x-custom").unwrap(), "value123");
}
#[rstest]
fn test_factory_get() {
let factory = APIRequestFactory::new();
let request = factory.get("/api/users/").build().unwrap();
assert_eq!(request.method(), Method::GET);
}
#[rstest]
fn test_factory_post() {
let factory = APIRequestFactory::new();
let request = factory.post("/api/users/").build().unwrap();
assert_eq!(request.method(), Method::POST);
}
#[rstest]
fn test_factory_put() {
let factory = APIRequestFactory::new();
let request = factory.put("/api/users/1/").build().unwrap();
assert_eq!(request.method(), Method::PUT);
}
#[rstest]
fn test_factory_patch() {
let factory = APIRequestFactory::new();
let request = factory.patch("/api/users/1/").build().unwrap();
assert_eq!(request.method(), Method::PATCH);
}
#[rstest]
fn test_factory_delete() {
let factory = APIRequestFactory::new();
let request = factory.delete("/api/users/1/").build().unwrap();
assert_eq!(request.method(), Method::DELETE);
}
#[rstest]
fn test_factory_head() {
let factory = APIRequestFactory::new();
let request = factory.head("/api/users/").build().unwrap();
assert_eq!(request.method(), Method::HEAD);
}
#[rstest]
fn test_factory_options() {
let factory = APIRequestFactory::new();
let request = factory.options("/api/users/").build().unwrap();
assert_eq!(request.method(), Method::OPTIONS);
}
#[rstest]
fn test_factory_request_custom() {
let factory = APIRequestFactory::new();
let request = factory
.request(Method::TRACE, "/api/trace/")
.build()
.unwrap();
assert_eq!(request.method(), Method::TRACE);
}
#[rstest]
fn test_builder_json() {
let factory = APIRequestFactory::new();
let data = json!({"name": "test"});
let request = factory
.post("/api/users/")
.json(&data)
.unwrap()
.build()
.unwrap();
assert_eq!(
request.headers().get("Content-Type").unwrap(),
"application/json"
);
assert_eq!(request.method(), Method::POST);
}
#[rstest]
fn test_builder_form() {
let factory = APIRequestFactory::new();
let data = json!({"name": "test", "age": 30});
let request = factory
.post("/api/users/")
.form(&data)
.unwrap()
.build()
.unwrap();
assert_eq!(
request.headers().get("Content-Type").unwrap(),
"application/x-www-form-urlencoded"
);
}
#[rstest]
fn test_builder_raw_body() {
let factory = APIRequestFactory::new();
let request = factory
.post("/api/upload/")
.body("raw data")
.build()
.unwrap();
assert_eq!(request.method(), Method::POST);
assert_eq!(
request.headers().get("Content-Type").unwrap(),
"application/json"
);
}
#[rstest]
fn test_builder_query_single() {
let factory = APIRequestFactory::new();
let request = factory
.get("/api/users/")
.query("page", "1")
.build()
.unwrap();
assert_eq!(request.uri().to_string(), "/api/users/?page=1");
}
#[rstest]
fn test_builder_query_multiple() {
let factory = APIRequestFactory::new();
let request = factory
.get("/api/users/")
.query("page", "1")
.query_param("limit", "10")
.build()
.unwrap();
let uri = request.uri().to_string();
assert!(uri.contains("page=1"));
assert!(uri.contains("limit=10"));
assert!(uri.contains('&'));
}
#[rstest]
fn test_builder_method_getter() {
let factory = APIRequestFactory::new();
let builder = factory.get("/test");
assert_eq!(builder.method(), Method::GET);
}
#[rstest]
fn test_builder_path_getter() {
let factory = APIRequestFactory::new();
let builder = factory.get("/api/items/");
assert_eq!(builder.path(), "/api/items/");
}
#[rstest]
fn test_builder_with_format() {
let factory = APIRequestFactory::new();
let request = factory
.post("/api/data/")
.with_format("form")
.body("key=val")
.build()
.unwrap();
assert_eq!(
request.headers().get("Content-Type").unwrap(),
"application/x-www-form-urlencoded"
);
}
#[rstest]
fn test_factory_with_header_invalid_name() {
let result = APIRequestFactory::new().with_header("invalid header!", "value");
assert!(result.is_err());
}
#[rstest]
fn test_builder_form_non_object() {
let factory = APIRequestFactory::new();
let data = json!([1, 2, 3]);
let result = factory.post("/api/users/").form(&data);
assert!(result.is_err());
}
#[rstest]
fn test_builder_header_invalid_name() {
let factory = APIRequestFactory::new();
let result = factory.get("/test").header("bad header!", "value");
assert!(result.is_err());
}
#[rstest]
fn test_builder_no_body_no_content_type() {
let factory = APIRequestFactory::new();
let request = factory.get("/api/users/").build().unwrap();
assert!(request.headers().get("Content-Type").is_none());
}
#[rstest]
fn test_builder_json_empty_object() {
let factory = APIRequestFactory::new();
let data = json!({});
let request = factory
.post("/api/data/")
.json(&data)
.unwrap()
.build()
.unwrap();
assert_eq!(
request.headers().get("Content-Type").unwrap(),
"application/json"
);
}
#[rstest]
fn test_builder_query_special_chars() {
let factory = APIRequestFactory::new();
let request = factory
.get("/api/search/")
.query("q", "hello world&foo=bar")
.build()
.unwrap();
let uri = request.uri().to_string();
assert!(uri.contains("hello+world"));
assert!(!uri.contains("hello world&foo=bar"));
}
#[rstest]
fn test_builder_unknown_format() {
let factory = APIRequestFactory::new().with_format("xml");
let request = factory.post("/api/data/").body("<xml/>").build().unwrap();
assert_eq!(
request.headers().get("Content-Type").unwrap(),
"application/octet-stream"
);
}
}