use core::{
fmt,
ops::{Deref, DerefMut},
};
use axum::{
extract::{FromRequest, FromRequestParts, Request},
http::header,
response::{IntoResponse, Response},
};
use bytes::{Bytes, BytesMut};
use crate::{Accept, CodecDecode, CodecEncode, CodecRejection, ContentType, IntoCodecResponse};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct Codec<T>(pub T);
impl<T> Codec<T> {
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> Codec<T>
where
T: CodecEncode,
{
pub fn to_response<C: Into<ContentType>>(&self, content_type: C) -> Response {
let content_type = content_type.into();
let bytes = match self.to_bytes(content_type) {
Ok(bytes) => bytes,
Err(rejection) => return rejection.into_response(),
};
([(header::CONTENT_TYPE, content_type.into_header())], bytes).into_response()
}
}
impl<T> Deref for Codec<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Codec<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T: fmt::Display> fmt::Display for Codec<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl<T, S> FromRequest<S> for Codec<T>
where
T: for<'de> CodecDecode<'de>,
S: Send + Sync + 'static,
{
type Rejection = Response;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let (mut parts, body) = req.into_parts();
let accept = Accept::from_request_parts(&mut parts, state).await.unwrap();
let req = Request::from_parts(parts, body);
let content_type = req
.headers()
.get(header::CONTENT_TYPE)
.and_then(ContentType::from_header)
.unwrap_or_default();
let data = match () {
#[cfg(feature = "form")]
() if content_type == ContentType::Form && req.method() == axum::http::Method::GET => {
let query = req.uri().query().unwrap_or("");
Codec::from_form(query.as_bytes()).map_err(CodecRejection::from)
}
() => {
let bytes = Bytes::from_request(req, state)
.await
.map_err(|e| CodecRejection::from(e).into_codec_response(accept.into()))?;
Codec::from_bytes(&bytes, content_type)
}
}
.map_err(|e| e.into_codec_response(accept.into()))?;
Ok(data)
}
}
#[cfg(feature = "aide")]
impl<T> aide::operation::OperationInput for Codec<T>
where
T: schemars::JsonSchema,
{
fn operation_input(
ctx: &mut aide::generate::GenContext,
operation: &mut aide::openapi::Operation,
) {
axum::Json::<T>::operation_input(ctx, operation);
}
fn inferred_early_responses(
ctx: &mut aide::generate::GenContext,
operation: &mut aide::openapi::Operation,
) -> Vec<(Option<u16>, aide::openapi::Response)> {
axum::Json::<T>::inferred_early_responses(ctx, operation)
}
}
#[cfg(feature = "aide")]
impl<T> aide::operation::OperationOutput for Codec<T>
where
T: schemars::JsonSchema,
{
type Inner = T;
fn operation_response(
ctx: &mut aide::generate::GenContext,
operation: &mut aide::openapi::Operation,
) -> Option<aide::openapi::Response> {
axum::Json::<T>::operation_response(ctx, operation)
}
fn inferred_responses(
ctx: &mut aide::generate::GenContext,
operation: &mut aide::openapi::Operation,
) -> Vec<(Option<u16>, aide::openapi::Response)> {
axum::Json::<T>::inferred_responses(ctx, operation)
}
}
#[cfg(feature = "validator")]
impl<T> validator::Validate for Codec<T>
where
T: validator::Validate,
{
fn validate(&self) -> Result<(), validator::ValidationErrors> {
self.0.validate()
}
}
pub struct BorrowCodec<T> {
data: T,
#[allow(dead_code)]
#[doc(hidden)]
bytes: BytesMut,
}
impl<T> BorrowCodec<T> {
pub unsafe fn as_mut_unchecked(&mut self) -> &mut T {
&mut self.data
}
}
impl<T> AsRef<T> for BorrowCodec<T> {
fn as_ref(&self) -> &T {
self
}
}
impl<T> Deref for BorrowCodec<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl<T> fmt::Debug for BorrowCodec<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BorrowCodec")
.field("data", &self.data)
.finish_non_exhaustive()
}
}
impl<T> PartialEq for BorrowCodec<T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.data == other.data
}
}
impl<T> Eq for BorrowCodec<T> where T: Eq {}
impl<T> PartialOrd for BorrowCodec<T>
where
T: PartialOrd,
{
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.data.partial_cmp(&other.data)
}
}
impl<T> Ord for BorrowCodec<T>
where
T: Ord,
{
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.data.cmp(&other.data)
}
}
impl<T> std::hash::Hash for BorrowCodec<T>
where
T: std::hash::Hash,
{
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.data.hash(state);
}
}
impl<'de, T> BorrowCodec<T>
where
T: CodecDecode<'de>,
{
pub fn from_bytes(bytes: BytesMut, content_type: ContentType) -> Result<Self, CodecRejection> {
let data = Codec::<T>::from_bytes(
unsafe { std::slice::from_raw_parts(bytes.as_ptr(), bytes.len()) },
content_type,
)?
.into_inner();
Ok(Self { data, bytes })
}
}
impl<T, S> FromRequest<S> for BorrowCodec<T>
where
T: CodecDecode<'static>,
S: Send + Sync + 'static,
{
type Rejection = Response;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let (mut parts, body) = req.into_parts();
let accept = Accept::from_request_parts(&mut parts, state).await.unwrap();
let req = Request::from_parts(parts, body);
let content_type = req
.headers()
.get(header::CONTENT_TYPE)
.and_then(ContentType::from_header)
.unwrap_or_default();
let bytes = match () {
#[cfg(feature = "form")]
() if content_type == ContentType::Form && req.method() == axum::http::Method::GET => {
req.uri().query().map_or_else(BytesMut::new, BytesMut::from)
}
() => BytesMut::from_request(req, state)
.await
.map_err(|e| CodecRejection::from(e).into_codec_response(accept.into()))?,
};
let data =
Self::from_bytes(bytes, content_type).map_err(|e| e.into_codec_response(accept.into()))?;
#[cfg(feature = "validator")]
data
.as_ref()
.validate()
.map_err(|e| CodecRejection::from(e).into_codec_response(accept.into()))?;
Ok(data)
}
}
#[cfg(feature = "aide")]
impl<T> aide::operation::OperationInput for BorrowCodec<T>
where
T: schemars::JsonSchema,
{
fn operation_input(
ctx: &mut aide::generate::GenContext,
operation: &mut aide::openapi::Operation,
) {
axum::Json::<T>::operation_input(ctx, operation);
}
fn inferred_early_responses(
ctx: &mut aide::generate::GenContext,
operation: &mut aide::openapi::Operation,
) -> Vec<(Option<u16>, aide::openapi::Response)> {
axum::Json::<T>::inferred_early_responses(ctx, operation)
}
}
#[cfg(test)]
mod test {
use super::{Codec, ContentType};
#[crate::apply(decode)]
#[derive(Debug, PartialEq, Eq)]
struct Data {
hello: String,
}
#[test]
fn test_json_codec() {
let bytes = b"{\"hello\": \"world\"}";
let Codec(data) = Codec::<Data>::from_bytes(bytes, ContentType::Json).unwrap();
assert_eq!(data, Data {
hello: "world".into()
});
}
#[test]
fn test_msgpack_codec() {
let bytes = b"\x81\xa5hello\xa5world";
let Codec(data) = Codec::<Data>::from_bytes(bytes, ContentType::MsgPack).unwrap();
assert_eq!(data, Data {
hello: "world".into()
});
}
}
#[cfg(any(test, miri))]
mod miri {
use std::borrow::Cow;
use bytes::Bytes;
use super::*;
#[crate::apply(decode, crate = "crate")]
#[derive(Debug, PartialEq, Eq)]
struct BorrowData<'a> {
#[serde(borrow)]
hello: Cow<'a, str>,
}
#[test]
fn test_zero_copy() {
let bytes = b"{\"hello\": \"world\"}".to_vec();
let data =
BorrowCodec::<BorrowData>::from_bytes(BytesMut::from(Bytes::from(bytes)), ContentType::Json)
.unwrap();
assert_eq!(data.hello, Cow::Borrowed("world"));
}
}