hive_router_plan_executor/headers/
response.rs1use std::iter::once;
2
3use crate::{
4 execution::client_request_details::ClientRequestDetails,
5 headers::{
6 errors::HeaderRuleRuntimeError,
7 expression::vrl_value_to_header_value,
8 plan::{
9 HeaderAggregationStrategy, HeaderRulesPlan, ResponseHeaderAggregator,
10 ResponseHeaderRule, ResponseInsertExpression, ResponseInsertStatic,
11 ResponsePropagateNamed, ResponsePropagateRegex, ResponseRemoveNamed,
12 ResponseRemoveRegex,
13 },
14 sanitizer::is_denied_header,
15 },
16};
17use hive_router_internal::expressions::ExecutableProgram;
18
19use super::sanitizer::is_never_join_header;
20use http::{header::InvalidHeaderValue, HeaderMap, HeaderName, HeaderValue};
21
22pub fn apply_subgraph_response_headers(
23 header_rule_plan: &HeaderRulesPlan,
24 subgraph_name: &str,
25 subgraph_headers: &HeaderMap,
26 client_request_details: &ClientRequestDetails,
27 accumulator: &mut ResponseHeaderAggregator,
28) -> Result<(), HeaderRuleRuntimeError> {
29 let global_actions = &header_rule_plan.response.global;
30 let subgraph_actions = header_rule_plan.response.by_subgraph.get(subgraph_name);
31
32 let ctx = ResponseExpressionContext {
33 subgraph_name,
34 subgraph_headers,
35 client_request: client_request_details,
36 };
37
38 for action in global_actions
39 .iter()
40 .chain(subgraph_actions.into_iter().flatten())
41 {
42 action.apply_response_headers(&ctx, accumulator)?;
43 }
44
45 Ok(())
46}
47
48pub struct ResponseExpressionContext<'a> {
49 pub subgraph_name: &'a str,
50 pub client_request: &'a ClientRequestDetails<'a>,
51 pub subgraph_headers: &'a HeaderMap,
52}
53
54trait ApplyResponseHeader {
55 fn apply_response_headers(
56 &self,
57 ctx: &ResponseExpressionContext,
58 accumulator: &mut ResponseHeaderAggregator,
59 ) -> Result<(), HeaderRuleRuntimeError>;
60}
61
62impl ApplyResponseHeader for ResponseHeaderRule {
63 fn apply_response_headers(
64 &self,
65 ctx: &ResponseExpressionContext,
66 accumulator: &mut ResponseHeaderAggregator,
67 ) -> Result<(), HeaderRuleRuntimeError> {
68 match self {
69 ResponseHeaderRule::PropagateNamed(data) => {
70 data.apply_response_headers(ctx, accumulator)
71 }
72 ResponseHeaderRule::PropagateRegex(data) => {
73 data.apply_response_headers(ctx, accumulator)
74 }
75 ResponseHeaderRule::InsertStatic(data) => data.apply_response_headers(ctx, accumulator),
76 ResponseHeaderRule::InsertExpression(data) => {
77 data.apply_response_headers(ctx, accumulator)
78 }
79 ResponseHeaderRule::RemoveNamed(data) => data.apply_response_headers(ctx, accumulator),
80 ResponseHeaderRule::RemoveRegex(data) => data.apply_response_headers(ctx, accumulator),
81 }
82 }
83}
84
85impl ApplyResponseHeader for ResponsePropagateNamed {
86 fn apply_response_headers(
87 &self,
88 ctx: &ResponseExpressionContext,
89 accumulator: &mut ResponseHeaderAggregator,
90 ) -> Result<(), HeaderRuleRuntimeError> {
91 let mut matched = false;
92
93 for header_name in &self.names {
94 if is_denied_header(header_name) {
95 continue;
96 }
97
98 if let Some(header_value) = ctx.subgraph_headers.get(header_name) {
99 matched = true;
100 write_agg(
101 accumulator,
102 self.rename.as_ref().unwrap_or(header_name),
103 header_value,
104 self.strategy,
105 );
106 }
107 }
108
109 if !matched {
110 if let (Some(default_value), Some(first_name)) = (&self.default, self.names.first()) {
111 let destination_name = self.rename.as_ref().unwrap_or(first_name);
112
113 if is_denied_header(destination_name) {
114 return Ok(());
115 }
116
117 write_agg(accumulator, destination_name, default_value, self.strategy);
118 }
119 }
120
121 Ok(())
122 }
123}
124
125impl ApplyResponseHeader for ResponsePropagateRegex {
126 fn apply_response_headers(
127 &self,
128 ctx: &ResponseExpressionContext,
129 accumulator: &mut ResponseHeaderAggregator,
130 ) -> Result<(), HeaderRuleRuntimeError> {
131 for (header_name, header_value) in ctx.subgraph_headers {
132 if is_denied_header(header_name) {
133 continue;
134 }
135
136 let header_bytes = header_name.as_str().as_bytes();
137
138 let Some(include_regex) = &self.include else {
139 continue;
140 };
141
142 if !include_regex.is_match(header_bytes) {
143 continue;
144 }
145
146 if self
147 .exclude
148 .as_ref()
149 .is_some_and(|regex| regex.is_match(header_bytes))
150 {
151 continue;
152 }
153
154 write_agg(accumulator, header_name, header_value, self.strategy);
155 }
156
157 Ok(())
158 }
159}
160
161impl ApplyResponseHeader for ResponseInsertStatic {
162 fn apply_response_headers(
163 &self,
164 _ctx: &ResponseExpressionContext,
165 accumulator: &mut ResponseHeaderAggregator,
166 ) -> Result<(), HeaderRuleRuntimeError> {
167 if is_denied_header(&self.name) {
168 return Ok(());
169 }
170
171 let strategy = if is_never_join_header(&self.name) {
172 HeaderAggregationStrategy::Append
173 } else {
174 self.strategy
175 };
176
177 write_agg(accumulator, &self.name, &self.value, strategy);
178
179 Ok(())
180 }
181}
182
183impl ApplyResponseHeader for ResponseInsertExpression {
184 fn apply_response_headers(
185 &self,
186 ctx: &ResponseExpressionContext,
187 accumulator: &mut ResponseHeaderAggregator,
188 ) -> Result<(), HeaderRuleRuntimeError> {
189 if is_denied_header(&self.name) {
190 return Ok(());
191 }
192 let value = self.expression.execute(ctx.into()).map_err(|err| {
193 HeaderRuleRuntimeError::ExpressionEvaluation(self.name.to_string(), Box::new(err.0))
194 })?;
195 if let Some(header_value) = vrl_value_to_header_value(value) {
196 let strategy = if is_never_join_header(&self.name) {
197 HeaderAggregationStrategy::Append
198 } else {
199 self.strategy
200 };
201
202 write_agg(accumulator, &self.name, &header_value, strategy);
203 }
204
205 Ok(())
206 }
207}
208
209impl ApplyResponseHeader for ResponseRemoveNamed {
210 fn apply_response_headers(
211 &self,
212 _ctx: &ResponseExpressionContext,
213 accumulator: &mut ResponseHeaderAggregator,
214 ) -> Result<(), HeaderRuleRuntimeError> {
215 for header_name in &self.names {
216 if is_denied_header(header_name) {
217 continue;
218 }
219 accumulator.entries.remove(header_name);
220 }
221
222 Ok(())
223 }
224}
225
226impl ApplyResponseHeader for ResponseRemoveRegex {
227 fn apply_response_headers(
228 &self,
229 _ctx: &ResponseExpressionContext,
230 accumulator: &mut ResponseHeaderAggregator,
231 ) -> Result<(), HeaderRuleRuntimeError> {
232 accumulator.entries.retain(|name, _| {
233 if is_denied_header(name) {
234 return true;
237 }
238
239 !self.regex.is_match(name.as_str().as_bytes())
240 });
241
242 Ok(())
243 }
244}
245
246fn write_agg(
248 agg: &mut ResponseHeaderAggregator,
249 name: &HeaderName,
250 value: &HeaderValue,
251 strategy: HeaderAggregationStrategy,
252) {
253 let strategy = if is_never_join_header(name) {
254 HeaderAggregationStrategy::Append
255 } else {
256 strategy
257 };
258
259 if !agg.entries.contains_key(name) {
260 agg.entries
261 .insert(name.clone(), (strategy, once(value.clone()).collect()));
262 return;
263 }
264
265 let (strategy, values) = agg.entries.get_mut(name).expect("Expected entry to exist");
267
268 match (strategy, values.len()) {
269 (HeaderAggregationStrategy::First, 0) => {
270 values.push(value.clone());
271 }
272 (HeaderAggregationStrategy::Last, _) => {
273 values.clear();
274 values.push(value.clone());
275 }
276 (HeaderAggregationStrategy::Append, _) => {
277 values.push(value.clone());
278 }
279 (_, _) => {}
280 }
281}
282impl ResponseHeaderAggregator {
283 #[inline]
285 pub fn modify_client_response_headers(
286 self,
287 out: &mut ntex::http::ResponseBuilder,
288 ) -> Result<(), HeaderRuleRuntimeError> {
289 for (name, (agg_strategy, mut values)) in self.entries {
290 if values.is_empty() {
291 continue;
292 }
293
294 if is_never_join_header(&name) {
295 for value in values {
297 out.header(&name, value);
298 }
299 continue;
300 }
301
302 if values.len() == 1 {
303 out.set_header(name, values.pop().unwrap());
304 continue;
305 }
306
307 if matches!(agg_strategy, HeaderAggregationStrategy::Append) {
308 let joined = join_with_comma(&values)
309 .map_err(|_| HeaderRuleRuntimeError::BadHeaderValue(name.to_string()))?;
310 out.set_header(name, joined);
311 }
312 }
313
314 Ok(())
315 }
316}
317
318#[inline]
319fn join_with_comma(values: &[HeaderValue]) -> Result<HeaderValue, InvalidHeaderValue> {
320 let mut cap = 0usize;
322
323 for value in values {
324 cap += value.as_bytes().len();
325 }
326
327 if values.len() > 1 {
328 cap += 2 * (values.len() - 1);
329 }
330
331 let mut buf = Vec::with_capacity(cap);
332 for (idx, value) in values.iter().enumerate() {
333 if idx > 0 {
334 buf.extend_from_slice(b", ");
335 }
336 buf.extend_from_slice(value.as_bytes());
337 }
338 HeaderValue::from_bytes(&buf)
339}