brainwires_proxy/middleware/
header_inject.rs1use crate::error::ProxyResult;
4use crate::middleware::{LayerAction, ProxyLayer};
5use crate::types::{ProxyRequest, ProxyResponse};
6use http::HeaderValue;
7use http::header::HeaderName;
8
9#[derive(Clone)]
11pub enum HeaderRule {
12 Set(HeaderName, HeaderValue),
14 Append(HeaderName, HeaderValue),
16 Remove(HeaderName),
18}
19
20pub struct HeaderInjectLayer {
22 request_rules: Vec<HeaderRule>,
23 response_rules: Vec<HeaderRule>,
24}
25
26impl HeaderInjectLayer {
27 pub fn new() -> Self {
28 Self {
29 request_rules: Vec::new(),
30 response_rules: Vec::new(),
31 }
32 }
33
34 pub fn request_rule(mut self, rule: HeaderRule) -> Self {
36 self.request_rules.push(rule);
37 self
38 }
39
40 pub fn response_rule(mut self, rule: HeaderRule) -> Self {
42 self.response_rules.push(rule);
43 self
44 }
45
46 pub fn set_request_header(self, name: HeaderName, value: HeaderValue) -> Self {
48 self.request_rule(HeaderRule::Set(name, value))
49 }
50
51 pub fn remove_request_header(self, name: HeaderName) -> Self {
53 self.request_rule(HeaderRule::Remove(name))
54 }
55
56 pub fn set_response_header(self, name: HeaderName, value: HeaderValue) -> Self {
58 self.response_rule(HeaderRule::Set(name, value))
59 }
60}
61
62impl Default for HeaderInjectLayer {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68fn apply_rules(headers: &mut http::HeaderMap, rules: &[HeaderRule]) {
69 for rule in rules {
70 match rule {
71 HeaderRule::Set(name, value) => {
72 headers.insert(name.clone(), value.clone());
73 }
74 HeaderRule::Append(name, value) => {
75 headers.append(name.clone(), value.clone());
76 }
77 HeaderRule::Remove(name) => {
78 headers.remove(name);
79 }
80 }
81 }
82}
83
84#[async_trait::async_trait]
85impl ProxyLayer for HeaderInjectLayer {
86 async fn on_request(&self, mut request: ProxyRequest) -> ProxyResult<LayerAction> {
87 apply_rules(&mut request.headers, &self.request_rules);
88 Ok(LayerAction::Forward(request))
89 }
90
91 async fn on_response(&self, mut response: ProxyResponse) -> ProxyResult<ProxyResponse> {
92 apply_rules(&mut response.headers, &self.response_rules);
93 Ok(response)
94 }
95
96 fn name(&self) -> &str {
97 "header_inject"
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use crate::types::ProxyRequest;
105 use http::{Method, StatusCode, header};
106
107 fn make_request() -> ProxyRequest {
108 ProxyRequest::new(Method::GET, "/test".parse().unwrap())
109 }
110
111 #[tokio::test]
112 async fn set_request_header() {
113 let layer = HeaderInjectLayer::new().set_request_header(
114 header::HeaderName::from_static("x-custom"),
115 HeaderValue::from_static("value"),
116 );
117
118 let result = layer.on_request(make_request()).await.unwrap();
119 match result {
120 LayerAction::Forward(req) => {
121 assert_eq!(req.headers.get("x-custom").unwrap(), "value");
122 }
123 _ => panic!("expected forward"),
124 }
125 }
126
127 #[tokio::test]
128 async fn remove_request_header() {
129 let mut req = make_request();
130 req.headers.insert(
131 header::HeaderName::from_static("x-remove-me"),
132 HeaderValue::from_static("bye"),
133 );
134
135 let layer = HeaderInjectLayer::new()
136 .remove_request_header(header::HeaderName::from_static("x-remove-me"));
137
138 let result = layer.on_request(req).await.unwrap();
139 match result {
140 LayerAction::Forward(req) => {
141 assert!(req.headers.get("x-remove-me").is_none());
142 }
143 _ => panic!("expected forward"),
144 }
145 }
146
147 #[tokio::test]
148 async fn set_response_header() {
149 let layer = HeaderInjectLayer::new().set_response_header(
150 header::HeaderName::from_static("x-proxy"),
151 HeaderValue::from_static("brainwires"),
152 );
153
154 let resp = crate::types::ProxyResponse::new(StatusCode::OK);
155 let resp = layer.on_response(resp).await.unwrap();
156 assert_eq!(resp.headers.get("x-proxy").unwrap(), "brainwires");
157 }
158
159 #[tokio::test]
160 async fn append_creates_multiple_values() {
161 let layer = HeaderInjectLayer::new()
162 .request_rule(HeaderRule::Append(
163 header::HeaderName::from_static("x-tag"),
164 HeaderValue::from_static("a"),
165 ))
166 .request_rule(HeaderRule::Append(
167 header::HeaderName::from_static("x-tag"),
168 HeaderValue::from_static("b"),
169 ));
170
171 let result = layer.on_request(make_request()).await.unwrap();
172 match result {
173 LayerAction::Forward(req) => {
174 let values: Vec<_> = req.headers.get_all("x-tag").iter().collect();
175 assert_eq!(values.len(), 2);
176 }
177 _ => panic!("expected forward"),
178 }
179 }
180}