1use {crate::{content_types::APPLICATION_PROTOBUF,
6 error::JetError},
7 axum::{async_trait,
8 body::Bytes,
9 extract::{FromRequest,
10 Request},
11 http::header::CONTENT_TYPE},
12 prost::Message,
13 std::marker::PhantomData};
14
15const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
17
18pub struct ProtobufRequest<T>(pub T)
39where
40 T: Message + Default;
41
42#[async_trait]
43impl<S, T> FromRequest<S> for ProtobufRequest<T>
44where
45 S: Send + Sync,
46 T: Message + Default,
47{
48 type Rejection = JetError;
49
50 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
51 let content_type = req
53 .headers()
54 .get(CONTENT_TYPE)
55 .and_then(|v| v.to_str().ok())
56 .unwrap_or("");
57
58 if !content_type.starts_with(APPLICATION_PROTOBUF) {
59 return Err(JetError::InvalidContentType {
60 expected: APPLICATION_PROTOBUF.to_string(),
61 actual: content_type.to_string(),
62 });
63 }
64
65 let bytes = Bytes::from_request(req, state)
67 .await
68 .map_err(|e| JetError::BadRequest(format!("Failed to read body: {}", e)))?;
69
70 if bytes.len() > MAX_BODY_SIZE {
72 return Err(JetError::BodyTooLarge {
73 size: bytes.len(),
74 max: MAX_BODY_SIZE,
75 });
76 }
77
78 let message = T::decode(bytes)?;
80
81 Ok(ProtobufRequest(message))
82 }
83}
84
85pub struct OptionalProtobufRequest<T>(pub Option<T>)
89where
90 T: Message + Default;
91
92#[async_trait]
93impl<S, T> FromRequest<S> for OptionalProtobufRequest<T>
94where
95 S: Send + Sync,
96 T: Message + Default,
97{
98 type Rejection = JetError;
99
100 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
101 let bytes = Bytes::from_request(req, state)
102 .await
103 .map_err(|e| JetError::BadRequest(format!("Failed to read body: {}", e)))?;
104
105 if bytes.is_empty() {
106 return Ok(OptionalProtobufRequest(None));
107 }
108
109 if bytes.len() > MAX_BODY_SIZE {
110 return Err(JetError::BodyTooLarge {
111 size: bytes.len(),
112 max: MAX_BODY_SIZE,
113 });
114 }
115
116 let message = T::decode(bytes)?;
117 Ok(OptionalProtobufRequest(Some(message)))
118 }
119}
120
121pub struct ProtobufRequestWithLimit<T, const LIMIT: usize>(pub T, PhantomData<T>)
123where
124 T: Message + Default;
125
126#[async_trait]
127impl<S, T, const LIMIT: usize> FromRequest<S> for ProtobufRequestWithLimit<T, LIMIT>
128where
129 S: Send + Sync,
130 T: Message + Default,
131{
132 type Rejection = JetError;
133
134 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
135 let bytes = Bytes::from_request(req, state)
136 .await
137 .map_err(|e| JetError::BadRequest(format!("Failed to read body: {}", e)))?;
138
139 if bytes.len() > LIMIT {
140 return Err(JetError::BodyTooLarge {
141 size: bytes.len(),
142 max: LIMIT,
143 });
144 }
145
146 let message = T::decode(bytes)?;
147 Ok(ProtobufRequestWithLimit(message, PhantomData))
148 }
149}