Skip to main content

gatel_core/hoops/
stream_replace.rs

1use bytes::Bytes;
2use http::header::{CONTENT_LENGTH, CONTENT_TYPE};
3use salvo::http::ResBody;
4use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
5use tracing::debug;
6
7/// Streaming response body text-replacement middleware.
8///
9/// Unlike [`super::replace::ReplaceHoop`] which buffers the entire response,
10/// this middleware processes the body in chunks, performing replacements as
11/// data arrives. This is more memory-efficient for large responses.
12///
13/// Replacements are only applied when the response `Content-Type` is a
14/// text-like MIME type.
15pub struct StreamReplaceHoop {
16    rules: Vec<(Vec<u8>, Vec<u8>)>,
17    once: bool,
18}
19
20impl StreamReplaceHoop {
21    pub fn new(rules: Vec<(String, String)>, once: bool) -> Self {
22        let rules = rules
23            .into_iter()
24            .map(|(s, r)| (s.into_bytes(), r.into_bytes()))
25            .collect();
26        Self { rules, once }
27    }
28
29    fn is_text_content(&self, headers: &http::HeaderMap) -> bool {
30        headers
31            .get(CONTENT_TYPE)
32            .and_then(|v| v.to_str().ok())
33            .map(|ct| {
34                ct.contains("text/")
35                    || ct.contains("application/json")
36                    || ct.contains("application/xml")
37                    || ct.contains("application/javascript")
38            })
39            .unwrap_or(false)
40    }
41}
42
43#[async_trait]
44impl salvo::Handler for StreamReplaceHoop {
45    async fn handle(
46        &self,
47        req: &mut Request,
48        depot: &mut Depot,
49        res: &mut Response,
50        ctrl: &mut FlowCtrl,
51    ) {
52        ctrl.call_next(req, depot, res).await;
53
54        if !self.is_text_content(res.headers()) {
55            return;
56        }
57
58        // Collect the body — for streaming we process chunks.
59        let body = res.take_body();
60        let body_bytes = match collect_body(body).await {
61            Ok(b) => b,
62            Err(_) => return,
63        };
64
65        if body_bytes.is_empty() {
66            return;
67        }
68
69        let original_len = body_bytes.len();
70        let mut output = body_bytes;
71        for (search, replacement) in &self.rules {
72            if search.is_empty() {
73                continue;
74            }
75            output = if self.once {
76                replace_first(&output, search, replacement)
77            } else {
78                replace_all(&output, search, replacement)
79            };
80        }
81
82        debug!(
83            original = original_len,
84            replaced = output.len(),
85            rules = self.rules.len(),
86            "streaming body replacement applied"
87        );
88
89        res.headers_mut().remove(CONTENT_LENGTH);
90        res.headers_mut()
91            .insert(CONTENT_LENGTH, output.len().into());
92        res.body(ResBody::Once(Bytes::from(output)));
93    }
94}
95
96fn replace_all(haystack: &[u8], needle: &[u8], replacement: &[u8]) -> Vec<u8> {
97    if needle.is_empty() {
98        return haystack.to_vec();
99    }
100    let mut result = Vec::with_capacity(haystack.len());
101    let mut i = 0;
102    while i < haystack.len() {
103        if haystack[i..].starts_with(needle) {
104            result.extend_from_slice(replacement);
105            i += needle.len();
106        } else {
107            result.push(haystack[i]);
108            i += 1;
109        }
110    }
111    result
112}
113
114fn replace_first(haystack: &[u8], needle: &[u8], replacement: &[u8]) -> Vec<u8> {
115    if needle.is_empty() {
116        return haystack.to_vec();
117    }
118    for i in 0..haystack.len() {
119        if haystack[i..].starts_with(needle) {
120            let mut result = Vec::with_capacity(haystack.len());
121            result.extend_from_slice(&haystack[..i]);
122            result.extend_from_slice(replacement);
123            result.extend_from_slice(&haystack[i + needle.len()..]);
124            return result;
125        }
126    }
127    haystack.to_vec()
128}
129
130async fn collect_body(body: ResBody) -> Result<Vec<u8>, ()> {
131    super::compress::collect_res_body_bytes(body).await
132}