Skip to main content

s2_api/
data.rs

1use std::str::FromStr;
2
3use base64ct::{Base64, Encoding as _};
4use bytes::Bytes;
5use s2_common::types::ValidationError;
6
7#[derive(Debug)]
8pub struct Json<T>(pub T);
9
10#[cfg(feature = "axum")]
11impl<T> axum::response::IntoResponse for Json<T>
12where
13    T: serde::Serialize,
14{
15    fn into_response(self) -> axum::response::Response {
16        let Self(value) = self;
17        axum::Json(value).into_response()
18    }
19}
20
21#[derive(Debug)]
22pub struct Proto<T>(pub T);
23
24#[cfg(feature = "axum")]
25impl<T> axum::response::IntoResponse for Proto<T>
26where
27    T: prost::Message,
28{
29    fn into_response(self) -> axum::response::Response {
30        let headers = [(
31            http::header::CONTENT_TYPE,
32            http::header::HeaderValue::from_static("application/protobuf"),
33        )];
34        let body = self.0.encode_to_vec();
35        (headers, body).into_response()
36    }
37}
38
39#[rustfmt::skip]
40#[derive(Debug, Default, Clone, Copy)]
41#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
42pub enum Format {
43    #[default]
44    #[cfg_attr(feature = "utoipa", schema(rename = "raw"))]
45    Raw,
46    #[cfg_attr(feature = "utoipa", schema(rename = "base64"))]
47    Base64,
48}
49
50impl s2_common::http::ParseableHeader for Format {
51    fn name() -> &'static http::HeaderName {
52        &FORMAT_HEADER
53    }
54}
55
56impl Format {
57    pub fn encode(self, bytes: &[u8]) -> String {
58        match self {
59            Format::Raw => String::from_utf8_lossy(bytes).into_owned(),
60            Format::Base64 => Base64::encode_string(bytes),
61        }
62    }
63
64    pub fn decode(self, s: String) -> Result<Bytes, ValidationError> {
65        Ok(match self {
66            Format::Raw => s.into_bytes().into(),
67            Format::Base64 => Base64::decode_vec(&s)
68                .map_err(|_| ValidationError("invalid Base64 encoding".to_owned()))?
69                .into(),
70        })
71    }
72}
73
74impl FromStr for Format {
75    type Err = ValidationError;
76
77    fn from_str(s: &str) -> Result<Self, Self::Err> {
78        match s.trim() {
79            "raw" | "json" => Ok(Self::Raw),
80            "base64" | "json-binsafe" => Ok(Self::Base64),
81            _ => Err(ValidationError(s.to_string())),
82        }
83    }
84}
85
86pub static FORMAT_HEADER: http::HeaderName = http::HeaderName::from_static("s2-format");
87
88#[rustfmt::skip]
89#[cfg_attr(feature = "utoipa", derive(utoipa::IntoParams))]
90#[cfg_attr(feature = "utoipa", into_params(parameter_in = Header))]
91pub struct S2FormatHeader {
92    /// Defines the interpretation of record data (header name, header value, and body) with the JSON content type.
93    /// Use `raw` (default) for efficient transmission and storage of Unicode data — storage will be in UTF-8.
94    /// Use `base64` for safe transmission with efficient storage of binary data.
95    #[cfg_attr(feature = "utoipa", param(required = false, rename = "s2-format"))]
96    pub s2_format: Format,
97}
98
99#[rustfmt::skip]
100#[derive(Debug)]
101#[cfg_attr(feature = "utoipa", derive(utoipa::IntoParams))]
102#[cfg_attr(feature = "utoipa", into_params(parameter_in = Header))]
103pub struct S2EncryptionHeader {
104    /// Optional per-request encryption spec for append/read operations.
105    /// Use `plain` for plaintext, or `<alg>; <base64-key>` where `<alg>` is `aegis-256` or `aes-256-gcm`.
106    #[cfg_attr(feature = "utoipa", param(required = false, rename = "s2-encryption", value_type = String))]
107    pub s2_encryption: String,
108}
109
110#[cfg(feature = "axum")]
111pub mod extract {
112    use std::borrow::Cow;
113
114    use axum::{
115        extract::{FromRequest, OptionalFromRequest, Request, rejection::BytesRejection},
116        response::{IntoResponse, Response},
117    };
118    use bytes::Bytes;
119    use serde::de::DeserializeOwned;
120
121    /// Rejection type for JSON extraction, owned by s2-api.
122    #[derive(Debug)]
123    #[non_exhaustive]
124    pub enum JsonExtractionRejection {
125        SyntaxError {
126            status: http::StatusCode,
127            message: Cow<'static, str>,
128        },
129        DataError {
130            status: http::StatusCode,
131            message: Cow<'static, str>,
132        },
133        MissingContentType,
134        Other {
135            status: http::StatusCode,
136            message: Cow<'static, str>,
137        },
138    }
139
140    const MISSING_CONTENT_TYPE_MSG: &str = "Expected request with `Content-Type: application/json`";
141
142    impl JsonExtractionRejection {
143        pub fn body_text(&self) -> &str {
144            match self {
145                Self::SyntaxError { message, .. }
146                | Self::DataError { message, .. }
147                | Self::Other { message, .. } => message,
148                Self::MissingContentType => MISSING_CONTENT_TYPE_MSG,
149            }
150        }
151
152        pub fn status(&self) -> http::StatusCode {
153            match self {
154                Self::SyntaxError { status, .. }
155                | Self::DataError { status, .. }
156                | Self::Other { status, .. } => *status,
157                Self::MissingContentType => http::StatusCode::UNSUPPORTED_MEDIA_TYPE,
158            }
159        }
160    }
161
162    impl std::fmt::Display for JsonExtractionRejection {
163        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164            f.write_str(self.body_text())
165        }
166    }
167
168    impl std::error::Error for JsonExtractionRejection {}
169
170    impl IntoResponse for JsonExtractionRejection {
171        fn into_response(self) -> Response {
172            let status = self.status();
173            match self {
174                Self::SyntaxError { message, .. }
175                | Self::DataError { message, .. }
176                | Self::Other { message, .. } => match message {
177                    Cow::Borrowed(s) => (status, s).into_response(),
178                    Cow::Owned(s) => (status, s).into_response(),
179                },
180                Self::MissingContentType => (status, MISSING_CONTENT_TYPE_MSG).into_response(),
181            }
182        }
183    }
184
185    // TODO: remove when we stop delegating to axum::Json.
186    impl From<axum::extract::rejection::JsonRejection> for JsonExtractionRejection {
187        fn from(rej: axum::extract::rejection::JsonRejection) -> Self {
188            use axum::extract::rejection::JsonRejection::*;
189            match rej {
190                JsonDataError(e) => Self::DataError {
191                    status: e.status(),
192                    message: e.body_text().into(),
193                },
194                JsonSyntaxError(e) => Self::SyntaxError {
195                    status: e.status(),
196                    message: e.body_text().into(),
197                },
198                MissingJsonContentType(_) => Self::MissingContentType,
199                other => Self::Other {
200                    status: other.status(),
201                    message: other.body_text().into(),
202                },
203            }
204        }
205    }
206
207    impl<S, T> FromRequest<S> for super::Json<T>
208    where
209        S: Send + Sync,
210        T: DeserializeOwned,
211    {
212        type Rejection = JsonExtractionRejection;
213
214        async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
215            let axum::Json(value) = <axum::Json<T> as FromRequest<S>>::from_request(req, state)
216                .await
217                .map_err(JsonExtractionRejection::from)?;
218            Ok(Self(value))
219        }
220    }
221
222    impl<S, T> OptionalFromRequest<S> for super::Json<T>
223    where
224        S: Send + Sync,
225        T: DeserializeOwned,
226    {
227        type Rejection = JsonExtractionRejection;
228
229        async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection> {
230            let Some(ctype) = req.headers().get(http::header::CONTENT_TYPE) else {
231                return Ok(None);
232            };
233            if !crate::mime::parse(ctype)
234                .as_ref()
235                .is_some_and(crate::mime::is_json)
236            {
237                return Err(JsonExtractionRejection::MissingContentType);
238            }
239            let bytes = Bytes::from_request(req, state).await.map_err(|e| {
240                JsonExtractionRejection::Other {
241                    status: e.status(),
242                    message: e.body_text().into(),
243                }
244            })?;
245            if bytes.is_empty() {
246                return Ok(None);
247            }
248            let value = axum::Json::<T>::from_bytes(&bytes)
249                .map_err(JsonExtractionRejection::from)?
250                .0;
251            Ok(Some(Self(value)))
252        }
253    }
254
255    /// Workaround for https://github.com/tokio-rs/axum/issues/3623
256    #[derive(Debug)]
257    pub struct JsonOpt<T>(pub Option<T>);
258
259    impl<S, T> FromRequest<S> for JsonOpt<T>
260    where
261        S: Send + Sync,
262        T: DeserializeOwned,
263    {
264        type Rejection = JsonExtractionRejection;
265
266        async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
267            match <super::Json<T> as OptionalFromRequest<S>>::from_request(req, state).await {
268                Ok(Some(super::Json(value))) => Ok(Self(Some(value))),
269                Ok(None) => Ok(Self(None)),
270                Err(e) => Err(e),
271            }
272        }
273    }
274
275    #[derive(Debug, thiserror::Error)]
276    pub enum ProtoRejection {
277        #[error(transparent)]
278        BytesRejection(#[from] BytesRejection),
279        #[error(transparent)]
280        Decode(#[from] prost::DecodeError),
281    }
282
283    impl IntoResponse for ProtoRejection {
284        fn into_response(self) -> Response {
285            match self {
286                ProtoRejection::BytesRejection(e) => e.into_response(),
287                ProtoRejection::Decode(e) => (
288                    http::StatusCode::BAD_REQUEST,
289                    format!("Invalid protobuf body: {e}"),
290                )
291                    .into_response(),
292            }
293        }
294    }
295
296    impl<S, T> FromRequest<S> for super::Proto<T>
297    where
298        S: Send + Sync,
299        T: prost::Message + Default,
300    {
301        type Rejection = ProtoRejection;
302
303        async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
304            let bytes = Bytes::from_request(req, state).await?;
305            Ok(super::Proto(T::decode(bytes)?))
306        }
307    }
308
309    #[cfg(test)]
310    mod tests {
311        use super::*;
312        use crate::v1::stream::AppendInput;
313
314        fn classify_json_error<T: DeserializeOwned>(
315            json: &[u8],
316        ) -> Result<T, JsonExtractionRejection> {
317            axum::Json::<T>::from_bytes(json)
318                .map(|axum::Json(v)| v)
319                .map_err(JsonExtractionRejection::from)
320        }
321
322        /// Verify that our rejection wrapper preserves axum's status code
323        /// classification for a variety of invalid JSON payloads. This same
324        /// table will be reused when switching to sonic-rs in PR 2.
325        #[test]
326        fn json_error_classification() {
327            let cases: &[(&[u8], http::StatusCode)] = &[
328                // Syntax errors → 400
329                (b"not json", http::StatusCode::BAD_REQUEST),
330                // `{}` is valid JSON but missing `records` — axum reports data error
331                // before checking trailing chars.
332                (b"{} trailing", http::StatusCode::UNPROCESSABLE_ENTITY),
333                (b"", http::StatusCode::BAD_REQUEST),
334                (b"{truncated", http::StatusCode::BAD_REQUEST),
335                // Data errors → 422
336                (b"{}", http::StatusCode::UNPROCESSABLE_ENTITY),
337                (
338                    br#"{"records": "nope"}"#,
339                    http::StatusCode::UNPROCESSABLE_ENTITY,
340                ),
341                (
342                    br#"{"records": [{"body": 123}]}"#,
343                    http::StatusCode::UNPROCESSABLE_ENTITY,
344                ),
345            ];
346
347            for (input, expected_status) in cases {
348                let err = classify_json_error::<AppendInput>(input).expect_err(&format!(
349                    "expected error for {:?}",
350                    String::from_utf8_lossy(input)
351                ));
352                assert_eq!(
353                    err.status(),
354                    *expected_status,
355                    "wrong status for {:?}: got {}, body: {}",
356                    String::from_utf8_lossy(input),
357                    err.status(),
358                    err.body_text(),
359                );
360            }
361        }
362
363        #[test]
364        fn valid_json_parses_successfully() {
365            let input = br#"{"records": [], "match_seq_num": null}"#;
366            let result = classify_json_error::<AppendInput>(input);
367            assert!(result.is_ok());
368        }
369    }
370}