Skip to main content

hive_router_plan_executor/headers/
compile.rs

1use crate::headers::{
2    errors::HeaderRuleCompileError,
3    plan::{
4        HeaderAggregationStrategy, HeaderRulesPlan, RequestHeaderRule, RequestHeaderRules,
5        RequestInsertExpression, RequestInsertStatic, RequestPropagateNamed, RequestPropagateRegex,
6        RequestRemoveNamed, RequestRemoveRegex, ResponseHeaderRule, ResponseHeaderRules,
7        ResponseInsertExpression, ResponseInsertStatic, ResponsePropagateNamed,
8        ResponsePropagateRegex, ResponseRemoveNamed, ResponseRemoveRegex,
9    },
10};
11use hive_router_internal::expressions::CompileExpression;
12
13use hive_router_config::headers as config;
14use http::HeaderName;
15use regex_automata::{meta, util::syntax::Config as SyntaxConfig};
16
17pub trait HeaderRuleCompiler<A> {
18    fn compile(&self, actions: &mut A) -> Result<(), HeaderRuleCompileError>;
19}
20
21impl HeaderRuleCompiler<Vec<RequestHeaderRule>> for config::RequestHeaderRule {
22    fn compile(&self, actions: &mut Vec<RequestHeaderRule>) -> Result<(), HeaderRuleCompileError> {
23        match self {
24            config::RequestHeaderRule::Propagate(rule) => {
25                let spec = materialize_match_spec(
26                    &rule.spec,
27                    rule.rename.as_ref(),
28                    rule.default.as_ref(),
29                )?;
30
31                if !spec.header_names.is_empty() {
32                    actions.push(RequestHeaderRule::PropagateNamed(RequestPropagateNamed {
33                        names: spec.header_names,
34                        default: spec.default_header_value,
35                        rename: spec.rename_header,
36                    }));
37                }
38                if spec.include_regex.is_some() {
39                    actions.push(RequestHeaderRule::PropagateRegex(RequestPropagateRegex {
40                        include: spec.include_regex,
41                        exclude: spec.exclude_regex,
42                    }));
43                }
44            }
45            config::RequestHeaderRule::Insert(rule) => match &rule.source {
46                config::InsertSource::Value { value } => {
47                    actions.push(RequestHeaderRule::InsertStatic(RequestInsertStatic {
48                        name: build_header_name(&rule.name)?,
49                        value: build_header_value(&rule.name, value)?,
50                    }));
51                }
52                config::InsertSource::Expression { expression } => {
53                    let program = expression.compile_expression(None).map_err(|err| {
54                        HeaderRuleCompileError::ExpressionBuild(rule.name.clone(), err.diagnostics)
55                    })?;
56                    actions.push(RequestHeaderRule::InsertExpression(
57                        RequestInsertExpression {
58                            name: build_header_name(&rule.name)?,
59                            expression: Box::new(program),
60                        },
61                    ));
62                }
63            },
64            config::RequestHeaderRule::Remove(rule) => {
65                let spec = materialize_match_spec(&rule.spec, None, None)?;
66                if !spec.header_names.is_empty() {
67                    actions.push(RequestHeaderRule::RemoveNamed(RequestRemoveNamed {
68                        names: spec.header_names,
69                    }));
70                }
71                if let Some(regex_set) = spec.include_regex {
72                    actions.push(RequestHeaderRule::RemoveRegex(RequestRemoveRegex {
73                        regex: regex_set,
74                    }));
75                }
76            }
77        }
78
79        Ok(())
80    }
81}
82
83impl HeaderRuleCompiler<Vec<ResponseHeaderRule>> for config::ResponseHeaderRule {
84    fn compile(&self, actions: &mut Vec<ResponseHeaderRule>) -> Result<(), HeaderRuleCompileError> {
85        match self {
86            config::ResponseHeaderRule::Propagate(rule) => {
87                let aggregation_strategy = rule.algorithm.into();
88                let spec = materialize_match_spec(
89                    &rule.spec,
90                    rule.rename.as_ref(),
91                    rule.default.as_ref(),
92                )?;
93
94                if !spec.header_names.is_empty() {
95                    actions.push(ResponseHeaderRule::PropagateNamed(ResponsePropagateNamed {
96                        names: spec.header_names,
97                        rename: spec.rename_header,
98                        default: spec.default_header_value,
99                        strategy: aggregation_strategy,
100                    }));
101                }
102
103                if spec.include_regex.is_some() || spec.exclude_regex.is_some() {
104                    actions.push(ResponseHeaderRule::PropagateRegex(ResponsePropagateRegex {
105                        include: spec.include_regex,
106                        exclude: spec.exclude_regex,
107                        strategy: aggregation_strategy,
108                    }));
109                }
110            }
111            config::ResponseHeaderRule::Insert(rule) => {
112                let aggregation_strategy = rule.algorithm.into();
113                match &rule.source {
114                    config::InsertSource::Value { value } => {
115                        actions.push(ResponseHeaderRule::InsertStatic(ResponseInsertStatic {
116                            name: build_header_name(&rule.name)?,
117                            value: build_header_value(&rule.name, value)?,
118                            strategy: aggregation_strategy,
119                        }));
120                    }
121                    config::InsertSource::Expression { expression } => {
122                        // NOTE: In case we ever need to improve performance and not pass the whole context
123                        // to VRL expressions, we can use:
124                        // - compilation_result.program.info().target_assignments
125                        // - compilation_result.program.info().target_queries
126                        // to determine what parts of the context are actually needed by the expression
127                        let program = expression.compile_expression(None).map_err(|err| {
128                            HeaderRuleCompileError::ExpressionBuild(
129                                rule.name.clone(),
130                                err.diagnostics,
131                            )
132                        })?;
133                        actions.push(ResponseHeaderRule::InsertExpression(
134                            ResponseInsertExpression {
135                                name: build_header_name(&rule.name)?,
136                                expression: Box::new(program),
137                                strategy: aggregation_strategy,
138                            },
139                        ));
140                    }
141                }
142            }
143            config::ResponseHeaderRule::Remove(rule) => {
144                let spec = materialize_match_spec(&rule.spec, None, None)?;
145                if !spec.header_names.is_empty() {
146                    actions.push(ResponseHeaderRule::RemoveNamed(ResponseRemoveNamed {
147                        names: spec.header_names,
148                    }));
149                }
150                if let Some(regex_set) = spec.include_regex {
151                    actions.push(ResponseHeaderRule::RemoveRegex(ResponseRemoveRegex {
152                        regex: regex_set,
153                    }));
154                }
155            }
156        }
157
158        Ok(())
159    }
160}
161
162pub fn compile_headers_plan(
163    cfg: &config::HeadersConfig,
164) -> Result<HeaderRulesPlan, HeaderRuleCompileError> {
165    let mut request_plan = RequestHeaderRules::default();
166    let mut response_plan = ResponseHeaderRules::default();
167
168    if let Some(global_rules) = &cfg.all {
169        request_plan.global = compile_request_header_rules(global_rules)?;
170        response_plan.global = compile_response_header_rules(global_rules)?;
171    }
172
173    if let Some(subgraph_rules_map) = &cfg.subgraphs {
174        for (subgraph_name, subgraph_rules) in subgraph_rules_map {
175            let request_actions = compile_request_header_rules(subgraph_rules)?;
176            let response_actions = compile_response_header_rules(subgraph_rules)?;
177            request_plan
178                .by_subgraph
179                .insert(subgraph_name.clone(), request_actions);
180            response_plan
181                .by_subgraph
182                .insert(subgraph_name.clone(), response_actions);
183        }
184    }
185
186    Ok(HeaderRulesPlan {
187        request: request_plan,
188        response: response_plan,
189    })
190}
191
192fn compile_request_header_rules(
193    header_rules: &config::HeaderRules,
194) -> Result<Vec<RequestHeaderRule>, HeaderRuleCompileError> {
195    let mut request_actions = Vec::new();
196    if let Some(request_rule_entries) = &header_rules.request {
197        for request_rule in request_rule_entries {
198            request_rule.compile(&mut request_actions)?;
199        }
200    }
201    Ok(request_actions)
202}
203
204fn compile_response_header_rules(
205    header_rules: &config::HeaderRules,
206) -> Result<Vec<ResponseHeaderRule>, HeaderRuleCompileError> {
207    let mut response_actions = Vec::new();
208    if let Some(response_rule_entries) = &header_rules.response {
209        for response_rule in response_rule_entries {
210            response_rule.compile(&mut response_actions)?;
211        }
212    }
213    Ok(response_actions)
214}
215
216struct HeaderMatchSpecResult {
217    header_names: Vec<HeaderName>,
218    include_regex: Option<meta::Regex>,
219    exclude_regex: Option<meta::Regex>,
220    rename_header: Option<HeaderName>,
221    default_header_value: Option<http::HeaderValue>,
222}
223
224fn materialize_match_spec(
225    match_spec: &config::MatchSpec,
226    rename_to: Option<&String>,
227    default_value: Option<&String>,
228) -> Result<HeaderMatchSpecResult, HeaderRuleCompileError> {
229    let header_names = match &match_spec.named {
230        Some(config::OneOrMany::One(single_name)) => vec![build_header_name(single_name)?],
231        Some(config::OneOrMany::Many(many_names)) => many_names
232            .iter()
233            .map(|name| build_header_name(name))
234            .collect::<Result<Vec<_>, _>>()?,
235        None => Vec::new(),
236    };
237
238    let include_regex = match match_spec.matching.as_ref() {
239        None => None,
240        Some(config::OneOrMany::One(pattern)) => build_regex_many(std::slice::from_ref(pattern))?,
241        Some(config::OneOrMany::Many(pattern_vec)) => build_regex_many(pattern_vec)?,
242    };
243
244    let exclude_regex = match match_spec.exclude.as_deref() {
245        None => None,
246        Some(pattern_vec) => build_regex_many(pattern_vec)?,
247    };
248
249    let rename_header = rename_to
250        .map(|name| match header_names.len() == 1 {
251            true => build_header_name(name),
252            false => Err(HeaderRuleCompileError::InvalidRename),
253        })
254        .transpose()?;
255
256    let default_header_value = default_value
257        .map(|value| match header_names.len() == 1 {
258            true => build_header_value(header_names[0].as_str(), value),
259            false => Err(HeaderRuleCompileError::InvalidDefault),
260        })
261        .transpose()?;
262
263    Ok(HeaderMatchSpecResult {
264        header_names,
265        include_regex,
266        exclude_regex,
267        rename_header,
268        default_header_value,
269    })
270}
271
272fn build_header_name(header_name_str: &str) -> Result<http::HeaderName, HeaderRuleCompileError> {
273    http::HeaderName::from_bytes(header_name_str.as_bytes())
274        .map_err(|err| HeaderRuleCompileError::BadHeaderName(header_name_str.into(), err))
275}
276
277fn build_header_value(
278    header_name_str: &str,
279    header_value_str: &str,
280) -> Result<http::HeaderValue, HeaderRuleCompileError> {
281    http::HeaderValue::from_str(header_value_str)
282        .map_err(|err| HeaderRuleCompileError::BadHeaderValue(header_name_str.to_string(), err))
283}
284
285fn build_regex_many(patterns: &[String]) -> Result<Option<meta::Regex>, HeaderRuleCompileError> {
286    if patterns.is_empty() {
287        return Ok(None);
288    }
289    let mut regex_builder = meta::Regex::builder();
290    regex_builder.syntax(SyntaxConfig::new().unicode(false).utf8(false));
291    regex_builder
292        .build_many(patterns)
293        .map(Some)
294        .map_err(|e| Box::new(e).into())
295}
296
297impl From<config::AggregationAlgo> for HeaderAggregationStrategy {
298    fn from(algo: config::AggregationAlgo) -> Self {
299        match algo {
300            config::AggregationAlgo::First => HeaderAggregationStrategy::First,
301            config::AggregationAlgo::Last => HeaderAggregationStrategy::Last,
302            config::AggregationAlgo::Append => HeaderAggregationStrategy::Append,
303        }
304    }
305}
306
307impl From<Option<config::AggregationAlgo>> for HeaderAggregationStrategy {
308    fn from(algo: Option<config::AggregationAlgo>) -> Self {
309        match algo {
310            Some(config::AggregationAlgo::First) => HeaderAggregationStrategy::First,
311            Some(config::AggregationAlgo::Last) => HeaderAggregationStrategy::Last,
312            Some(config::AggregationAlgo::Append) => HeaderAggregationStrategy::Append,
313            None => HeaderAggregationStrategy::Last,
314        }
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use hive_router_config::headers as config;
321    use http::HeaderName;
322
323    use crate::headers::{
324        compile::{build_header_value, HeaderRuleCompiler},
325        errors::HeaderRuleCompileError,
326        plan::{HeaderAggregationStrategy, RequestHeaderRule, ResponseHeaderRule},
327    };
328
329    fn header_name_owned(s: &str) -> HeaderName {
330        HeaderName::from_bytes(s.as_bytes()).unwrap()
331    }
332
333    #[test]
334    fn test_propagate_named_request() {
335        let rule = config::RequestHeaderRule::Propagate(config::RequestPropagateRule {
336            spec: config::MatchSpec {
337                named: Some(config::OneOrMany::One("x-test".to_string())),
338                matching: None,
339                exclude: None,
340            },
341            rename: None,
342            default: None,
343        });
344        let mut actions = Vec::new();
345        rule.compile(&mut actions).unwrap();
346        assert_eq!(actions.len(), 1);
347        match &actions[0] {
348            RequestHeaderRule::PropagateNamed(data) => {
349                assert_eq!(data.names, vec![header_name_owned("x-test")]);
350                assert!(data.default.is_none());
351                assert!(data.rename.is_none());
352            }
353            _ => panic!("Expected PropagateNamed"),
354        }
355    }
356
357    #[test]
358    fn test_set_request() {
359        let rule = config::RequestHeaderRule::Insert(config::RequestInsertRule {
360            name: "x-set".to_string(),
361            source: config::InsertSource::Value {
362                value: "abc".to_string(),
363            },
364        });
365        let mut actions = Vec::new();
366        rule.compile(&mut actions).unwrap();
367        assert_eq!(actions.len(), 1);
368        match &actions[0] {
369            RequestHeaderRule::InsertStatic(data) => {
370                assert_eq!(data.name, header_name_owned("x-set"));
371                assert_eq!(data.value, build_header_value("x-set", "abc").unwrap());
372            }
373            _ => panic!("Expected SetStatic"),
374        }
375    }
376
377    #[test]
378    fn test_remove_named_request() {
379        let rule = config::RequestHeaderRule::Remove(config::RemoveRule {
380            spec: config::MatchSpec {
381                named: Some(config::OneOrMany::One("x-remove".to_string())),
382                matching: None,
383                exclude: None,
384            },
385        });
386        let mut actions = Vec::new();
387        rule.compile(&mut actions).unwrap();
388        assert_eq!(actions.len(), 1);
389        match &actions[0] {
390            RequestHeaderRule::RemoveNamed(data) => {
391                assert_eq!(data.names, vec![header_name_owned("x-remove")]);
392            }
393            _ => panic!("Expected RemoveNamed"),
394        }
395    }
396
397    #[test]
398    fn test_invalid_default_request() {
399        let rule = config::RequestHeaderRule::Propagate(config::RequestPropagateRule {
400            spec: config::MatchSpec {
401                named: Some(config::OneOrMany::Many(vec![
402                    "x1".to_string(),
403                    "x2".to_string(),
404                ])),
405                matching: None,
406                exclude: None,
407            },
408            rename: None,
409            default: Some("def".to_string()),
410        });
411        let mut actions = Vec::new();
412        let err = rule.compile(&mut actions).unwrap_err();
413        match err {
414            HeaderRuleCompileError::InvalidDefault => {}
415            _ => panic!("Expected InvalidDefault error"),
416        }
417    }
418
419    #[test]
420    fn test_propagate_named_response() {
421        let rule = config::ResponseHeaderRule::Propagate(config::ResponsePropagateRule {
422            spec: config::MatchSpec {
423                named: Some(config::OneOrMany::One("x-resp".to_string())),
424                matching: None,
425                exclude: None,
426            },
427            rename: None,
428            default: None,
429            algorithm: config::AggregationAlgo::First,
430        });
431        let mut actions = Vec::new();
432        rule.compile(&mut actions).unwrap();
433        assert_eq!(actions.len(), 1);
434        match &actions[0] {
435            ResponseHeaderRule::PropagateNamed(data) => {
436                assert_eq!(data.names, vec![header_name_owned("x-resp")]);
437                assert!(matches!(data.strategy, HeaderAggregationStrategy::First));
438            }
439            _ => panic!("Expected PropagateNamed"),
440        }
441    }
442}