use crate::error::{ApiError, Result};
use crate::json;
use crate::request::Request;
use crate::response::IntoResponse;
use crate::stream::{StreamingBody, StreamingConfig};
use crate::validation::Validatable;
use bytes::Bytes;
use http::{header, StatusCode};
use rustapi_validate::v2::{AsyncValidate, ValidationContext};
use rustapi_openapi::schema::{RustApiSchema, SchemaCtx, SchemaRef};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::collections::BTreeMap;
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::str::FromStr;
pub trait FromRequestParts: Sized {
fn from_request_parts(req: &Request) -> Result<Self>;
}
pub trait FromRequest: Sized {
fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
}
impl<T: FromRequestParts> FromRequest for T {
async fn from_request(req: &mut Request) -> Result<Self> {
T::from_request_parts(req)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Json<T>(pub T);
impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
async fn from_request(req: &mut Request) -> Result<Self> {
req.load_body().await?;
let body = req
.take_body()
.ok_or_else(|| ApiError::internal("Body already consumed"))?;
let value: T = json::from_slice(&body)?;
Ok(Json(value))
}
}
impl<T> Deref for Json<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Json<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> From<T> for Json<T> {
fn from(value: T) -> Self {
Json(value)
}
}
const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
impl<T: Serialize> IntoResponse for Json<T> {
fn into_response(self) -> crate::response::Response {
match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
Ok(body) => http::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json")
.body(crate::response::Body::from(body))
.unwrap(),
Err(err) => {
ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
}
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ValidatedJson<T>(pub T);
impl<T> ValidatedJson<T> {
pub fn new(value: T) -> Self {
Self(value)
}
pub fn into_inner(self) -> T {
self.0
}
}
impl<T: DeserializeOwned + Validatable + Send> FromRequest for ValidatedJson<T> {
async fn from_request(req: &mut Request) -> Result<Self> {
req.load_body().await?;
let body = req
.take_body()
.ok_or_else(|| ApiError::internal("Body already consumed"))?;
let value: T = json::from_slice(&body)?;
value.do_validate()?;
Ok(ValidatedJson(value))
}
}
impl<T> Deref for ValidatedJson<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for ValidatedJson<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> From<T> for ValidatedJson<T> {
fn from(value: T) -> Self {
ValidatedJson(value)
}
}
impl<T: Serialize> IntoResponse for ValidatedJson<T> {
fn into_response(self) -> crate::response::Response {
Json(self.0).into_response()
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AsyncValidatedJson<T>(pub T);
impl<T> AsyncValidatedJson<T> {
pub fn new(value: T) -> Self {
Self(value)
}
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> Deref for AsyncValidatedJson<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for AsyncValidatedJson<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> From<T> for AsyncValidatedJson<T> {
fn from(value: T) -> Self {
AsyncValidatedJson(value)
}
}
impl<T: Serialize> IntoResponse for AsyncValidatedJson<T> {
fn into_response(self) -> crate::response::Response {
Json(self.0).into_response()
}
}
impl<T: DeserializeOwned + AsyncValidate + Send + Sync> FromRequest for AsyncValidatedJson<T> {
async fn from_request(req: &mut Request) -> Result<Self> {
req.load_body().await?;
let body = req
.take_body()
.ok_or_else(|| ApiError::internal("Body already consumed"))?;
let value: T = json::from_slice(&body)?;
let ctx = if let Some(ctx) = req.state().get::<ValidationContext>() {
ctx.clone()
} else {
ValidationContext::default()
};
if let Err(errors) = value.validate_full(&ctx).await {
let field_errors: Vec<crate::error::FieldError> = errors
.fields
.iter()
.flat_map(|(field, errs)| {
let field_name = field.to_string();
errs.iter().map(move |e| crate::error::FieldError {
field: field_name.clone(),
code: e.code.to_string(),
message: e.message.clone(),
})
})
.collect();
return Err(ApiError::validation(field_errors));
}
Ok(AsyncValidatedJson(value))
}
}
#[derive(Debug, Clone)]
pub struct Query<T>(pub T);
impl<T: DeserializeOwned> FromRequestParts for Query<T> {
fn from_request_parts(req: &Request) -> Result<Self> {
let query = req.query_string().unwrap_or("");
let value: T = serde_urlencoded::from_str(query)
.map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
Ok(Query(value))
}
}
impl<T> Deref for Query<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct Path<T>(pub T);
impl<T: FromStr> FromRequestParts for Path<T>
where
T::Err: std::fmt::Display,
{
fn from_request_parts(req: &Request) -> Result<Self> {
let params = req.path_params();
if let Some((_, value)) = params.iter().next() {
let parsed = value
.parse::<T>()
.map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
return Ok(Path(parsed));
}
Err(ApiError::internal("Missing path parameter"))
}
}
impl<T> Deref for Path<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct Typed<T>(pub T);
impl<T: DeserializeOwned + Send> FromRequestParts for Typed<T> {
fn from_request_parts(req: &Request) -> Result<Self> {
let params = req.path_params();
let mut map = serde_json::Map::new();
for (k, v) in params.iter() {
map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
}
let value = serde_json::Value::Object(map);
let parsed: T = serde_json::from_value(value)
.map_err(|e| ApiError::bad_request(format!("Invalid path parameters: {}", e)))?;
Ok(Typed(parsed))
}
}
impl<T> Deref for Typed<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct State<T>(pub T);
impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
fn from_request_parts(req: &Request) -> Result<Self> {
req.state().get::<T>().cloned().map(State).ok_or_else(|| {
ApiError::internal(format!(
"State of type `{}` not found. Did you forget to call .state()?",
std::any::type_name::<T>()
))
})
}
}
impl<T> Deref for State<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct Body(pub Bytes);
impl FromRequest for Body {
async fn from_request(req: &mut Request) -> Result<Self> {
req.load_body().await?;
let body = req
.take_body()
.ok_or_else(|| ApiError::internal("Body already consumed"))?;
Ok(Body(body))
}
}
impl Deref for Body {
type Target = Bytes;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct BodyStream(pub StreamingBody);
impl FromRequest for BodyStream {
async fn from_request(req: &mut Request) -> Result<Self> {
let config = StreamingConfig::default();
if let Some(stream) = req.take_stream() {
Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
} else if let Some(bytes) = req.take_body() {
let stream = futures_util::stream::once(async move { Ok(bytes) });
Ok(BodyStream(StreamingBody::from_stream(
stream,
config.max_body_size,
)))
} else {
Err(ApiError::internal("Body already consumed"))
}
}
}
impl Deref for BodyStream {
type Target = StreamingBody;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for BodyStream {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl futures_util::Stream for BodyStream {
type Item = Result<Bytes, ApiError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
std::pin::Pin::new(&mut self.0).poll_next(cx)
}
}
impl<T: FromRequestParts> FromRequestParts for Option<T> {
fn from_request_parts(req: &Request) -> Result<Self> {
Ok(T::from_request_parts(req).ok())
}
}
#[derive(Debug, Clone)]
pub struct Headers(pub http::HeaderMap);
impl Headers {
pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
self.0.get(name)
}
pub fn contains(&self, name: &str) -> bool {
self.0.contains_key(name)
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
self.0.iter()
}
}
impl FromRequestParts for Headers {
fn from_request_parts(req: &Request) -> Result<Self> {
Ok(Headers(req.headers().clone()))
}
}
impl Deref for Headers {
type Target = http::HeaderMap;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct HeaderValue(pub String, pub &'static str);
impl HeaderValue {
pub fn new(name: &'static str, value: String) -> Self {
Self(value, name)
}
pub fn value(&self) -> &str {
&self.0
}
pub fn name(&self) -> &'static str {
self.1
}
pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
req.headers()
.get(name)
.and_then(|v| v.to_str().ok())
.map(|s| HeaderValue(s.to_string(), name))
.ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
}
}
impl Deref for HeaderValue {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct Extension<T>(pub T);
impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
fn from_request_parts(req: &Request) -> Result<Self> {
req.extensions()
.get::<T>()
.cloned()
.map(Extension)
.ok_or_else(|| {
ApiError::internal(format!(
"Extension of type `{}` not found. Did middleware insert it?",
std::any::type_name::<T>()
))
})
}
}
impl<T> Deref for Extension<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Extension<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[derive(Debug, Clone)]
pub struct ClientIp(pub std::net::IpAddr);
impl ClientIp {
pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
if trust_proxy {
if let Some(forwarded) = req.headers().get("x-forwarded-for") {
if let Ok(forwarded_str) = forwarded.to_str() {
if let Some(first_ip) = forwarded_str.split(',').next() {
if let Ok(ip) = first_ip.trim().parse() {
return Ok(ClientIp(ip));
}
}
}
}
}
if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
return Ok(ClientIp(addr.ip()));
}
Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
127, 0, 0, 1,
))))
}
}
impl FromRequestParts for ClientIp {
fn from_request_parts(req: &Request) -> Result<Self> {
Self::extract_with_config(req, true)
}
}
#[cfg(feature = "cookies")]
#[derive(Debug, Clone)]
pub struct Cookies(pub cookie::CookieJar);
#[cfg(feature = "cookies")]
impl Cookies {
pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
self.0.get(name)
}
pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
self.0.iter()
}
pub fn contains(&self, name: &str) -> bool {
self.0.get(name).is_some()
}
}
#[cfg(feature = "cookies")]
impl FromRequestParts for Cookies {
fn from_request_parts(req: &Request) -> Result<Self> {
let mut jar = cookie::CookieJar::new();
if let Some(cookie_header) = req.headers().get(header::COOKIE) {
if let Ok(cookie_str) = cookie_header.to_str() {
for cookie_part in cookie_str.split(';') {
let trimmed = cookie_part.trim();
if !trimmed.is_empty() {
if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
jar.add_original(cookie.into_owned());
}
}
}
}
}
Ok(Cookies(jar))
}
}
#[cfg(feature = "cookies")]
impl Deref for Cookies {
type Target = cookie::CookieJar;
fn deref(&self) -> &Self::Target {
&self.0
}
}
macro_rules! impl_from_request_parts_for_primitives {
($($ty:ty),*) => {
$(
impl FromRequestParts for $ty {
fn from_request_parts(req: &Request) -> Result<Self> {
let Path(value) = Path::<$ty>::from_request_parts(req)?;
Ok(value)
}
}
)*
};
}
impl_from_request_parts_for_primitives!(
i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
);
use rustapi_openapi::{
MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier, ResponseSpec,
};
impl<T: RustApiSchema> OperationModifier for ValidatedJson<T> {
fn update_operation(op: &mut Operation) {
let mut ctx = SchemaCtx::new();
let schema_ref = T::schema(&mut ctx);
let mut content = BTreeMap::new();
content.insert(
"application/json".to_string(),
MediaType {
schema: Some(schema_ref),
example: None,
},
);
op.request_body = Some(RequestBody {
description: None,
required: Some(true),
content,
});
let mut responses_content = BTreeMap::new();
responses_content.insert(
"application/json".to_string(),
MediaType {
schema: Some(SchemaRef::Ref {
reference: "#/components/schemas/ValidationErrorSchema".to_string(),
}),
example: None,
},
);
op.responses.insert(
"422".to_string(),
ResponseSpec {
description: "Validation Error".to_string(),
content: responses_content,
headers: BTreeMap::new(),
},
);
}
fn register_components(spec: &mut rustapi_openapi::OpenApiSpec) {
spec.register_in_place::<T>();
spec.register_in_place::<rustapi_openapi::ValidationErrorSchema>();
spec.register_in_place::<rustapi_openapi::ValidationErrorBodySchema>();
spec.register_in_place::<rustapi_openapi::FieldErrorSchema>();
}
}
impl<T: RustApiSchema> OperationModifier for AsyncValidatedJson<T> {
fn update_operation(op: &mut Operation) {
let mut ctx = SchemaCtx::new();
let schema_ref = T::schema(&mut ctx);
let mut content = BTreeMap::new();
content.insert(
"application/json".to_string(),
MediaType {
schema: Some(schema_ref),
example: None,
},
);
op.request_body = Some(RequestBody {
description: None,
required: Some(true),
content,
});
let mut responses_content = BTreeMap::new();
responses_content.insert(
"application/json".to_string(),
MediaType {
schema: Some(SchemaRef::Ref {
reference: "#/components/schemas/ValidationErrorSchema".to_string(),
}),
example: None,
},
);
op.responses.insert(
"422".to_string(),
ResponseSpec {
description: "Validation Error".to_string(),
content: responses_content,
headers: BTreeMap::new(),
},
);
}
fn register_components(spec: &mut rustapi_openapi::OpenApiSpec) {
spec.register_in_place::<T>();
spec.register_in_place::<rustapi_openapi::ValidationErrorSchema>();
spec.register_in_place::<rustapi_openapi::ValidationErrorBodySchema>();
spec.register_in_place::<rustapi_openapi::FieldErrorSchema>();
}
}
impl<T: RustApiSchema> OperationModifier for Json<T> {
fn update_operation(op: &mut Operation) {
let mut ctx = SchemaCtx::new();
let schema_ref = T::schema(&mut ctx);
let mut content = BTreeMap::new();
content.insert(
"application/json".to_string(),
MediaType {
schema: Some(schema_ref),
example: None,
},
);
op.request_body = Some(RequestBody {
description: None,
required: Some(true),
content,
});
}
fn register_components(spec: &mut rustapi_openapi::OpenApiSpec) {
spec.register_in_place::<T>();
}
}
impl<T> OperationModifier for Path<T> {
fn update_operation(_op: &mut Operation) {}
}
impl<T> OperationModifier for Typed<T> {
fn update_operation(_op: &mut Operation) {}
}
impl<T: RustApiSchema> OperationModifier for Query<T> {
fn update_operation(op: &mut Operation) {
let mut ctx = SchemaCtx::new();
if let Some(fields) = T::field_schemas(&mut ctx) {
let new_params: Vec<Parameter> = fields
.into_iter()
.map(|(name, schema)| {
Parameter {
name,
location: "query".to_string(),
required: false, deprecated: None,
description: None,
schema: Some(schema),
}
})
.collect();
op.parameters.extend(new_params);
}
}
fn register_components(spec: &mut rustapi_openapi::OpenApiSpec) {
spec.register_in_place::<T>();
}
}
impl<T> OperationModifier for State<T> {
fn update_operation(_op: &mut Operation) {}
}
impl OperationModifier for Body {
fn update_operation(op: &mut Operation) {
let mut content = BTreeMap::new();
content.insert(
"application/octet-stream".to_string(),
MediaType {
schema: Some(SchemaRef::Inline(
serde_json::json!({ "type": "string", "format": "binary" }),
)),
example: None,
},
);
op.request_body = Some(RequestBody {
description: None,
required: Some(true),
content,
});
}
}
impl OperationModifier for BodyStream {
fn update_operation(op: &mut Operation) {
let mut content = BTreeMap::new();
content.insert(
"application/octet-stream".to_string(),
MediaType {
schema: Some(SchemaRef::Inline(
serde_json::json!({ "type": "string", "format": "binary" }),
)),
example: None,
},
);
op.request_body = Some(RequestBody {
description: None,
required: Some(true),
content,
});
}
}
impl<T: RustApiSchema> ResponseModifier for Json<T> {
fn update_response(op: &mut Operation) {
let mut ctx = SchemaCtx::new();
let schema_ref = T::schema(&mut ctx);
let mut content = BTreeMap::new();
content.insert(
"application/json".to_string(),
MediaType {
schema: Some(schema_ref),
example: None,
},
);
op.responses.insert(
"200".to_string(),
ResponseSpec {
description: "Successful response".to_string(),
content,
headers: BTreeMap::new(),
},
);
}
fn register_components(spec: &mut rustapi_openapi::OpenApiSpec) {
spec.register_in_place::<T>();
}
}
impl<T: RustApiSchema> RustApiSchema for Json<T> {
fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
T::schema(ctx)
}
}
impl<T: RustApiSchema> RustApiSchema for ValidatedJson<T> {
fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
T::schema(ctx)
}
}
impl<T: RustApiSchema> RustApiSchema for AsyncValidatedJson<T> {
fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
T::schema(ctx)
}
}
impl<T: RustApiSchema> RustApiSchema for Query<T> {
fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
T::schema(ctx)
}
fn field_schemas(ctx: &mut SchemaCtx) -> Option<BTreeMap<String, SchemaRef>> {
T::field_schemas(ctx)
}
}
const DEFAULT_PAGE: u64 = 1;
const DEFAULT_PER_PAGE: u64 = 20;
const MAX_PER_PAGE: u64 = 100;
#[derive(Debug, Clone, Copy)]
pub struct Paginate {
pub page: u64,
pub per_page: u64,
}
impl Paginate {
pub fn new(page: u64, per_page: u64) -> Self {
Self {
page: page.max(1),
per_page: per_page.clamp(1, MAX_PER_PAGE),
}
}
pub fn offset(&self) -> u64 {
(self.page - 1) * self.per_page
}
pub fn limit(&self) -> u64 {
self.per_page
}
pub fn paginate<T>(self, items: Vec<T>, total: u64) -> crate::hateoas::Paginated<T> {
crate::hateoas::Paginated {
items,
page: self.page,
per_page: self.per_page,
total,
}
}
}
impl Default for Paginate {
fn default() -> Self {
Self {
page: DEFAULT_PAGE,
per_page: DEFAULT_PER_PAGE,
}
}
}
impl FromRequestParts for Paginate {
fn from_request_parts(req: &Request) -> Result<Self> {
let query = req.query_string().unwrap_or("");
#[derive(serde::Deserialize)]
struct PaginateQuery {
page: Option<u64>,
per_page: Option<u64>,
}
let params: PaginateQuery = serde_urlencoded::from_str(query).unwrap_or(PaginateQuery {
page: None,
per_page: None,
});
Ok(Paginate::new(
params.page.unwrap_or(DEFAULT_PAGE),
params.per_page.unwrap_or(DEFAULT_PER_PAGE),
))
}
}
#[derive(Debug, Clone)]
pub struct CursorPaginate {
pub cursor: Option<String>,
pub per_page: u64,
}
impl CursorPaginate {
pub fn new(cursor: Option<String>, per_page: u64) -> Self {
Self {
cursor,
per_page: per_page.clamp(1, MAX_PER_PAGE),
}
}
pub fn after(&self) -> Option<&str> {
self.cursor.as_deref()
}
pub fn limit(&self) -> u64 {
self.per_page
}
pub fn is_first_page(&self) -> bool {
self.cursor.is_none()
}
}
impl Default for CursorPaginate {
fn default() -> Self {
Self {
cursor: None,
per_page: DEFAULT_PER_PAGE,
}
}
}
impl FromRequestParts for CursorPaginate {
fn from_request_parts(req: &Request) -> Result<Self> {
let query = req.query_string().unwrap_or("");
#[derive(serde::Deserialize)]
struct CursorQuery {
cursor: Option<String>,
limit: Option<u64>,
}
let params: CursorQuery = serde_urlencoded::from_str(query).unwrap_or(CursorQuery {
cursor: None,
limit: None,
});
Ok(CursorPaginate::new(
params.cursor,
params.limit.unwrap_or(DEFAULT_PER_PAGE),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::path_params::PathParams;
use bytes::Bytes;
use http::{Extensions, Method};
use proptest::prelude::*;
use proptest::test_runner::TestCaseError;
use std::sync::Arc;
fn create_test_request_with_headers(
method: Method,
path: &str,
headers: Vec<(&str, &str)>,
) -> Request {
let uri: http::Uri = path.parse().unwrap();
let mut builder = http::Request::builder().method(method).uri(uri);
for (name, value) in headers {
builder = builder.header(name, value);
}
let req = builder.body(()).unwrap();
let (parts, _) = req.into_parts();
Request::new(
parts,
crate::request::BodyVariant::Buffered(Bytes::new()),
Arc::new(Extensions::new()),
PathParams::new(),
)
}
fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
method: Method,
path: &str,
extension: T,
) -> Request {
let uri: http::Uri = path.parse().unwrap();
let builder = http::Request::builder().method(method).uri(uri);
let req = builder.body(()).unwrap();
let (mut parts, _) = req.into_parts();
parts.extensions.insert(extension);
Request::new(
parts,
crate::request::BodyVariant::Buffered(Bytes::new()),
Arc::new(Extensions::new()),
PathParams::new(),
)
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_headers_extractor_completeness(
headers in prop::collection::vec(
(
"[a-z][a-z0-9-]{0,20}", // Valid header name pattern
"[a-zA-Z0-9 ]{1,50}" // Valid header value pattern
),
0..10
)
) {
let result: Result<(), TestCaseError> = (|| {
let header_tuples: Vec<(&str, &str)> = headers
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
let request = create_test_request_with_headers(
Method::GET,
"/test",
header_tuples.clone(),
);
let extracted = Headers::from_request_parts(&request)
.map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
for (name, value) in &headers {
let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
prop_assert!(
!all_values.is_empty(),
"Header '{}' not found",
name
);
let value_found = all_values.iter().any(|v| {
v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
});
prop_assert!(
value_found,
"Header '{}' value '{}' not found in extracted values",
name,
value
);
}
Ok(())
})();
result?;
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_header_value_extractor_correctness(
header_name in "[a-z][a-z0-9-]{0,20}",
header_value in "[a-zA-Z0-9 ]{1,50}",
has_header in prop::bool::ANY,
) {
let result: Result<(), TestCaseError> = (|| {
let headers = if has_header {
vec![(header_name.as_str(), header_value.as_str())]
} else {
vec![]
};
let _request = create_test_request_with_headers(Method::GET, "/test", headers);
let test_header = "x-test-header";
let request_with_known_header = if has_header {
create_test_request_with_headers(
Method::GET,
"/test",
vec![(test_header, header_value.as_str())],
)
} else {
create_test_request_with_headers(Method::GET, "/test", vec![])
};
let result = HeaderValue::extract(&request_with_known_header, test_header);
if has_header {
let extracted = result
.map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
prop_assert_eq!(
extracted.value(),
header_value.as_str(),
"Header value mismatch"
);
} else {
prop_assert!(
result.is_err(),
"Expected error when header is missing"
);
}
Ok(())
})();
result?;
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_client_ip_extractor_with_forwarding(
forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
.prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
.prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
has_forwarded_header in prop::bool::ANY,
trust_proxy in prop::bool::ANY,
) {
let result: Result<(), TestCaseError> = (|| {
let headers = if has_forwarded_header {
vec![("x-forwarded-for", forwarded_ip.as_str())]
} else {
vec![]
};
let uri: http::Uri = "/test".parse().unwrap();
let mut builder = http::Request::builder().method(Method::GET).uri(uri);
for (name, value) in &headers {
builder = builder.header(*name, *value);
}
let req = builder.body(()).unwrap();
let (mut parts, _) = req.into_parts();
let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
parts.extensions.insert(socket_addr);
let request = Request::new(
parts,
crate::request::BodyVariant::Buffered(Bytes::new()),
Arc::new(Extensions::new()),
PathParams::new(),
);
let extracted = ClientIp::extract_with_config(&request, trust_proxy)
.map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
if trust_proxy && has_forwarded_header {
let expected_ip: std::net::IpAddr = forwarded_ip.parse()
.map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
prop_assert_eq!(
extracted.0,
expected_ip,
"Should use X-Forwarded-For IP when trust_proxy is enabled"
);
} else {
prop_assert_eq!(
extracted.0,
socket_ip,
"Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
);
}
Ok(())
})();
result?;
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_extension_extractor_retrieval(
value in any::<i64>(),
has_extension in prop::bool::ANY,
) {
let result: Result<(), TestCaseError> = (|| {
#[derive(Clone, Debug, PartialEq)]
struct TestExtension(i64);
let uri: http::Uri = "/test".parse().unwrap();
let builder = http::Request::builder().method(Method::GET).uri(uri);
let req = builder.body(()).unwrap();
let (mut parts, _) = req.into_parts();
if has_extension {
parts.extensions.insert(TestExtension(value));
}
let request = Request::new(
parts,
crate::request::BodyVariant::Buffered(Bytes::new()),
Arc::new(Extensions::new()),
PathParams::new(),
);
let result = Extension::<TestExtension>::from_request_parts(&request);
if has_extension {
let extracted = result
.map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
prop_assert_eq!(
extracted.0,
TestExtension(value),
"Extension value mismatch"
);
} else {
prop_assert!(
result.is_err(),
"Expected error when extension is missing"
);
}
Ok(())
})();
result?;
}
}
#[test]
fn test_headers_extractor_basic() {
let request = create_test_request_with_headers(
Method::GET,
"/test",
vec![
("content-type", "application/json"),
("accept", "text/html"),
],
);
let headers = Headers::from_request_parts(&request).unwrap();
assert!(headers.contains("content-type"));
assert!(headers.contains("accept"));
assert!(!headers.contains("x-custom"));
assert_eq!(headers.len(), 2);
}
#[test]
fn test_header_value_extractor_present() {
let request = create_test_request_with_headers(
Method::GET,
"/test",
vec![("authorization", "Bearer token123")],
);
let result = HeaderValue::extract(&request, "authorization");
assert!(result.is_ok());
assert_eq!(result.unwrap().value(), "Bearer token123");
}
#[test]
fn test_header_value_extractor_missing() {
let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
let result = HeaderValue::extract(&request, "authorization");
assert!(result.is_err());
}
#[test]
fn test_client_ip_from_forwarded_header() {
let request = create_test_request_with_headers(
Method::GET,
"/test",
vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
);
let ip = ClientIp::extract_with_config(&request, true).unwrap();
assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
}
#[test]
fn test_client_ip_ignores_forwarded_when_not_trusted() {
let uri: http::Uri = "/test".parse().unwrap();
let builder = http::Request::builder()
.method(Method::GET)
.uri(uri)
.header("x-forwarded-for", "192.168.1.100");
let req = builder.body(()).unwrap();
let (mut parts, _) = req.into_parts();
let socket_addr = std::net::SocketAddr::new(
std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
8080,
);
parts.extensions.insert(socket_addr);
let request = Request::new(
parts,
crate::request::BodyVariant::Buffered(Bytes::new()),
Arc::new(Extensions::new()),
PathParams::new(),
);
let ip = ClientIp::extract_with_config(&request, false).unwrap();
assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
}
#[test]
fn test_extension_extractor_present() {
#[derive(Clone, Debug, PartialEq)]
struct MyData(String);
let request =
create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
let result = Extension::<MyData>::from_request_parts(&request);
assert!(result.is_ok());
assert_eq!(result.unwrap().0, MyData("hello".to_string()));
}
#[test]
fn test_extension_extractor_missing() {
#[derive(Clone, Debug)]
#[allow(dead_code)]
struct MyData(String);
let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
let result = Extension::<MyData>::from_request_parts(&request);
assert!(result.is_err());
}
#[cfg(feature = "cookies")]
mod cookies_tests {
use super::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_cookies_extractor_parsing(
cookies in prop::collection::vec(
(
"[a-zA-Z][a-zA-Z0-9_]{0,15}", // Valid cookie name pattern
"[a-zA-Z0-9]{1,30}" // Valid cookie value pattern (no special chars)
),
0..5
)
) {
let result: Result<(), TestCaseError> = (|| {
let cookie_header = cookies
.iter()
.map(|(name, value)| format!("{}={}", name, value))
.collect::<Vec<_>>()
.join("; ");
let headers = if !cookies.is_empty() {
vec![("cookie", cookie_header.as_str())]
} else {
vec![]
};
let request = create_test_request_with_headers(Method::GET, "/test", headers);
let extracted = Cookies::from_request_parts(&request)
.map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
for (name, value) in &cookies {
expected_cookies.insert(name.as_str(), value.as_str());
}
for (name, expected_value) in &expected_cookies {
let cookie = extracted.get(name)
.ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
prop_assert_eq!(
cookie.value(),
*expected_value,
"Cookie '{}' value mismatch",
name
);
}
let extracted_count = extracted.iter().count();
prop_assert_eq!(
extracted_count,
expected_cookies.len(),
"Expected {} unique cookies, got {}",
expected_cookies.len(),
extracted_count
);
Ok(())
})();
result?;
}
}
#[test]
fn test_cookies_extractor_basic() {
let request = create_test_request_with_headers(
Method::GET,
"/test",
vec![("cookie", "session=abc123; user=john")],
);
let cookies = Cookies::from_request_parts(&request).unwrap();
assert!(cookies.contains("session"));
assert!(cookies.contains("user"));
assert!(!cookies.contains("other"));
assert_eq!(cookies.get("session").unwrap().value(), "abc123");
assert_eq!(cookies.get("user").unwrap().value(), "john");
}
#[test]
fn test_cookies_extractor_empty() {
let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
let cookies = Cookies::from_request_parts(&request).unwrap();
assert_eq!(cookies.iter().count(), 0);
}
#[test]
fn test_cookies_extractor_single() {
let request = create_test_request_with_headers(
Method::GET,
"/test",
vec![("cookie", "token=xyz789")],
);
let cookies = Cookies::from_request_parts(&request).unwrap();
assert_eq!(cookies.iter().count(), 1);
assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
}
}
#[tokio::test]
async fn test_async_validated_json_with_state_context() {
use async_trait::async_trait;
use rustapi_validate::prelude::*;
use rustapi_validate::v2::{
AsyncValidationRule, DatabaseValidator, ValidationContextBuilder,
};
use serde::{Deserialize, Serialize};
struct MockDbValidator {
unique_values: Vec<String>,
}
#[async_trait]
impl DatabaseValidator for MockDbValidator {
async fn exists(
&self,
_table: &str,
_column: &str,
_value: &str,
) -> Result<bool, String> {
Ok(true)
}
async fn is_unique(
&self,
_table: &str,
_column: &str,
value: &str,
) -> Result<bool, String> {
Ok(!self.unique_values.contains(&value.to_string()))
}
async fn is_unique_except(
&self,
_table: &str,
_column: &str,
value: &str,
_except_id: &str,
) -> Result<bool, String> {
Ok(!self.unique_values.contains(&value.to_string()))
}
}
#[derive(Debug, Deserialize, Serialize)]
struct TestUser {
email: String,
}
impl Validate for TestUser {
fn validate_with_group(
&self,
_group: rustapi_validate::v2::ValidationGroup,
) -> Result<(), rustapi_validate::v2::ValidationErrors> {
Ok(())
}
}
#[async_trait]
impl AsyncValidate for TestUser {
async fn validate_async_with_group(
&self,
ctx: &ValidationContext,
_group: rustapi_validate::v2::ValidationGroup,
) -> Result<(), rustapi_validate::v2::ValidationErrors> {
let mut errors = rustapi_validate::v2::ValidationErrors::new();
let rule = AsyncUniqueRule::new("users", "email");
if let Err(e) = rule.validate_async(&self.email, ctx).await {
errors.add("email", e);
}
errors.into_result()
}
}
let uri: http::Uri = "/test".parse().unwrap();
let user = TestUser {
email: "new@example.com".to_string(),
};
let body_bytes = serde_json::to_vec(&user).unwrap();
let builder = http::Request::builder()
.method(Method::POST)
.uri(uri.clone())
.header("content-type", "application/json");
let req = builder.body(()).unwrap();
let (parts, _) = req.into_parts();
let mut request = Request::new(
parts,
crate::request::BodyVariant::Buffered(Bytes::from(body_bytes.clone())),
Arc::new(Extensions::new()),
PathParams::new(),
);
let result = AsyncValidatedJson::<TestUser>::from_request(&mut request).await;
assert!(result.is_err(), "Expected error when validator is missing");
let err = result.unwrap_err();
let err_str = format!("{:?}", err);
assert!(
err_str.contains("Database validator not configured")
|| err_str.contains("async_unique"),
"Error should mention missing configuration or rule: {:?}",
err_str
);
let db_validator = MockDbValidator {
unique_values: vec!["taken@example.com".to_string()],
};
let ctx = ValidationContextBuilder::new()
.database(db_validator)
.build();
let mut extensions = Extensions::new();
extensions.insert(ctx);
let builder = http::Request::builder()
.method(Method::POST)
.uri(uri.clone())
.header("content-type", "application/json");
let req = builder.body(()).unwrap();
let (parts, _) = req.into_parts();
let mut request = Request::new(
parts,
crate::request::BodyVariant::Buffered(Bytes::from(body_bytes.clone())),
Arc::new(extensions),
PathParams::new(),
);
let result = AsyncValidatedJson::<TestUser>::from_request(&mut request).await;
assert!(
result.is_ok(),
"Expected success when validator is present and value is unique. Error: {:?}",
result.err()
);
let user_taken = TestUser {
email: "taken@example.com".to_string(),
};
let body_taken = serde_json::to_vec(&user_taken).unwrap();
let db_validator = MockDbValidator {
unique_values: vec!["taken@example.com".to_string()],
};
let ctx = ValidationContextBuilder::new()
.database(db_validator)
.build();
let mut extensions = Extensions::new();
extensions.insert(ctx);
let builder = http::Request::builder()
.method(Method::POST)
.uri("/test")
.header("content-type", "application/json");
let req = builder.body(()).unwrap();
let (parts, _) = req.into_parts();
let mut request = Request::new(
parts,
crate::request::BodyVariant::Buffered(Bytes::from(body_taken)),
Arc::new(extensions),
PathParams::new(),
);
let result = AsyncValidatedJson::<TestUser>::from_request(&mut request).await;
assert!(result.is_err(), "Expected validation error for taken email");
}
}