1use relay_core_api::flow::{BodyData, Flow, HttpResponse, Layer, ResponseTiming, WebSocketMessage};
2use relay_core_api::modification::FlowModification;
3use relay_core_lib::InterceptionResult;
4use url::Url;
5
6pub use relay_core_api::modification::{FlowQuery, FlowSummary};
8
9pub fn apply_flow_modification(flow: &Flow, phase: &str, mods: FlowModification) -> InterceptionResult {
16 if phase.starts_with("request") {
17 let mut req = match &flow.layer {
18 Layer::Http(h) => h.request.clone(),
19 Layer::WebSocket(ws) => ws.handshake_request.clone(),
20 _ => return InterceptionResult::Continue,
21 };
22
23 if let Some(m) = mods.method {
24 req.method = m;
25 }
26 if let Some(u) = mods.url {
27 if let Ok(parsed) = Url::parse(&u) {
29 req.url = parsed;
30 }
31 }
32 if let Some(h) = mods.request_headers {
33 req.headers = h.into_iter().collect();
34 }
35 if let Some(b) = mods.request_body {
36 req.body = Some(BodyData {
37 encoding: "utf-8".to_string(),
38 size: b.len() as u64,
39 content: b,
40 });
41 }
42
43 InterceptionResult::ModifiedRequest(req)
44 } else if phase.starts_with("response") {
45 let mut res = match &flow.layer {
46 Layer::Http(h) => h.response.clone().unwrap_or_else(|| HttpResponse {
47 status: 200,
48 status_text: "OK".to_string(),
49 version: "HTTP/1.1".to_string(),
50 headers: vec![],
51 body: None,
52 timing: ResponseTiming { time_to_first_byte: None, time_to_last_byte: None },
53 cookies: vec![],
54 }),
55 Layer::WebSocket(ws) => ws.handshake_response.clone(),
56 _ => return InterceptionResult::Continue,
57 };
58
59 if let Some(s) = mods.status_code {
60 res.status = s;
61 }
62 if let Some(h) = mods.response_headers {
63 res.headers = h.into_iter().collect();
64 }
65 if let Some(b) = mods.response_body {
66 res.body = Some(BodyData {
67 encoding: "utf-8".to_string(),
68 size: b.len() as u64,
69 content: b,
70 });
71 }
72
73 InterceptionResult::ModifiedResponse(res)
74 } else {
75 InterceptionResult::Continue
76 }
77}
78
79pub fn apply_ws_modification(message: &WebSocketMessage, mods: FlowModification) -> InterceptionResult {
84 let mut new_msg = message.clone();
85 if let Some(content) = mods.message_content {
86 new_msg.content.size = content.len() as u64;
87 new_msg.content.content = content;
88 }
89 InterceptionResult::ModifiedMessage(new_msg)
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95 use chrono::Utc;
96 use relay_core_api::flow::{
97 BodyData, Direction, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo,
98 ResponseTiming, TransportProtocol, WebSocketLayer, WebSocketMessage,
99 };
100 use relay_core_api::modification::FlowModification;
101 use std::collections::HashMap;
102 use url::Url;
103 use uuid::Uuid;
104
105 fn make_http_flow(url: &str) -> Flow {
106 Flow {
107 id: Uuid::new_v4(),
108 start_time: Utc::now(),
109 end_time: None,
110 network: NetworkInfo {
111 client_ip: "127.0.0.1".to_string(),
112 client_port: 12345,
113 server_ip: "1.1.1.1".to_string(),
114 server_port: 80,
115 protocol: TransportProtocol::TCP,
116 tls: false,
117 tls_version: None,
118 sni: None,
119 },
120 layer: Layer::Http(HttpLayer {
121 request: HttpRequest {
122 method: "GET".to_string(),
123 url: Url::parse(url).unwrap(),
124 version: "HTTP/1.1".to_string(),
125 headers: vec![],
126 body: None,
127 cookies: vec![],
128 query: vec![],
129 },
130 response: None,
131 error: None,
132 }),
133 tags: vec![],
134 meta: HashMap::new(),
135 }
136 }
137
138 fn make_http_flow_with_response(url: &str) -> Flow {
139 let mut flow = make_http_flow(url);
140 if let Layer::Http(ref mut h) = flow.layer {
141 h.response = Some(HttpResponse {
142 status: 200,
143 status_text: "OK".to_string(),
144 version: "HTTP/1.1".to_string(),
145 headers: vec![],
146 body: None,
147 timing: ResponseTiming { time_to_first_byte: None, time_to_last_byte: None },
148 cookies: vec![],
149 });
150 }
151 flow
152 }
153
154 fn make_ws_flow(url: &str) -> Flow {
155 Flow {
156 id: Uuid::new_v4(),
157 start_time: Utc::now(),
158 end_time: None,
159 network: NetworkInfo {
160 client_ip: "127.0.0.1".to_string(),
161 client_port: 12345,
162 server_ip: "1.1.1.1".to_string(),
163 server_port: 80,
164 protocol: TransportProtocol::TCP,
165 tls: false,
166 tls_version: None,
167 sni: None,
168 },
169 layer: Layer::WebSocket(WebSocketLayer {
170 handshake_request: HttpRequest {
171 method: "GET".to_string(),
172 url: Url::parse(url).unwrap(),
173 version: "HTTP/1.1".to_string(),
174 headers: vec![("Upgrade".to_string(), "websocket".to_string())],
175 body: None,
176 cookies: vec![],
177 query: vec![],
178 },
179 handshake_response: HttpResponse {
180 status: 101,
181 status_text: "Switching Protocols".to_string(),
182 version: "HTTP/1.1".to_string(),
183 headers: vec![],
184 body: None,
185 timing: ResponseTiming { time_to_first_byte: None, time_to_last_byte: None },
186 cookies: vec![],
187 },
188 messages: vec![],
189 closed: false,
190 }),
191 tags: vec![],
192 meta: HashMap::new(),
193 }
194 }
195
196 fn make_ws_message(content: &str) -> WebSocketMessage {
197 WebSocketMessage {
198 id: Uuid::new_v4(),
199 timestamp: Utc::now(),
200 direction: Direction::ClientToServer,
201 content: BodyData {
202 encoding: "utf-8".to_string(),
203 content: content.to_string(),
204 size: content.len() as u64,
205 },
206 opcode: "Text".to_string(),
207 }
208 }
209
210 #[test]
213 fn test_request_modification_applies_all_fields() {
214 let flow = make_http_flow("http://example.com/api");
215 let mods = FlowModification {
216 method: Some("POST".to_string()),
217 url: Some("http://example.com/v2/api".to_string()),
218 request_headers: Some(HashMap::from([("X-Custom".to_string(), "123".to_string())])),
219 request_body: Some("new-body".to_string()),
220 ..Default::default()
221 };
222
223 let result = apply_flow_modification(&flow, "request", mods);
224
225 if let InterceptionResult::ModifiedRequest(req) = result {
226 assert_eq!(req.method, "POST");
227 assert_eq!(req.url.as_str(), "http://example.com/v2/api");
228 assert!(req.headers.iter().any(|(k, v)| k == "X-Custom" && v == "123"));
229 assert_eq!(req.body.unwrap().content, "new-body");
230 } else {
231 panic!("expected ModifiedRequest");
232 }
233 }
234
235 #[test]
236 fn test_request_modification_invalid_url_keeps_original() {
237 let flow = make_http_flow("http://example.com/api");
238 let original_url = match &flow.layer {
239 Layer::Http(h) => h.request.url.clone(),
240 _ => panic!("expected http layer"),
241 };
242 let mods = FlowModification {
243 method: Some("PUT".to_string()),
244 url: Some("://invalid-url".to_string()),
245 ..Default::default()
246 };
247
248 let result = apply_flow_modification(&flow, "request", mods);
249
250 match result {
251 InterceptionResult::ModifiedRequest(req) => {
252 assert_eq!(req.method, "PUT");
253 assert_eq!(req.url, original_url, "invalid URL should keep original");
254 }
255 other => panic!("expected ModifiedRequest, got {:?}", other),
256 }
257 }
258
259 #[test]
260 fn test_request_headers_phase_prefix_routes_to_request() {
261 let flow = make_http_flow("http://example.com/api");
262 let mods = FlowModification {
263 method: Some("PATCH".to_string()),
264 ..Default::default()
265 };
266 let result = apply_flow_modification(&flow, "request_headers", mods);
267 assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
268 }
269
270 #[test]
271 fn test_request_body_phase_prefix_routes_to_request() {
272 let flow = make_http_flow("http://example.com/api");
273 let mods = FlowModification {
274 request_body: Some("hello".to_string()),
275 ..Default::default()
276 };
277 let result = apply_flow_modification(&flow, "request_body", mods);
278 assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
279 }
280
281 #[test]
284 fn test_response_modification_applies_all_fields() {
285 let flow = make_http_flow_with_response("http://example.com/api");
286 let mods = FlowModification {
287 status_code: Some(404),
288 response_headers: Some(HashMap::from([("Content-Type".to_string(), "application/json".to_string())])),
289 response_body: Some("{\"error\": \"not found\"}".to_string()),
290 ..Default::default()
291 };
292
293 let result = apply_flow_modification(&flow, "response", mods);
294
295 if let InterceptionResult::ModifiedResponse(res) = result {
296 assert_eq!(res.status, 404);
297 assert!(res.headers.iter().any(|(k, v)| k == "Content-Type" && v == "application/json"));
298 assert_eq!(res.body.unwrap().content, "{\"error\": \"not found\"}");
299 } else {
300 panic!("expected ModifiedResponse");
301 }
302 }
303
304 #[test]
305 fn test_response_modification_no_existing_response_uses_default() {
306 let flow = make_http_flow("http://example.com/api");
308 let mods = FlowModification {
309 status_code: Some(503),
310 ..Default::default()
311 };
312
313 let result = apply_flow_modification(&flow, "response_headers", mods);
314
315 if let InterceptionResult::ModifiedResponse(res) = result {
316 assert_eq!(res.status, 503);
317 } else {
318 panic!("expected ModifiedResponse");
319 }
320 }
321
322 #[test]
325 fn test_ws_handshake_request_modification() {
326 let flow = make_ws_flow("ws://example.com/socket");
327 let mods = FlowModification {
328 url: Some("ws://example.com/socket-v2".to_string()),
329 ..Default::default()
330 };
331
332 let result = apply_flow_modification(&flow, "request", mods);
333
334 if let InterceptionResult::ModifiedRequest(req) = result {
335 assert_eq!(req.url.as_str(), "ws://example.com/socket-v2");
336 } else {
337 panic!("expected ModifiedRequest for WebSocket handshake");
338 }
339 }
340
341 #[test]
344 fn test_unknown_phase_returns_continue() {
345 let flow = make_http_flow("http://example.com/api");
346 let mods = FlowModification::default();
347 let result = apply_flow_modification(&flow, "pre-request", mods);
348 assert!(matches!(result, InterceptionResult::Continue));
349 }
350
351 #[test]
354 fn test_ws_modification_replaces_content() {
355 let msg = make_ws_message("original");
356 let mods = FlowModification {
357 message_content: Some("modified".to_string()),
358 ..Default::default()
359 };
360
361 let result = apply_ws_modification(&msg, mods);
362
363 if let InterceptionResult::ModifiedMessage(new_msg) = result {
364 assert_eq!(new_msg.content.content, "modified");
365 assert_eq!(new_msg.content.size, 8);
366 assert_eq!(new_msg.direction, Direction::ClientToServer);
367 assert_eq!(new_msg.opcode, "Text");
368 } else {
369 panic!("expected ModifiedMessage");
370 }
371 }
372
373 #[test]
374 fn test_ws_modification_no_content_returns_original_message() {
375 let msg = make_ws_message("origin");
376 let mods = FlowModification::default();
377
378 let result = apply_ws_modification(&msg, mods);
379
380 if let InterceptionResult::ModifiedMessage(new_msg) = result {
381 assert_eq!(new_msg.content.content, "origin");
382 assert_eq!(new_msg.content.size, 6);
383 } else {
384 panic!("expected ModifiedMessage");
385 }
386 }
387}