1use std::str::FromStr;
2
3use base64ct::{Base64, Encoding as _};
4use bytes::Bytes;
5use s2_common::types::ValidationError;
6
7#[derive(Debug)]
8pub struct Json<T>(pub T);
9
10#[cfg(feature = "axum")]
11impl<T> axum::response::IntoResponse for Json<T>
12where
13 T: serde::Serialize,
14{
15 fn into_response(self) -> axum::response::Response {
16 let Self(value) = self;
17 axum::Json(value).into_response()
18 }
19}
20
21#[derive(Debug)]
22pub struct Proto<T>(pub T);
23
24#[cfg(feature = "axum")]
25impl<T> axum::response::IntoResponse for Proto<T>
26where
27 T: prost::Message,
28{
29 fn into_response(self) -> axum::response::Response {
30 let headers = [(
31 http::header::CONTENT_TYPE,
32 http::header::HeaderValue::from_static("application/protobuf"),
33 )];
34 let body = self.0.encode_to_vec();
35 (headers, body).into_response()
36 }
37}
38
39#[rustfmt::skip]
40#[derive(Debug, Default, Clone, Copy)]
41#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
42pub enum Format {
43 #[default]
44 #[cfg_attr(feature = "utoipa", schema(rename = "raw"))]
45 Raw,
46 #[cfg_attr(feature = "utoipa", schema(rename = "base64"))]
47 Base64,
48}
49
50impl s2_common::http::ParseableHeader for Format {
51 fn name() -> &'static http::HeaderName {
52 &FORMAT_HEADER
53 }
54}
55
56impl Format {
57 pub fn encode(self, bytes: &[u8]) -> String {
58 match self {
59 Format::Raw => String::from_utf8_lossy(bytes).into_owned(),
60 Format::Base64 => Base64::encode_string(bytes),
61 }
62 }
63
64 pub fn decode(self, s: String) -> Result<Bytes, ValidationError> {
65 Ok(match self {
66 Format::Raw => s.into_bytes().into(),
67 Format::Base64 => Base64::decode_vec(&s)
68 .map_err(|_| ValidationError("invalid Base64 encoding".to_owned()))?
69 .into(),
70 })
71 }
72}
73
74impl FromStr for Format {
75 type Err = ValidationError;
76
77 fn from_str(s: &str) -> Result<Self, Self::Err> {
78 match s.trim() {
79 "raw" | "json" => Ok(Self::Raw),
80 "base64" | "json-binsafe" => Ok(Self::Base64),
81 _ => Err(ValidationError(s.to_string())),
82 }
83 }
84}
85
86pub static FORMAT_HEADER: http::HeaderName = http::HeaderName::from_static("s2-format");
87
88#[rustfmt::skip]
89#[cfg_attr(feature = "utoipa", derive(utoipa::IntoParams))]
90#[cfg_attr(feature = "utoipa", into_params(parameter_in = Header))]
91pub struct S2FormatHeader {
92 #[cfg_attr(feature = "utoipa", param(required = false, rename = "s2-format"))]
96 pub s2_format: Format,
97}
98
99#[cfg(feature = "axum")]
100pub mod extract {
101 use axum::{
102 extract::{
103 FromRequest, OptionalFromRequest, Request,
104 rejection::{BytesRejection, JsonRejection},
105 },
106 response::{IntoResponse, Response},
107 };
108 use bytes::Bytes;
109 use serde::de::DeserializeOwned;
110
111 impl<S, T> FromRequest<S> for super::Json<T>
112 where
113 S: Send + Sync,
114 T: DeserializeOwned,
115 {
116 type Rejection = JsonRejection;
117
118 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
119 let axum::Json(value) =
120 <axum::Json<T> as FromRequest<S>>::from_request(req, state).await?;
121 Ok(Self(value))
122 }
123 }
124
125 impl<S, T> OptionalFromRequest<S> for super::Json<T>
126 where
127 S: Send + Sync,
128 T: DeserializeOwned,
129 {
130 type Rejection = JsonRejection;
131
132 async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection> {
133 let Some(ctype) = req.headers().get(http::header::CONTENT_TYPE) else {
134 return Ok(None);
135 };
136 if !crate::mime::parse(ctype)
137 .as_ref()
138 .is_some_and(crate::mime::is_json)
139 {
140 Err(JsonRejection::MissingJsonContentType(Default::default()))?;
141 }
142 let bytes = Bytes::from_request(req, state)
143 .await
144 .map_err(JsonRejection::BytesRejection)?;
145 if bytes.is_empty() {
146 return Ok(None);
147 }
148 let value = axum::Json::<T>::from_bytes(&bytes)?.0;
149 Ok(Some(Self(value)))
150 }
151 }
152
153 #[derive(Debug)]
155 pub struct JsonOpt<T>(pub Option<T>);
156
157 impl<S, T> FromRequest<S> for JsonOpt<T>
158 where
159 S: Send + Sync,
160 T: DeserializeOwned,
161 {
162 type Rejection = JsonRejection;
163
164 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
165 match <super::Json<T> as OptionalFromRequest<S>>::from_request(req, state).await {
166 Ok(Some(super::Json(value))) => Ok(Self(Some(value))),
167 Ok(None) => Ok(Self(None)),
168 Err(e) => Err(e),
169 }
170 }
171 }
172
173 #[derive(Debug, thiserror::Error)]
174 pub enum ProtoRejection {
175 #[error(transparent)]
176 BytesRejection(#[from] BytesRejection),
177 #[error(transparent)]
178 Decode(#[from] prost::DecodeError),
179 }
180
181 impl IntoResponse for ProtoRejection {
182 fn into_response(self) -> Response {
183 match self {
184 ProtoRejection::BytesRejection(e) => e.into_response(),
185 ProtoRejection::Decode(e) => (
186 http::StatusCode::BAD_REQUEST,
187 format!("Invalid protobuf body: {e}"),
188 )
189 .into_response(),
190 }
191 }
192 }
193
194 impl<S, T> FromRequest<S> for super::Proto<T>
195 where
196 S: Send + Sync,
197 T: prost::Message + Default,
198 {
199 type Rejection = ProtoRejection;
200
201 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
202 let bytes = Bytes::from_request(req, state).await?;
203 Ok(super::Proto(T::decode(bytes)?))
204 }
205 }
206}