1use relay_core_api::flow::{BodyData, Flow, HttpResponse, Layer, ResponseTiming, WebSocketMessage};
2use relay_core_api::modification::FlowModification;
3use relay_core_lib::InterceptionResult;
4use url::Url;
5
6pub use relay_core_api::modification::{FlowQuery, FlowSummary};
8
9pub fn apply_flow_modification(
16 flow: &Flow,
17 phase: &str,
18 mods: FlowModification,
19) -> InterceptionResult {
20 if phase.starts_with("request") {
21 let mut req = match &flow.layer {
22 Layer::Http(h) => h.request.clone(),
23 Layer::WebSocket(ws) => ws.handshake_request.clone(),
24 _ => return InterceptionResult::Continue,
25 };
26
27 if let Some(m) = mods.method {
28 req.method = m;
29 }
30 if let Some(u) = mods.url {
31 if let Ok(parsed) = Url::parse(&u) {
33 req.url = parsed;
34 }
35 }
36 if let Some(h) = mods.request_headers {
37 req.headers = h.into_iter().collect();
38 }
39 if let Some(b) = mods.request_body {
40 req.body = Some(BodyData {
41 encoding: "utf-8".to_string(),
42 size: b.len() as u64,
43 content: b,
44 });
45 }
46
47 InterceptionResult::ModifiedRequest(req)
48 } else if phase.starts_with("response") {
49 let mut res = match &flow.layer {
50 Layer::Http(h) => h.response.clone().unwrap_or_else(|| HttpResponse {
51 status: 200,
52 status_text: "OK".to_string(),
53 version: "HTTP/1.1".to_string(),
54 headers: vec![],
55 body: None,
56 timing: ResponseTiming {
57 time_to_first_byte: None,
58 time_to_last_byte: None,
59 connect_time_ms: None,
60 ssl_time_ms: None,
61 },
62 cookies: vec![],
63 }),
64 Layer::WebSocket(ws) => ws.handshake_response.clone(),
65 _ => return InterceptionResult::Continue,
66 };
67
68 if let Some(s) = mods.status_code {
69 res.status = s;
70 }
71 if let Some(h) = mods.response_headers {
72 res.headers = h.into_iter().collect();
73 }
74 if let Some(b) = mods.response_body {
75 res.body = Some(BodyData {
76 encoding: "utf-8".to_string(),
77 size: b.len() as u64,
78 content: b,
79 });
80 }
81
82 InterceptionResult::ModifiedResponse(res)
83 } else {
84 InterceptionResult::Continue
85 }
86}
87
88pub fn apply_ws_modification(
93 message: &WebSocketMessage,
94 mods: FlowModification,
95) -> InterceptionResult {
96 let mut new_msg = message.clone();
97 if let Some(content) = mods.message_content {
98 new_msg.content.size = content.len() as u64;
99 new_msg.content.content = content;
100 }
101 InterceptionResult::ModifiedMessage(new_msg)
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use chrono::Utc;
108 use relay_core_api::flow::{
109 BodyData, Direction, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo,
110 ResponseTiming, TransportProtocol, WebSocketLayer, WebSocketMessage,
111 };
112 use relay_core_api::modification::FlowModification;
113 use std::collections::HashMap;
114 use url::Url;
115 use uuid::Uuid;
116
117 fn make_http_flow(url: &str) -> Flow {
118 Flow {
119 id: Uuid::new_v4(),
120 start_time: Utc::now(),
121 end_time: None,
122 network: NetworkInfo {
123 client_ip: "127.0.0.1".to_string(),
124 client_port: 12345,
125 server_ip: "1.1.1.1".to_string(),
126 server_port: 80,
127 protocol: TransportProtocol::TCP,
128 tls: false,
129 tls_version: None,
130 sni: None,
131 },
132 layer: Layer::Http(HttpLayer {
133 request: HttpRequest {
134 method: "GET".to_string(),
135 url: Url::parse(url).unwrap(),
136 version: "HTTP/1.1".to_string(),
137 headers: vec![],
138 body: None,
139 cookies: vec![],
140 query: vec![],
141 },
142 response: None,
143 error: None,
144 }),
145 tags: vec![],
146 meta: HashMap::new(),
147 resilience_trace: None,
148 rule_variables: std::collections::HashMap::new(),
149 matched_rules: vec![],
150 }
151 }
152
153 fn make_http_flow_with_response(url: &str) -> Flow {
154 let mut flow = make_http_flow(url);
155 if let Layer::Http(ref mut h) = flow.layer {
156 h.response = Some(HttpResponse {
157 status: 200,
158 status_text: "OK".to_string(),
159 version: "HTTP/1.1".to_string(),
160 headers: vec![],
161 body: None,
162 timing: ResponseTiming {
163 time_to_first_byte: None,
164 time_to_last_byte: None,
165 connect_time_ms: None,
166 ssl_time_ms: None,
167 },
168 cookies: vec![],
169 });
170 }
171 flow
172 }
173
174 fn make_ws_flow(url: &str) -> Flow {
175 Flow {
176 id: Uuid::new_v4(),
177 start_time: Utc::now(),
178 end_time: None,
179 network: NetworkInfo {
180 client_ip: "127.0.0.1".to_string(),
181 client_port: 12345,
182 server_ip: "1.1.1.1".to_string(),
183 server_port: 80,
184 protocol: TransportProtocol::TCP,
185 tls: false,
186 tls_version: None,
187 sni: None,
188 },
189 layer: Layer::WebSocket(WebSocketLayer {
190 handshake_request: HttpRequest {
191 method: "GET".to_string(),
192 url: Url::parse(url).unwrap(),
193 version: "HTTP/1.1".to_string(),
194 headers: vec![("Upgrade".to_string(), "websocket".to_string())],
195 body: None,
196 cookies: vec![],
197 query: vec![],
198 },
199 handshake_response: HttpResponse {
200 status: 101,
201 status_text: "Switching Protocols".to_string(),
202 version: "HTTP/1.1".to_string(),
203 headers: vec![],
204 body: None,
205 timing: ResponseTiming {
206 time_to_first_byte: None,
207 time_to_last_byte: None,
208 connect_time_ms: None,
209 ssl_time_ms: None,
210 },
211 cookies: vec![],
212 },
213 messages: vec![],
214 closed: false,
215 }),
216 tags: vec![],
217 meta: HashMap::new(),
218 resilience_trace: None,
219 rule_variables: std::collections::HashMap::new(),
220 matched_rules: vec![],
221 }
222 }
223
224 fn make_ws_message(content: &str) -> WebSocketMessage {
225 WebSocketMessage {
226 id: Uuid::new_v4(),
227 timestamp: Utc::now(),
228 direction: Direction::ClientToServer,
229 content: BodyData {
230 encoding: "utf-8".to_string(),
231 content: content.to_string(),
232 size: content.len() as u64,
233 },
234 opcode: "Text".to_string(),
235 }
236 }
237
238 #[test]
241 fn test_request_modification_applies_all_fields() {
242 let flow = make_http_flow("http://example.com/api");
243 let mods = FlowModification {
244 method: Some("POST".to_string()),
245 url: Some("http://example.com/v2/api".to_string()),
246 request_headers: Some(HashMap::from([("X-Custom".to_string(), "123".to_string())])),
247 request_body: Some("new-body".to_string()),
248 ..Default::default()
249 };
250
251 let result = apply_flow_modification(&flow, "request", mods);
252
253 if let InterceptionResult::ModifiedRequest(req) = result {
254 assert_eq!(req.method, "POST");
255 assert_eq!(req.url.as_str(), "http://example.com/v2/api");
256 assert!(
257 req.headers
258 .iter()
259 .any(|(k, v)| k == "X-Custom" && v == "123")
260 );
261 assert_eq!(req.body.unwrap().content, "new-body");
262 } else {
263 panic!("expected ModifiedRequest");
264 }
265 }
266
267 #[test]
268 fn test_request_modification_invalid_url_keeps_original() {
269 let flow = make_http_flow("http://example.com/api");
270 let original_url = match &flow.layer {
271 Layer::Http(h) => h.request.url.clone(),
272 _ => panic!("expected http layer"),
273 };
274 let mods = FlowModification {
275 method: Some("PUT".to_string()),
276 url: Some("://invalid-url".to_string()),
277 ..Default::default()
278 };
279
280 let result = apply_flow_modification(&flow, "request", mods);
281
282 match result {
283 InterceptionResult::ModifiedRequest(req) => {
284 assert_eq!(req.method, "PUT");
285 assert_eq!(req.url, original_url, "invalid URL should keep original");
286 }
287 other => panic!("expected ModifiedRequest, got {:?}", other),
288 }
289 }
290
291 #[test]
292 fn test_request_headers_phase_prefix_routes_to_request() {
293 let flow = make_http_flow("http://example.com/api");
294 let mods = FlowModification {
295 method: Some("PATCH".to_string()),
296 ..Default::default()
297 };
298 let result = apply_flow_modification(&flow, "request_headers", mods);
299 assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
300 }
301
302 #[test]
303 fn test_request_body_phase_prefix_routes_to_request() {
304 let flow = make_http_flow("http://example.com/api");
305 let mods = FlowModification {
306 request_body: Some("hello".to_string()),
307 ..Default::default()
308 };
309 let result = apply_flow_modification(&flow, "request_body", mods);
310 assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
311 }
312
313 #[test]
316 fn test_response_modification_applies_all_fields() {
317 let flow = make_http_flow_with_response("http://example.com/api");
318 let mods = FlowModification {
319 status_code: Some(404),
320 response_headers: Some(HashMap::from([(
321 "Content-Type".to_string(),
322 "application/json".to_string(),
323 )])),
324 response_body: Some("{\"error\": \"not found\"}".to_string()),
325 ..Default::default()
326 };
327
328 let result = apply_flow_modification(&flow, "response", mods);
329
330 if let InterceptionResult::ModifiedResponse(res) = result {
331 assert_eq!(res.status, 404);
332 assert!(
333 res.headers
334 .iter()
335 .any(|(k, v)| k == "Content-Type" && v == "application/json")
336 );
337 assert_eq!(res.body.unwrap().content, "{\"error\": \"not found\"}");
338 } else {
339 panic!("expected ModifiedResponse");
340 }
341 }
342
343 #[test]
344 fn test_response_modification_no_existing_response_uses_default() {
345 let flow = make_http_flow("http://example.com/api");
347 let mods = FlowModification {
348 status_code: Some(503),
349 ..Default::default()
350 };
351
352 let result = apply_flow_modification(&flow, "response_headers", mods);
353
354 if let InterceptionResult::ModifiedResponse(res) = result {
355 assert_eq!(res.status, 503);
356 } else {
357 panic!("expected ModifiedResponse");
358 }
359 }
360
361 #[test]
364 fn test_ws_handshake_request_modification() {
365 let flow = make_ws_flow("ws://example.com/socket");
366 let mods = FlowModification {
367 url: Some("ws://example.com/socket-v2".to_string()),
368 ..Default::default()
369 };
370
371 let result = apply_flow_modification(&flow, "request", mods);
372
373 if let InterceptionResult::ModifiedRequest(req) = result {
374 assert_eq!(req.url.as_str(), "ws://example.com/socket-v2");
375 } else {
376 panic!("expected ModifiedRequest for WebSocket handshake");
377 }
378 }
379
380 #[test]
383 fn test_unknown_phase_returns_continue() {
384 let flow = make_http_flow("http://example.com/api");
385 let mods = FlowModification::default();
386 let result = apply_flow_modification(&flow, "pre-request", mods);
387 assert!(matches!(result, InterceptionResult::Continue));
388 }
389
390 #[test]
393 fn test_ws_modification_replaces_content() {
394 let msg = make_ws_message("original");
395 let mods = FlowModification {
396 message_content: Some("modified".to_string()),
397 ..Default::default()
398 };
399
400 let result = apply_ws_modification(&msg, mods);
401
402 if let InterceptionResult::ModifiedMessage(new_msg) = result {
403 assert_eq!(new_msg.content.content, "modified");
404 assert_eq!(new_msg.content.size, 8);
405 assert_eq!(new_msg.direction, Direction::ClientToServer);
406 assert_eq!(new_msg.opcode, "Text");
407 } else {
408 panic!("expected ModifiedMessage");
409 }
410 }
411
412 #[test]
413 fn test_ws_modification_no_content_returns_original_message() {
414 let msg = make_ws_message("origin");
415 let mods = FlowModification::default();
416
417 let result = apply_ws_modification(&msg, mods);
418
419 if let InterceptionResult::ModifiedMessage(new_msg) = result {
420 assert_eq!(new_msg.content.content, "origin");
421 assert_eq!(new_msg.content.size, 6);
422 } else {
423 panic!("expected ModifiedMessage");
424 }
425 }
426}