1use std::str::FromStr;
2
3use base64ct::{Base64, Encoding as _};
4use bytes::Bytes;
5use s2_common::types::ValidationError;
6
7#[derive(Debug)]
8pub struct Json<T>(pub T);
9
10#[cfg(feature = "axum")]
11impl<T> axum::response::IntoResponse for Json<T>
12where
13 T: serde::Serialize,
14{
15 fn into_response(self) -> axum::response::Response {
16 let Self(value) = self;
17 axum::Json(value).into_response()
18 }
19}
20
21#[derive(Debug)]
22pub struct Proto<T>(pub T);
23
24#[cfg(feature = "axum")]
25impl<T> axum::response::IntoResponse for Proto<T>
26where
27 T: prost::Message,
28{
29 fn into_response(self) -> axum::response::Response {
30 let headers = [(
31 http::header::CONTENT_TYPE,
32 http::header::HeaderValue::from_static("application/protobuf"),
33 )];
34 let body = self.0.encode_to_vec();
35 (headers, body).into_response()
36 }
37}
38
39#[rustfmt::skip]
40#[derive(Debug, Default, Clone, Copy)]
41#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
42pub enum Format {
43 #[default]
44 #[cfg_attr(feature = "utoipa", schema(rename = "raw"))]
45 Raw,
46 #[cfg_attr(feature = "utoipa", schema(rename = "base64"))]
47 Base64,
48}
49
50impl s2_common::http::ParseableHeader for Format {
51 fn name() -> &'static http::HeaderName {
52 &FORMAT_HEADER
53 }
54}
55
56impl Format {
57 pub fn encode(self, bytes: &[u8]) -> String {
58 match self {
59 Format::Raw => String::from_utf8_lossy(bytes).into_owned(),
60 Format::Base64 => Base64::encode_string(bytes),
61 }
62 }
63
64 pub fn decode(self, s: String) -> Result<Bytes, ValidationError> {
65 Ok(match self {
66 Format::Raw => s.into_bytes().into(),
67 Format::Base64 => Base64::decode_vec(&s)
68 .map_err(|_| ValidationError("invalid Base64 encoding".to_owned()))?
69 .into(),
70 })
71 }
72}
73
74impl FromStr for Format {
75 type Err = ValidationError;
76
77 fn from_str(s: &str) -> Result<Self, Self::Err> {
78 match s.trim() {
79 "raw" | "json" => Ok(Self::Raw),
80 "base64" | "json-binsafe" => Ok(Self::Base64),
81 _ => Err(ValidationError(s.to_string())),
82 }
83 }
84}
85
86pub static FORMAT_HEADER: http::HeaderName = http::HeaderName::from_static("s2-format");
87
88#[rustfmt::skip]
89#[cfg_attr(feature = "utoipa", derive(utoipa::IntoParams))]
90#[cfg_attr(feature = "utoipa", into_params(parameter_in = Header))]
91pub struct S2FormatHeader {
92 #[cfg_attr(feature = "utoipa", param(required = false, rename = "s2-format"))]
96 pub s2_format: Format,
97}
98
99#[rustfmt::skip]
100#[derive(Debug)]
101#[cfg_attr(feature = "utoipa", derive(utoipa::IntoParams))]
102#[cfg_attr(feature = "utoipa", into_params(parameter_in = Header))]
103pub struct S2EncryptionKeyHeader {
104 #[cfg_attr(feature = "utoipa", param(required = false, rename = "s2-encryption-key", value_type = String))]
107 pub s2_encryption_key: String,
108}
109
110#[cfg(feature = "axum")]
111pub mod extract {
112 use std::borrow::Cow;
113
114 use axum::{
115 extract::{FromRequest, OptionalFromRequest, Request, rejection::BytesRejection},
116 response::{IntoResponse, Response},
117 };
118 use bytes::Bytes;
119 use serde::de::DeserializeOwned;
120
121 #[derive(Debug)]
123 #[non_exhaustive]
124 pub enum JsonExtractionRejection {
125 SyntaxError {
126 status: http::StatusCode,
127 message: Cow<'static, str>,
128 },
129 DataError {
130 status: http::StatusCode,
131 message: Cow<'static, str>,
132 },
133 MissingContentType,
134 Other {
135 status: http::StatusCode,
136 message: Cow<'static, str>,
137 },
138 }
139
140 const MISSING_CONTENT_TYPE_MSG: &str = "Expected request with `Content-Type: application/json`";
141
142 impl JsonExtractionRejection {
143 pub fn body_text(&self) -> &str {
144 match self {
145 Self::SyntaxError { message, .. }
146 | Self::DataError { message, .. }
147 | Self::Other { message, .. } => message,
148 Self::MissingContentType => MISSING_CONTENT_TYPE_MSG,
149 }
150 }
151
152 pub fn status(&self) -> http::StatusCode {
153 match self {
154 Self::SyntaxError { status, .. }
155 | Self::DataError { status, .. }
156 | Self::Other { status, .. } => *status,
157 Self::MissingContentType => http::StatusCode::UNSUPPORTED_MEDIA_TYPE,
158 }
159 }
160 }
161
162 impl std::fmt::Display for JsonExtractionRejection {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.write_str(self.body_text())
165 }
166 }
167
168 impl std::error::Error for JsonExtractionRejection {}
169
170 impl IntoResponse for JsonExtractionRejection {
171 fn into_response(self) -> Response {
172 let status = self.status();
173 match self {
174 Self::SyntaxError { message, .. }
175 | Self::DataError { message, .. }
176 | Self::Other { message, .. } => match message {
177 Cow::Borrowed(s) => (status, s).into_response(),
178 Cow::Owned(s) => (status, s).into_response(),
179 },
180 Self::MissingContentType => (status, MISSING_CONTENT_TYPE_MSG).into_response(),
181 }
182 }
183 }
184
185 fn classify_sonic_error(err: sonic_rs::Error) -> JsonExtractionRejection {
186 use sonic_rs::error::Category;
187 match err.classify() {
188 Category::TypeUnmatched | Category::NotFound => JsonExtractionRejection::DataError {
189 status: http::StatusCode::UNPROCESSABLE_ENTITY,
190 message: err.to_string().into(),
191 },
192 Category::Io => JsonExtractionRejection::Other {
193 status: http::StatusCode::INTERNAL_SERVER_ERROR,
194 message: err.to_string().into(),
195 },
196 _ => JsonExtractionRejection::SyntaxError {
197 status: http::StatusCode::BAD_REQUEST,
198 message: err.to_string().into(),
199 },
200 }
201 }
202
203 impl<S, T> FromRequest<S> for super::Json<T>
204 where
205 S: Send + Sync,
206 T: DeserializeOwned,
207 {
208 type Rejection = JsonExtractionRejection;
209
210 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
211 let Some(ctype) = req.headers().get(http::header::CONTENT_TYPE) else {
212 return Err(JsonExtractionRejection::MissingContentType);
213 };
214 if !crate::mime::parse(ctype)
215 .as_ref()
216 .is_some_and(crate::mime::is_json)
217 {
218 return Err(JsonExtractionRejection::MissingContentType);
219 }
220 let bytes = Bytes::from_request(req, state).await.map_err(|e| {
221 JsonExtractionRejection::Other {
222 status: e.status(),
223 message: e.body_text().into(),
224 }
225 })?;
226 sonic_rs::from_slice(&bytes)
227 .map(Self)
228 .map_err(classify_sonic_error)
229 }
230 }
231
232 impl<S, T> OptionalFromRequest<S> for super::Json<T>
233 where
234 S: Send + Sync,
235 T: DeserializeOwned,
236 {
237 type Rejection = JsonExtractionRejection;
238
239 async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection> {
240 let Some(ctype) = req.headers().get(http::header::CONTENT_TYPE) else {
241 return Ok(None);
242 };
243 if !crate::mime::parse(ctype)
244 .as_ref()
245 .is_some_and(crate::mime::is_json)
246 {
247 return Err(JsonExtractionRejection::MissingContentType);
248 }
249 let bytes = Bytes::from_request(req, state).await.map_err(|e| {
250 JsonExtractionRejection::Other {
251 status: e.status(),
252 message: e.body_text().into(),
253 }
254 })?;
255 if bytes.is_empty() {
256 return Ok(None);
257 }
258 sonic_rs::from_slice(&bytes)
259 .map(|v| Some(Self(v)))
260 .map_err(classify_sonic_error)
261 }
262 }
263
264 #[derive(Debug)]
266 pub struct JsonOpt<T>(pub Option<T>);
267
268 impl<S, T> FromRequest<S> for JsonOpt<T>
269 where
270 S: Send + Sync,
271 T: DeserializeOwned,
272 {
273 type Rejection = JsonExtractionRejection;
274
275 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
276 match <super::Json<T> as OptionalFromRequest<S>>::from_request(req, state).await {
277 Ok(Some(super::Json(value))) => Ok(Self(Some(value))),
278 Ok(None) => Ok(Self(None)),
279 Err(e) => Err(e),
280 }
281 }
282 }
283
284 #[derive(Debug, thiserror::Error)]
285 pub enum ProtoRejection {
286 #[error(transparent)]
287 BytesRejection(#[from] BytesRejection),
288 #[error(transparent)]
289 Decode(#[from] prost::DecodeError),
290 }
291
292 impl IntoResponse for ProtoRejection {
293 fn into_response(self) -> Response {
294 match self {
295 ProtoRejection::BytesRejection(e) => e.into_response(),
296 ProtoRejection::Decode(e) => (
297 http::StatusCode::BAD_REQUEST,
298 format!("Invalid protobuf body: {e}"),
299 )
300 .into_response(),
301 }
302 }
303 }
304
305 impl<S, T> FromRequest<S> for super::Proto<T>
306 where
307 S: Send + Sync,
308 T: prost::Message + Default,
309 {
310 type Rejection = ProtoRejection;
311
312 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
313 let bytes = Bytes::from_request(req, state).await?;
314 Ok(super::Proto(T::decode(bytes)?))
315 }
316 }
317
318 #[cfg(test)]
319 mod tests {
320 use super::*;
321 use crate::v1::{
322 config::{BasinReconfiguration, StreamReconfiguration},
323 stream::{AppendInput, AppendRecord, Header},
324 };
325
326 fn classify_json_error<T: DeserializeOwned>(
327 json: &[u8],
328 ) -> Result<T, JsonExtractionRejection> {
329 sonic_rs::from_slice(json).map_err(classify_sonic_error)
330 }
331
332 #[test]
336 fn json_error_classification() {
337 let cases: &[(&[u8], http::StatusCode)] = &[
338 (b"not json", http::StatusCode::BAD_REQUEST),
340 (b"{} trailing", http::StatusCode::UNPROCESSABLE_ENTITY),
343 (b"", http::StatusCode::BAD_REQUEST),
344 (b"{truncated", http::StatusCode::BAD_REQUEST),
345 (b"{}", http::StatusCode::UNPROCESSABLE_ENTITY),
347 (
348 br#"{"records": "nope"}"#,
349 http::StatusCode::UNPROCESSABLE_ENTITY,
350 ),
351 (
352 br#"{"records": [{"body": 123}]}"#,
353 http::StatusCode::UNPROCESSABLE_ENTITY,
354 ),
355 ];
356
357 for (input, expected_status) in cases {
358 let err = classify_json_error::<AppendInput>(input).expect_err(&format!(
359 "expected error for {:?}",
360 String::from_utf8_lossy(input)
361 ));
362 assert_eq!(
363 err.status(),
364 *expected_status,
365 "wrong status for {:?}: got {}, body: {}",
366 String::from_utf8_lossy(input),
367 err.status(),
368 err.body_text(),
369 );
370 }
371 }
372
373 #[test]
374 fn valid_json_parses_successfully() {
375 let input = br#"{"records": [], "match_seq_num": null}"#;
376 let result = classify_json_error::<AppendInput>(input);
377 assert!(result.is_ok());
378 }
379
380 #[test]
383 fn serde_json_sonic_rs_roundtrip() {
384 fn assert_roundtrip<T>(input: &T)
385 where
386 T: serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug,
387 {
388 let json = serde_json::to_vec(input).unwrap();
389 let from_serde: T = serde_json::from_slice(&json).unwrap();
390 let from_sonic: T = sonic_rs::from_slice(&json).unwrap();
391 assert_eq!(
392 format!("{from_serde:?}"),
393 format!("{from_sonic:?}"),
394 "roundtrip mismatch for {}",
395 String::from_utf8_lossy(&json),
396 );
397 }
398
399 assert_roundtrip(&AppendInput {
401 records: vec![],
402 match_seq_num: None,
403 fencing_token: None,
404 });
405 assert_roundtrip(&AppendInput {
406 records: vec![AppendRecord {
407 timestamp: None,
408 headers: vec![Header("key".into(), "val".into())],
409 body: "hello world".into(),
410 }],
411 match_seq_num: Some(42),
412 fencing_token: Some("token".parse().unwrap()),
413 });
414
415 use s2_common::maybe::Maybe;
417
418 use crate::v1::config::{StorageClass, TimestampingMode, TimestampingReconfiguration};
419
420 assert_roundtrip(&StreamReconfiguration {
422 storage_class: Maybe::Unspecified,
423 retention_policy: Maybe::Unspecified,
424 timestamping: Maybe::Unspecified,
425 delete_on_empty: Maybe::Unspecified,
426 });
427 assert_roundtrip(&StreamReconfiguration {
429 storage_class: Maybe::Specified(Some(StorageClass::Express)),
430 retention_policy: Maybe::Specified(None),
431 timestamping: Maybe::Specified(Some(TimestampingReconfiguration {
432 mode: Maybe::Specified(Some(TimestampingMode::ClientRequire)),
433 uncapped: Maybe::Specified(Some(true)),
434 })),
435 delete_on_empty: Maybe::Unspecified,
436 });
437
438 assert_roundtrip(&BasinReconfiguration {
440 default_stream_config: Maybe::Specified(None),
441 stream_cipher: Maybe::Unspecified,
442 create_stream_on_append: Maybe::Specified(true),
443 create_stream_on_read: Maybe::Unspecified,
444 });
445 }
446 }
447}