1pub mod deno_engine;
7pub mod engine_trait;
8pub mod streams;
9
10use crate::deno_engine::DenoScriptEngine;
11use crate::engine_trait::ScriptEngineTrait;
12use async_trait::async_trait;
13use relay_core_api::flow::{Flow, WebSocketMessage};
14use relay_core_lib::interceptor::{
15 BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, ResponseAction,
16 WebSocketMessageAction,
17};
18use std::collections::HashSet;
19use std::collections::hash_map::DefaultHasher;
20use std::hash::{Hash, Hasher};
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::time::Instant;
23use tokio::sync::RwLock;
24
25pub struct ScriptMetrics {
27 pub on_request_headers_duration_us: AtomicU64,
28 pub on_request_headers_invocations: AtomicU64,
29 pub on_request_headers_errors: AtomicU64,
30
31 pub on_request_duration_us: AtomicU64,
32 pub on_request_invocations: AtomicU64,
33 pub on_request_errors: AtomicU64,
34
35 pub on_response_headers_duration_us: AtomicU64,
36 pub on_response_headers_invocations: AtomicU64,
37 pub on_response_headers_errors: AtomicU64,
38
39 pub on_response_duration_us: AtomicU64,
40 pub on_response_invocations: AtomicU64,
41 pub on_response_errors: AtomicU64,
42
43 pub on_websocket_message_duration_us: AtomicU64,
44 pub on_websocket_message_invocations: AtomicU64,
45 pub on_websocket_message_errors: AtomicU64,
46}
47
48impl Default for ScriptMetrics {
49 fn default() -> Self {
50 Self {
51 on_request_headers_duration_us: AtomicU64::new(0),
52 on_request_headers_invocations: AtomicU64::new(0),
53 on_request_headers_errors: AtomicU64::new(0),
54 on_request_duration_us: AtomicU64::new(0),
55 on_request_invocations: AtomicU64::new(0),
56 on_request_errors: AtomicU64::new(0),
57 on_response_headers_duration_us: AtomicU64::new(0),
58 on_response_headers_invocations: AtomicU64::new(0),
59 on_response_headers_errors: AtomicU64::new(0),
60 on_response_duration_us: AtomicU64::new(0),
61 on_response_invocations: AtomicU64::new(0),
62 on_response_errors: AtomicU64::new(0),
63 on_websocket_message_duration_us: AtomicU64::new(0),
64 on_websocket_message_invocations: AtomicU64::new(0),
65 on_websocket_message_errors: AtomicU64::new(0),
66 }
67 }
68}
69
70impl ScriptMetrics {
71 pub fn prometheus_lines(&self) -> String {
72 let mut out = String::new();
73 macro_rules! push_metric {
74 ($name:expr, $dur:expr, $inv:expr, $err:expr) => {
75 out.push_str(&format!(
76 "relay_core_script_hook_duration_us{{hook=\"{}\"}} {}\n",
77 $name,
78 $dur.load(Ordering::Relaxed)
79 ));
80 out.push_str(&format!(
81 "relay_core_script_hook_invocations_total{{hook=\"{}\"}} {}\n",
82 $name,
83 $inv.load(Ordering::Relaxed)
84 ));
85 out.push_str(&format!(
86 "relay_core_script_hook_errors_total{{hook=\"{}\"}} {}\n",
87 $name,
88 $err.load(Ordering::Relaxed)
89 ));
90 };
91 }
92 push_metric!(
93 "onRequestHeaders",
94 &self.on_request_headers_duration_us,
95 &self.on_request_headers_invocations,
96 &self.on_request_headers_errors
97 );
98 push_metric!(
99 "onRequest",
100 &self.on_request_duration_us,
101 &self.on_request_invocations,
102 &self.on_request_errors
103 );
104 push_metric!(
105 "onResponseHeaders",
106 &self.on_response_headers_duration_us,
107 &self.on_response_headers_invocations,
108 &self.on_response_headers_errors
109 );
110 push_metric!(
111 "onResponse",
112 &self.on_response_duration_us,
113 &self.on_response_invocations,
114 &self.on_response_errors
115 );
116 push_metric!(
117 "onWebSocketMessage",
118 &self.on_websocket_message_duration_us,
119 &self.on_websocket_message_invocations,
120 &self.on_websocket_message_errors
121 );
122 out
123 }
124}
125
126pub struct ScriptInterceptor {
127 engines: Vec<RwLock<Box<dyn ScriptEngineTrait>>>,
128 pub metrics: ScriptMetrics,
129 env_allow: RwLock<HashSet<String>>,
130}
131
132impl ScriptInterceptor {
133 pub async fn new() -> Result<Self, BoxError> {
134 Self::new_with_env(HashSet::new()).await
135 }
136
137 pub async fn new_with_env(env_allow: HashSet<String>) -> Result<Self, BoxError> {
138 let pool_size = std::thread::available_parallelism()
139 .map(|n| n.get())
140 .unwrap_or(4);
141 let mut engines = Vec::with_capacity(pool_size);
142
143 for _ in 0..pool_size {
144 let engine: Box<dyn ScriptEngineTrait> =
145 Box::new(DenoScriptEngine::new(env_allow.clone()));
146 engines.push(RwLock::new(engine));
147 }
148
149 Ok(Self {
150 engines,
151 metrics: ScriptMetrics::default(),
152 env_allow: RwLock::new(env_allow),
153 })
154 }
155
156 pub async fn set_env_allow(&self, env_allow: HashSet<String>) {
157 let mut guard = self.env_allow.write().await;
158 *guard = env_allow;
159 }
160
161 pub async fn load_script(&self, script: &str) -> Result<(), BoxError> {
162 let pool_size = self.engines.len();
164 let mut new_engines = Vec::with_capacity(pool_size);
165
166 let env = self.env_allow.read().await.clone();
167 for _ in 0..pool_size {
168 let mut engine = DenoScriptEngine::new(env.clone());
169 engine.load_script(script).await?;
170 new_engines.push(Box::new(engine) as Box<dyn ScriptEngineTrait>);
171 }
172
173 let mut new_engines_iter = new_engines.into_iter();
176 for engine_lock in &self.engines {
177 if let Some(new_engine) = new_engines_iter.next() {
178 let mut guard = engine_lock.write().await;
179 *guard = new_engine;
180 }
181 }
182
183 Ok(())
184 }
185
186 fn get_engine_index(&self) -> usize {
187 if let Ok(index) = relay_core_lib::interceptor::ENGINE_INDEX.try_with(|i| *i) {
189 return index % self.engines.len();
190 }
191
192 let thread_id = std::thread::current().id();
193 let mut hasher = DefaultHasher::new();
194 thread_id.hash(&mut hasher);
195 (hasher.finish() as usize) % self.engines.len()
196 }
197}
198
199#[async_trait]
200impl Interceptor for ScriptInterceptor {
201 async fn on_request_headers(&self, flow: &mut Flow) -> InterceptionResult {
202 let start = Instant::now();
203 let index = self.get_engine_index();
204 let engine_lock = &self.engines[index];
205 let engine = engine_lock.read().await;
206
207 let result = match engine.on_request_headers(flow).await {
208 Ok(Some(modified_flow)) => {
209 *flow = modified_flow;
210 InterceptionResult::ModifiedRequest(match &flow.layer {
211 relay_core_api::flow::Layer::Http(h) => h.request.clone(),
212 relay_core_api::flow::Layer::WebSocket(w) => w.handshake_request.clone(),
213 _ => {
214 self.metrics
215 .on_request_headers_errors
216 .fetch_add(1, Ordering::Relaxed);
217 return InterceptionResult::Continue;
218 }
219 })
220 }
221 Ok(None) => InterceptionResult::Continue,
222 Err(e) => {
223 tracing::error!("Script execution error (on_request_headers): {}", e);
224 flow.tags.push("script-error".to_string());
225 if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
226 http.error = Some(format!("Script Error: {}", e));
227 }
228 self.metrics
229 .on_request_headers_errors
230 .fetch_add(1, Ordering::Relaxed);
231 InterceptionResult::Continue
232 }
233 };
234 let dur_us = start.elapsed().as_micros() as u64;
235 self.metrics
236 .on_request_headers_duration_us
237 .fetch_add(dur_us, Ordering::Relaxed);
238 self.metrics
239 .on_request_headers_invocations
240 .fetch_add(1, Ordering::Relaxed);
241 result
242 }
243
244 async fn on_request(&self, flow: &mut Flow, body: HttpBody) -> Result<RequestAction, BoxError> {
245 let start = Instant::now();
246 let index = self.get_engine_index();
247 let engine_lock = &self.engines[index];
248 let engine = engine_lock.read().await;
249
250 match engine.on_request(flow, body).await {
251 Ok(action) => {
252 let dur_us = start.elapsed().as_micros() as u64;
253 self.metrics
254 .on_request_duration_us
255 .fetch_add(dur_us, Ordering::Relaxed);
256 self.metrics
257 .on_request_invocations
258 .fetch_add(1, Ordering::Relaxed);
259 Ok(action)
260 }
261 Err(e) => {
262 self.metrics
263 .on_request_errors
264 .fetch_add(1, Ordering::Relaxed);
265 Err(e)
266 }
267 }
268 }
269
270 async fn on_response_headers(&self, flow: &mut Flow) -> InterceptionResult {
271 let start = Instant::now();
272 let index = self.get_engine_index();
273 let engine_lock = &self.engines[index];
274 let engine = engine_lock.read().await;
275
276 let result = match engine.on_response_headers(flow).await {
277 Ok(Some(modified_flow)) => {
278 *flow = modified_flow;
279 InterceptionResult::ModifiedResponse(match &flow.layer {
280 relay_core_api::flow::Layer::Http(h) => h.response.clone().unwrap(),
281 relay_core_api::flow::Layer::WebSocket(w) => w.handshake_response.clone(),
282 _ => {
283 self.metrics
284 .on_response_headers_errors
285 .fetch_add(1, Ordering::Relaxed);
286 return InterceptionResult::Continue;
287 }
288 })
289 }
290 Ok(None) => InterceptionResult::Continue,
291 Err(e) => {
292 tracing::error!("Script execution error (on_response_headers): {}", e);
293 flow.tags.push("script-error".to_string());
294 if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
295 http.error = Some(format!("Script Error: {}", e));
296 }
297 self.metrics
298 .on_response_headers_errors
299 .fetch_add(1, Ordering::Relaxed);
300 InterceptionResult::Continue
301 }
302 };
303 let dur_us = start.elapsed().as_micros() as u64;
304 self.metrics
305 .on_response_headers_duration_us
306 .fetch_add(dur_us, Ordering::Relaxed);
307 self.metrics
308 .on_response_headers_invocations
309 .fetch_add(1, Ordering::Relaxed);
310 result
311 }
312
313 async fn on_response(
314 &self,
315 flow: &mut Flow,
316 body: HttpBody,
317 ) -> Result<ResponseAction, BoxError> {
318 let start = Instant::now();
319 let index = self.get_engine_index();
320 let engine_lock = &self.engines[index];
321 let engine = engine_lock.read().await;
322
323 match engine.on_response(flow, body).await {
324 Ok(action) => {
325 let dur_us = start.elapsed().as_micros() as u64;
326 self.metrics
327 .on_response_duration_us
328 .fetch_add(dur_us, Ordering::Relaxed);
329 self.metrics
330 .on_response_invocations
331 .fetch_add(1, Ordering::Relaxed);
332 Ok(action)
333 }
334 Err(e) => {
335 self.metrics
336 .on_response_errors
337 .fetch_add(1, Ordering::Relaxed);
338 Err(e)
339 }
340 }
341 }
342
343 async fn on_websocket_message(
344 &self,
345 flow: &mut Flow,
346 mut message: WebSocketMessage,
347 ) -> Result<WebSocketMessageAction, BoxError> {
348 let start = Instant::now();
349 let index = self.get_engine_index();
350 let engine_lock = &self.engines[index];
351 let engine = engine_lock.read().await;
352
353 match engine.on_websocket_message(flow, &mut message).await {
354 Ok(action) => {
355 let dur_us = start.elapsed().as_micros() as u64;
356 self.metrics
357 .on_websocket_message_duration_us
358 .fetch_add(dur_us, Ordering::Relaxed);
359 self.metrics
360 .on_websocket_message_invocations
361 .fetch_add(1, Ordering::Relaxed);
362 Ok(action)
363 }
364 Err(e) => {
365 tracing::error!("Script execution error (on_websocket_message): {}", e);
366 flow.tags.push("script-error".to_string());
367 self.metrics
368 .on_websocket_message_errors
369 .fetch_add(1, Ordering::Relaxed);
370 Ok(WebSocketMessageAction::Continue(message))
371 }
372 }
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use bytes::Bytes;
380 use chrono::Utc;
381 use http_body_util::{BodyExt, Empty};
382 use relay_core_api::flow::{
383 BodyData, Direction, Flow, HttpLayer, HttpRequest, Layer, NetworkInfo, TransportProtocol,
384 WebSocketMessage,
385 };
386 use std::collections::HashMap;
387 use url::Url;
388 use uuid::Uuid;
389
390 fn create_test_flow() -> Flow {
391 Flow {
392 id: Uuid::new_v4(),
393 start_time: Utc::now(),
394 end_time: None,
395 network: NetworkInfo {
396 client_ip: "127.0.0.1".to_string(),
397 client_port: 12345,
398 server_ip: "1.1.1.1".to_string(),
399 server_port: 80,
400 protocol: TransportProtocol::TCP,
401 tls: false,
402 tls_version: None,
403 sni: None,
404 },
405 layer: Layer::Http(HttpLayer {
406 request: HttpRequest {
407 method: "GET".to_string(),
408 url: Url::parse("http://example.com").unwrap(),
409 version: "HTTP/1.1".to_string(),
410 headers: vec![],
411 cookies: vec![],
412 query: vec![],
413 body: None,
414 },
415 response: None,
416 error: None,
417 }),
418 tags: vec![],
419 meta: HashMap::new(),
420 }
421 }
422
423 #[tokio::test]
424 async fn test_script_error_propagation() {
425 let interceptor = ScriptInterceptor::new().await.unwrap();
426
427 let script = r#"
428 globalThis.onRequestHeaders = (flow) => {
429 throw new Error("Test Error 123");
430 };
431 "#;
432 interceptor.load_script(script).await.unwrap();
433
434 let mut flow = create_test_flow();
435
436 let result = interceptor.on_request_headers(&mut flow).await;
437
438 match result {
439 InterceptionResult::Continue => {}
440 _ => panic!("Expected Continue"),
441 }
442
443 assert!(flow.tags.contains(&"script-error".to_string()));
444
445 if let Layer::Http(http) = &flow.layer {
446 assert!(http.error.is_some());
447 let err = http.error.as_ref().unwrap();
448 assert!(err.contains("Test Error 123"));
449 } else {
450 panic!("Expected Http layer");
451 }
452 }
453
454 #[tokio::test]
455 async fn test_script_api_relay_body() {
456 let interceptor = ScriptInterceptor::new().await.unwrap();
457
458 let script = r#"
459 globalThis.onRequest = (body, flow) => {
460 if (!(body instanceof RelayBody)) {
461 throw new Error("First argument is not RelayBody");
462 }
463 // We return nothing (undefined), which means "continue with original body"
464 // But we successfully verified the type.
465 };
466 "#;
467 interceptor.load_script(script).await.unwrap();
468
469 let mut flow = create_test_flow();
470 let body = Empty::<Bytes>::new()
471 .map_err(|_| -> BoxError { unreachable!() })
472 .boxed();
473
474 let result = interceptor.on_request(&mut flow, body).await;
475 assert!(result.is_ok());
476 }
477
478 #[tokio::test]
479 async fn test_load_script_failure_does_not_replace_existing_engines() {
480 let interceptor = ScriptInterceptor::new().await.unwrap();
481
482 let good_script = r#"
483 globalThis.onRequestHeaders = (_context, flow) => {
484 if (flow.layer.type === "Http") {
485 flow.layer.data.request.headers.push(["X-Good-Script", "1"]);
486 }
487 return flow;
488 };
489 "#;
490 interceptor
491 .load_script(good_script)
492 .await
493 .expect("good script should load");
494
495 let bad_script = "globalThis.onRequestHeaders = () => { invalid javascript !!!";
496 let bad = interceptor.load_script(bad_script).await;
497 assert!(bad.is_err(), "bad script should be rejected");
498
499 let mut flow = create_test_flow();
500 let result = interceptor.on_request_headers(&mut flow).await;
501 assert!(
502 matches!(result, InterceptionResult::ModifiedRequest(_)),
503 "existing good script should still be active after failed reload"
504 );
505 if let Layer::Http(http) = &flow.layer {
506 assert!(
507 http.request
508 .headers
509 .iter()
510 .any(|(k, v)| k == "X-Good-Script" && v == "1")
511 );
512 } else {
513 panic!("Expected Http layer");
514 }
515 }
516
517 #[tokio::test]
518 async fn test_websocket_script_error_falls_back_continue_and_tags_flow() {
519 let interceptor = ScriptInterceptor::new().await.unwrap();
520 let script = r#"
521 globalThis.onWebSocketMessage = function(_context, _flow, _message) {
522 throw new Error("ws explode");
523 };
524 "#;
525 interceptor
526 .load_script(script)
527 .await
528 .expect("script should load");
529
530 let mut flow = create_test_flow();
531 let msg = WebSocketMessage {
532 id: Uuid::new_v4(),
533 timestamp: Utc::now(),
534 direction: Direction::ClientToServer,
535 content: BodyData {
536 encoding: "utf-8".to_string(),
537 content: "hello".to_string(),
538 size: 5,
539 },
540 opcode: "Text".to_string(),
541 };
542
543 let result = interceptor
544 .on_websocket_message(&mut flow, msg.clone())
545 .await
546 .expect("websocket interception should not return hard error");
547 match result {
548 WebSocketMessageAction::Continue(forwarded) => {
549 assert_eq!(forwarded.content.content, "hello");
550 }
551 other => panic!("expected Continue fallback, got {:?}", other),
552 }
553 assert!(
554 flow.tags.iter().any(|t| t == "script-error"),
555 "script error should be tagged for observability"
556 );
557 }
558}