at_jet/
extractors.rs

1//! Request extractors for AT-Jet
2//!
3//! Provides Protobuf-aware request body extractors for axum handlers.
4
5use {crate::{content_types::APPLICATION_PROTOBUF,
6             error::JetError},
7     axum::{async_trait,
8            body::Bytes,
9            extract::{FromRequest,
10                      Request},
11            http::header::CONTENT_TYPE},
12     prost::Message,
13     std::marker::PhantomData};
14
15/// Maximum request body size (10MB default)
16const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
17
18/// Protobuf request body extractor
19///
20/// Extracts and decodes a Protobuf message from the request body.
21///
22/// # Example
23///
24/// ```rust,ignore
25/// use at_jet::prelude::*;
26///
27/// async fn create_user(
28///     ProtobufRequest(request): ProtobufRequest<CreateUserRequest>
29/// ) -> ProtobufResponse<User> {
30///     // request is already decoded
31///     let user = User {
32///         id: 1,
33///         name: request.name,
34///     };
35///     ProtobufResponse::ok(user)
36/// }
37/// ```
38pub struct ProtobufRequest<T>(pub T)
39where
40  T: Message + Default;
41
42#[async_trait]
43impl<S, T> FromRequest<S> for ProtobufRequest<T>
44where
45  S: Send + Sync,
46  T: Message + Default,
47{
48  type Rejection = JetError;
49
50  async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
51    // Check content type
52    let content_type = req
53      .headers()
54      .get(CONTENT_TYPE)
55      .and_then(|v| v.to_str().ok())
56      .unwrap_or("");
57
58    if !content_type.starts_with(APPLICATION_PROTOBUF) {
59      return Err(JetError::InvalidContentType {
60        expected: APPLICATION_PROTOBUF.to_string(),
61        actual:   content_type.to_string(),
62      });
63    }
64
65    // Extract body
66    let bytes = Bytes::from_request(req, state)
67      .await
68      .map_err(|e| JetError::BadRequest(format!("Failed to read body: {}", e)))?;
69
70    // Check size
71    if bytes.len() > MAX_BODY_SIZE {
72      return Err(JetError::BodyTooLarge {
73        size: bytes.len(),
74        max:  MAX_BODY_SIZE,
75      });
76    }
77
78    // Decode protobuf
79    let message = T::decode(bytes)?;
80
81    Ok(ProtobufRequest(message))
82  }
83}
84
85/// Optional Protobuf request body extractor
86///
87/// Like `ProtobufRequest`, but returns `None` if the body is empty.
88pub struct OptionalProtobufRequest<T>(pub Option<T>)
89where
90  T: Message + Default;
91
92#[async_trait]
93impl<S, T> FromRequest<S> for OptionalProtobufRequest<T>
94where
95  S: Send + Sync,
96  T: Message + Default,
97{
98  type Rejection = JetError;
99
100  async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
101    let bytes = Bytes::from_request(req, state)
102      .await
103      .map_err(|e| JetError::BadRequest(format!("Failed to read body: {}", e)))?;
104
105    if bytes.is_empty() {
106      return Ok(OptionalProtobufRequest(None));
107    }
108
109    if bytes.len() > MAX_BODY_SIZE {
110      return Err(JetError::BodyTooLarge {
111        size: bytes.len(),
112        max:  MAX_BODY_SIZE,
113      });
114    }
115
116    let message = T::decode(bytes)?;
117    Ok(OptionalProtobufRequest(Some(message)))
118  }
119}
120
121/// Protobuf request with configurable max size
122pub struct ProtobufRequestWithLimit<T, const LIMIT: usize>(pub T, PhantomData<T>)
123where
124  T: Message + Default;
125
126#[async_trait]
127impl<S, T, const LIMIT: usize> FromRequest<S> for ProtobufRequestWithLimit<T, LIMIT>
128where
129  S: Send + Sync,
130  T: Message + Default,
131{
132  type Rejection = JetError;
133
134  async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
135    let bytes = Bytes::from_request(req, state)
136      .await
137      .map_err(|e| JetError::BadRequest(format!("Failed to read body: {}", e)))?;
138
139    if bytes.len() > LIMIT {
140      return Err(JetError::BodyTooLarge {
141        size: bytes.len(),
142        max:  LIMIT,
143      });
144    }
145
146    let message = T::decode(bytes)?;
147    Ok(ProtobufRequestWithLimit(message, PhantomData))
148  }
149}