use crate::error::{ApiError, Result};
use crate::json;
use crate::request::Request;
use crate::response::IntoResponse;
use crate::stream::{StreamingBody, StreamingConfig};
use crate::validation::Validatable;
use bytes::Bytes;
use http::{header, StatusCode};
use rustapi_validate::v2::{AsyncValidate, ValidationContext};
use rustapi_openapi::schema::{RustApiSchema, SchemaCtx, SchemaRef};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::collections::BTreeMap;
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::str::FromStr;
pub trait FromRequestParts: Sized {
fn from_request_parts(req: &Request) -> Result<Self>;
}
pub trait FromRequest: Sized {
fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
}
impl<T: FromRequestParts> FromRequest for T {
async fn from_request(req: &mut Request) -> Result<Self> {
T::from_request_parts(req)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Json<T>(pub T);
impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
async fn from_request(req: &mut Request) -> Result<Self> {
req.load_body().await?;
let body = req
.take_body()
.ok_or_else(|| ApiError::internal("Body already consumed"))?;
let value: T = json::from_slice(&body)?;
Ok(Json(value))
}
}
impl<T> Deref for Json<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Json<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> From<T> for Json<T> {
fn from(value: T) -> Self {
Json(value)
}
}
const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
impl<T: Serialize> IntoResponse for Json<T> {
fn into_response(self) -> crate::response::Response {
match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
Ok(body) => http::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json")
.body(crate::response::Body::from(body))
.unwrap(),
Err(err) => {
ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
}
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ValidatedJson<T>(pub T);
impl<T> ValidatedJson<T> {
pub fn new(value: T) -> Self {
Self(value)
}
pub fn into_inner(self) -> T {
self.0
}
}
impl<T: DeserializeOwned + Validatable + Send> FromRequest for ValidatedJson<T> {
async fn from_request(req: &mut Request) -> Result<Self> {
req.load_body().await?;
let body = req
.take_body()
.ok_or_else(|| ApiError::internal("Body already consumed"))?;
let value: T = json::from_slice(&body)?;
value.do_validate()?;
Ok(ValidatedJson(value))
}
}
impl<T> Deref for ValidatedJson<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for ValidatedJson<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> From<T> for ValidatedJson<T> {
fn from(value: T) -> Self {
ValidatedJson(value)
}
}
impl<T: Serialize> IntoResponse for ValidatedJson<T> {
fn into_response(self) -> crate::response::Response {
Json(self.0).into_response()
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AsyncValidatedJson<T>(pub T);
impl<T> AsyncValidatedJson<T> {
pub fn new(value: T) -> Self {
Self(value)
}
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> Deref for AsyncValidatedJson<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for AsyncValidatedJson<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> From<T> for AsyncValidatedJson<T> {
fn from(value: T) -> Self {
AsyncValidatedJson(value)
}
}
impl<T: Serialize> IntoResponse for AsyncValidatedJson<T> {
fn into_response(self) -> crate::response::Response {
Json(self.0).into_response()
}
}
impl<T: DeserializeOwned + AsyncValidate + Send + Sync> FromRequest for AsyncValidatedJson<T> {
async fn from_request(req: &mut Request) -> Result<Self> {
req.load_body().await?;
let body = req
.take_body()
.ok_or_else(|| ApiError::internal("Body already consumed"))?;
let value: T = json::from_slice(&body)?;
let ctx = if let Some(ctx) = req.state().get::<ValidationContext>() {
ctx.clone()
} else {
ValidationContext::default()
};
if let Err(errors) = value.validate_full(&ctx).await {
let field_errors: Vec<crate::error::FieldError> = errors
.fields
.iter()
.flat_map(|(field, errs)| {
let field_name = field.to_string();
errs.iter().map(move |e| crate::error::FieldError {
field: field_name.clone(),
code: e.code.to_string(),
message: e.message.clone(),
})
})
.collect();
return Err(ApiError::validation(field_errors));
}
Ok(AsyncValidatedJson(value))
}
}
#[derive(Debug, Clone)]
pub struct Query<T>(pub T);
impl<T: DeserializeOwned> FromRequestParts for Query<T> {
fn from_request_parts(req: &Request) -> Result<Self> {
let query = req.query_string().unwrap_or("");
let value: T = serde_urlencoded::from_str(query)
.map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
Ok(Query(value))
}
}
impl<T> Deref for Query<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct Path<T>(pub T);
impl<T: FromStr> FromRequestParts for Path<T>
where
T::Err: std::fmt::Display,
{
fn from_request_parts(req: &Request) -> Result<Self> {
let params = req.path_params();
if let Some((_, value)) = params.iter().next() {
let parsed = value
.parse::<T>()
.map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
return Ok(Path(parsed));
}
Err(ApiError::internal("Missing path parameter"))
}
}
impl<T> Deref for Path<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct Typed<T>(pub T);
impl<T: DeserializeOwned + Send> FromRequestParts for Typed<T> {
fn from_request_parts(req: &Request) -> Result<Self> {
let params = req.path_params();
let mut map = serde_json::Map::new();
for (k, v) in params.iter() {
map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
}
let value = serde_json::Value::Object(map);
let parsed: T = serde_json::from_value(value)
.map_err(|e| ApiError::bad_request(format!("Invalid path parameters: {}", e)))?;
Ok(Typed(parsed))
}
}
impl<T> Deref for Typed<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct State<T>(pub T);
impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
fn from_request_parts(req: &Request) -> Result<Self> {
req.state().get::<T>().cloned().map(State).ok_or_else(|| {
ApiError::internal(format!(
"State of type `{}` not found. Did you forget to call .state()?",
std::any::type_name::<T>()
))
})
}
}
impl<T> Deref for State<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct Body(pub Bytes);
impl FromRequest for Body {
async fn from_request(req: &mut Request) -> Result<Self> {
req.load_body().await?;
let body = req
.take_body()
.ok_or_else(|| ApiError::internal("Body already consumed"))?;
Ok(Body(body))
}
}
impl Deref for Body {
type Target = Bytes;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct BodyStream(pub StreamingBody);
impl FromRequest for BodyStream {
async fn from_request(req: &mut Request) -> Result<Self> {
let config = StreamingConfig::default();
if let Some(stream) = req.take_stream() {
Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
} else if let Some(bytes) = req.take_body() {
let stream = futures_util::stream::once(async move { Ok(bytes) });
Ok(BodyStream(StreamingBody::from_stream(
stream,
config.max_body_size,
)))
} else {
Err(ApiError::internal("Body already consumed"))
}
}
}
impl Deref for BodyStream {
type Target = StreamingBody;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for BodyStream {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl futures_util::Stream for BodyStream {
type Item = Result<Bytes, ApiError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
std::pin::Pin::new(&mut self.0).poll_next(cx)
}
}
impl<T: FromRequestParts> FromRequestParts for Option<T> {
fn from_request_parts(req: &Request) -> Result<Self> {
Ok(T::from_request_parts(req).ok())
}
}
#[derive(Debug, Clone)]
pub struct Headers(pub http::HeaderMap);
impl Headers {
pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
self.0.get(name)
}
pub fn contains(&self, name: &str) -> bool {
self.0.contains_key(name)
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
self.0.iter()
}
}
impl FromRequestParts for Headers {
fn from_request_parts(req: &Request) -> Result<Self> {
Ok(Headers(req.headers().clone()))
}
}
impl Deref for Headers {
type Target = http::HeaderMap;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct HeaderValue(pub String, pub &'static str);
impl HeaderValue {
pub fn new(name: &'static str, value: String) -> Self {
Self(value, name)
}
pub fn value(&self) -> &str {
&self.0
}
pub fn name(&self) -> &'static str {
self.1
}
pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
req.headers()
.get(name)
.and_then(|v| v.to_str().ok())
.map(|s| HeaderValue(s.to_string(), name))
.ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
}
}
impl Deref for HeaderValue {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct Extension<T>(pub T);
impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
fn from_request_parts(req: &Request) -> Result<Self> {
req.extensions()
.get::<T>()
.cloned()
.map(Extension)
.ok_or_else(|| {
ApiError::internal(format!(
"Extension of type `{}` not found. Did middleware insert it?",
std::any::type_name::<T>()
))
})
}
}
impl<T> Deref for Extension<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Extension<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[derive(Debug, Clone)]
pub struct ClientIp(pub std::net::IpAddr);
impl ClientIp {
pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
if trust_proxy {
if let Some(forwarded) = req.headers().get("x-forwarded-for") {
if let Ok(forwarded_str) = forwarded.to_str() {
if let Some(first_ip) = forwarded_str.split(',').next() {
if let Ok(ip) = first_ip.trim().parse() {
return Ok(ClientIp(ip));
}
}
}
}
}
if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
return Ok(ClientIp(addr.ip()));
}
Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
127, 0, 0, 1,
))))
}
}
impl FromRequestParts for ClientIp {
fn from_request_parts(req: &Request) -> Result<Self> {
Self::extract_with_config(req, true)
}
}
#[cfg(feature = "cookies")]
#[derive(Debug, Clone)]
pub struct Cookies(pub cookie::CookieJar);
#[cfg(feature = "cookies")]
impl Cookies {
pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
self.0.get(name)
}
pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
self.0.iter()
}
pub fn contains(&self, name: &str) -> bool {
self.0.get(name).is_some()
}
}
#[cfg(feature = "cookies")]
impl FromRequestParts for Cookies {
fn from_request_parts(req: &Request) -> Result<Self> {
let mut jar = cookie::CookieJar::new();
if let Some(cookie_header) = req.headers().get(header::COOKIE) {
if let Ok(cookie_str) = cookie_header.to_str() {
for cookie_part in cookie_str.split(';') {
let trimmed = cookie_part.trim();
if !trimmed.is_empty() {
if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
jar.add_original(cookie.into_owned());
}
}
}
}
}
Ok(Cookies(jar))
}
}
#[cfg(feature = "cookies")]
impl Deref for Cookies {
type Target = cookie::CookieJar;
fn deref(&self) -> &Self::Target {
&self.0
}
}
macro_rules! impl_from_request_parts_for_primitives {
($($ty:ty),*) => {
$(
impl FromRequestParts for $ty {
fn from_request_parts(req: &Request) -> Result<Self> {
let Path(value) = Path::<$ty>::from_request_parts(req)?;
Ok(value)
}
}
)*
};
}
impl_from_request_parts_for_primitives!(
i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
);
use rustapi_openapi::{
MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier, ResponseSpec,
};
impl<T: RustApiSchema> OperationModifier for ValidatedJson<T> {
fn update_operation(op: &mut Operation) {
let mut ctx = SchemaCtx::new();
let schema_ref = T::schema(&mut ctx);
let mut content = BTreeMap::new();
content.insert(
"application/json".to_string(),
MediaType {
schema: Some(schema_ref),
example: None,
},
);
op.request_body = Some(RequestBody {
description: None,
required: Some(true),
content,
});
let mut responses_content = BTreeMap::new();
responses_content.insert(
"application/json".to_string(),
MediaType {
schema: Some(SchemaRef::Ref {
reference: "#/components/schemas/ValidationErrorSchema".to_string(),
}),
example: None,
},
);
op.responses.insert(
"422".to_string(),
ResponseSpec {
description: "Validation Error".to_string(),
content: responses_content,
headers: BTreeMap::new(),
},
);
}
fn register_components(spec: &mut rustapi_openapi::OpenApiSpec) {
spec.register_in_place::<T>();
spec.register_in_place::<rustapi_openapi::ValidationErrorSchema>();
spec.register_in_place::<rustapi_openapi::ValidationErrorBodySchema>();
spec.register_in_place::<rustapi_openapi::FieldErrorSchema>();
}
}
impl<T: RustApiSchema> OperationModifier for AsyncValidatedJson<T> {
fn update_operation(op: &mut Operation) {
let mut ctx = SchemaCtx::new();
let schema_ref = T::schema(&mut ctx);
let mut content = BTreeMap::new();
content.insert(
"application/json".to_string(),
MediaType {
schema: Some(schema_ref),
example: None,
},
);
op.request_body = Some(RequestBody {
description: None,
required: Some(true),
content,
});
let mut responses_content = BTreeMap::new();
responses_content.insert(
"application/json".to_string(),
MediaType {
schema: Some(SchemaRef::Ref {
reference: "#/components/schemas/ValidationErrorSchema".to_string(),
}),
example: None,
},
);
op.responses.insert(
"422".to_string(),
ResponseSpec {
description: "Validation Error".to_string(),
content: responses_content,
headers: BTreeMap::new(),
},
);
}
fn register_components(spec: &mut rustapi_openapi::OpenApiSpec) {
spec.register_in_place::<T>();
spec.register_in_place::<rustapi_openapi::ValidationErrorSchema>();
spec.register_in_place::<rustapi_openapi::ValidationErrorBodySchema>();
spec.register_in_place::<rustapi_openapi::FieldErrorSchema>();
}
}
impl<T: RustApiSchema> OperationModifier for Json<T> {
fn update_operation(op: &mut Operation) {
let mut ctx = SchemaCtx::new();
let schema_ref = T::schema(&mut ctx);
let mut content = BTreeMap::new();
content.insert(
"application/json".to_string(),
MediaType {
schema: Some(schema_ref),
example: None,
},
);
op.request_body = Some(RequestBody {
description: None,
required: Some(true),
content,
});
}
fn register_components(spec: &mut rustapi_openapi::OpenApiSpec) {
spec.register_in_place::<T>();
}
}
impl<T> OperationModifier for Path<T> {
fn update_operation(_op: &mut Operation) {}
}
impl<T> OperationModifier for Typed<T> {
fn update_operation(_op: &mut Operation) {}
}
impl<T: RustApiSchema> OperationModifier for Query<T> {
fn update_operation(op: &mut Operation) {
let mut ctx = SchemaCtx::new();
if let Some(fields) = T::field_schemas(&mut ctx) {
let new_params: Vec<Parameter> = fields
.into_iter()
.map(|(name, schema)| {
Parameter {
name,
location: "query".to_string(),
required: false, deprecated: None,
description: None,
schema: Some(schema),
}
})
.collect();
op.parameters.extend(new_params);
}
}
fn register_components(spec: &mut rustapi_openapi::OpenApiSpec) {
spec.register_in_place::<T>();
}
}
impl<T> OperationModifier for State<T> {
fn update_operation(_op: &mut Operation) {}
}
impl OperationModifier for Body {
fn update_operation(op: &mut Operation) {
let mut content = BTreeMap::new();
content.insert(
"application/octet-stream".to_string(),
MediaType {
schema: Some(SchemaRef::Inline(
serde_json::json!({ "type": "string", "format": "binary" }),
)),
example: None,
},
);
op.request_body = Some(RequestBody {
description: None,
required: Some(true),
content,
});
}
}
impl OperationModifier for BodyStream {
fn update_operation(op: &mut Operation) {
let mut content = BTreeMap::new();
content.insert(
"application/octet-stream".to_string(),
MediaType {
schema: Some(SchemaRef::Inline(
serde_json::json!({ "type": "string", "format": "binary" }),
)),
example: None,
},
);
op.request_body = Some(RequestBody {
description: None,
required: Some(true),
content,
});
}
}
impl<T: RustApiSchema> ResponseModifier for Json<T> {
fn update_response(op: &mut Operation) {
let mut ctx = SchemaCtx::new();
let schema_ref = T::schema(&mut ctx);
let mut content = BTreeMap::new();
content.insert(
"application/json".to_string(),
MediaType {
schema: Some(schema_ref),
example: None,
},
);
op.responses.insert(
"200".to_string(),
ResponseSpec {
description: "Successful response".to_string(),
content,
headers: BTreeMap::new(),
},
);
}
fn register_components(spec: &mut rustapi_openapi::OpenApiSpec) {
spec.register_in_place::<T>();
}
}
impl<T: RustApiSchema> RustApiSchema for Json<T> {
fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
T::schema(ctx)
}
}
impl<T: RustApiSchema> RustApiSchema for ValidatedJson<T> {
fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
T::schema(ctx)
}
}
impl<T: RustApiSchema> RustApiSchema for AsyncValidatedJson<T> {
fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
T::schema(ctx)
}
}
impl<T: RustApiSchema> RustApiSchema for Query<T> {
fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
T::schema(ctx)
}
fn field_schemas(ctx: &mut SchemaCtx) -> Option<BTreeMap<String, SchemaRef>> {
T::field_schemas(ctx)
}
}
const DEFAULT_PAGE: u64 = 1;
const DEFAULT_PER_PAGE: u64 = 20;
const MAX_PER_PAGE: u64 = 100;
#[derive(Debug, Clone, Copy)]
pub struct Paginate {
pub page: u64,
pub per_page: u64,
}
impl Paginate {
pub fn new(page: u64, per_page: u64) -> Self {
Self {
page: page.max(1),
per_page: per_page.clamp(1, MAX_PER_PAGE),
}
}
pub fn offset(&self) -> u64 {
(self.page - 1) * self.per_page
}
pub fn limit(&self) -> u64 {
self.per_page
}
pub fn paginate<T>(self, items: Vec<T>, total: u64) -> crate::hateoas::Paginated<T> {
crate::hateoas::Paginated {
items,
page: self.page,
per_page: self.per_page,
total,
}
}
}
impl Default for Paginate {
fn default() -> Self {
Self {
page: DEFAULT_PAGE,
per_page: DEFAULT_PER_PAGE,
}
}
}
impl FromRequestParts for Paginate {
fn from_request_parts(req: &Request) -> Result<Self> {
let query = req.query_string().unwrap_or("");
#[derive(serde::Deserialize)]
struct PaginateQuery {
page: Option<u64>,
per_page: Option<u64>,
}
let params: PaginateQuery = serde_urlencoded::from_str(query).unwrap_or(PaginateQuery {
page: None,
per_page: None,
});
Ok(Paginate::new(
params.page.unwrap_or(DEFAULT_PAGE),
params.per_page.unwrap_or(DEFAULT_PER_PAGE),
))
}
}
#[derive(Debug, Clone)]
pub struct CursorPaginate {
pub cursor: Option<String>,
pub per_page: u64,
}
impl CursorPaginate {
pub fn new(cursor: Option<String>, per_page: u64) -> Self {
Self {
cursor,
per_page: per_page.clamp(1, MAX_PER_PAGE),
}
}
pub fn after(&self) -> Option<&str> {
self.cursor.as_deref()
}
pub fn limit(&self) -> u64 {
self.per_page
}
pub fn is_first_page(&self) -> bool {
self.cursor.is_none()
}
}
impl Default for CursorPaginate {
fn default() -> Self {
Self {
cursor: None,
per_page: DEFAULT_PER_PAGE,
}
}
}
impl FromRequestParts for CursorPaginate {
fn from_request_parts(req: &Request) -> Result<Self> {
let query = req.query_string().unwrap_or("");
#[derive(serde::Deserialize)]
struct CursorQuery {
cursor: Option<String>,
limit: Option<u64>,
}
let params: CursorQuery = serde_urlencoded::from_str(query).unwrap_or(CursorQuery {
cursor: None,
limit: None,
});
Ok(CursorPaginate::new(
params.cursor,
params.limit.unwrap_or(DEFAULT_PER_PAGE),
))
}
}
#[cfg(test)]
mod tests {
include!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/support/extract_lib.rs"
));
}