Skip to main content

gateway_runtime/
utilities.rs

1//! # Utilities
2//!
3//! ## Purpose
4//! Shared helper functions and types that don't fit neatly into other modules.
5//!
6//! ## Scope
7//! This module defines:
8//! -   `SyncService`: A thread-safe wrapper for services (enabled with `std` feature).
9//! -   `parse_path_param`: Helper for parsing path parameters.
10//! -   `parse_body`: Helper for parsing request bodies (handles multipart).
11
12#[allow(unused_imports)]
13use crate::alloc;
14use core::str::FromStr;
15
16#[cfg(feature = "std")]
17use std::sync::Mutex;
18
19#[cfg(feature = "std")]
20use crate::codec::Codec;
21#[cfg(feature = "std")]
22use crate::errors::GatewayError;
23#[cfg(feature = "std")]
24use bytes::Bytes;
25#[cfg(feature = "std")]
26use prost::Message;
27#[cfg(feature = "std")]
28use serde::de::DeserializeOwned;
29#[cfg(feature = "std")]
30use std::task::{Context, Poll};
31#[cfg(feature = "std")]
32use tower::Service;
33
34/// A thread-safe wrapper for a service, enabling sharing across threads.
35///
36/// This struct wraps a service in a `Mutex`, making it `Sync` (provided the inner service is `Send`).
37/// It is particularly useful for wrapping `BoxCloneService` which is `!Sync` by default, allowing
38/// it to be stored in a `Router` that is shared via `Arc`.
39#[cfg(feature = "std")]
40#[derive(Debug)]
41pub struct SyncService<S>(pub Mutex<S>);
42
43#[cfg(feature = "std")]
44impl<S> SyncService<S> {
45    /// Creates a new `SyncService` wrapping the given service.
46    pub fn new(service: S) -> Self {
47        Self(Mutex::new(service))
48    }
49
50    /// Acquires a lock on the inner service.
51    ///
52    /// # Panics
53    /// Panics if the lock is poisoned.
54    pub fn get(&self) -> std::sync::MutexGuard<'_, S> {
55        self.0.lock().unwrap()
56    }
57}
58
59#[cfg(feature = "std")]
60impl<S> From<S> for SyncService<S> {
61    fn from(service: S) -> Self {
62        Self::new(service)
63    }
64}
65
66#[cfg(feature = "std")]
67impl<S: Clone> Clone for SyncService<S> {
68    fn clone(&self) -> Self {
69        Self(Mutex::new(self.get().clone()))
70    }
71}
72
73#[cfg(feature = "std")]
74impl<S, Request> Service<Request> for SyncService<S>
75where
76    S: Service<Request>,
77{
78    type Response = S::Response;
79    type Error = S::Error;
80    type Future = S::Future;
81
82    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
83        self.get().poll_ready(cx)
84    }
85
86    fn call(&mut self, req: Request) -> Self::Future {
87        self.get().call(req)
88    }
89}
90
91/// Parses a path parameter string into a target type.
92///
93/// # Parameters
94/// *   `value`: The string value of the path parameter.
95///
96/// # Returns
97/// A `Result` containing the parsed value or the parsing error.
98pub fn parse_path_param<T: FromStr>(value: &str) -> Result<T, T::Err> {
99    value.parse()
100}
101
102/// Parses the request body into a Protobuf message.
103///
104/// This function handles:
105/// - `application/json`: Decodes using the provided codec.
106/// - `multipart/form-data`: Parses multipart parts and maps them to the message fields.
107///
108/// # Parameters
109/// *   `headers`: The request headers.
110/// *   `body`: The request body as bytes.
111/// *   `codec`: The codec to use for decoding.
112///
113/// # Returns
114/// A `Result` containing the parsed message or a `GatewayError`.
115#[cfg(feature = "std")]
116pub async fn parse_body<T, C>(
117    headers: &http::HeaderMap,
118    body: alloc::vec::Vec<u8>,
119    codec: &C,
120) -> Result<T, GatewayError>
121where
122    T: Message + Default + DeserializeOwned,
123    C: Codec,
124{
125    let content_type = headers
126        .get(http::header::CONTENT_TYPE)
127        .and_then(|h| h.to_str().ok())
128        .unwrap_or("");
129
130    if content_type.starts_with("multipart/form-data") {
131        let boundary = multer::parse_boundary(content_type)
132            .map_err(|e| GatewayError::Encoding(Box::new(e)))?;
133
134        let stream = futures::stream::iter(vec![Ok::<Bytes, multer::Error>(Bytes::from(body))]);
135        let mut multipart = multer::Multipart::new(stream, boundary);
136        let mut map = serde_json::Map::new();
137
138        while let Some(mut field) = multipart
139            .next_field()
140            .await
141            .map_err(|e| GatewayError::Encoding(Box::new(e)))?
142        {
143            let name = field.name().map(|s| s.to_string());
144            if let Some(name) = name {
145                let mut data = alloc::vec::Vec::new();
146                while let Some(chunk) = field
147                    .chunk()
148                    .await
149                    .map_err(|e| GatewayError::Encoding(Box::new(e)))?
150                {
151                    data.extend_from_slice(&chunk);
152                }
153
154                // Heuristic: If filename exists, treat as bytes (array of numbers).
155                // If not, try string.
156                if field.file_name().is_some() {
157                    let arr: alloc::vec::Vec<serde_json::Value> = data
158                        .into_iter()
159                        .map(|b| serde_json::Value::Number(serde_json::Number::from(b)))
160                        .collect();
161                    map.insert(name, serde_json::Value::Array(arr));
162                } else {
163                    if let Ok(s) = String::from_utf8(data.clone()) {
164                        map.insert(name, serde_json::Value::String(s));
165                    } else {
166                        let arr: alloc::vec::Vec<serde_json::Value> = data
167                            .into_iter()
168                            .map(|b| serde_json::Value::Number(serde_json::Number::from(b)))
169                            .collect();
170                        map.insert(name, serde_json::Value::Array(arr));
171                    }
172                }
173            }
174        }
175
176        let value = serde_json::Value::Object(map);
177        serde_json::from_value(value).map_err(|e| GatewayError::Encoding(Box::new(e)))
178    } else {
179        let content_type = if content_type.is_empty() {
180            None
181        } else {
182            Some(content_type)
183        };
184        codec.decode(&body, content_type)
185    }
186}
187
188/// Parses the request body into a Protobuf message (no_std fallback).
189#[cfg(not(feature = "std"))]
190pub async fn parse_body<T, C>(
191    headers: &http::HeaderMap,
192    body: alloc::vec::Vec<u8>,
193    codec: &C,
194) -> Result<T, crate::errors::GatewayError>
195where
196    T: prost::Message + Default + serde::de::DeserializeOwned,
197    C: crate::codec::Codec,
198{
199    let content_type = headers
200        .get(http::header::CONTENT_TYPE)
201        .and_then(|h| h.to_str().ok());
202    codec.decode(&body, content_type)
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use http::HeaderMap;
209
210    #[test]
211    fn test_parse_path_param_string() {
212        let res: Result<String, _> = parse_path_param("abc");
213        assert_eq!(res.unwrap(), "abc");
214    }
215
216    #[test]
217    fn test_parse_path_param_int() {
218        let res: Result<i32, _> = parse_path_param("123");
219        assert_eq!(res.unwrap(), 123);
220    }
221
222    #[test]
223    fn test_parse_path_param_invalid() {
224        let res: Result<i32, _> = parse_path_param("abc");
225        assert!(res.is_err());
226    }
227
228    #[tokio::test]
229    async fn test_parse_body_json() {
230        let body = r#"{"foo": "bar"}"#.as_bytes().to_vec();
231        let headers = HeaderMap::new();
232        // Requires codec implementation.
233        struct MockCodec;
234        impl Codec for MockCodec {
235            fn encode<T: Message + serde::Serialize>(
236                &self,
237                _item: &T,
238                _buf: Option<&str>,
239            ) -> Result<crate::bytes::Bytes, GatewayError> {
240                unimplemented!()
241            }
242            fn decode<T: Message + Default + serde::de::DeserializeOwned>(
243                &self,
244                buf: &[u8],
245                _content_type: Option<&str>,
246            ) -> Result<T, GatewayError> {
247                let s = String::from_utf8(buf.to_vec()).unwrap();
248                if s.contains("foo") {
249                    Ok(T::default())
250                } else {
251                    Err(GatewayError::Encoding(Box::new(std::fmt::Error)))
252                }
253            }
254            fn encoder_content_type(&self, _accept: Option<&str>) -> String {
255                "application/json".to_string()
256            }
257        }
258
259        #[derive(serde::Deserialize)]
260        struct Dummy {
261            #[serde(default)]
262            foo: String,
263        }
264        // Manual impls to avoid conflicts
265        impl Default for Dummy {
266            fn default() -> Self {
267                Self { foo: String::new() }
268            }
269        }
270        impl prost::Message for Dummy {
271            fn encode_raw(&self, _buf: &mut impl bytes::BufMut) {}
272            fn merge_field(
273                &mut self,
274                _tag: u32,
275                _wire_type: prost::encoding::WireType,
276                _buf: &mut impl bytes::Buf,
277                _ctx: prost::encoding::DecodeContext,
278            ) -> Result<(), prost::DecodeError> {
279                Ok(())
280            }
281            fn encoded_len(&self) -> usize {
282                0
283            }
284            fn clear(&mut self) {
285                self.foo.clear();
286            }
287        }
288
289        // Codec decode is called.
290        let codec = MockCodec;
291        let res: Result<Dummy, _> = parse_body(&headers, body, &codec).await;
292        assert!(res.is_ok());
293    }
294
295    // SyncService test
296    #[test]
297    fn test_sync_service() {
298        let val = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
299        let val_clone = val.clone();
300
301        // Mock service
302        struct S(std::sync::Arc<std::sync::atomic::AtomicUsize>);
303        impl Service<()> for S {
304            type Response = ();
305            type Error = ();
306            type Future = std::future::Ready<Result<(), ()>>;
307            fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
308                Poll::Ready(Ok(()))
309            }
310            fn call(&mut self, _req: ()) -> Self::Future {
311                self.0.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
312                std::future::ready(Ok(()))
313            }
314        }
315        impl Clone for S {
316            fn clone(&self) -> Self {
317                S(self.0.clone())
318            }
319        }
320
321        let s = S(val);
322        // SyncService wraps S.
323        // We verify Clone works and call works.
324        let mut sync_s = SyncService::new(s);
325        let mut cloned = sync_s.clone();
326
327        let _ = sync_s.call(());
328        let _ = cloned.call(());
329
330        assert_eq!(val_clone.load(std::sync::atomic::Ordering::SeqCst), 2);
331    }
332}