Skip to main content

relay_core_script/
lib.rs

1//! Internal Deno/V8 scripting engine for [relay-core](https://crates.io/crates/relay-core).
2//! Provides the `ScriptInterceptor` that implements runtime script modification of traffic.
3//!
4//! **This is a feature backend for `relay-core`.** Enable with `relay-core = { features = ["script"] }`.
5
6pub mod engine_trait;
7pub mod deno_engine;
8pub mod streams;
9
10use relay_core_lib::interceptor::{Interceptor, InterceptionResult, RequestAction, ResponseAction, WebSocketMessageAction, HttpBody, BoxError};
11use relay_core_api::flow::{Flow, WebSocketMessage};
12use crate::deno_engine::DenoScriptEngine;
13use crate::engine_trait::ScriptEngineTrait;
14use async_trait::async_trait;
15use tokio::sync::RwLock;
16use std::hash::{Hash, Hasher};
17use std::collections::hash_map::DefaultHasher;
18
19pub struct ScriptInterceptor {
20    engines: Vec<RwLock<Box<dyn ScriptEngineTrait>>>,
21}
22
23impl ScriptInterceptor {
24    pub async fn new() -> Result<Self, BoxError> {
25        let pool_size = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4);
26        let mut engines = Vec::with_capacity(pool_size);
27
28        for _ in 0..pool_size {
29            let engine: Box<dyn ScriptEngineTrait> = Box::new(DenoScriptEngine::new());
30            engines.push(RwLock::new(engine));
31        }
32
33        Ok(Self { 
34            engines,
35        })
36    }
37
38    pub async fn load_script(&self, script: &str) -> Result<(), BoxError> {
39        // Load script into ALL engines to keep them consistent.
40        // Optimization: Load into new engines first to avoid blocking request processing,
41        // then swap them in quickly.
42        let pool_size = self.engines.len();
43        let mut new_engines = Vec::with_capacity(pool_size);
44
45        for _ in 0..pool_size {
46            let mut engine = DenoScriptEngine::new();
47            engine.load_script(script).await?;
48            new_engines.push(Box::new(engine) as Box<dyn ScriptEngineTrait>);
49        }
50
51        // We acquire write locks only for the swap.
52        // The new engines are already prepared, so the critical section is very short.
53        let mut new_engines_iter = new_engines.into_iter();
54        for engine_lock in &self.engines {
55             if let Some(new_engine) = new_engines_iter.next() {
56                 let mut guard = engine_lock.write().await;
57                 *guard = new_engine;
58             }
59        }
60        
61        Ok(())
62    }
63
64    fn get_engine_index(&self) -> usize {
65        // Optimization: Allow task-local override for engine index (e.g. for testing or specific routing)
66        if let Ok(index) = relay_core_lib::interceptor::ENGINE_INDEX.try_with(|i| *i) {
67            return index % self.engines.len();
68        }
69
70        let thread_id = std::thread::current().id();
71        let mut hasher = DefaultHasher::new();
72        thread_id.hash(&mut hasher);
73        (hasher.finish() as usize) % self.engines.len()
74    }
75}
76
77#[async_trait]
78impl Interceptor for ScriptInterceptor {
79    async fn on_request_headers(&self, flow: &mut Flow) -> InterceptionResult {
80        let index = self.get_engine_index();
81        let engine_lock = &self.engines[index];
82        let engine = engine_lock.read().await;
83        
84        match engine.on_request_headers(flow).await {
85            Ok(Some(modified_flow)) => {
86                *flow = modified_flow;
87                InterceptionResult::ModifiedRequest(match &flow.layer {
88                    relay_core_api::flow::Layer::Http(h) => h.request.clone(),
89                    relay_core_api::flow::Layer::WebSocket(w) => w.handshake_request.clone(),
90                    _ => return InterceptionResult::Continue,
91                })
92            },
93            Ok(None) => InterceptionResult::Continue,
94            Err(e) => {
95                tracing::error!("Script execution error (on_request_headers): {}", e);
96                flow.tags.push("script-error".to_string());
97                if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
98                    http.error = Some(format!("Script Error: {}", e));
99                }
100                InterceptionResult::Continue
101            }
102        }
103    }
104
105    async fn on_request(&self, flow: &mut Flow, body: HttpBody) -> Result<RequestAction, BoxError> {
106        let index = self.get_engine_index();
107        let engine_lock = &self.engines[index];
108        let engine = engine_lock.read().await;
109        
110        engine.on_request(flow, body).await
111    }
112
113    async fn on_response_headers(&self, flow: &mut Flow) -> InterceptionResult {
114        let index = self.get_engine_index();
115        let engine_lock = &self.engines[index];
116        let engine = engine_lock.read().await;
117        
118        match engine.on_response_headers(flow).await {
119            Ok(Some(modified_flow)) => {
120                *flow = modified_flow;
121                InterceptionResult::ModifiedResponse(match &flow.layer {
122                    relay_core_api::flow::Layer::Http(h) => h.response.clone().unwrap(),
123                    relay_core_api::flow::Layer::WebSocket(w) => w.handshake_response.clone(),
124                    _ => return InterceptionResult::Continue,
125                })
126            },
127            Ok(None) => InterceptionResult::Continue,
128            Err(e) => {
129                tracing::error!("Script execution error (on_response_headers): {}", e);
130                flow.tags.push("script-error".to_string());
131                if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
132                    http.error = Some(format!("Script Error: {}", e));
133                }
134                InterceptionResult::Continue
135            }
136        }
137    }
138
139    async fn on_response(&self, flow: &mut Flow, body: HttpBody) -> Result<ResponseAction, BoxError> {
140        let index = self.get_engine_index();
141        let engine_lock = &self.engines[index];
142        let engine = engine_lock.read().await;
143
144        engine.on_response(flow, body).await
145    }
146    
147    async fn on_websocket_message(&self, flow: &mut Flow, mut message: WebSocketMessage) -> Result<WebSocketMessageAction, BoxError> {
148        let index = self.get_engine_index();
149        let engine_lock = &self.engines[index];
150        let engine = engine_lock.read().await;
151
152        match engine.on_websocket_message(flow, &mut message).await {
153            Ok(action) => Ok(action),
154            Err(e) => {
155                tracing::error!("Script execution error (on_websocket_message): {}", e);
156                flow.tags.push("script-error".to_string());
157                Ok(WebSocketMessageAction::Continue(message))
158            }
159        }
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use relay_core_api::flow::{
167        BodyData, Direction, Flow, HttpLayer, HttpRequest, Layer, NetworkInfo, TransportProtocol,
168        WebSocketMessage,
169    };
170    use url::Url;
171    use uuid::Uuid;
172    use chrono::Utc;
173    use std::collections::HashMap;
174    use http_body_util::{BodyExt, Empty};
175    use bytes::Bytes;
176
177    fn create_test_flow() -> Flow {
178        Flow {
179            id: Uuid::new_v4(),
180            start_time: Utc::now(),
181            end_time: None,
182            network: NetworkInfo {
183                client_ip: "127.0.0.1".to_string(),
184                client_port: 12345,
185                server_ip: "1.1.1.1".to_string(),
186                server_port: 80,
187                protocol: TransportProtocol::TCP,
188                tls: false,
189                tls_version: None,
190                sni: None,
191            },
192            layer: Layer::Http(HttpLayer {
193                request: HttpRequest {
194                    method: "GET".to_string(),
195                    url: Url::parse("http://example.com").unwrap(),
196                    version: "HTTP/1.1".to_string(),
197                    headers: vec![],
198                    cookies: vec![],
199                    query: vec![],
200                    body: None,
201                },
202                response: None,
203                error: None,
204            }),
205            tags: vec![],
206            meta: HashMap::new(),
207        }
208    }
209
210    #[tokio::test]
211    async fn test_script_error_propagation() {
212        let interceptor = ScriptInterceptor::new().await.unwrap();
213        
214        let script = r#"
215            globalThis.onRequestHeaders = (flow) => {
216                throw new Error("Test Error 123");
217            };
218        "#;
219        interceptor.load_script(script).await.unwrap();
220
221        let mut flow = create_test_flow();
222        
223        let result = interceptor.on_request_headers(&mut flow).await;
224        
225        match result {
226            InterceptionResult::Continue => {},
227            _ => panic!("Expected Continue"),
228        }
229
230        assert!(flow.tags.contains(&"script-error".to_string()));
231        
232        if let Layer::Http(http) = &flow.layer {
233            assert!(http.error.is_some());
234            let err = http.error.as_ref().unwrap();
235            assert!(err.contains("Test Error 123"));
236        } else {
237            panic!("Expected Http layer");
238        }
239    }
240
241    #[tokio::test]
242    async fn test_script_api_relay_body() {
243        let interceptor = ScriptInterceptor::new().await.unwrap();
244        
245        let script = r#"
246            globalThis.onRequest = (body, flow) => {
247                if (!(body instanceof RelayBody)) {
248                    throw new Error("First argument is not RelayBody");
249                }
250                // We return nothing (undefined), which means "continue with original body"
251                // But we successfully verified the type.
252            };
253        "#;
254        interceptor.load_script(script).await.unwrap();
255        
256        let mut flow = create_test_flow();
257        let body = Empty::<Bytes>::new().map_err(|_| -> BoxError { unreachable!() }).boxed();
258        
259        let result = interceptor.on_request(&mut flow, body).await;
260        assert!(result.is_ok());
261    }
262
263    #[tokio::test]
264    async fn test_load_script_failure_does_not_replace_existing_engines() {
265        let interceptor = ScriptInterceptor::new().await.unwrap();
266
267        let good_script = r#"
268            globalThis.onRequestHeaders = (_context, flow) => {
269                if (flow.layer.type === "Http") {
270                    flow.layer.data.request.headers.push(["X-Good-Script", "1"]);
271                }
272                return flow;
273            };
274        "#;
275        interceptor
276            .load_script(good_script)
277            .await
278            .expect("good script should load");
279
280        let bad_script = "globalThis.onRequestHeaders = () => { invalid javascript !!!";
281        let bad = interceptor.load_script(bad_script).await;
282        assert!(bad.is_err(), "bad script should be rejected");
283
284        let mut flow = create_test_flow();
285        let result = interceptor.on_request_headers(&mut flow).await;
286        assert!(
287            matches!(result, InterceptionResult::ModifiedRequest(_)),
288            "existing good script should still be active after failed reload"
289        );
290        if let Layer::Http(http) = &flow.layer {
291            assert!(
292                http.request
293                    .headers
294                    .iter()
295                    .any(|(k, v)| k == "X-Good-Script" && v == "1")
296            );
297        } else {
298            panic!("Expected Http layer");
299        }
300    }
301
302    #[tokio::test]
303    async fn test_websocket_script_error_falls_back_continue_and_tags_flow() {
304        let interceptor = ScriptInterceptor::new().await.unwrap();
305        let script = r#"
306            globalThis.onWebSocketMessage = function(_context, _flow, _message) {
307                throw new Error("ws explode");
308            };
309        "#;
310        interceptor
311            .load_script(script)
312            .await
313            .expect("script should load");
314
315        let mut flow = create_test_flow();
316        let msg = WebSocketMessage {
317            id: Uuid::new_v4(),
318            timestamp: Utc::now(),
319            direction: Direction::ClientToServer,
320            content: BodyData {
321                encoding: "utf-8".to_string(),
322                content: "hello".to_string(),
323                size: 5,
324            },
325            opcode: "Text".to_string(),
326        };
327
328        let result = interceptor
329            .on_websocket_message(&mut flow, msg.clone())
330            .await
331            .expect("websocket interception should not return hard error");
332        match result {
333            WebSocketMessageAction::Continue(forwarded) => {
334                assert_eq!(forwarded.content.content, "hello");
335            }
336            other => panic!("expected Continue fallback, got {:?}", other),
337        }
338        assert!(
339            flow.tags.iter().any(|t| t == "script-error"),
340            "script error should be tagged for observability"
341        );
342    }
343}