1pub mod engine_trait;
7pub mod deno_engine;
8pub mod streams;
9
10use relay_core_lib::interceptor::{Interceptor, InterceptionResult, RequestAction, ResponseAction, WebSocketMessageAction, HttpBody, BoxError};
11use relay_core_api::flow::{Flow, WebSocketMessage};
12use crate::deno_engine::DenoScriptEngine;
13use crate::engine_trait::ScriptEngineTrait;
14use async_trait::async_trait;
15use tokio::sync::RwLock;
16use std::hash::{Hash, Hasher};
17use std::collections::hash_map::DefaultHasher;
18
19pub struct ScriptInterceptor {
20 engines: Vec<RwLock<Box<dyn ScriptEngineTrait>>>,
21}
22
23impl ScriptInterceptor {
24 pub async fn new() -> Result<Self, BoxError> {
25 let pool_size = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4);
26 let mut engines = Vec::with_capacity(pool_size);
27
28 for _ in 0..pool_size {
29 let engine: Box<dyn ScriptEngineTrait> = Box::new(DenoScriptEngine::new());
30 engines.push(RwLock::new(engine));
31 }
32
33 Ok(Self {
34 engines,
35 })
36 }
37
38 pub async fn load_script(&self, script: &str) -> Result<(), BoxError> {
39 let pool_size = self.engines.len();
43 let mut new_engines = Vec::with_capacity(pool_size);
44
45 for _ in 0..pool_size {
46 let mut engine = DenoScriptEngine::new();
47 engine.load_script(script).await?;
48 new_engines.push(Box::new(engine) as Box<dyn ScriptEngineTrait>);
49 }
50
51 let mut new_engines_iter = new_engines.into_iter();
54 for engine_lock in &self.engines {
55 if let Some(new_engine) = new_engines_iter.next() {
56 let mut guard = engine_lock.write().await;
57 *guard = new_engine;
58 }
59 }
60
61 Ok(())
62 }
63
64 fn get_engine_index(&self) -> usize {
65 if let Ok(index) = relay_core_lib::interceptor::ENGINE_INDEX.try_with(|i| *i) {
67 return index % self.engines.len();
68 }
69
70 let thread_id = std::thread::current().id();
71 let mut hasher = DefaultHasher::new();
72 thread_id.hash(&mut hasher);
73 (hasher.finish() as usize) % self.engines.len()
74 }
75}
76
77#[async_trait]
78impl Interceptor for ScriptInterceptor {
79 async fn on_request_headers(&self, flow: &mut Flow) -> InterceptionResult {
80 let index = self.get_engine_index();
81 let engine_lock = &self.engines[index];
82 let engine = engine_lock.read().await;
83
84 match engine.on_request_headers(flow).await {
85 Ok(Some(modified_flow)) => {
86 *flow = modified_flow;
87 InterceptionResult::ModifiedRequest(match &flow.layer {
88 relay_core_api::flow::Layer::Http(h) => h.request.clone(),
89 relay_core_api::flow::Layer::WebSocket(w) => w.handshake_request.clone(),
90 _ => return InterceptionResult::Continue,
91 })
92 },
93 Ok(None) => InterceptionResult::Continue,
94 Err(e) => {
95 tracing::error!("Script execution error (on_request_headers): {}", e);
96 flow.tags.push("script-error".to_string());
97 if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
98 http.error = Some(format!("Script Error: {}", e));
99 }
100 InterceptionResult::Continue
101 }
102 }
103 }
104
105 async fn on_request(&self, flow: &mut Flow, body: HttpBody) -> Result<RequestAction, BoxError> {
106 let index = self.get_engine_index();
107 let engine_lock = &self.engines[index];
108 let engine = engine_lock.read().await;
109
110 engine.on_request(flow, body).await
111 }
112
113 async fn on_response_headers(&self, flow: &mut Flow) -> InterceptionResult {
114 let index = self.get_engine_index();
115 let engine_lock = &self.engines[index];
116 let engine = engine_lock.read().await;
117
118 match engine.on_response_headers(flow).await {
119 Ok(Some(modified_flow)) => {
120 *flow = modified_flow;
121 InterceptionResult::ModifiedResponse(match &flow.layer {
122 relay_core_api::flow::Layer::Http(h) => h.response.clone().unwrap(),
123 relay_core_api::flow::Layer::WebSocket(w) => w.handshake_response.clone(),
124 _ => return InterceptionResult::Continue,
125 })
126 },
127 Ok(None) => InterceptionResult::Continue,
128 Err(e) => {
129 tracing::error!("Script execution error (on_response_headers): {}", e);
130 flow.tags.push("script-error".to_string());
131 if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
132 http.error = Some(format!("Script Error: {}", e));
133 }
134 InterceptionResult::Continue
135 }
136 }
137 }
138
139 async fn on_response(&self, flow: &mut Flow, body: HttpBody) -> Result<ResponseAction, BoxError> {
140 let index = self.get_engine_index();
141 let engine_lock = &self.engines[index];
142 let engine = engine_lock.read().await;
143
144 engine.on_response(flow, body).await
145 }
146
147 async fn on_websocket_message(&self, flow: &mut Flow, mut message: WebSocketMessage) -> Result<WebSocketMessageAction, BoxError> {
148 let index = self.get_engine_index();
149 let engine_lock = &self.engines[index];
150 let engine = engine_lock.read().await;
151
152 match engine.on_websocket_message(flow, &mut message).await {
153 Ok(action) => Ok(action),
154 Err(e) => {
155 tracing::error!("Script execution error (on_websocket_message): {}", e);
156 flow.tags.push("script-error".to_string());
157 Ok(WebSocketMessageAction::Continue(message))
158 }
159 }
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use relay_core_api::flow::{
167 BodyData, Direction, Flow, HttpLayer, HttpRequest, Layer, NetworkInfo, TransportProtocol,
168 WebSocketMessage,
169 };
170 use url::Url;
171 use uuid::Uuid;
172 use chrono::Utc;
173 use std::collections::HashMap;
174 use http_body_util::{BodyExt, Empty};
175 use bytes::Bytes;
176
177 fn create_test_flow() -> Flow {
178 Flow {
179 id: Uuid::new_v4(),
180 start_time: Utc::now(),
181 end_time: None,
182 network: NetworkInfo {
183 client_ip: "127.0.0.1".to_string(),
184 client_port: 12345,
185 server_ip: "1.1.1.1".to_string(),
186 server_port: 80,
187 protocol: TransportProtocol::TCP,
188 tls: false,
189 tls_version: None,
190 sni: None,
191 },
192 layer: Layer::Http(HttpLayer {
193 request: HttpRequest {
194 method: "GET".to_string(),
195 url: Url::parse("http://example.com").unwrap(),
196 version: "HTTP/1.1".to_string(),
197 headers: vec![],
198 cookies: vec![],
199 query: vec![],
200 body: None,
201 },
202 response: None,
203 error: None,
204 }),
205 tags: vec![],
206 meta: HashMap::new(),
207 }
208 }
209
210 #[tokio::test]
211 async fn test_script_error_propagation() {
212 let interceptor = ScriptInterceptor::new().await.unwrap();
213
214 let script = r#"
215 globalThis.onRequestHeaders = (flow) => {
216 throw new Error("Test Error 123");
217 };
218 "#;
219 interceptor.load_script(script).await.unwrap();
220
221 let mut flow = create_test_flow();
222
223 let result = interceptor.on_request_headers(&mut flow).await;
224
225 match result {
226 InterceptionResult::Continue => {},
227 _ => panic!("Expected Continue"),
228 }
229
230 assert!(flow.tags.contains(&"script-error".to_string()));
231
232 if let Layer::Http(http) = &flow.layer {
233 assert!(http.error.is_some());
234 let err = http.error.as_ref().unwrap();
235 assert!(err.contains("Test Error 123"));
236 } else {
237 panic!("Expected Http layer");
238 }
239 }
240
241 #[tokio::test]
242 async fn test_script_api_relay_body() {
243 let interceptor = ScriptInterceptor::new().await.unwrap();
244
245 let script = r#"
246 globalThis.onRequest = (body, flow) => {
247 if (!(body instanceof RelayBody)) {
248 throw new Error("First argument is not RelayBody");
249 }
250 // We return nothing (undefined), which means "continue with original body"
251 // But we successfully verified the type.
252 };
253 "#;
254 interceptor.load_script(script).await.unwrap();
255
256 let mut flow = create_test_flow();
257 let body = Empty::<Bytes>::new().map_err(|_| -> BoxError { unreachable!() }).boxed();
258
259 let result = interceptor.on_request(&mut flow, body).await;
260 assert!(result.is_ok());
261 }
262
263 #[tokio::test]
264 async fn test_load_script_failure_does_not_replace_existing_engines() {
265 let interceptor = ScriptInterceptor::new().await.unwrap();
266
267 let good_script = r#"
268 globalThis.onRequestHeaders = (_context, flow) => {
269 if (flow.layer.type === "Http") {
270 flow.layer.data.request.headers.push(["X-Good-Script", "1"]);
271 }
272 return flow;
273 };
274 "#;
275 interceptor
276 .load_script(good_script)
277 .await
278 .expect("good script should load");
279
280 let bad_script = "globalThis.onRequestHeaders = () => { invalid javascript !!!";
281 let bad = interceptor.load_script(bad_script).await;
282 assert!(bad.is_err(), "bad script should be rejected");
283
284 let mut flow = create_test_flow();
285 let result = interceptor.on_request_headers(&mut flow).await;
286 assert!(
287 matches!(result, InterceptionResult::ModifiedRequest(_)),
288 "existing good script should still be active after failed reload"
289 );
290 if let Layer::Http(http) = &flow.layer {
291 assert!(
292 http.request
293 .headers
294 .iter()
295 .any(|(k, v)| k == "X-Good-Script" && v == "1")
296 );
297 } else {
298 panic!("Expected Http layer");
299 }
300 }
301
302 #[tokio::test]
303 async fn test_websocket_script_error_falls_back_continue_and_tags_flow() {
304 let interceptor = ScriptInterceptor::new().await.unwrap();
305 let script = r#"
306 globalThis.onWebSocketMessage = function(_context, _flow, _message) {
307 throw new Error("ws explode");
308 };
309 "#;
310 interceptor
311 .load_script(script)
312 .await
313 .expect("script should load");
314
315 let mut flow = create_test_flow();
316 let msg = WebSocketMessage {
317 id: Uuid::new_v4(),
318 timestamp: Utc::now(),
319 direction: Direction::ClientToServer,
320 content: BodyData {
321 encoding: "utf-8".to_string(),
322 content: "hello".to_string(),
323 size: 5,
324 },
325 opcode: "Text".to_string(),
326 };
327
328 let result = interceptor
329 .on_websocket_message(&mut flow, msg.clone())
330 .await
331 .expect("websocket interception should not return hard error");
332 match result {
333 WebSocketMessageAction::Continue(forwarded) => {
334 assert_eq!(forwarded.content.content, "hello");
335 }
336 other => panic!("expected Continue fallback, got {:?}", other),
337 }
338 assert!(
339 flow.tags.iter().any(|t| t == "script-error"),
340 "script error should be tagged for observability"
341 );
342 }
343}