Skip to main content

hive_router_plan_executor/headers/
response.rs

1use std::iter::once;
2
3use crate::{
4    execution::client_request_details::ClientRequestDetails,
5    headers::{
6        errors::HeaderRuleRuntimeError,
7        expression::vrl_value_to_header_value,
8        plan::{
9            HeaderAggregationStrategy, HeaderRulesPlan, ResponseHeaderAggregator,
10            ResponseHeaderRule, ResponseInsertExpression, ResponseInsertStatic,
11            ResponsePropagateNamed, ResponsePropagateRegex, ResponseRemoveNamed,
12            ResponseRemoveRegex,
13        },
14        sanitizer::is_denied_header,
15    },
16};
17use hive_router_internal::expressions::ExecutableProgram;
18
19use super::sanitizer::is_never_join_header;
20use http::{header::InvalidHeaderValue, HeaderMap, HeaderName, HeaderValue};
21
22pub fn apply_subgraph_response_headers(
23    header_rule_plan: &HeaderRulesPlan,
24    subgraph_name: &str,
25    subgraph_headers: &HeaderMap,
26    client_request_details: &ClientRequestDetails,
27    accumulator: &mut ResponseHeaderAggregator,
28) -> Result<(), HeaderRuleRuntimeError> {
29    let global_actions = &header_rule_plan.response.global;
30    let subgraph_actions = header_rule_plan.response.by_subgraph.get(subgraph_name);
31
32    let ctx = ResponseExpressionContext {
33        subgraph_name,
34        subgraph_headers,
35        client_request: client_request_details,
36    };
37
38    for action in global_actions
39        .iter()
40        .chain(subgraph_actions.into_iter().flatten())
41    {
42        action.apply_response_headers(&ctx, accumulator)?;
43    }
44
45    Ok(())
46}
47
48pub struct ResponseExpressionContext<'a> {
49    pub subgraph_name: &'a str,
50    pub client_request: &'a ClientRequestDetails<'a>,
51    pub subgraph_headers: &'a HeaderMap,
52}
53
54trait ApplyResponseHeader {
55    fn apply_response_headers(
56        &self,
57        ctx: &ResponseExpressionContext,
58        accumulator: &mut ResponseHeaderAggregator,
59    ) -> Result<(), HeaderRuleRuntimeError>;
60}
61
62impl ApplyResponseHeader for ResponseHeaderRule {
63    fn apply_response_headers(
64        &self,
65        ctx: &ResponseExpressionContext,
66        accumulator: &mut ResponseHeaderAggregator,
67    ) -> Result<(), HeaderRuleRuntimeError> {
68        match self {
69            ResponseHeaderRule::PropagateNamed(data) => {
70                data.apply_response_headers(ctx, accumulator)
71            }
72            ResponseHeaderRule::PropagateRegex(data) => {
73                data.apply_response_headers(ctx, accumulator)
74            }
75            ResponseHeaderRule::InsertStatic(data) => data.apply_response_headers(ctx, accumulator),
76            ResponseHeaderRule::InsertExpression(data) => {
77                data.apply_response_headers(ctx, accumulator)
78            }
79            ResponseHeaderRule::RemoveNamed(data) => data.apply_response_headers(ctx, accumulator),
80            ResponseHeaderRule::RemoveRegex(data) => data.apply_response_headers(ctx, accumulator),
81        }
82    }
83}
84
85impl ApplyResponseHeader for ResponsePropagateNamed {
86    fn apply_response_headers(
87        &self,
88        ctx: &ResponseExpressionContext,
89        accumulator: &mut ResponseHeaderAggregator,
90    ) -> Result<(), HeaderRuleRuntimeError> {
91        let mut matched = false;
92
93        for header_name in &self.names {
94            if is_denied_header(header_name) {
95                continue;
96            }
97
98            if let Some(header_value) = ctx.subgraph_headers.get(header_name) {
99                matched = true;
100                write_agg(
101                    accumulator,
102                    self.rename.as_ref().unwrap_or(header_name),
103                    header_value,
104                    self.strategy,
105                );
106            }
107        }
108
109        if !matched {
110            if let (Some(default_value), Some(first_name)) = (&self.default, self.names.first()) {
111                let destination_name = self.rename.as_ref().unwrap_or(first_name);
112
113                if is_denied_header(destination_name) {
114                    return Ok(());
115                }
116
117                write_agg(accumulator, destination_name, default_value, self.strategy);
118            }
119        }
120
121        Ok(())
122    }
123}
124
125impl ApplyResponseHeader for ResponsePropagateRegex {
126    fn apply_response_headers(
127        &self,
128        ctx: &ResponseExpressionContext,
129        accumulator: &mut ResponseHeaderAggregator,
130    ) -> Result<(), HeaderRuleRuntimeError> {
131        for (header_name, header_value) in ctx.subgraph_headers {
132            if is_denied_header(header_name) {
133                continue;
134            }
135
136            let header_bytes = header_name.as_str().as_bytes();
137
138            let Some(include_regex) = &self.include else {
139                continue;
140            };
141
142            if !include_regex.is_match(header_bytes) {
143                continue;
144            }
145
146            if self
147                .exclude
148                .as_ref()
149                .is_some_and(|regex| regex.is_match(header_bytes))
150            {
151                continue;
152            }
153
154            write_agg(accumulator, header_name, header_value, self.strategy);
155        }
156
157        Ok(())
158    }
159}
160
161impl ApplyResponseHeader for ResponseInsertStatic {
162    fn apply_response_headers(
163        &self,
164        _ctx: &ResponseExpressionContext,
165        accumulator: &mut ResponseHeaderAggregator,
166    ) -> Result<(), HeaderRuleRuntimeError> {
167        if is_denied_header(&self.name) {
168            return Ok(());
169        }
170
171        let strategy = if is_never_join_header(&self.name) {
172            HeaderAggregationStrategy::Append
173        } else {
174            self.strategy
175        };
176
177        write_agg(accumulator, &self.name, &self.value, strategy);
178
179        Ok(())
180    }
181}
182
183impl ApplyResponseHeader for ResponseInsertExpression {
184    fn apply_response_headers(
185        &self,
186        ctx: &ResponseExpressionContext,
187        accumulator: &mut ResponseHeaderAggregator,
188    ) -> Result<(), HeaderRuleRuntimeError> {
189        if is_denied_header(&self.name) {
190            return Ok(());
191        }
192        let value = self.expression.execute(ctx.into()).map_err(|err| {
193            HeaderRuleRuntimeError::ExpressionEvaluation(self.name.to_string(), Box::new(err.0))
194        })?;
195        if let Some(header_value) = vrl_value_to_header_value(value) {
196            let strategy = if is_never_join_header(&self.name) {
197                HeaderAggregationStrategy::Append
198            } else {
199                self.strategy
200            };
201
202            write_agg(accumulator, &self.name, &header_value, strategy);
203        }
204
205        Ok(())
206    }
207}
208
209impl ApplyResponseHeader for ResponseRemoveNamed {
210    fn apply_response_headers(
211        &self,
212        _ctx: &ResponseExpressionContext,
213        accumulator: &mut ResponseHeaderAggregator,
214    ) -> Result<(), HeaderRuleRuntimeError> {
215        for header_name in &self.names {
216            if is_denied_header(header_name) {
217                continue;
218            }
219            accumulator.entries.remove(header_name);
220        }
221
222        Ok(())
223    }
224}
225
226impl ApplyResponseHeader for ResponseRemoveRegex {
227    fn apply_response_headers(
228        &self,
229        _ctx: &ResponseExpressionContext,
230        accumulator: &mut ResponseHeaderAggregator,
231    ) -> Result<(), HeaderRuleRuntimeError> {
232        accumulator.entries.retain(|name, _| {
233            if is_denied_header(name) {
234                // Denied headers (hop-by–hop) are never inserted in the first place
235                // and should not be removed here.
236                return true;
237            }
238
239            !self.regex.is_match(name.as_str().as_bytes())
240        });
241
242        Ok(())
243    }
244}
245
246/// Write a header to the aggregator according to the specified strategy.
247fn write_agg(
248    agg: &mut ResponseHeaderAggregator,
249    name: &HeaderName,
250    value: &HeaderValue,
251    strategy: HeaderAggregationStrategy,
252) {
253    let strategy = if is_never_join_header(name) {
254        HeaderAggregationStrategy::Append
255    } else {
256        strategy
257    };
258
259    if !agg.entries.contains_key(name) {
260        agg.entries
261            .insert(name.clone(), (strategy, once(value.clone()).collect()));
262        return;
263    }
264
265    // The `expect` is safe because we just inserted the entry if it didn't exist
266    let (strategy, values) = agg.entries.get_mut(name).expect("Expected entry to exist");
267
268    match (strategy, values.len()) {
269        (HeaderAggregationStrategy::First, 0) => {
270            values.push(value.clone());
271        }
272        (HeaderAggregationStrategy::Last, _) => {
273            values.clear();
274            values.push(value.clone());
275        }
276        (HeaderAggregationStrategy::Append, _) => {
277            values.push(value.clone());
278        }
279        (_, _) => {}
280    }
281}
282impl ResponseHeaderAggregator {
283    /// Modify the outgoing client response headers based on the aggregated headers from subgraphs.
284    #[inline]
285    pub fn modify_client_response_headers(
286        self,
287        out: &mut ntex::http::ResponseBuilder,
288    ) -> Result<(), HeaderRuleRuntimeError> {
289        for (name, (agg_strategy, mut values)) in self.entries {
290            if values.is_empty() {
291                continue;
292            }
293
294            if is_never_join_header(&name) {
295                // never-join headers must be emitted as multiple header fields
296                for value in values {
297                    out.header(&name, value);
298                }
299                continue;
300            }
301
302            if values.len() == 1 {
303                out.set_header(name, values.pop().unwrap());
304                continue;
305            }
306
307            if matches!(agg_strategy, HeaderAggregationStrategy::Append) {
308                let joined = join_with_comma(&values)
309                    .map_err(|_| HeaderRuleRuntimeError::BadHeaderValue(name.to_string()))?;
310                out.set_header(name, joined);
311            }
312        }
313
314        Ok(())
315    }
316}
317
318#[inline]
319fn join_with_comma(values: &[HeaderValue]) -> Result<HeaderValue, InvalidHeaderValue> {
320    // Compute capacity: sum of lengths + ", ".len() * (n-1)
321    let mut cap = 0usize;
322
323    for value in values {
324        cap += value.as_bytes().len();
325    }
326
327    if values.len() > 1 {
328        cap += 2 * (values.len() - 1);
329    }
330
331    let mut buf = Vec::with_capacity(cap);
332    for (idx, value) in values.iter().enumerate() {
333        if idx > 0 {
334            buf.extend_from_slice(b", ");
335        }
336        buf.extend_from_slice(value.as_bytes());
337    }
338    HeaderValue::from_bytes(&buf)
339}