1use relay_core_api::flow::{BodyData, Flow, HttpResponse, Layer, ResponseTiming, WebSocketMessage};
2use relay_core_api::modification::FlowModification;
3use relay_core_lib::InterceptionResult;
4use url::Url;
5
6pub use relay_core_api::modification::{FlowQuery, FlowSummary};
8
9pub fn apply_flow_modification(
16 flow: &Flow,
17 phase: &str,
18 mods: FlowModification,
19) -> InterceptionResult {
20 if phase.starts_with("request") {
21 let mut req = match &flow.layer {
22 Layer::Http(h) => h.request.clone(),
23 Layer::WebSocket(ws) => ws.handshake_request.clone(),
24 _ => return InterceptionResult::Continue,
25 };
26
27 if let Some(m) = mods.method {
28 req.method = m;
29 }
30 if let Some(u) = mods.url {
31 if let Ok(parsed) = Url::parse(&u) {
33 req.url = parsed;
34 }
35 }
36 if let Some(h) = mods.request_headers {
37 req.headers = h.into_iter().collect();
38 }
39 if let Some(b) = mods.request_body {
40 req.body = Some(BodyData {
41 encoding: "utf-8".to_string(),
42 size: b.len() as u64,
43 content: b,
44 });
45 }
46
47 InterceptionResult::ModifiedRequest(req)
48 } else if phase.starts_with("response") {
49 let mut res = match &flow.layer {
50 Layer::Http(h) => h.response.clone().unwrap_or_else(|| HttpResponse {
51 status: 200,
52 status_text: "OK".to_string(),
53 version: "HTTP/1.1".to_string(),
54 headers: vec![],
55 body: None,
56 timing: ResponseTiming {
57 time_to_first_byte: None,
58 time_to_last_byte: None,
59 connect_time_ms: None,
60 ssl_time_ms: None,
61 },
62 cookies: vec![],
63 }),
64 Layer::WebSocket(ws) => ws.handshake_response.clone(),
65 _ => return InterceptionResult::Continue,
66 };
67
68 if let Some(s) = mods.status_code {
69 res.status = s;
70 }
71 if let Some(h) = mods.response_headers {
72 res.headers = h.into_iter().collect();
73 }
74 if let Some(b) = mods.response_body {
75 res.body = Some(BodyData {
76 encoding: "utf-8".to_string(),
77 size: b.len() as u64,
78 content: b,
79 });
80 }
81
82 InterceptionResult::ModifiedResponse(res)
83 } else {
84 InterceptionResult::Continue
85 }
86}
87
88pub fn apply_ws_modification(
93 message: &WebSocketMessage,
94 mods: FlowModification,
95) -> InterceptionResult {
96 let mut new_msg = message.clone();
97 if let Some(content) = mods.message_content {
98 new_msg.content.size = content.len() as u64;
99 new_msg.content.content = content;
100 }
101 InterceptionResult::ModifiedMessage(new_msg)
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use chrono::Utc;
108 use relay_core_api::flow::{
109 BodyData, Direction, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo,
110 ResponseTiming, TransportProtocol, WebSocketLayer, WebSocketMessage,
111 };
112 use relay_core_api::modification::FlowModification;
113 use std::collections::HashMap;
114 use url::Url;
115 use uuid::Uuid;
116
117 fn make_http_flow(url: &str) -> Flow {
118 Flow {
119 id: Uuid::new_v4(),
120 start_time: Utc::now(),
121 end_time: None,
122 network: NetworkInfo {
123 client_ip: "127.0.0.1".to_string(),
124 client_port: 12345,
125 server_ip: "1.1.1.1".to_string(),
126 server_port: 80,
127 protocol: TransportProtocol::TCP,
128 tls: false,
129 tls_version: None,
130 sni: None,
131 },
132 layer: Layer::Http(HttpLayer {
133 request: HttpRequest {
134 method: "GET".to_string(),
135 url: Url::parse(url).unwrap(),
136 version: "HTTP/1.1".to_string(),
137 headers: vec![],
138 body: None,
139 cookies: vec![],
140 query: vec![],
141 },
142 response: None,
143 error: None,
144 }),
145 tags: vec![],
146 meta: HashMap::new(),
147 }
148 }
149
150 fn make_http_flow_with_response(url: &str) -> Flow {
151 let mut flow = make_http_flow(url);
152 if let Layer::Http(ref mut h) = flow.layer {
153 h.response = Some(HttpResponse {
154 status: 200,
155 status_text: "OK".to_string(),
156 version: "HTTP/1.1".to_string(),
157 headers: vec![],
158 body: None,
159 timing: ResponseTiming {
160 time_to_first_byte: None,
161 time_to_last_byte: None,
162 connect_time_ms: None,
163 ssl_time_ms: None,
164 },
165 cookies: vec![],
166 });
167 }
168 flow
169 }
170
171 fn make_ws_flow(url: &str) -> Flow {
172 Flow {
173 id: Uuid::new_v4(),
174 start_time: Utc::now(),
175 end_time: None,
176 network: NetworkInfo {
177 client_ip: "127.0.0.1".to_string(),
178 client_port: 12345,
179 server_ip: "1.1.1.1".to_string(),
180 server_port: 80,
181 protocol: TransportProtocol::TCP,
182 tls: false,
183 tls_version: None,
184 sni: None,
185 },
186 layer: Layer::WebSocket(WebSocketLayer {
187 handshake_request: HttpRequest {
188 method: "GET".to_string(),
189 url: Url::parse(url).unwrap(),
190 version: "HTTP/1.1".to_string(),
191 headers: vec![("Upgrade".to_string(), "websocket".to_string())],
192 body: None,
193 cookies: vec![],
194 query: vec![],
195 },
196 handshake_response: HttpResponse {
197 status: 101,
198 status_text: "Switching Protocols".to_string(),
199 version: "HTTP/1.1".to_string(),
200 headers: vec![],
201 body: None,
202 timing: ResponseTiming {
203 time_to_first_byte: None,
204 time_to_last_byte: None,
205 connect_time_ms: None,
206 ssl_time_ms: None,
207 },
208 cookies: vec![],
209 },
210 messages: vec![],
211 closed: false,
212 }),
213 tags: vec![],
214 meta: HashMap::new(),
215 }
216 }
217
218 fn make_ws_message(content: &str) -> WebSocketMessage {
219 WebSocketMessage {
220 id: Uuid::new_v4(),
221 timestamp: Utc::now(),
222 direction: Direction::ClientToServer,
223 content: BodyData {
224 encoding: "utf-8".to_string(),
225 content: content.to_string(),
226 size: content.len() as u64,
227 },
228 opcode: "Text".to_string(),
229 }
230 }
231
232 #[test]
235 fn test_request_modification_applies_all_fields() {
236 let flow = make_http_flow("http://example.com/api");
237 let mods = FlowModification {
238 method: Some("POST".to_string()),
239 url: Some("http://example.com/v2/api".to_string()),
240 request_headers: Some(HashMap::from([("X-Custom".to_string(), "123".to_string())])),
241 request_body: Some("new-body".to_string()),
242 ..Default::default()
243 };
244
245 let result = apply_flow_modification(&flow, "request", mods);
246
247 if let InterceptionResult::ModifiedRequest(req) = result {
248 assert_eq!(req.method, "POST");
249 assert_eq!(req.url.as_str(), "http://example.com/v2/api");
250 assert!(
251 req.headers
252 .iter()
253 .any(|(k, v)| k == "X-Custom" && v == "123")
254 );
255 assert_eq!(req.body.unwrap().content, "new-body");
256 } else {
257 panic!("expected ModifiedRequest");
258 }
259 }
260
261 #[test]
262 fn test_request_modification_invalid_url_keeps_original() {
263 let flow = make_http_flow("http://example.com/api");
264 let original_url = match &flow.layer {
265 Layer::Http(h) => h.request.url.clone(),
266 _ => panic!("expected http layer"),
267 };
268 let mods = FlowModification {
269 method: Some("PUT".to_string()),
270 url: Some("://invalid-url".to_string()),
271 ..Default::default()
272 };
273
274 let result = apply_flow_modification(&flow, "request", mods);
275
276 match result {
277 InterceptionResult::ModifiedRequest(req) => {
278 assert_eq!(req.method, "PUT");
279 assert_eq!(req.url, original_url, "invalid URL should keep original");
280 }
281 other => panic!("expected ModifiedRequest, got {:?}", other),
282 }
283 }
284
285 #[test]
286 fn test_request_headers_phase_prefix_routes_to_request() {
287 let flow = make_http_flow("http://example.com/api");
288 let mods = FlowModification {
289 method: Some("PATCH".to_string()),
290 ..Default::default()
291 };
292 let result = apply_flow_modification(&flow, "request_headers", mods);
293 assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
294 }
295
296 #[test]
297 fn test_request_body_phase_prefix_routes_to_request() {
298 let flow = make_http_flow("http://example.com/api");
299 let mods = FlowModification {
300 request_body: Some("hello".to_string()),
301 ..Default::default()
302 };
303 let result = apply_flow_modification(&flow, "request_body", mods);
304 assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
305 }
306
307 #[test]
310 fn test_response_modification_applies_all_fields() {
311 let flow = make_http_flow_with_response("http://example.com/api");
312 let mods = FlowModification {
313 status_code: Some(404),
314 response_headers: Some(HashMap::from([(
315 "Content-Type".to_string(),
316 "application/json".to_string(),
317 )])),
318 response_body: Some("{\"error\": \"not found\"}".to_string()),
319 ..Default::default()
320 };
321
322 let result = apply_flow_modification(&flow, "response", mods);
323
324 if let InterceptionResult::ModifiedResponse(res) = result {
325 assert_eq!(res.status, 404);
326 assert!(
327 res.headers
328 .iter()
329 .any(|(k, v)| k == "Content-Type" && v == "application/json")
330 );
331 assert_eq!(res.body.unwrap().content, "{\"error\": \"not found\"}");
332 } else {
333 panic!("expected ModifiedResponse");
334 }
335 }
336
337 #[test]
338 fn test_response_modification_no_existing_response_uses_default() {
339 let flow = make_http_flow("http://example.com/api");
341 let mods = FlowModification {
342 status_code: Some(503),
343 ..Default::default()
344 };
345
346 let result = apply_flow_modification(&flow, "response_headers", mods);
347
348 if let InterceptionResult::ModifiedResponse(res) = result {
349 assert_eq!(res.status, 503);
350 } else {
351 panic!("expected ModifiedResponse");
352 }
353 }
354
355 #[test]
358 fn test_ws_handshake_request_modification() {
359 let flow = make_ws_flow("ws://example.com/socket");
360 let mods = FlowModification {
361 url: Some("ws://example.com/socket-v2".to_string()),
362 ..Default::default()
363 };
364
365 let result = apply_flow_modification(&flow, "request", mods);
366
367 if let InterceptionResult::ModifiedRequest(req) = result {
368 assert_eq!(req.url.as_str(), "ws://example.com/socket-v2");
369 } else {
370 panic!("expected ModifiedRequest for WebSocket handshake");
371 }
372 }
373
374 #[test]
377 fn test_unknown_phase_returns_continue() {
378 let flow = make_http_flow("http://example.com/api");
379 let mods = FlowModification::default();
380 let result = apply_flow_modification(&flow, "pre-request", mods);
381 assert!(matches!(result, InterceptionResult::Continue));
382 }
383
384 #[test]
387 fn test_ws_modification_replaces_content() {
388 let msg = make_ws_message("original");
389 let mods = FlowModification {
390 message_content: Some("modified".to_string()),
391 ..Default::default()
392 };
393
394 let result = apply_ws_modification(&msg, mods);
395
396 if let InterceptionResult::ModifiedMessage(new_msg) = result {
397 assert_eq!(new_msg.content.content, "modified");
398 assert_eq!(new_msg.content.size, 8);
399 assert_eq!(new_msg.direction, Direction::ClientToServer);
400 assert_eq!(new_msg.opcode, "Text");
401 } else {
402 panic!("expected ModifiedMessage");
403 }
404 }
405
406 #[test]
407 fn test_ws_modification_no_content_returns_original_message() {
408 let msg = make_ws_message("origin");
409 let mods = FlowModification::default();
410
411 let result = apply_ws_modification(&msg, mods);
412
413 if let InterceptionResult::ModifiedMessage(new_msg) = result {
414 assert_eq!(new_msg.content.content, "origin");
415 assert_eq!(new_msg.content.size, 6);
416 } else {
417 panic!("expected ModifiedMessage");
418 }
419 }
420}