1use relay_core_api::flow::{BodyData, Flow, HttpResponse, Layer, ResponseTiming, WebSocketMessage};
2use relay_core_api::modification::FlowModification;
3use relay_core_lib::InterceptionResult;
4use url::Url;
5
6pub use relay_core_api::modification::{FlowQuery, FlowSummary};
8
9pub fn apply_flow_modification(flow: &Flow, phase: &str, mods: FlowModification) -> InterceptionResult {
16 if phase.starts_with("request") {
17 let mut req = match &flow.layer {
18 Layer::Http(h) => h.request.clone(),
19 Layer::WebSocket(ws) => ws.handshake_request.clone(),
20 _ => return InterceptionResult::Continue,
21 };
22
23 if let Some(m) = mods.method {
24 req.method = m;
25 }
26 if let Some(u) = mods.url {
27 if let Ok(parsed) = Url::parse(&u) {
29 req.url = parsed;
30 }
31 }
32 if let Some(h) = mods.request_headers {
33 req.headers = h.into_iter().collect();
34 }
35 if let Some(b) = mods.request_body {
36 req.body = Some(BodyData {
37 encoding: "utf-8".to_string(),
38 size: b.len() as u64,
39 content: b,
40 });
41 }
42
43 InterceptionResult::ModifiedRequest(req)
44 } else if phase.starts_with("response") {
45 let mut res = match &flow.layer {
46 Layer::Http(h) => h.response.clone().unwrap_or_else(|| HttpResponse {
47 status: 200,
48 status_text: "OK".to_string(),
49 version: "HTTP/1.1".to_string(),
50 headers: vec![],
51 body: None,
52 timing: ResponseTiming { time_to_first_byte: None, time_to_last_byte: None,
53 connect_time_ms: None,
54 ssl_time_ms: None,
55 },
56 cookies: vec![],
57 }),
58 Layer::WebSocket(ws) => ws.handshake_response.clone(),
59 _ => return InterceptionResult::Continue,
60 };
61
62 if let Some(s) = mods.status_code {
63 res.status = s;
64 }
65 if let Some(h) = mods.response_headers {
66 res.headers = h.into_iter().collect();
67 }
68 if let Some(b) = mods.response_body {
69 res.body = Some(BodyData {
70 encoding: "utf-8".to_string(),
71 size: b.len() as u64,
72 content: b,
73 });
74 }
75
76 InterceptionResult::ModifiedResponse(res)
77 } else {
78 InterceptionResult::Continue
79 }
80}
81
82pub fn apply_ws_modification(message: &WebSocketMessage, mods: FlowModification) -> InterceptionResult {
87 let mut new_msg = message.clone();
88 if let Some(content) = mods.message_content {
89 new_msg.content.size = content.len() as u64;
90 new_msg.content.content = content;
91 }
92 InterceptionResult::ModifiedMessage(new_msg)
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use chrono::Utc;
99 use relay_core_api::flow::{
100 BodyData, Direction, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo,
101 ResponseTiming, TransportProtocol, WebSocketLayer, WebSocketMessage,
102 };
103 use relay_core_api::modification::FlowModification;
104 use std::collections::HashMap;
105 use url::Url;
106 use uuid::Uuid;
107
108 fn make_http_flow(url: &str) -> Flow {
109 Flow {
110 id: Uuid::new_v4(),
111 start_time: Utc::now(),
112 end_time: None,
113 network: NetworkInfo {
114 client_ip: "127.0.0.1".to_string(),
115 client_port: 12345,
116 server_ip: "1.1.1.1".to_string(),
117 server_port: 80,
118 protocol: TransportProtocol::TCP,
119 tls: false,
120 tls_version: None,
121 sni: None,
122 },
123 layer: Layer::Http(HttpLayer {
124 request: HttpRequest {
125 method: "GET".to_string(),
126 url: Url::parse(url).unwrap(),
127 version: "HTTP/1.1".to_string(),
128 headers: vec![],
129 body: None,
130 cookies: vec![],
131 query: vec![],
132 },
133 response: None,
134 error: None,
135 }),
136 tags: vec![],
137 meta: HashMap::new(),
138 }
139 }
140
141 fn make_http_flow_with_response(url: &str) -> Flow {
142 let mut flow = make_http_flow(url);
143 if let Layer::Http(ref mut h) = flow.layer {
144 h.response = Some(HttpResponse {
145 status: 200,
146 status_text: "OK".to_string(),
147 version: "HTTP/1.1".to_string(),
148 headers: vec![],
149 body: None,
150 timing: ResponseTiming { time_to_first_byte: None, time_to_last_byte: None,
151 connect_time_ms: None,
152 ssl_time_ms: None,
153 },
154 cookies: vec![],
155 });
156 }
157 flow
158 }
159
160 fn make_ws_flow(url: &str) -> Flow {
161 Flow {
162 id: Uuid::new_v4(),
163 start_time: Utc::now(),
164 end_time: None,
165 network: NetworkInfo {
166 client_ip: "127.0.0.1".to_string(),
167 client_port: 12345,
168 server_ip: "1.1.1.1".to_string(),
169 server_port: 80,
170 protocol: TransportProtocol::TCP,
171 tls: false,
172 tls_version: None,
173 sni: None,
174 },
175 layer: Layer::WebSocket(WebSocketLayer {
176 handshake_request: HttpRequest {
177 method: "GET".to_string(),
178 url: Url::parse(url).unwrap(),
179 version: "HTTP/1.1".to_string(),
180 headers: vec![("Upgrade".to_string(), "websocket".to_string())],
181 body: None,
182 cookies: vec![],
183 query: vec![],
184 },
185 handshake_response: HttpResponse {
186 status: 101,
187 status_text: "Switching Protocols".to_string(),
188 version: "HTTP/1.1".to_string(),
189 headers: vec![],
190 body: None,
191 timing: ResponseTiming { time_to_first_byte: None, time_to_last_byte: None,
192 connect_time_ms: None,
193 ssl_time_ms: None,
194 },
195 cookies: vec![],
196 },
197 messages: vec![],
198 closed: false,
199 }),
200 tags: vec![],
201 meta: HashMap::new(),
202 }
203 }
204
205 fn make_ws_message(content: &str) -> WebSocketMessage {
206 WebSocketMessage {
207 id: Uuid::new_v4(),
208 timestamp: Utc::now(),
209 direction: Direction::ClientToServer,
210 content: BodyData {
211 encoding: "utf-8".to_string(),
212 content: content.to_string(),
213 size: content.len() as u64,
214 },
215 opcode: "Text".to_string(),
216 }
217 }
218
219 #[test]
222 fn test_request_modification_applies_all_fields() {
223 let flow = make_http_flow("http://example.com/api");
224 let mods = FlowModification {
225 method: Some("POST".to_string()),
226 url: Some("http://example.com/v2/api".to_string()),
227 request_headers: Some(HashMap::from([("X-Custom".to_string(), "123".to_string())])),
228 request_body: Some("new-body".to_string()),
229 ..Default::default()
230 };
231
232 let result = apply_flow_modification(&flow, "request", mods);
233
234 if let InterceptionResult::ModifiedRequest(req) = result {
235 assert_eq!(req.method, "POST");
236 assert_eq!(req.url.as_str(), "http://example.com/v2/api");
237 assert!(req.headers.iter().any(|(k, v)| k == "X-Custom" && v == "123"));
238 assert_eq!(req.body.unwrap().content, "new-body");
239 } else {
240 panic!("expected ModifiedRequest");
241 }
242 }
243
244 #[test]
245 fn test_request_modification_invalid_url_keeps_original() {
246 let flow = make_http_flow("http://example.com/api");
247 let original_url = match &flow.layer {
248 Layer::Http(h) => h.request.url.clone(),
249 _ => panic!("expected http layer"),
250 };
251 let mods = FlowModification {
252 method: Some("PUT".to_string()),
253 url: Some("://invalid-url".to_string()),
254 ..Default::default()
255 };
256
257 let result = apply_flow_modification(&flow, "request", mods);
258
259 match result {
260 InterceptionResult::ModifiedRequest(req) => {
261 assert_eq!(req.method, "PUT");
262 assert_eq!(req.url, original_url, "invalid URL should keep original");
263 }
264 other => panic!("expected ModifiedRequest, got {:?}", other),
265 }
266 }
267
268 #[test]
269 fn test_request_headers_phase_prefix_routes_to_request() {
270 let flow = make_http_flow("http://example.com/api");
271 let mods = FlowModification {
272 method: Some("PATCH".to_string()),
273 ..Default::default()
274 };
275 let result = apply_flow_modification(&flow, "request_headers", mods);
276 assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
277 }
278
279 #[test]
280 fn test_request_body_phase_prefix_routes_to_request() {
281 let flow = make_http_flow("http://example.com/api");
282 let mods = FlowModification {
283 request_body: Some("hello".to_string()),
284 ..Default::default()
285 };
286 let result = apply_flow_modification(&flow, "request_body", mods);
287 assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
288 }
289
290 #[test]
293 fn test_response_modification_applies_all_fields() {
294 let flow = make_http_flow_with_response("http://example.com/api");
295 let mods = FlowModification {
296 status_code: Some(404),
297 response_headers: Some(HashMap::from([("Content-Type".to_string(), "application/json".to_string())])),
298 response_body: Some("{\"error\": \"not found\"}".to_string()),
299 ..Default::default()
300 };
301
302 let result = apply_flow_modification(&flow, "response", mods);
303
304 if let InterceptionResult::ModifiedResponse(res) = result {
305 assert_eq!(res.status, 404);
306 assert!(res.headers.iter().any(|(k, v)| k == "Content-Type" && v == "application/json"));
307 assert_eq!(res.body.unwrap().content, "{\"error\": \"not found\"}");
308 } else {
309 panic!("expected ModifiedResponse");
310 }
311 }
312
313 #[test]
314 fn test_response_modification_no_existing_response_uses_default() {
315 let flow = make_http_flow("http://example.com/api");
317 let mods = FlowModification {
318 status_code: Some(503),
319 ..Default::default()
320 };
321
322 let result = apply_flow_modification(&flow, "response_headers", mods);
323
324 if let InterceptionResult::ModifiedResponse(res) = result {
325 assert_eq!(res.status, 503);
326 } else {
327 panic!("expected ModifiedResponse");
328 }
329 }
330
331 #[test]
334 fn test_ws_handshake_request_modification() {
335 let flow = make_ws_flow("ws://example.com/socket");
336 let mods = FlowModification {
337 url: Some("ws://example.com/socket-v2".to_string()),
338 ..Default::default()
339 };
340
341 let result = apply_flow_modification(&flow, "request", mods);
342
343 if let InterceptionResult::ModifiedRequest(req) = result {
344 assert_eq!(req.url.as_str(), "ws://example.com/socket-v2");
345 } else {
346 panic!("expected ModifiedRequest for WebSocket handshake");
347 }
348 }
349
350 #[test]
353 fn test_unknown_phase_returns_continue() {
354 let flow = make_http_flow("http://example.com/api");
355 let mods = FlowModification::default();
356 let result = apply_flow_modification(&flow, "pre-request", mods);
357 assert!(matches!(result, InterceptionResult::Continue));
358 }
359
360 #[test]
363 fn test_ws_modification_replaces_content() {
364 let msg = make_ws_message("original");
365 let mods = FlowModification {
366 message_content: Some("modified".to_string()),
367 ..Default::default()
368 };
369
370 let result = apply_ws_modification(&msg, mods);
371
372 if let InterceptionResult::ModifiedMessage(new_msg) = result {
373 assert_eq!(new_msg.content.content, "modified");
374 assert_eq!(new_msg.content.size, 8);
375 assert_eq!(new_msg.direction, Direction::ClientToServer);
376 assert_eq!(new_msg.opcode, "Text");
377 } else {
378 panic!("expected ModifiedMessage");
379 }
380 }
381
382 #[test]
383 fn test_ws_modification_no_content_returns_original_message() {
384 let msg = make_ws_message("origin");
385 let mods = FlowModification::default();
386
387 let result = apply_ws_modification(&msg, mods);
388
389 if let InterceptionResult::ModifiedMessage(new_msg) = result {
390 assert_eq!(new_msg.content.content, "origin");
391 assert_eq!(new_msg.content.size, 6);
392 } else {
393 panic!("expected ModifiedMessage");
394 }
395 }
396}