hive_router_plan_executor/headers/
compile.rs

1use crate::{
2    expressions::CompileExpression,
3    headers::{
4        errors::HeaderRuleCompileError,
5        plan::{
6            HeaderAggregationStrategy, HeaderRulesPlan, RequestHeaderRule, RequestHeaderRules,
7            RequestInsertExpression, RequestInsertStatic, RequestPropagateNamed,
8            RequestPropagateRegex, RequestRemoveNamed, RequestRemoveRegex, ResponseHeaderRule,
9            ResponseHeaderRules, ResponseInsertExpression, ResponseInsertStatic,
10            ResponsePropagateNamed, ResponsePropagateRegex, ResponseRemoveNamed,
11            ResponseRemoveRegex,
12        },
13    },
14};
15
16use hive_router_config::headers as config;
17use http::HeaderName;
18use regex_automata::{meta, util::syntax::Config as SyntaxConfig};
19
20pub trait HeaderRuleCompiler<A> {
21    fn compile(&self, actions: &mut A) -> Result<(), HeaderRuleCompileError>;
22}
23
24impl HeaderRuleCompiler<Vec<RequestHeaderRule>> for config::RequestHeaderRule {
25    fn compile(&self, actions: &mut Vec<RequestHeaderRule>) -> Result<(), HeaderRuleCompileError> {
26        match self {
27            config::RequestHeaderRule::Propagate(rule) => {
28                let spec = materialize_match_spec(
29                    &rule.spec,
30                    rule.rename.as_ref(),
31                    rule.default.as_ref(),
32                )?;
33
34                if !spec.header_names.is_empty() {
35                    actions.push(RequestHeaderRule::PropagateNamed(RequestPropagateNamed {
36                        names: spec.header_names,
37                        default: spec.default_header_value,
38                        rename: spec.rename_header,
39                    }));
40                }
41                if spec.include_regex.is_some() {
42                    actions.push(RequestHeaderRule::PropagateRegex(RequestPropagateRegex {
43                        include: spec.include_regex,
44                        exclude: spec.exclude_regex,
45                    }));
46                }
47            }
48            config::RequestHeaderRule::Insert(rule) => match &rule.source {
49                config::InsertSource::Value { value } => {
50                    actions.push(RequestHeaderRule::InsertStatic(RequestInsertStatic {
51                        name: build_header_name(&rule.name)?,
52                        value: build_header_value(&rule.name, value)?,
53                    }));
54                }
55                config::InsertSource::Expression { expression } => {
56                    let program = expression.compile_expression(None).map_err(|err| {
57                        HeaderRuleCompileError::ExpressionBuild(rule.name.clone(), err.diagnostics)
58                    })?;
59                    actions.push(RequestHeaderRule::InsertExpression(
60                        RequestInsertExpression {
61                            name: build_header_name(&rule.name)?,
62                            expression: Box::new(program),
63                        },
64                    ));
65                }
66            },
67            config::RequestHeaderRule::Remove(rule) => {
68                let spec = materialize_match_spec(&rule.spec, None, None)?;
69                if !spec.header_names.is_empty() {
70                    actions.push(RequestHeaderRule::RemoveNamed(RequestRemoveNamed {
71                        names: spec.header_names,
72                    }));
73                }
74                if let Some(regex_set) = spec.include_regex {
75                    actions.push(RequestHeaderRule::RemoveRegex(RequestRemoveRegex {
76                        regex: regex_set,
77                    }));
78                }
79            }
80        }
81
82        Ok(())
83    }
84}
85
86impl HeaderRuleCompiler<Vec<ResponseHeaderRule>> for config::ResponseHeaderRule {
87    fn compile(&self, actions: &mut Vec<ResponseHeaderRule>) -> Result<(), HeaderRuleCompileError> {
88        match self {
89            config::ResponseHeaderRule::Propagate(rule) => {
90                let aggregation_strategy = rule.algorithm.into();
91                let spec = materialize_match_spec(
92                    &rule.spec,
93                    rule.rename.as_ref(),
94                    rule.default.as_ref(),
95                )?;
96
97                if !spec.header_names.is_empty() {
98                    actions.push(ResponseHeaderRule::PropagateNamed(ResponsePropagateNamed {
99                        names: spec.header_names,
100                        rename: spec.rename_header,
101                        default: spec.default_header_value,
102                        strategy: aggregation_strategy,
103                    }));
104                }
105
106                if spec.include_regex.is_some() || spec.exclude_regex.is_some() {
107                    actions.push(ResponseHeaderRule::PropagateRegex(ResponsePropagateRegex {
108                        include: spec.include_regex,
109                        exclude: spec.exclude_regex,
110                        strategy: aggregation_strategy,
111                    }));
112                }
113            }
114            config::ResponseHeaderRule::Insert(rule) => {
115                let aggregation_strategy = rule.algorithm.into();
116                match &rule.source {
117                    config::InsertSource::Value { value } => {
118                        actions.push(ResponseHeaderRule::InsertStatic(ResponseInsertStatic {
119                            name: build_header_name(&rule.name)?,
120                            value: build_header_value(&rule.name, value)?,
121                            strategy: aggregation_strategy,
122                        }));
123                    }
124                    config::InsertSource::Expression { expression } => {
125                        // NOTE: In case we ever need to improve performance and not pass the whole context
126                        // to VRL expressions, we can use:
127                        // - compilation_result.program.info().target_assignments
128                        // - compilation_result.program.info().target_queries
129                        // to determine what parts of the context are actually needed by the expression
130                        let program = expression.compile_expression(None).map_err(|err| {
131                            HeaderRuleCompileError::ExpressionBuild(
132                                rule.name.clone(),
133                                err.diagnostics,
134                            )
135                        })?;
136                        actions.push(ResponseHeaderRule::InsertExpression(
137                            ResponseInsertExpression {
138                                name: build_header_name(&rule.name)?,
139                                expression: Box::new(program),
140                                strategy: aggregation_strategy,
141                            },
142                        ));
143                    }
144                }
145            }
146            config::ResponseHeaderRule::Remove(rule) => {
147                let spec = materialize_match_spec(&rule.spec, None, None)?;
148                if !spec.header_names.is_empty() {
149                    actions.push(ResponseHeaderRule::RemoveNamed(ResponseRemoveNamed {
150                        names: spec.header_names,
151                    }));
152                }
153                if let Some(regex_set) = spec.include_regex {
154                    actions.push(ResponseHeaderRule::RemoveRegex(ResponseRemoveRegex {
155                        regex: regex_set,
156                    }));
157                }
158            }
159        }
160
161        Ok(())
162    }
163}
164
165pub fn compile_headers_plan(
166    cfg: &config::HeadersConfig,
167) -> Result<HeaderRulesPlan, HeaderRuleCompileError> {
168    let mut request_plan = RequestHeaderRules::default();
169    let mut response_plan = ResponseHeaderRules::default();
170
171    if let Some(global_rules) = &cfg.all {
172        request_plan.global = compile_request_header_rules(global_rules)?;
173        response_plan.global = compile_response_header_rules(global_rules)?;
174    }
175
176    if let Some(subgraph_rules_map) = &cfg.subgraphs {
177        for (subgraph_name, subgraph_rules) in subgraph_rules_map {
178            let request_actions = compile_request_header_rules(subgraph_rules)?;
179            let response_actions = compile_response_header_rules(subgraph_rules)?;
180            request_plan
181                .by_subgraph
182                .insert(subgraph_name.clone(), request_actions);
183            response_plan
184                .by_subgraph
185                .insert(subgraph_name.clone(), response_actions);
186        }
187    }
188
189    Ok(HeaderRulesPlan {
190        request: request_plan,
191        response: response_plan,
192    })
193}
194
195fn compile_request_header_rules(
196    header_rules: &config::HeaderRules,
197) -> Result<Vec<RequestHeaderRule>, HeaderRuleCompileError> {
198    let mut request_actions = Vec::new();
199    if let Some(request_rule_entries) = &header_rules.request {
200        for request_rule in request_rule_entries {
201            request_rule.compile(&mut request_actions)?;
202        }
203    }
204    Ok(request_actions)
205}
206
207fn compile_response_header_rules(
208    header_rules: &config::HeaderRules,
209) -> Result<Vec<ResponseHeaderRule>, HeaderRuleCompileError> {
210    let mut response_actions = Vec::new();
211    if let Some(response_rule_entries) = &header_rules.response {
212        for response_rule in response_rule_entries {
213            response_rule.compile(&mut response_actions)?;
214        }
215    }
216    Ok(response_actions)
217}
218
219struct HeaderMatchSpecResult {
220    header_names: Vec<HeaderName>,
221    include_regex: Option<meta::Regex>,
222    exclude_regex: Option<meta::Regex>,
223    rename_header: Option<HeaderName>,
224    default_header_value: Option<http::HeaderValue>,
225}
226
227fn materialize_match_spec(
228    match_spec: &config::MatchSpec,
229    rename_to: Option<&String>,
230    default_value: Option<&String>,
231) -> Result<HeaderMatchSpecResult, HeaderRuleCompileError> {
232    let header_names = match &match_spec.named {
233        Some(config::OneOrMany::One(single_name)) => vec![build_header_name(single_name)?],
234        Some(config::OneOrMany::Many(many_names)) => many_names
235            .iter()
236            .map(|name| build_header_name(name))
237            .collect::<Result<Vec<_>, _>>()?,
238        None => Vec::new(),
239    };
240
241    let include_regex = match match_spec.matching.as_ref() {
242        None => None,
243        Some(config::OneOrMany::One(pattern)) => build_regex_many(std::slice::from_ref(pattern))?,
244        Some(config::OneOrMany::Many(pattern_vec)) => build_regex_many(pattern_vec)?,
245    };
246
247    let exclude_regex = match match_spec.exclude.as_deref() {
248        None => None,
249        Some(pattern_vec) => build_regex_many(pattern_vec)?,
250    };
251
252    let rename_header = rename_to
253        .map(|name| match header_names.len() == 1 {
254            true => build_header_name(name),
255            false => Err(HeaderRuleCompileError::InvalidRename),
256        })
257        .transpose()?;
258
259    let default_header_value = default_value
260        .map(|value| match header_names.len() == 1 {
261            true => build_header_value(header_names[0].as_str(), value),
262            false => Err(HeaderRuleCompileError::InvalidDefault),
263        })
264        .transpose()?;
265
266    Ok(HeaderMatchSpecResult {
267        header_names,
268        include_regex,
269        exclude_regex,
270        rename_header,
271        default_header_value,
272    })
273}
274
275fn build_header_name(header_name_str: &str) -> Result<http::HeaderName, HeaderRuleCompileError> {
276    http::HeaderName::from_bytes(header_name_str.as_bytes())
277        .map_err(|err| HeaderRuleCompileError::BadHeaderName(header_name_str.into(), err))
278}
279
280fn build_header_value(
281    header_name_str: &str,
282    header_value_str: &str,
283) -> Result<http::HeaderValue, HeaderRuleCompileError> {
284    http::HeaderValue::from_str(header_value_str)
285        .map_err(|err| HeaderRuleCompileError::BadHeaderValue(header_name_str.to_string(), err))
286}
287
288fn build_regex_many(patterns: &[String]) -> Result<Option<meta::Regex>, HeaderRuleCompileError> {
289    if patterns.is_empty() {
290        return Ok(None);
291    }
292    let mut regex_builder = meta::Regex::builder();
293    regex_builder.syntax(SyntaxConfig::new().unicode(false).utf8(false));
294    regex_builder
295        .build_many(patterns)
296        .map(Some)
297        .map_err(|e| Box::new(e).into())
298}
299
300impl From<config::AggregationAlgo> for HeaderAggregationStrategy {
301    fn from(algo: config::AggregationAlgo) -> Self {
302        match algo {
303            config::AggregationAlgo::First => HeaderAggregationStrategy::First,
304            config::AggregationAlgo::Last => HeaderAggregationStrategy::Last,
305            config::AggregationAlgo::Append => HeaderAggregationStrategy::Append,
306        }
307    }
308}
309
310impl From<Option<config::AggregationAlgo>> for HeaderAggregationStrategy {
311    fn from(algo: Option<config::AggregationAlgo>) -> Self {
312        match algo {
313            Some(config::AggregationAlgo::First) => HeaderAggregationStrategy::First,
314            Some(config::AggregationAlgo::Last) => HeaderAggregationStrategy::Last,
315            Some(config::AggregationAlgo::Append) => HeaderAggregationStrategy::Append,
316            None => HeaderAggregationStrategy::Last,
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use hive_router_config::headers as config;
324    use http::HeaderName;
325
326    use crate::headers::{
327        compile::{build_header_value, HeaderRuleCompiler},
328        errors::HeaderRuleCompileError,
329        plan::{HeaderAggregationStrategy, RequestHeaderRule, ResponseHeaderRule},
330    };
331
332    fn header_name_owned(s: &str) -> HeaderName {
333        HeaderName::from_bytes(s.as_bytes()).unwrap()
334    }
335
336    #[test]
337    fn test_propagate_named_request() {
338        let rule = config::RequestHeaderRule::Propagate(config::RequestPropagateRule {
339            spec: config::MatchSpec {
340                named: Some(config::OneOrMany::One("x-test".to_string())),
341                matching: None,
342                exclude: None,
343            },
344            rename: None,
345            default: None,
346        });
347        let mut actions = Vec::new();
348        rule.compile(&mut actions).unwrap();
349        assert_eq!(actions.len(), 1);
350        match &actions[0] {
351            RequestHeaderRule::PropagateNamed(data) => {
352                assert_eq!(data.names, vec![header_name_owned("x-test")]);
353                assert!(data.default.is_none());
354                assert!(data.rename.is_none());
355            }
356            _ => panic!("Expected PropagateNamed"),
357        }
358    }
359
360    #[test]
361    fn test_set_request() {
362        let rule = config::RequestHeaderRule::Insert(config::RequestInsertRule {
363            name: "x-set".to_string(),
364            source: config::InsertSource::Value {
365                value: "abc".to_string(),
366            },
367        });
368        let mut actions = Vec::new();
369        rule.compile(&mut actions).unwrap();
370        assert_eq!(actions.len(), 1);
371        match &actions[0] {
372            RequestHeaderRule::InsertStatic(data) => {
373                assert_eq!(data.name, header_name_owned("x-set"));
374                assert_eq!(data.value, build_header_value("x-set", "abc").unwrap());
375            }
376            _ => panic!("Expected SetStatic"),
377        }
378    }
379
380    #[test]
381    fn test_remove_named_request() {
382        let rule = config::RequestHeaderRule::Remove(config::RemoveRule {
383            spec: config::MatchSpec {
384                named: Some(config::OneOrMany::One("x-remove".to_string())),
385                matching: None,
386                exclude: None,
387            },
388        });
389        let mut actions = Vec::new();
390        rule.compile(&mut actions).unwrap();
391        assert_eq!(actions.len(), 1);
392        match &actions[0] {
393            RequestHeaderRule::RemoveNamed(data) => {
394                assert_eq!(data.names, vec![header_name_owned("x-remove")]);
395            }
396            _ => panic!("Expected RemoveNamed"),
397        }
398    }
399
400    #[test]
401    fn test_invalid_default_request() {
402        let rule = config::RequestHeaderRule::Propagate(config::RequestPropagateRule {
403            spec: config::MatchSpec {
404                named: Some(config::OneOrMany::Many(vec![
405                    "x1".to_string(),
406                    "x2".to_string(),
407                ])),
408                matching: None,
409                exclude: None,
410            },
411            rename: None,
412            default: Some("def".to_string()),
413        });
414        let mut actions = Vec::new();
415        let err = rule.compile(&mut actions).unwrap_err();
416        match err {
417            HeaderRuleCompileError::InvalidDefault => {}
418            _ => panic!("Expected InvalidDefault error"),
419        }
420    }
421
422    #[test]
423    fn test_propagate_named_response() {
424        let rule = config::ResponseHeaderRule::Propagate(config::ResponsePropagateRule {
425            spec: config::MatchSpec {
426                named: Some(config::OneOrMany::One("x-resp".to_string())),
427                matching: None,
428                exclude: None,
429            },
430            rename: None,
431            default: None,
432            algorithm: config::AggregationAlgo::First,
433        });
434        let mut actions = Vec::new();
435        rule.compile(&mut actions).unwrap();
436        assert_eq!(actions.len(), 1);
437        match &actions[0] {
438            ResponseHeaderRule::PropagateNamed(data) => {
439                assert_eq!(data.names, vec![header_name_owned("x-resp")]);
440                assert!(matches!(data.strategy, HeaderAggregationStrategy::First));
441            }
442            _ => panic!("Expected PropagateNamed"),
443        }
444    }
445}