1pub mod deno_engine;
7pub mod engine_trait;
8pub mod streams;
9
10use crate::deno_engine::DenoScriptEngine;
11use crate::engine_trait::ScriptEngineTrait;
12use async_trait::async_trait;
13use relay_core_api::flow::{Flow, WebSocketMessage};
14use relay_core_lib::interceptor::{
15 BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, ResponseAction,
16 WebSocketMessageAction,
17};
18use std::collections::hash_map::DefaultHasher;
19use std::hash::{Hash, Hasher};
20use tokio::sync::RwLock;
21
22pub struct ScriptInterceptor {
23 engines: Vec<RwLock<Box<dyn ScriptEngineTrait>>>,
24}
25
26impl ScriptInterceptor {
27 pub async fn new() -> Result<Self, BoxError> {
28 let pool_size = std::thread::available_parallelism()
29 .map(|n| n.get())
30 .unwrap_or(4);
31 let mut engines = Vec::with_capacity(pool_size);
32
33 for _ in 0..pool_size {
34 let engine: Box<dyn ScriptEngineTrait> = Box::new(DenoScriptEngine::new());
35 engines.push(RwLock::new(engine));
36 }
37
38 Ok(Self { engines })
39 }
40
41 pub async fn load_script(&self, script: &str) -> Result<(), BoxError> {
42 let pool_size = self.engines.len();
46 let mut new_engines = Vec::with_capacity(pool_size);
47
48 for _ in 0..pool_size {
49 let mut engine = DenoScriptEngine::new();
50 engine.load_script(script).await?;
51 new_engines.push(Box::new(engine) as Box<dyn ScriptEngineTrait>);
52 }
53
54 let mut new_engines_iter = new_engines.into_iter();
57 for engine_lock in &self.engines {
58 if let Some(new_engine) = new_engines_iter.next() {
59 let mut guard = engine_lock.write().await;
60 *guard = new_engine;
61 }
62 }
63
64 Ok(())
65 }
66
67 fn get_engine_index(&self) -> usize {
68 if let Ok(index) = relay_core_lib::interceptor::ENGINE_INDEX.try_with(|i| *i) {
70 return index % self.engines.len();
71 }
72
73 let thread_id = std::thread::current().id();
74 let mut hasher = DefaultHasher::new();
75 thread_id.hash(&mut hasher);
76 (hasher.finish() as usize) % self.engines.len()
77 }
78}
79
80#[async_trait]
81impl Interceptor for ScriptInterceptor {
82 async fn on_request_headers(&self, flow: &mut Flow) -> InterceptionResult {
83 let index = self.get_engine_index();
84 let engine_lock = &self.engines[index];
85 let engine = engine_lock.read().await;
86
87 match engine.on_request_headers(flow).await {
88 Ok(Some(modified_flow)) => {
89 *flow = modified_flow;
90 InterceptionResult::ModifiedRequest(match &flow.layer {
91 relay_core_api::flow::Layer::Http(h) => h.request.clone(),
92 relay_core_api::flow::Layer::WebSocket(w) => w.handshake_request.clone(),
93 _ => return InterceptionResult::Continue,
94 })
95 }
96 Ok(None) => InterceptionResult::Continue,
97 Err(e) => {
98 tracing::error!("Script execution error (on_request_headers): {}", e);
99 flow.tags.push("script-error".to_string());
100 if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
101 http.error = Some(format!("Script Error: {}", e));
102 }
103 InterceptionResult::Continue
104 }
105 }
106 }
107
108 async fn on_request(&self, flow: &mut Flow, body: HttpBody) -> Result<RequestAction, BoxError> {
109 let index = self.get_engine_index();
110 let engine_lock = &self.engines[index];
111 let engine = engine_lock.read().await;
112
113 engine.on_request(flow, body).await
114 }
115
116 async fn on_response_headers(&self, flow: &mut Flow) -> InterceptionResult {
117 let index = self.get_engine_index();
118 let engine_lock = &self.engines[index];
119 let engine = engine_lock.read().await;
120
121 match engine.on_response_headers(flow).await {
122 Ok(Some(modified_flow)) => {
123 *flow = modified_flow;
124 InterceptionResult::ModifiedResponse(match &flow.layer {
125 relay_core_api::flow::Layer::Http(h) => h.response.clone().unwrap(),
126 relay_core_api::flow::Layer::WebSocket(w) => w.handshake_response.clone(),
127 _ => return InterceptionResult::Continue,
128 })
129 }
130 Ok(None) => InterceptionResult::Continue,
131 Err(e) => {
132 tracing::error!("Script execution error (on_response_headers): {}", e);
133 flow.tags.push("script-error".to_string());
134 if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
135 http.error = Some(format!("Script Error: {}", e));
136 }
137 InterceptionResult::Continue
138 }
139 }
140 }
141
142 async fn on_response(
143 &self,
144 flow: &mut Flow,
145 body: HttpBody,
146 ) -> Result<ResponseAction, BoxError> {
147 let index = self.get_engine_index();
148 let engine_lock = &self.engines[index];
149 let engine = engine_lock.read().await;
150
151 engine.on_response(flow, body).await
152 }
153
154 async fn on_websocket_message(
155 &self,
156 flow: &mut Flow,
157 mut message: WebSocketMessage,
158 ) -> Result<WebSocketMessageAction, BoxError> {
159 let index = self.get_engine_index();
160 let engine_lock = &self.engines[index];
161 let engine = engine_lock.read().await;
162
163 match engine.on_websocket_message(flow, &mut message).await {
164 Ok(action) => Ok(action),
165 Err(e) => {
166 tracing::error!("Script execution error (on_websocket_message): {}", e);
167 flow.tags.push("script-error".to_string());
168 Ok(WebSocketMessageAction::Continue(message))
169 }
170 }
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use bytes::Bytes;
178 use chrono::Utc;
179 use http_body_util::{BodyExt, Empty};
180 use relay_core_api::flow::{
181 BodyData, Direction, Flow, HttpLayer, HttpRequest, Layer, NetworkInfo, TransportProtocol,
182 WebSocketMessage,
183 };
184 use std::collections::HashMap;
185 use url::Url;
186 use uuid::Uuid;
187
188 fn create_test_flow() -> Flow {
189 Flow {
190 id: Uuid::new_v4(),
191 start_time: Utc::now(),
192 end_time: None,
193 network: NetworkInfo {
194 client_ip: "127.0.0.1".to_string(),
195 client_port: 12345,
196 server_ip: "1.1.1.1".to_string(),
197 server_port: 80,
198 protocol: TransportProtocol::TCP,
199 tls: false,
200 tls_version: None,
201 sni: None,
202 },
203 layer: Layer::Http(HttpLayer {
204 request: HttpRequest {
205 method: "GET".to_string(),
206 url: Url::parse("http://example.com").unwrap(),
207 version: "HTTP/1.1".to_string(),
208 headers: vec![],
209 cookies: vec![],
210 query: vec![],
211 body: None,
212 },
213 response: None,
214 error: None,
215 }),
216 tags: vec![],
217 meta: HashMap::new(),
218 }
219 }
220
221 #[tokio::test]
222 async fn test_script_error_propagation() {
223 let interceptor = ScriptInterceptor::new().await.unwrap();
224
225 let script = r#"
226 globalThis.onRequestHeaders = (flow) => {
227 throw new Error("Test Error 123");
228 };
229 "#;
230 interceptor.load_script(script).await.unwrap();
231
232 let mut flow = create_test_flow();
233
234 let result = interceptor.on_request_headers(&mut flow).await;
235
236 match result {
237 InterceptionResult::Continue => {}
238 _ => panic!("Expected Continue"),
239 }
240
241 assert!(flow.tags.contains(&"script-error".to_string()));
242
243 if let Layer::Http(http) = &flow.layer {
244 assert!(http.error.is_some());
245 let err = http.error.as_ref().unwrap();
246 assert!(err.contains("Test Error 123"));
247 } else {
248 panic!("Expected Http layer");
249 }
250 }
251
252 #[tokio::test]
253 async fn test_script_api_relay_body() {
254 let interceptor = ScriptInterceptor::new().await.unwrap();
255
256 let script = r#"
257 globalThis.onRequest = (body, flow) => {
258 if (!(body instanceof RelayBody)) {
259 throw new Error("First argument is not RelayBody");
260 }
261 // We return nothing (undefined), which means "continue with original body"
262 // But we successfully verified the type.
263 };
264 "#;
265 interceptor.load_script(script).await.unwrap();
266
267 let mut flow = create_test_flow();
268 let body = Empty::<Bytes>::new()
269 .map_err(|_| -> BoxError { unreachable!() })
270 .boxed();
271
272 let result = interceptor.on_request(&mut flow, body).await;
273 assert!(result.is_ok());
274 }
275
276 #[tokio::test]
277 async fn test_load_script_failure_does_not_replace_existing_engines() {
278 let interceptor = ScriptInterceptor::new().await.unwrap();
279
280 let good_script = r#"
281 globalThis.onRequestHeaders = (_context, flow) => {
282 if (flow.layer.type === "Http") {
283 flow.layer.data.request.headers.push(["X-Good-Script", "1"]);
284 }
285 return flow;
286 };
287 "#;
288 interceptor
289 .load_script(good_script)
290 .await
291 .expect("good script should load");
292
293 let bad_script = "globalThis.onRequestHeaders = () => { invalid javascript !!!";
294 let bad = interceptor.load_script(bad_script).await;
295 assert!(bad.is_err(), "bad script should be rejected");
296
297 let mut flow = create_test_flow();
298 let result = interceptor.on_request_headers(&mut flow).await;
299 assert!(
300 matches!(result, InterceptionResult::ModifiedRequest(_)),
301 "existing good script should still be active after failed reload"
302 );
303 if let Layer::Http(http) = &flow.layer {
304 assert!(
305 http.request
306 .headers
307 .iter()
308 .any(|(k, v)| k == "X-Good-Script" && v == "1")
309 );
310 } else {
311 panic!("Expected Http layer");
312 }
313 }
314
315 #[tokio::test]
316 async fn test_websocket_script_error_falls_back_continue_and_tags_flow() {
317 let interceptor = ScriptInterceptor::new().await.unwrap();
318 let script = r#"
319 globalThis.onWebSocketMessage = function(_context, _flow, _message) {
320 throw new Error("ws explode");
321 };
322 "#;
323 interceptor
324 .load_script(script)
325 .await
326 .expect("script should load");
327
328 let mut flow = create_test_flow();
329 let msg = WebSocketMessage {
330 id: Uuid::new_v4(),
331 timestamp: Utc::now(),
332 direction: Direction::ClientToServer,
333 content: BodyData {
334 encoding: "utf-8".to_string(),
335 content: "hello".to_string(),
336 size: 5,
337 },
338 opcode: "Text".to_string(),
339 };
340
341 let result = interceptor
342 .on_websocket_message(&mut flow, msg.clone())
343 .await
344 .expect("websocket interception should not return hard error");
345 match result {
346 WebSocketMessageAction::Continue(forwarded) => {
347 assert_eq!(forwarded.content.content, "hello");
348 }
349 other => panic!("expected Continue fallback, got {:?}", other),
350 }
351 assert!(
352 flow.tags.iter().any(|t| t == "script-error"),
353 "script error should be tagged for observability"
354 );
355 }
356}