Skip to main content

brainwires_proxy/
types.rs

1use bytes::Bytes;
2use http::{HeaderMap, Method, StatusCode, Uri};
3use std::any::Any;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::request_id::RequestId;
8
9/// The transport protocol over which a request arrived.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
11pub enum TransportKind {
12    Http,
13    WebSocket,
14    Tcp,
15    Unix,
16    Sse,
17}
18
19/// Body payload for proxied messages.
20#[derive(Debug, Clone)]
21pub enum ProxyBody {
22    /// Complete body available in memory.
23    Full(Bytes),
24    /// Empty body.
25    Empty,
26}
27
28impl ProxyBody {
29    pub fn is_empty(&self) -> bool {
30        match self {
31            ProxyBody::Full(b) => b.is_empty(),
32            ProxyBody::Empty => true,
33        }
34    }
35
36    pub fn len(&self) -> usize {
37        match self {
38            ProxyBody::Full(b) => b.len(),
39            ProxyBody::Empty => 0,
40        }
41    }
42
43    pub fn as_bytes(&self) -> &[u8] {
44        match self {
45            ProxyBody::Full(b) => b,
46            ProxyBody::Empty => &[],
47        }
48    }
49
50    pub fn into_bytes(self) -> Bytes {
51        match self {
52            ProxyBody::Full(b) => b,
53            ProxyBody::Empty => Bytes::new(),
54        }
55    }
56}
57
58impl From<Bytes> for ProxyBody {
59    fn from(b: Bytes) -> Self {
60        if b.is_empty() {
61            ProxyBody::Empty
62        } else {
63            ProxyBody::Full(b)
64        }
65    }
66}
67
68impl From<Vec<u8>> for ProxyBody {
69    fn from(v: Vec<u8>) -> Self {
70        Bytes::from(v).into()
71    }
72}
73
74impl From<String> for ProxyBody {
75    fn from(s: String) -> Self {
76        Bytes::from(s).into()
77    }
78}
79
80impl From<&str> for ProxyBody {
81    fn from(s: &str) -> Self {
82        Bytes::copy_from_slice(s.as_bytes()).into()
83    }
84}
85
86/// Type-safe extension map for attaching arbitrary metadata to requests/responses.
87#[derive(Default, Clone)]
88pub struct Extensions {
89    map: HashMap<std::any::TypeId, Arc<dyn Any + Send + Sync>>,
90}
91
92impl Extensions {
93    pub fn new() -> Self {
94        Self::default()
95    }
96
97    pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
98        self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
99    }
100
101    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
102        self.map
103            .get(&std::any::TypeId::of::<T>())
104            .and_then(|v| v.downcast_ref::<T>())
105    }
106}
107
108impl std::fmt::Debug for Extensions {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        f.debug_struct("Extensions")
111            .field("count", &self.map.len())
112            .finish()
113    }
114}
115
116/// A request flowing through the proxy.
117#[derive(Debug, Clone)]
118pub struct ProxyRequest {
119    pub id: RequestId,
120    pub method: Method,
121    pub uri: Uri,
122    pub headers: HeaderMap,
123    pub body: ProxyBody,
124    pub transport: TransportKind,
125    pub timestamp: chrono::DateTime<chrono::Utc>,
126    pub extensions: Extensions,
127}
128
129impl ProxyRequest {
130    pub fn new(method: Method, uri: Uri) -> Self {
131        Self {
132            id: RequestId::new(),
133            method,
134            uri,
135            headers: HeaderMap::new(),
136            body: ProxyBody::Empty,
137            transport: TransportKind::Http,
138            timestamp: chrono::Utc::now(),
139            extensions: Extensions::new(),
140        }
141    }
142
143    pub fn with_body(mut self, body: impl Into<ProxyBody>) -> Self {
144        self.body = body.into();
145        self
146    }
147
148    pub fn with_headers(mut self, headers: HeaderMap) -> Self {
149        self.headers = headers;
150        self
151    }
152
153    pub fn with_transport(mut self, transport: TransportKind) -> Self {
154        self.transport = transport;
155        self
156    }
157}
158
159/// A response flowing back through the proxy.
160#[derive(Debug, Clone)]
161pub struct ProxyResponse {
162    pub id: RequestId,
163    pub status: StatusCode,
164    pub headers: HeaderMap,
165    pub body: ProxyBody,
166    pub timestamp: chrono::DateTime<chrono::Utc>,
167    pub extensions: Extensions,
168}
169
170impl ProxyResponse {
171    pub fn new(status: StatusCode) -> Self {
172        Self {
173            id: RequestId::new(),
174            status,
175            headers: HeaderMap::new(),
176            body: ProxyBody::Empty,
177            timestamp: chrono::Utc::now(),
178            extensions: Extensions::new(),
179        }
180    }
181
182    pub fn for_request(request_id: RequestId, status: StatusCode) -> Self {
183        Self {
184            id: request_id,
185            status,
186            headers: HeaderMap::new(),
187            body: ProxyBody::Empty,
188            timestamp: chrono::Utc::now(),
189            extensions: Extensions::new(),
190        }
191    }
192
193    pub fn with_body(mut self, body: impl Into<ProxyBody>) -> Self {
194        self.body = body.into();
195        self
196    }
197
198    pub fn with_headers(mut self, headers: HeaderMap) -> Self {
199        self.headers = headers;
200        self
201    }
202}
203
204/// Identifier for a body format used by the conversion system.
205#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
206pub struct FormatId(pub String);
207
208impl FormatId {
209    pub fn new(id: impl Into<String>) -> Self {
210        Self(id.into())
211    }
212}
213
214impl std::fmt::Display for FormatId {
215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        f.write_str(&self.0)
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn proxy_body_empty() {
226        let body = ProxyBody::Empty;
227        assert!(body.is_empty());
228        assert_eq!(body.len(), 0);
229        assert_eq!(body.as_bytes(), &[] as &[u8]);
230        assert!(body.into_bytes().is_empty());
231    }
232
233    #[test]
234    fn proxy_body_from_str() {
235        let body = ProxyBody::from("hello");
236        assert!(!body.is_empty());
237        assert_eq!(body.len(), 5);
238        assert_eq!(body.as_bytes(), b"hello");
239    }
240
241    #[test]
242    fn proxy_body_from_string() {
243        let body = ProxyBody::from("world".to_string());
244        assert_eq!(body.as_bytes(), b"world");
245    }
246
247    #[test]
248    fn proxy_body_from_vec() {
249        let body = ProxyBody::from(vec![1, 2, 3]);
250        assert_eq!(body.len(), 3);
251        assert_eq!(body.as_bytes(), &[1, 2, 3]);
252    }
253
254    #[test]
255    fn proxy_body_from_empty_bytes_is_empty_variant() {
256        let body = ProxyBody::from(Bytes::new());
257        assert!(body.is_empty());
258        assert!(matches!(body, ProxyBody::Empty));
259    }
260
261    #[test]
262    fn extensions_insert_and_get() {
263        let mut ext = Extensions::new();
264        ext.insert(42u32);
265        ext.insert("hello".to_string());
266
267        assert_eq!(ext.get::<u32>(), Some(&42));
268        assert_eq!(ext.get::<String>(), Some(&"hello".to_string()));
269        assert_eq!(ext.get::<bool>(), None);
270    }
271
272    #[test]
273    fn proxy_request_builder() {
274        let req = ProxyRequest::new(Method::GET, "/api/test".parse().unwrap())
275            .with_body("request body")
276            .with_transport(TransportKind::Http);
277
278        assert_eq!(req.method, Method::GET);
279        assert_eq!(req.uri, "/api/test");
280        assert_eq!(req.body.as_bytes(), b"request body");
281        assert_eq!(req.transport, TransportKind::Http);
282    }
283
284    #[test]
285    fn proxy_response_for_request() {
286        let req = ProxyRequest::new(Method::POST, "/submit".parse().unwrap());
287        let req_id = req.id.clone();
288        let resp = ProxyResponse::for_request(req.id, StatusCode::OK).with_body("ok");
289
290        assert_eq!(resp.id, req_id);
291        assert_eq!(resp.status, StatusCode::OK);
292        assert_eq!(resp.body.as_bytes(), b"ok");
293    }
294
295    #[test]
296    fn format_id_display() {
297        let fmt = FormatId::new("application/json");
298        assert_eq!(fmt.to_string(), "application/json");
299    }
300
301    #[test]
302    fn transport_kind_serde() {
303        let kind = TransportKind::WebSocket;
304        let json = serde_json::to_string(&kind).unwrap();
305        let deserialized: TransportKind = serde_json::from_str(&json).unwrap();
306        assert_eq!(deserialized, TransportKind::WebSocket);
307    }
308}