Skip to main content

relay_core_runtime/interceptors/
rule.rs

1use crate::CoreState;
2use crate::interceptors::inspect::handle_rule_termination;
3use async_trait::async_trait;
4use relay_core_api::flow::{Flow, Layer};
5use relay_core_api::rule::{RuleStage, RuleTraceSummary};
6use relay_core_lib::intercept::{
7    BoxError, ConnectAction, ConnectionInfo, ConnectionStats, HttpBody, InterceptionResult,
8    Interceptor, RequestAction, ResponseAction, WebSocketMessageAction,
9};
10use relay_core_lib::proxy::http_utils::mock_to_response;
11use std::sync::Arc;
12
13pub struct RuleInterceptor {
14    state: Arc<CoreState>,
15}
16
17impl RuleInterceptor {
18    pub fn new(state: Arc<CoreState>) -> Self {
19        Self { state }
20    }
21}
22
23#[async_trait]
24impl Interceptor for RuleInterceptor {
25    async fn on_request_headers(&self, flow: &mut Flow) -> InterceptionResult {
26        let engine = self.state.get_rule_engine().await;
27        if !engine.has_rules_for_stage(RuleStage::RequestHeaders) {
28            return InterceptionResult::Continue;
29        }
30
31        let ctx = engine.execute(RuleStage::RequestHeaders, flow).await;
32
33        if let RuleTraceSummary::Terminated { reason, .. } = &ctx.summary {
34            return handle_rule_termination(&self.state, reason, flow, "request_headers", None)
35                .await;
36        }
37
38        InterceptionResult::Continue
39    }
40
41    async fn on_request(&self, flow: &mut Flow, body: HttpBody) -> Result<RequestAction, BoxError> {
42        let engine = self.state.get_rule_engine().await;
43        if engine.has_rules_for_stage(RuleStage::RequestBody) {
44            let ctx = engine.execute(RuleStage::RequestBody, flow).await;
45            if let RuleTraceSummary::Terminated { reason, .. } = &ctx.summary {
46                let result =
47                    handle_rule_termination(&self.state, reason, flow, "request_body", None).await;
48                return Ok(match result {
49                    InterceptionResult::Drop => RequestAction::Drop,
50                    InterceptionResult::MockResponse(res) => {
51                        RequestAction::MockResponse(mock_to_response(res))
52                    }
53                    _ => RequestAction::Drop,
54                });
55            }
56        }
57        Ok(RequestAction::Continue(body))
58    }
59
60    async fn on_response_headers(&self, flow: &mut Flow) -> InterceptionResult {
61        let engine = self.state.get_rule_engine().await;
62        if !engine.has_rules_for_stage(RuleStage::ResponseHeaders) {
63            return InterceptionResult::Continue;
64        }
65
66        let ctx = engine.execute(RuleStage::ResponseHeaders, flow).await;
67        if let RuleTraceSummary::Terminated { reason, .. } = &ctx.summary {
68            return handle_rule_termination(&self.state, reason, flow, "response_headers", None)
69                .await;
70        }
71
72        InterceptionResult::Continue
73    }
74
75    async fn on_response(
76        &self,
77        flow: &mut Flow,
78        body: HttpBody,
79    ) -> Result<ResponseAction, BoxError> {
80        let engine = self.state.get_rule_engine().await;
81        if engine.has_rules_for_stage(RuleStage::ResponseBody) {
82            let ctx = engine.execute(RuleStage::ResponseBody, flow).await;
83            if let RuleTraceSummary::Terminated { reason, .. } = &ctx.summary {
84                let result =
85                    handle_rule_termination(&self.state, reason, flow, "response_body", None).await;
86                return Ok(match result {
87                    InterceptionResult::Drop => ResponseAction::Drop,
88                    InterceptionResult::MockResponse(res) => {
89                        ResponseAction::ModifiedResponse(mock_to_response(res))
90                    }
91                    _ => ResponseAction::Drop,
92                });
93            }
94        }
95        Ok(ResponseAction::Continue(body))
96    }
97
98    async fn on_websocket_message(
99        &self,
100        flow: &mut Flow,
101        message: relay_core_api::flow::WebSocketMessage,
102    ) -> Result<WebSocketMessageAction, BoxError> {
103        let engine = self.state.get_rule_engine().await;
104        if engine.has_rules_for_stage(RuleStage::WebSocketMessage) {
105            if let Layer::WebSocket(ws) = &mut flow.layer {
106                ws.messages.push(message.clone());
107            }
108            let ctx = engine.execute(RuleStage::WebSocketMessage, flow).await;
109            if let RuleTraceSummary::Terminated { reason, .. } = &ctx.summary {
110                let result =
111                    handle_rule_termination(&self.state, reason, flow, "ws_msg", Some(&message))
112                        .await;
113                return Ok(match result {
114                    InterceptionResult::Drop => WebSocketMessageAction::Drop,
115                    InterceptionResult::ModifiedMessage(msg) => {
116                        WebSocketMessageAction::Continue(msg)
117                    }
118                    _ => WebSocketMessageAction::Continue(message),
119                });
120            }
121        }
122        Ok(WebSocketMessageAction::Continue(message))
123    }
124
125    async fn on_connect(&self, _conn: &ConnectionInfo) -> ConnectAction {
126        ConnectAction::Allow
127    }
128
129    async fn on_disconnect(&self, _conn: &ConnectionInfo, _stats: &ConnectionStats) {}
130
131    async fn on_websocket_start(&self, _flow: &mut Flow) {}
132
133    async fn on_websocket_end(&self, _flow: &mut Flow, _close_code: u16, _close_reason: &str) {}
134
135    async fn on_websocket_error(&self, _flow: &mut Flow, _error: &str) {}
136}