Skip to main content

brainwires_proxy/middleware/
header_inject.rs

1//! Header add/remove/replace middleware.
2
3use crate::error::ProxyResult;
4use crate::middleware::{LayerAction, ProxyLayer};
5use crate::types::{ProxyRequest, ProxyResponse};
6use http::HeaderValue;
7use http::header::HeaderName;
8
9/// Rule for modifying headers.
10#[derive(Clone)]
11pub enum HeaderRule {
12    /// Set a header, replacing any existing value.
13    Set(HeaderName, HeaderValue),
14    /// Append a header value (allows duplicates).
15    Append(HeaderName, HeaderValue),
16    /// Remove a header.
17    Remove(HeaderName),
18}
19
20/// Applies header rules to requests and/or responses.
21pub struct HeaderInjectLayer {
22    request_rules: Vec<HeaderRule>,
23    response_rules: Vec<HeaderRule>,
24}
25
26impl HeaderInjectLayer {
27    pub fn new() -> Self {
28        Self {
29            request_rules: Vec::new(),
30            response_rules: Vec::new(),
31        }
32    }
33
34    /// Add a rule applied to requests.
35    pub fn request_rule(mut self, rule: HeaderRule) -> Self {
36        self.request_rules.push(rule);
37        self
38    }
39
40    /// Add a rule applied to responses.
41    pub fn response_rule(mut self, rule: HeaderRule) -> Self {
42        self.response_rules.push(rule);
43        self
44    }
45
46    /// Set a request header.
47    pub fn set_request_header(self, name: HeaderName, value: HeaderValue) -> Self {
48        self.request_rule(HeaderRule::Set(name, value))
49    }
50
51    /// Remove a request header.
52    pub fn remove_request_header(self, name: HeaderName) -> Self {
53        self.request_rule(HeaderRule::Remove(name))
54    }
55
56    /// Set a response header.
57    pub fn set_response_header(self, name: HeaderName, value: HeaderValue) -> Self {
58        self.response_rule(HeaderRule::Set(name, value))
59    }
60}
61
62impl Default for HeaderInjectLayer {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68fn apply_rules(headers: &mut http::HeaderMap, rules: &[HeaderRule]) {
69    for rule in rules {
70        match rule {
71            HeaderRule::Set(name, value) => {
72                headers.insert(name.clone(), value.clone());
73            }
74            HeaderRule::Append(name, value) => {
75                headers.append(name.clone(), value.clone());
76            }
77            HeaderRule::Remove(name) => {
78                headers.remove(name);
79            }
80        }
81    }
82}
83
84#[async_trait::async_trait]
85impl ProxyLayer for HeaderInjectLayer {
86    async fn on_request(&self, mut request: ProxyRequest) -> ProxyResult<LayerAction> {
87        apply_rules(&mut request.headers, &self.request_rules);
88        Ok(LayerAction::Forward(request))
89    }
90
91    async fn on_response(&self, mut response: ProxyResponse) -> ProxyResult<ProxyResponse> {
92        apply_rules(&mut response.headers, &self.response_rules);
93        Ok(response)
94    }
95
96    fn name(&self) -> &str {
97        "header_inject"
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use crate::types::ProxyRequest;
105    use http::{Method, StatusCode, header};
106
107    fn make_request() -> ProxyRequest {
108        ProxyRequest::new(Method::GET, "/test".parse().unwrap())
109    }
110
111    #[tokio::test]
112    async fn set_request_header() {
113        let layer = HeaderInjectLayer::new().set_request_header(
114            header::HeaderName::from_static("x-custom"),
115            HeaderValue::from_static("value"),
116        );
117
118        let result = layer.on_request(make_request()).await.unwrap();
119        match result {
120            LayerAction::Forward(req) => {
121                assert_eq!(req.headers.get("x-custom").unwrap(), "value");
122            }
123            _ => panic!("expected forward"),
124        }
125    }
126
127    #[tokio::test]
128    async fn remove_request_header() {
129        let mut req = make_request();
130        req.headers.insert(
131            header::HeaderName::from_static("x-remove-me"),
132            HeaderValue::from_static("bye"),
133        );
134
135        let layer = HeaderInjectLayer::new()
136            .remove_request_header(header::HeaderName::from_static("x-remove-me"));
137
138        let result = layer.on_request(req).await.unwrap();
139        match result {
140            LayerAction::Forward(req) => {
141                assert!(req.headers.get("x-remove-me").is_none());
142            }
143            _ => panic!("expected forward"),
144        }
145    }
146
147    #[tokio::test]
148    async fn set_response_header() {
149        let layer = HeaderInjectLayer::new().set_response_header(
150            header::HeaderName::from_static("x-proxy"),
151            HeaderValue::from_static("brainwires"),
152        );
153
154        let resp = crate::types::ProxyResponse::new(StatusCode::OK);
155        let resp = layer.on_response(resp).await.unwrap();
156        assert_eq!(resp.headers.get("x-proxy").unwrap(), "brainwires");
157    }
158
159    #[tokio::test]
160    async fn append_creates_multiple_values() {
161        let layer = HeaderInjectLayer::new()
162            .request_rule(HeaderRule::Append(
163                header::HeaderName::from_static("x-tag"),
164                HeaderValue::from_static("a"),
165            ))
166            .request_rule(HeaderRule::Append(
167                header::HeaderName::from_static("x-tag"),
168                HeaderValue::from_static("b"),
169            ));
170
171        let result = layer.on_request(make_request()).await.unwrap();
172        match result {
173            LayerAction::Forward(req) => {
174                let values: Vec<_> = req.headers.get_all("x-tag").iter().collect();
175                assert_eq!(values.len(), 2);
176            }
177            _ => panic!("expected forward"),
178        }
179    }
180}