gateway_runtime/
utilities.rs1#[allow(unused_imports)]
13use crate::alloc;
14use core::str::FromStr;
15
16#[cfg(feature = "std")]
17use std::sync::Mutex;
18
19#[cfg(feature = "std")]
20use crate::codec::Codec;
21#[cfg(feature = "std")]
22use crate::errors::GatewayError;
23#[cfg(feature = "std")]
24use bytes::Bytes;
25#[cfg(feature = "std")]
26use prost::Message;
27#[cfg(feature = "std")]
28use serde::de::DeserializeOwned;
29#[cfg(feature = "std")]
30use std::task::{Context, Poll};
31#[cfg(feature = "std")]
32use tower::Service;
33
34#[cfg(feature = "std")]
40#[derive(Debug)]
41pub struct SyncService<S>(pub Mutex<S>);
42
43#[cfg(feature = "std")]
44impl<S> SyncService<S> {
45 pub fn new(service: S) -> Self {
47 Self(Mutex::new(service))
48 }
49
50 pub fn get(&self) -> std::sync::MutexGuard<'_, S> {
55 self.0.lock().unwrap()
56 }
57}
58
59#[cfg(feature = "std")]
60impl<S> From<S> for SyncService<S> {
61 fn from(service: S) -> Self {
62 Self::new(service)
63 }
64}
65
66#[cfg(feature = "std")]
67impl<S: Clone> Clone for SyncService<S> {
68 fn clone(&self) -> Self {
69 Self(Mutex::new(self.get().clone()))
70 }
71}
72
73#[cfg(feature = "std")]
74impl<S, Request> Service<Request> for SyncService<S>
75where
76 S: Service<Request>,
77{
78 type Response = S::Response;
79 type Error = S::Error;
80 type Future = S::Future;
81
82 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
83 self.get().poll_ready(cx)
84 }
85
86 fn call(&mut self, req: Request) -> Self::Future {
87 self.get().call(req)
88 }
89}
90
91pub fn parse_path_param<T: FromStr>(value: &str) -> Result<T, T::Err> {
99 value.parse()
100}
101
102#[cfg(feature = "std")]
116pub async fn parse_body<T, C>(
117 headers: &http::HeaderMap,
118 body: alloc::vec::Vec<u8>,
119 codec: &C,
120) -> Result<T, GatewayError>
121where
122 T: Message + Default + DeserializeOwned,
123 C: Codec,
124{
125 let content_type = headers
126 .get(http::header::CONTENT_TYPE)
127 .and_then(|h| h.to_str().ok())
128 .unwrap_or("");
129
130 if content_type.starts_with("multipart/form-data") {
131 let boundary = multer::parse_boundary(content_type)
132 .map_err(|e| GatewayError::Encoding(Box::new(e)))?;
133
134 let stream = futures::stream::iter(vec![Ok::<Bytes, multer::Error>(Bytes::from(body))]);
135 let mut multipart = multer::Multipart::new(stream, boundary);
136 let mut map = serde_json::Map::new();
137
138 while let Some(mut field) = multipart
139 .next_field()
140 .await
141 .map_err(|e| GatewayError::Encoding(Box::new(e)))?
142 {
143 let name = field.name().map(|s| s.to_string());
144 if let Some(name) = name {
145 let mut data = alloc::vec::Vec::new();
146 while let Some(chunk) = field
147 .chunk()
148 .await
149 .map_err(|e| GatewayError::Encoding(Box::new(e)))?
150 {
151 data.extend_from_slice(&chunk);
152 }
153
154 if field.file_name().is_some() {
157 let arr: alloc::vec::Vec<serde_json::Value> = data
158 .into_iter()
159 .map(|b| serde_json::Value::Number(serde_json::Number::from(b)))
160 .collect();
161 map.insert(name, serde_json::Value::Array(arr));
162 } else {
163 if let Ok(s) = String::from_utf8(data.clone()) {
164 map.insert(name, serde_json::Value::String(s));
165 } else {
166 let arr: alloc::vec::Vec<serde_json::Value> = data
167 .into_iter()
168 .map(|b| serde_json::Value::Number(serde_json::Number::from(b)))
169 .collect();
170 map.insert(name, serde_json::Value::Array(arr));
171 }
172 }
173 }
174 }
175
176 let value = serde_json::Value::Object(map);
177 serde_json::from_value(value).map_err(|e| GatewayError::Encoding(Box::new(e)))
178 } else {
179 let content_type = if content_type.is_empty() {
180 None
181 } else {
182 Some(content_type)
183 };
184 codec.decode(&body, content_type)
185 }
186}
187
188#[cfg(not(feature = "std"))]
190pub async fn parse_body<T, C>(
191 headers: &http::HeaderMap,
192 body: alloc::vec::Vec<u8>,
193 codec: &C,
194) -> Result<T, crate::errors::GatewayError>
195where
196 T: prost::Message + Default + serde::de::DeserializeOwned,
197 C: crate::codec::Codec,
198{
199 let content_type = headers
200 .get(http::header::CONTENT_TYPE)
201 .and_then(|h| h.to_str().ok());
202 codec.decode(&body, content_type)
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use http::HeaderMap;
209
210 #[test]
211 fn test_parse_path_param_string() {
212 let res: Result<String, _> = parse_path_param("abc");
213 assert_eq!(res.unwrap(), "abc");
214 }
215
216 #[test]
217 fn test_parse_path_param_int() {
218 let res: Result<i32, _> = parse_path_param("123");
219 assert_eq!(res.unwrap(), 123);
220 }
221
222 #[test]
223 fn test_parse_path_param_invalid() {
224 let res: Result<i32, _> = parse_path_param("abc");
225 assert!(res.is_err());
226 }
227
228 #[tokio::test]
229 async fn test_parse_body_json() {
230 let body = r#"{"foo": "bar"}"#.as_bytes().to_vec();
231 let headers = HeaderMap::new();
232 struct MockCodec;
234 impl Codec for MockCodec {
235 fn encode<T: Message + serde::Serialize>(
236 &self,
237 _item: &T,
238 _buf: Option<&str>,
239 ) -> Result<crate::bytes::Bytes, GatewayError> {
240 unimplemented!()
241 }
242 fn decode<T: Message + Default + serde::de::DeserializeOwned>(
243 &self,
244 buf: &[u8],
245 _content_type: Option<&str>,
246 ) -> Result<T, GatewayError> {
247 let s = String::from_utf8(buf.to_vec()).unwrap();
248 if s.contains("foo") {
249 Ok(T::default())
250 } else {
251 Err(GatewayError::Encoding(Box::new(std::fmt::Error)))
252 }
253 }
254 fn encoder_content_type(&self, _accept: Option<&str>) -> String {
255 "application/json".to_string()
256 }
257 }
258
259 #[derive(serde::Deserialize)]
260 struct Dummy {
261 #[serde(default)]
262 foo: String,
263 }
264 impl Default for Dummy {
266 fn default() -> Self {
267 Self { foo: String::new() }
268 }
269 }
270 impl prost::Message for Dummy {
271 fn encode_raw(&self, _buf: &mut impl bytes::BufMut) {}
272 fn merge_field(
273 &mut self,
274 _tag: u32,
275 _wire_type: prost::encoding::WireType,
276 _buf: &mut impl bytes::Buf,
277 _ctx: prost::encoding::DecodeContext,
278 ) -> Result<(), prost::DecodeError> {
279 Ok(())
280 }
281 fn encoded_len(&self) -> usize {
282 0
283 }
284 fn clear(&mut self) {
285 self.foo.clear();
286 }
287 }
288
289 let codec = MockCodec;
291 let res: Result<Dummy, _> = parse_body(&headers, body, &codec).await;
292 assert!(res.is_ok());
293 }
294
295 #[test]
297 fn test_sync_service() {
298 let val = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
299 let val_clone = val.clone();
300
301 struct S(std::sync::Arc<std::sync::atomic::AtomicUsize>);
303 impl Service<()> for S {
304 type Response = ();
305 type Error = ();
306 type Future = std::future::Ready<Result<(), ()>>;
307 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
308 Poll::Ready(Ok(()))
309 }
310 fn call(&mut self, _req: ()) -> Self::Future {
311 self.0.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
312 std::future::ready(Ok(()))
313 }
314 }
315 impl Clone for S {
316 fn clone(&self) -> Self {
317 S(self.0.clone())
318 }
319 }
320
321 let s = S(val);
322 let mut sync_s = SyncService::new(s);
325 let mut cloned = sync_s.clone();
326
327 let _ = sync_s.call(());
328 let _ = cloned.call(());
329
330 assert_eq!(val_clone.load(std::sync::atomic::Ordering::SeqCst), 2);
331 }
332}