Skip to main content

relay_core_runtime/interceptors/
rule.rs

1use crate::CoreState;
2use async_trait::async_trait;
3use relay_core_api::flow::{Flow, Layer};
4use relay_core_api::rule::RuleStage;
5use relay_core_lib::intercept::{
6    BoxError, ConnectAction, ConnectionInfo, ConnectionStats, HttpBody, InterceptionResult,
7    Interceptor, RequestAction, ResponseAction, WebSocketMessageAction,
8};
9use std::sync::Arc;
10
11pub struct RuleInterceptor {
12    state: Arc<CoreState>,
13}
14
15impl RuleInterceptor {
16    pub fn new(state: Arc<CoreState>) -> Self {
17        Self { state }
18    }
19}
20
21#[async_trait]
22impl Interceptor for RuleInterceptor {
23    async fn on_request_headers(&self, flow: &mut Flow) -> InterceptionResult {
24        let engine = self.state.get_rule_engine().await;
25        if !engine.has_rules_for_stage(RuleStage::RequestHeaders) {
26            return InterceptionResult::Continue;
27        }
28        let had_response = matches!(&flow.layer, Layer::Http(http) if http.response.is_some());
29        let ctx = engine.execute(RuleStage::RequestHeaders, flow).await;
30        if ctx.is_terminated() {
31            let resp = match &flow.layer {
32                Layer::Http(http) if !had_response => http.response.clone(),
33                _ => None,
34            };
35            if let Some(resp) = resp {
36                return InterceptionResult::MockResponse(resp);
37            }
38            return InterceptionResult::Drop;
39        }
40        InterceptionResult::Continue
41    }
42
43    async fn on_request(&self, flow: &mut Flow, body: HttpBody) -> Result<RequestAction, BoxError> {
44        let engine = self.state.get_rule_engine().await;
45        if engine.has_rules_for_stage(RuleStage::RequestBody) {
46            let ctx = engine.execute(RuleStage::RequestBody, flow).await;
47            if ctx.is_terminated() {
48                return Ok(RequestAction::Drop);
49            }
50        }
51        Ok(RequestAction::Continue(body))
52    }
53
54    async fn on_response_headers(&self, flow: &mut Flow) -> InterceptionResult {
55        let engine = self.state.get_rule_engine().await;
56        if engine.has_rules_for_stage(RuleStage::ResponseHeaders) {
57            let ctx = engine.execute(RuleStage::ResponseHeaders, flow).await;
58            if ctx.is_terminated() {
59                return InterceptionResult::Drop;
60            }
61        }
62        InterceptionResult::Continue
63    }
64
65    async fn on_response(
66        &self,
67        flow: &mut Flow,
68        body: HttpBody,
69    ) -> Result<ResponseAction, BoxError> {
70        let engine = self.state.get_rule_engine().await;
71        if engine.has_rules_for_stage(RuleStage::ResponseBody) {
72            let ctx = engine.execute(RuleStage::ResponseBody, flow).await;
73            if ctx.is_terminated() {
74                return Ok(ResponseAction::Drop);
75            }
76        }
77        Ok(ResponseAction::Continue(body))
78    }
79
80    async fn on_websocket_message(
81        &self,
82        flow: &mut Flow,
83        message: relay_core_api::flow::WebSocketMessage,
84    ) -> Result<WebSocketMessageAction, BoxError> {
85        let engine = self.state.get_rule_engine().await;
86        if engine.has_rules_for_stage(RuleStage::WebSocketMessage) {
87            engine.execute(RuleStage::WebSocketMessage, flow).await;
88        }
89        Ok(WebSocketMessageAction::Continue(message))
90    }
91
92    async fn on_connect(&self, _conn: &ConnectionInfo) -> ConnectAction {
93        ConnectAction::Allow
94    }
95
96    async fn on_disconnect(&self, _conn: &ConnectionInfo, _stats: &ConnectionStats) {}
97
98    async fn on_websocket_start(&self, _flow: &mut Flow) {}
99
100    async fn on_websocket_end(&self, _flow: &mut Flow, _close_code: u16, _close_reason: &str) {}
101
102    async fn on_websocket_error(&self, _flow: &mut Flow, _error: &str) {}
103}