use bytes::Bytes;
use http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
use serde::Serialize;
use std::collections::HashMap;
use wae_types::{WaeError, WaeErrorKind, WaeResult as TestingResult};
#[derive(Debug, Clone)]
pub struct TestClient {
base_url: Option<String>,
headers: HeaderMap,
query_params: HashMap<String, String>,
}
impl TestClient {
pub fn new() -> Self {
Self { base_url: None, headers: HeaderMap::new(), query_params: HashMap::new() }
}
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = Some(url.into());
self
}
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
K: TryInto<HeaderName>,
V: TryInto<HeaderValue>,
{
if let Ok(key) = key.try_into()
&& let Ok(value) = value.try_into()
{
self.headers.append(key, value);
}
self
}
pub fn headers(mut self, headers: HeaderMap) -> Self {
self.headers.extend(headers);
self
}
pub fn query_param<K, V>(mut self, key: K, value: V) -> Self
where
K: Into<String>,
V: Into<String>,
{
self.query_params.insert(key.into(), value.into());
self
}
pub fn query_params(mut self, params: HashMap<String, String>) -> Self {
self.query_params.extend(params);
self
}
pub fn get(self, url: impl Into<String>) -> RequestBuilder {
self.request(http::Method::GET, url)
}
pub fn post(self, url: impl Into<String>) -> RequestBuilder {
self.request(http::Method::POST, url)
}
pub fn put(self, url: impl Into<String>) -> RequestBuilder {
self.request(http::Method::PUT, url)
}
pub fn delete(self, url: impl Into<String>) -> RequestBuilder {
self.request(http::Method::DELETE, url)
}
pub fn patch(self, url: impl Into<String>) -> RequestBuilder {
self.request(http::Method::PATCH, url)
}
pub fn head(self, url: impl Into<String>) -> RequestBuilder {
self.request(http::Method::HEAD, url)
}
pub fn options(self, url: impl Into<String>) -> RequestBuilder {
self.request(http::Method::OPTIONS, url)
}
pub fn request<M, U>(self, method: M, url: U) -> RequestBuilder
where
M: Into<http::Method>,
U: Into<String>,
{
RequestBuilder::new(method, url, self.base_url, self.headers, self.query_params)
}
}
impl Default for TestClient {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct RequestBuilder {
method: http::Method,
url: String,
base_url: Option<String>,
headers: HeaderMap,
query_params: HashMap<String, String>,
body: Option<RequestBody>,
}
#[derive(Debug, Clone)]
enum RequestBody {
Json(Bytes),
Text(String),
Bytes(Bytes),
}
impl RequestBuilder {
fn new<M, U>(method: M, url: U, base_url: Option<String>, headers: HeaderMap, query_params: HashMap<String, String>) -> Self
where
M: Into<http::Method>,
U: Into<String>,
{
Self { method: method.into(), url: url.into(), base_url, headers, query_params, body: None }
}
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
K: TryInto<HeaderName>,
V: TryInto<HeaderValue>,
{
if let Ok(key) = key.try_into()
&& let Ok(value) = value.try_into()
{
self.headers.append(key, value);
}
self
}
pub fn headers(mut self, headers: HeaderMap) -> Self {
self.headers.extend(headers);
self
}
pub fn query_param<K, V>(mut self, key: K, value: V) -> Self
where
K: Into<String>,
V: Into<String>,
{
self.query_params.insert(key.into(), value.into());
self
}
pub fn query_params(mut self, params: HashMap<String, String>) -> Self {
self.query_params.extend(params);
self
}
pub fn json<T: Serialize>(mut self, data: &T) -> TestingResult<Self> {
let bytes = serde_json::to_vec(data).map_err(|e| WaeError::new(WaeErrorKind::JsonError { reason: e.to_string() }))?;
self.body = Some(RequestBody::Json(Bytes::from(bytes)));
Ok(self)
}
pub fn text(mut self, data: impl Into<String>) -> Self {
self.body = Some(RequestBody::Text(data.into()));
self
}
pub fn bytes(mut self, data: impl Into<Bytes>) -> Self {
self.body = Some(RequestBody::Bytes(data.into()));
self
}
fn build_url(&self) -> String {
let mut url_str = if let Some(ref base_url) = self.base_url {
if self.url.starts_with("http://") || self.url.starts_with("https://") {
self.url.clone()
}
else {
format!("{}{}", base_url, self.url)
}
}
else {
self.url.clone()
};
if !self.query_params.is_empty() {
let separator = if url_str.contains('?') { "&" } else { "?" };
let params: Vec<String> = self.query_params.iter().map(|(k, v)| format!("{}={}", k, v)).collect();
url_str.push_str(separator);
url_str.push_str(¶ms.join("&"));
}
url_str
}
pub fn method(&self) -> &http::Method {
&self.method
}
pub fn url(&self) -> String {
self.build_url()
}
pub fn get_headers(&self) -> &HeaderMap {
&self.headers
}
pub fn body(&self) -> Option<Bytes> {
match &self.body {
Some(RequestBody::Json(b)) => Some(b.clone()),
Some(RequestBody::Text(t)) => Some(Bytes::from(t.clone())),
Some(RequestBody::Bytes(b)) => Some(b.clone()),
None => None,
}
}
pub fn create_response(&self, status: StatusCode, headers: HeaderMap, body: Bytes) -> TestResponse {
TestResponse { status, headers, body, request_method: self.method.clone(), request_url: self.build_url() }
}
}
#[derive(Debug, Clone)]
pub struct TestResponse {
pub status: StatusCode,
pub headers: HeaderMap,
pub body: Bytes,
request_method: http::Method,
request_url: String,
}
impl TestResponse {
pub fn new(status: StatusCode, headers: HeaderMap, body: Bytes) -> Self {
Self { status, headers, body, request_method: http::Method::GET, request_url: String::new() }
}
pub fn status(&self) -> StatusCode {
self.status
}
pub fn is_success(&self) -> bool {
self.status.is_success()
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn header<K: TryInto<HeaderName>>(&self, key: K) -> Option<&HeaderValue> {
if let Ok(key) = key.try_into() { self.headers.get(key) } else { None }
}
pub fn body(&self) -> &Bytes {
&self.body
}
pub fn text(&self) -> TestingResult<String> {
String::from_utf8(self.body.to_vec())
.map_err(|e| WaeError::new(WaeErrorKind::ParseError { type_name: "String".to_string(), reason: e.to_string() }))
}
pub fn json<T: serde::de::DeserializeOwned>(&self) -> TestingResult<T> {
serde_json::from_slice(&self.body).map_err(|e| WaeError::new(WaeErrorKind::JsonError { reason: e.to_string() }))
}
pub fn assert_status(&self, status: StatusCode) -> TestingResult<&Self> {
if self.status != status {
return Err(WaeError::new(WaeErrorKind::AssertionFailed {
message: format!("Expected status {}, got {}", status, self.status),
}));
}
Ok(self)
}
pub fn assert_success(&self) -> TestingResult<&Self> {
if !self.is_success() {
return Err(WaeError::new(WaeErrorKind::AssertionFailed {
message: format!("Expected success status, got {}", self.status),
}));
}
Ok(self)
}
pub fn assert_header<K: TryInto<HeaderName>>(&self, key: K) -> TestingResult<&Self> {
if let Ok(key) = key.try_into()
&& self.headers.contains_key(&key)
{
return Ok(self);
}
Err(WaeError::new(WaeErrorKind::AssertionFailed { message: "Expected header not found".to_string() }))
}
pub fn assert_header_eq<K, V>(&self, key: K, value: V) -> TestingResult<&Self>
where
K: TryInto<HeaderName>,
V: AsRef<str>,
{
if let Ok(key) = key.try_into()
&& let Some(header_value) = self.headers.get(&key)
&& let Ok(hv_str) = header_value.to_str()
&& hv_str == value.as_ref()
{
return Ok(self);
}
Err(WaeError::new(WaeErrorKind::AssertionFailed { message: format!("Expected header value {}", value.as_ref()) }))
}
pub fn assert_body_contains(&self, text: &str) -> TestingResult<&Self> {
let body_str = self.text()?;
if !body_str.contains(text) {
return Err(WaeError::new(WaeErrorKind::AssertionFailed {
message: "Expected text not found in response body".to_string(),
}));
}
Ok(self)
}
pub fn request_method(&self) -> &http::Method {
&self.request_method
}
pub fn request_url(&self) -> &str {
&self.request_url
}
}