1use bytes::Bytes;
2use http::{HeaderMap, Method, StatusCode, Uri};
3use std::any::Any;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::request_id::RequestId;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
11pub enum TransportKind {
12 Http,
13 WebSocket,
14 Tcp,
15 Unix,
16 Sse,
17}
18
19#[derive(Debug, Clone)]
21pub enum ProxyBody {
22 Full(Bytes),
24 Empty,
26}
27
28impl ProxyBody {
29 pub fn is_empty(&self) -> bool {
30 match self {
31 ProxyBody::Full(b) => b.is_empty(),
32 ProxyBody::Empty => true,
33 }
34 }
35
36 pub fn len(&self) -> usize {
37 match self {
38 ProxyBody::Full(b) => b.len(),
39 ProxyBody::Empty => 0,
40 }
41 }
42
43 pub fn as_bytes(&self) -> &[u8] {
44 match self {
45 ProxyBody::Full(b) => b,
46 ProxyBody::Empty => &[],
47 }
48 }
49
50 pub fn into_bytes(self) -> Bytes {
51 match self {
52 ProxyBody::Full(b) => b,
53 ProxyBody::Empty => Bytes::new(),
54 }
55 }
56}
57
58impl From<Bytes> for ProxyBody {
59 fn from(b: Bytes) -> Self {
60 if b.is_empty() {
61 ProxyBody::Empty
62 } else {
63 ProxyBody::Full(b)
64 }
65 }
66}
67
68impl From<Vec<u8>> for ProxyBody {
69 fn from(v: Vec<u8>) -> Self {
70 Bytes::from(v).into()
71 }
72}
73
74impl From<String> for ProxyBody {
75 fn from(s: String) -> Self {
76 Bytes::from(s).into()
77 }
78}
79
80impl From<&str> for ProxyBody {
81 fn from(s: &str) -> Self {
82 Bytes::copy_from_slice(s.as_bytes()).into()
83 }
84}
85
86#[derive(Default, Clone)]
88pub struct Extensions {
89 map: HashMap<std::any::TypeId, Arc<dyn Any + Send + Sync>>,
90}
91
92impl Extensions {
93 pub fn new() -> Self {
94 Self::default()
95 }
96
97 pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
98 self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
99 }
100
101 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
102 self.map
103 .get(&std::any::TypeId::of::<T>())
104 .and_then(|v| v.downcast_ref::<T>())
105 }
106}
107
108impl std::fmt::Debug for Extensions {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 f.debug_struct("Extensions")
111 .field("count", &self.map.len())
112 .finish()
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct ProxyRequest {
119 pub id: RequestId,
120 pub method: Method,
121 pub uri: Uri,
122 pub headers: HeaderMap,
123 pub body: ProxyBody,
124 pub transport: TransportKind,
125 pub timestamp: chrono::DateTime<chrono::Utc>,
126 pub extensions: Extensions,
127}
128
129impl ProxyRequest {
130 pub fn new(method: Method, uri: Uri) -> Self {
131 Self {
132 id: RequestId::new(),
133 method,
134 uri,
135 headers: HeaderMap::new(),
136 body: ProxyBody::Empty,
137 transport: TransportKind::Http,
138 timestamp: chrono::Utc::now(),
139 extensions: Extensions::new(),
140 }
141 }
142
143 pub fn with_body(mut self, body: impl Into<ProxyBody>) -> Self {
144 self.body = body.into();
145 self
146 }
147
148 pub fn with_headers(mut self, headers: HeaderMap) -> Self {
149 self.headers = headers;
150 self
151 }
152
153 pub fn with_transport(mut self, transport: TransportKind) -> Self {
154 self.transport = transport;
155 self
156 }
157}
158
159#[derive(Debug, Clone)]
161pub struct ProxyResponse {
162 pub id: RequestId,
163 pub status: StatusCode,
164 pub headers: HeaderMap,
165 pub body: ProxyBody,
166 pub timestamp: chrono::DateTime<chrono::Utc>,
167 pub extensions: Extensions,
168}
169
170impl ProxyResponse {
171 pub fn new(status: StatusCode) -> Self {
172 Self {
173 id: RequestId::new(),
174 status,
175 headers: HeaderMap::new(),
176 body: ProxyBody::Empty,
177 timestamp: chrono::Utc::now(),
178 extensions: Extensions::new(),
179 }
180 }
181
182 pub fn for_request(request_id: RequestId, status: StatusCode) -> Self {
183 Self {
184 id: request_id,
185 status,
186 headers: HeaderMap::new(),
187 body: ProxyBody::Empty,
188 timestamp: chrono::Utc::now(),
189 extensions: Extensions::new(),
190 }
191 }
192
193 pub fn with_body(mut self, body: impl Into<ProxyBody>) -> Self {
194 self.body = body.into();
195 self
196 }
197
198 pub fn with_headers(mut self, headers: HeaderMap) -> Self {
199 self.headers = headers;
200 self
201 }
202}
203
204#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
206pub struct FormatId(pub String);
207
208impl FormatId {
209 pub fn new(id: impl Into<String>) -> Self {
210 Self(id.into())
211 }
212}
213
214impl std::fmt::Display for FormatId {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 f.write_str(&self.0)
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn proxy_body_empty() {
226 let body = ProxyBody::Empty;
227 assert!(body.is_empty());
228 assert_eq!(body.len(), 0);
229 assert_eq!(body.as_bytes(), &[] as &[u8]);
230 assert!(body.into_bytes().is_empty());
231 }
232
233 #[test]
234 fn proxy_body_from_str() {
235 let body = ProxyBody::from("hello");
236 assert!(!body.is_empty());
237 assert_eq!(body.len(), 5);
238 assert_eq!(body.as_bytes(), b"hello");
239 }
240
241 #[test]
242 fn proxy_body_from_string() {
243 let body = ProxyBody::from("world".to_string());
244 assert_eq!(body.as_bytes(), b"world");
245 }
246
247 #[test]
248 fn proxy_body_from_vec() {
249 let body = ProxyBody::from(vec![1, 2, 3]);
250 assert_eq!(body.len(), 3);
251 assert_eq!(body.as_bytes(), &[1, 2, 3]);
252 }
253
254 #[test]
255 fn proxy_body_from_empty_bytes_is_empty_variant() {
256 let body = ProxyBody::from(Bytes::new());
257 assert!(body.is_empty());
258 assert!(matches!(body, ProxyBody::Empty));
259 }
260
261 #[test]
262 fn extensions_insert_and_get() {
263 let mut ext = Extensions::new();
264 ext.insert(42u32);
265 ext.insert("hello".to_string());
266
267 assert_eq!(ext.get::<u32>(), Some(&42));
268 assert_eq!(ext.get::<String>(), Some(&"hello".to_string()));
269 assert_eq!(ext.get::<bool>(), None);
270 }
271
272 #[test]
273 fn proxy_request_builder() {
274 let req = ProxyRequest::new(Method::GET, "/api/test".parse().unwrap())
275 .with_body("request body")
276 .with_transport(TransportKind::Http);
277
278 assert_eq!(req.method, Method::GET);
279 assert_eq!(req.uri, "/api/test");
280 assert_eq!(req.body.as_bytes(), b"request body");
281 assert_eq!(req.transport, TransportKind::Http);
282 }
283
284 #[test]
285 fn proxy_response_for_request() {
286 let req = ProxyRequest::new(Method::POST, "/submit".parse().unwrap());
287 let req_id = req.id.clone();
288 let resp = ProxyResponse::for_request(req.id, StatusCode::OK).with_body("ok");
289
290 assert_eq!(resp.id, req_id);
291 assert_eq!(resp.status, StatusCode::OK);
292 assert_eq!(resp.body.as_bytes(), b"ok");
293 }
294
295 #[test]
296 fn format_id_display() {
297 let fmt = FormatId::new("application/json");
298 assert_eq!(fmt.to_string(), "application/json");
299 }
300
301 #[test]
302 fn transport_kind_serde() {
303 let kind = TransportKind::WebSocket;
304 let json = serde_json::to_string(&kind).unwrap();
305 let deserialized: TransportKind = serde_json::from_str(&json).unwrap();
306 assert_eq!(deserialized, TransportKind::WebSocket);
307 }
308}