use bytes::Bytes;
use http::{header, HeaderMap, HeaderValue, Method, StatusCode};
use http_body_util::BodyExt;
use rustapi_core::middleware::{BodyLimitLayer, BoxedNext, LayerStack, DEFAULT_BODY_LIMIT};
use rustapi_core::{ApiError, BodyVariant, IntoResponse, Request, Response, RouteMatch, Router};
use serde::{de::DeserializeOwned, Serialize};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
pub struct TestClient {
router: Arc<Router>,
layers: Arc<LayerStack>,
}
impl TestClient {
pub fn new(app: rustapi_core::RustApi) -> Self {
let layers = app.layers().clone();
let router = app.into_router();
let mut layers = layers;
layers.prepend(Box::new(BodyLimitLayer::new(DEFAULT_BODY_LIMIT)));
Self {
router: Arc::new(router),
layers: Arc::new(layers),
}
}
pub fn with_body_limit(app: rustapi_core::RustApi, limit: usize) -> Self {
let layers = app.layers().clone();
let router = app.into_router();
let mut layers = layers;
layers.prepend(Box::new(BodyLimitLayer::new(limit)));
Self {
router: Arc::new(router),
layers: Arc::new(layers),
}
}
pub async fn get(&self, path: &str) -> TestResponse {
self.request(TestRequest::get(path)).await
}
pub async fn post_json<T: Serialize>(&self, path: &str, body: &T) -> TestResponse {
self.request(TestRequest::post(path).json(body)).await
}
pub async fn request(&self, req: TestRequest) -> TestResponse {
let method = req.method.clone();
let path = req.path.clone();
let (handler, params) = match self.router.match_route(&path, &method) {
RouteMatch::Found { handler, params } => (handler.clone(), params),
RouteMatch::NotFound => {
let response =
ApiError::not_found(format!("No route found for {} {}", method, path))
.into_response();
return TestResponse::from_response(response).await;
}
RouteMatch::MethodNotAllowed { allowed } => {
let allowed_str: Vec<&str> = allowed.iter().map(|m| m.as_str()).collect();
let mut response = ApiError::new(
StatusCode::METHOD_NOT_ALLOWED,
"method_not_allowed",
format!("Method {} not allowed for {}", method, path),
)
.into_response();
response
.headers_mut()
.insert(header::ALLOW, allowed_str.join(", ").parse().unwrap());
return TestResponse::from_response(response).await;
}
};
let uri: http::Uri = path.parse().unwrap_or_else(|_| "/".parse().unwrap());
let mut builder = http::Request::builder().method(method).uri(uri);
for (key, value) in req.headers.iter() {
builder = builder.header(key, value);
}
let http_req = builder.body(()).unwrap();
let (parts, _) = http_req.into_parts();
let body_bytes = req.body.unwrap_or_default();
let request = Request::new(
parts,
BodyVariant::Buffered(body_bytes),
self.router.state_ref(),
params,
);
let final_handler: BoxedNext = Arc::new(move |req: Request| {
let handler = handler.clone();
Box::pin(async move { handler(req).await })
as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
let response = self.layers.execute(request, final_handler).await;
TestResponse::from_response(response).await
}
}
#[derive(Debug, Clone)]
pub struct TestRequest {
method: Method,
path: String,
headers: HeaderMap,
body: Option<Bytes>,
}
impl TestRequest {
fn new(method: Method, path: &str) -> Self {
Self {
method,
path: path.to_string(),
headers: HeaderMap::new(),
body: None,
}
}
pub fn get(path: &str) -> Self {
Self::new(Method::GET, path)
}
pub fn post(path: &str) -> Self {
Self::new(Method::POST, path)
}
pub fn put(path: &str) -> Self {
Self::new(Method::PUT, path)
}
pub fn patch(path: &str) -> Self {
Self::new(Method::PATCH, path)
}
pub fn delete(path: &str) -> Self {
Self::new(Method::DELETE, path)
}
pub fn header(mut self, key: &str, value: &str) -> Self {
if let (Ok(name), Ok(val)) = (
key.parse::<http::header::HeaderName>(),
HeaderValue::from_str(value),
) {
self.headers.insert(name, val);
}
self
}
pub fn json<T: Serialize>(mut self, body: &T) -> Self {
match serde_json::to_vec(body) {
Ok(bytes) => {
self.body = Some(Bytes::from(bytes));
self.headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
}
Err(_) => {
}
}
self
}
pub fn body(mut self, body: impl Into<Bytes>) -> Self {
self.body = Some(body.into());
self
}
pub fn content_type(self, content_type: &str) -> Self {
self.header("content-type", content_type)
}
}
#[derive(Debug)]
pub struct TestResponse {
status: StatusCode,
headers: HeaderMap,
body: Bytes,
}
impl TestResponse {
async fn from_response(response: Response) -> Self {
let (parts, body) = response.into_parts();
let body_bytes = body
.collect()
.await
.map(|b| b.to_bytes())
.unwrap_or_default();
Self {
status: parts.status,
headers: parts.headers,
body: body_bytes,
}
}
pub fn status(&self) -> StatusCode {
self.status
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn body(&self) -> &Bytes {
&self.body
}
pub fn text(&self) -> String {
String::from_utf8_lossy(&self.body).to_string()
}
pub fn json<T: DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
serde_json::from_slice(&self.body)
}
pub fn assert_status<S: Into<StatusCode>>(&self, expected: S) -> &Self {
let expected = expected.into();
assert_eq!(
self.status,
expected,
"Expected status {}, got {}. Body: {}",
expected,
self.status,
self.text()
);
self
}
pub fn assert_header(&self, key: &str, expected: &str) -> &Self {
let actual = self
.headers
.get(key)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert_eq!(
actual, expected,
"Expected header '{}' to be '{}', got '{}'",
key, expected, actual
);
self
}
pub fn assert_json<T: DeserializeOwned + PartialEq + std::fmt::Debug>(
&self,
expected: &T,
) -> &Self {
let actual: T = self.json().expect("Failed to parse response body as JSON");
assert_eq!(&actual, expected, "JSON body mismatch");
self
}
pub fn assert_body_contains(&self, expected: &str) -> &Self {
let body = self.text();
assert!(
body.contains(expected),
"Expected body to contain '{}', got '{}'",
expected,
body
);
self
}
}