hive_router_plan_executor/headers/
compile.rs

1use crate::headers::{
2    errors::HeaderRuleCompileError,
3    plan::{
4        HeaderAggregationStrategy, HeaderRulesPlan, RequestHeaderRule, RequestHeaderRules,
5        RequestInsertExpression, RequestInsertStatic, RequestPropagateNamed, RequestPropagateRegex,
6        RequestRemoveNamed, RequestRemoveRegex, ResponseHeaderRule, ResponseHeaderRules,
7        ResponseInsertExpression, ResponseInsertStatic, ResponsePropagateNamed,
8        ResponsePropagateRegex, ResponseRemoveNamed, ResponseRemoveRegex,
9    },
10};
11
12use hive_router_config::headers as config;
13use http::HeaderName;
14use regex_automata::{meta, util::syntax::Config as SyntaxConfig};
15use vrl::{
16    compiler::compile as vrl_compile, prelude::Function as VrlFunction,
17    stdlib::all as vrl_build_functions,
18};
19
20pub struct HeaderRuleCompilerContext {
21    vrl_functions: Vec<Box<dyn VrlFunction>>,
22}
23
24impl Default for HeaderRuleCompilerContext {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl HeaderRuleCompilerContext {
31    pub fn new() -> Self {
32        Self {
33            vrl_functions: vrl_build_functions(),
34        }
35    }
36}
37
38pub trait HeaderRuleCompiler<A> {
39    fn compile(
40        &self,
41        ctx: &HeaderRuleCompilerContext,
42        actions: &mut A,
43    ) -> Result<(), HeaderRuleCompileError>;
44}
45
46impl HeaderRuleCompiler<Vec<RequestHeaderRule>> for config::RequestHeaderRule {
47    fn compile(
48        &self,
49        ctx: &HeaderRuleCompilerContext,
50        actions: &mut Vec<RequestHeaderRule>,
51    ) -> Result<(), HeaderRuleCompileError> {
52        match self {
53            config::RequestHeaderRule::Propagate(rule) => {
54                let spec = materialize_match_spec(
55                    &rule.spec,
56                    rule.rename.as_ref(),
57                    rule.default.as_ref(),
58                )?;
59
60                if !spec.header_names.is_empty() {
61                    actions.push(RequestHeaderRule::PropagateNamed(RequestPropagateNamed {
62                        names: spec.header_names,
63                        default: spec.default_header_value,
64                        rename: spec.rename_header,
65                    }));
66                }
67                if spec.include_regex.is_some() {
68                    actions.push(RequestHeaderRule::PropagateRegex(RequestPropagateRegex {
69                        include: spec.include_regex,
70                        exclude: spec.exclude_regex,
71                    }));
72                }
73            }
74            config::RequestHeaderRule::Insert(rule) => match &rule.source {
75                config::InsertSource::Value { value } => {
76                    actions.push(RequestHeaderRule::InsertStatic(RequestInsertStatic {
77                        name: build_header_name(&rule.name)?,
78                        value: build_header_value(&rule.name, value)?,
79                    }));
80                }
81                config::InsertSource::Expression { expression } => {
82                    let compilation_result =
83                        vrl_compile(expression, &ctx.vrl_functions).map_err(|e| {
84                            HeaderRuleCompileError::new_expression_build(rule.name.clone(), e)
85                        })?;
86
87                    actions.push(RequestHeaderRule::InsertExpression(
88                        RequestInsertExpression {
89                            name: build_header_name(&rule.name)?,
90                            expression: Box::new(compilation_result.program),
91                        },
92                    ));
93                }
94            },
95            config::RequestHeaderRule::Remove(rule) => {
96                let spec = materialize_match_spec(&rule.spec, None, None)?;
97                if !spec.header_names.is_empty() {
98                    actions.push(RequestHeaderRule::RemoveNamed(RequestRemoveNamed {
99                        names: spec.header_names,
100                    }));
101                }
102                if let Some(regex_set) = spec.include_regex {
103                    actions.push(RequestHeaderRule::RemoveRegex(RequestRemoveRegex {
104                        regex: regex_set,
105                    }));
106                }
107            }
108        }
109
110        Ok(())
111    }
112}
113
114impl HeaderRuleCompiler<Vec<ResponseHeaderRule>> for config::ResponseHeaderRule {
115    fn compile(
116        &self,
117        ctx: &HeaderRuleCompilerContext,
118        actions: &mut Vec<ResponseHeaderRule>,
119    ) -> Result<(), HeaderRuleCompileError> {
120        match self {
121            config::ResponseHeaderRule::Propagate(rule) => {
122                let aggregation_strategy = rule.algorithm.into();
123                let spec = materialize_match_spec(
124                    &rule.spec,
125                    rule.rename.as_ref(),
126                    rule.default.as_ref(),
127                )?;
128
129                if !spec.header_names.is_empty() {
130                    actions.push(ResponseHeaderRule::PropagateNamed(ResponsePropagateNamed {
131                        names: spec.header_names,
132                        rename: spec.rename_header,
133                        default: spec.default_header_value,
134                        strategy: aggregation_strategy,
135                    }));
136                }
137
138                if spec.include_regex.is_some() || spec.exclude_regex.is_some() {
139                    actions.push(ResponseHeaderRule::PropagateRegex(ResponsePropagateRegex {
140                        include: spec.include_regex,
141                        exclude: spec.exclude_regex,
142                        strategy: aggregation_strategy,
143                    }));
144                }
145            }
146            config::ResponseHeaderRule::Insert(rule) => {
147                let aggregation_strategy = rule.algorithm.into();
148                match &rule.source {
149                    config::InsertSource::Value { value } => {
150                        actions.push(ResponseHeaderRule::InsertStatic(ResponseInsertStatic {
151                            name: build_header_name(&rule.name)?,
152                            value: build_header_value(&rule.name, value)?,
153                            strategy: aggregation_strategy,
154                        }));
155                    }
156                    config::InsertSource::Expression { expression } => {
157                        // NOTE: In case we ever need to improve performance and not pass the whole context
158                        // to VRL expressions, we can use:
159                        // - compilation_result.program.info().target_assignments
160                        // - compilation_result.program.info().target_queries
161                        // to determine what parts of the context are actually needed by the expression
162                        let compilation_result = vrl_compile(expression, &ctx.vrl_functions)
163                            .map_err(|e| {
164                                HeaderRuleCompileError::new_expression_build(rule.name.clone(), e)
165                            })?;
166
167                        actions.push(ResponseHeaderRule::InsertExpression(
168                            ResponseInsertExpression {
169                                name: build_header_name(&rule.name)?,
170                                expression: Box::new(compilation_result.program),
171                                strategy: aggregation_strategy,
172                            },
173                        ));
174                    }
175                }
176            }
177            config::ResponseHeaderRule::Remove(rule) => {
178                let spec = materialize_match_spec(&rule.spec, None, None)?;
179                if !spec.header_names.is_empty() {
180                    actions.push(ResponseHeaderRule::RemoveNamed(ResponseRemoveNamed {
181                        names: spec.header_names,
182                    }));
183                }
184                if let Some(regex_set) = spec.include_regex {
185                    actions.push(ResponseHeaderRule::RemoveRegex(ResponseRemoveRegex {
186                        regex: regex_set,
187                    }));
188                }
189            }
190        }
191
192        Ok(())
193    }
194}
195
196pub fn compile_headers_plan(
197    cfg: &config::HeadersConfig,
198) -> Result<HeaderRulesPlan, HeaderRuleCompileError> {
199    let ctx = HeaderRuleCompilerContext::new();
200    let mut request_plan = RequestHeaderRules::default();
201    let mut response_plan = ResponseHeaderRules::default();
202
203    if let Some(global_rules) = &cfg.all {
204        request_plan.global = compile_request_header_rules(&ctx, global_rules)?;
205        response_plan.global = compile_response_header_rules(&ctx, global_rules)?;
206    }
207
208    if let Some(subgraph_rules_map) = &cfg.subgraphs {
209        for (subgraph_name, subgraph_rules) in subgraph_rules_map {
210            let request_actions = compile_request_header_rules(&ctx, subgraph_rules)?;
211            let response_actions = compile_response_header_rules(&ctx, subgraph_rules)?;
212            request_plan
213                .by_subgraph
214                .insert(subgraph_name.clone(), request_actions);
215            response_plan
216                .by_subgraph
217                .insert(subgraph_name.clone(), response_actions);
218        }
219    }
220
221    Ok(HeaderRulesPlan {
222        request: request_plan,
223        response: response_plan,
224    })
225}
226
227fn compile_request_header_rules(
228    ctx: &HeaderRuleCompilerContext,
229    header_rules: &config::HeaderRules,
230) -> Result<Vec<RequestHeaderRule>, HeaderRuleCompileError> {
231    let mut request_actions = Vec::new();
232    if let Some(request_rule_entries) = &header_rules.request {
233        for request_rule in request_rule_entries {
234            request_rule.compile(ctx, &mut request_actions)?;
235        }
236    }
237    Ok(request_actions)
238}
239
240fn compile_response_header_rules(
241    ctx: &HeaderRuleCompilerContext,
242    header_rules: &config::HeaderRules,
243) -> Result<Vec<ResponseHeaderRule>, HeaderRuleCompileError> {
244    let mut response_actions = Vec::new();
245    if let Some(response_rule_entries) = &header_rules.response {
246        for response_rule in response_rule_entries {
247            response_rule.compile(ctx, &mut response_actions)?;
248        }
249    }
250    Ok(response_actions)
251}
252
253struct HeaderMatchSpecResult {
254    header_names: Vec<HeaderName>,
255    include_regex: Option<meta::Regex>,
256    exclude_regex: Option<meta::Regex>,
257    rename_header: Option<HeaderName>,
258    default_header_value: Option<http::HeaderValue>,
259}
260
261fn materialize_match_spec(
262    match_spec: &config::MatchSpec,
263    rename_to: Option<&String>,
264    default_value: Option<&String>,
265) -> Result<HeaderMatchSpecResult, HeaderRuleCompileError> {
266    let header_names = match &match_spec.named {
267        Some(config::OneOrMany::One(single_name)) => vec![build_header_name(single_name)?],
268        Some(config::OneOrMany::Many(many_names)) => many_names
269            .iter()
270            .map(|name| build_header_name(name))
271            .collect::<Result<Vec<_>, _>>()?,
272        None => Vec::new(),
273    };
274
275    let include_regex = match match_spec.matching.as_ref() {
276        None => None,
277        Some(config::OneOrMany::One(pattern)) => build_regex_many(std::slice::from_ref(pattern))?,
278        Some(config::OneOrMany::Many(pattern_vec)) => build_regex_many(pattern_vec)?,
279    };
280
281    let exclude_regex = match match_spec.exclude.as_deref() {
282        None => None,
283        Some(pattern_vec) => build_regex_many(pattern_vec)?,
284    };
285
286    let rename_header = rename_to
287        .map(|name| match header_names.len() == 1 {
288            true => build_header_name(name),
289            false => Err(HeaderRuleCompileError::InvalidRename),
290        })
291        .transpose()?;
292
293    let default_header_value = default_value
294        .map(|value| match header_names.len() == 1 {
295            true => build_header_value(header_names[0].as_str(), value),
296            false => Err(HeaderRuleCompileError::InvalidDefault),
297        })
298        .transpose()?;
299
300    Ok(HeaderMatchSpecResult {
301        header_names,
302        include_regex,
303        exclude_regex,
304        rename_header,
305        default_header_value,
306    })
307}
308
309fn build_header_name(header_name_str: &str) -> Result<http::HeaderName, HeaderRuleCompileError> {
310    http::HeaderName::from_bytes(header_name_str.as_bytes())
311        .map_err(|err| HeaderRuleCompileError::BadHeaderName(header_name_str.into(), err))
312}
313
314fn build_header_value(
315    header_name_str: &str,
316    header_value_str: &str,
317) -> Result<http::HeaderValue, HeaderRuleCompileError> {
318    http::HeaderValue::from_str(header_value_str)
319        .map_err(|err| HeaderRuleCompileError::BadHeaderValue(header_name_str.to_string(), err))
320}
321
322fn build_regex_many(patterns: &[String]) -> Result<Option<meta::Regex>, HeaderRuleCompileError> {
323    if patterns.is_empty() {
324        return Ok(None);
325    }
326    let mut regex_builder = meta::Regex::builder();
327    regex_builder.syntax(SyntaxConfig::new().unicode(false).utf8(false));
328    regex_builder
329        .build_many(patterns)
330        .map(Some)
331        .map_err(|e| Box::new(e).into())
332}
333
334impl From<config::AggregationAlgo> for HeaderAggregationStrategy {
335    fn from(algo: config::AggregationAlgo) -> Self {
336        match algo {
337            config::AggregationAlgo::First => HeaderAggregationStrategy::First,
338            config::AggregationAlgo::Last => HeaderAggregationStrategy::Last,
339            config::AggregationAlgo::Append => HeaderAggregationStrategy::Append,
340        }
341    }
342}
343
344impl From<Option<config::AggregationAlgo>> for HeaderAggregationStrategy {
345    fn from(algo: Option<config::AggregationAlgo>) -> Self {
346        match algo {
347            Some(config::AggregationAlgo::First) => HeaderAggregationStrategy::First,
348            Some(config::AggregationAlgo::Last) => HeaderAggregationStrategy::Last,
349            Some(config::AggregationAlgo::Append) => HeaderAggregationStrategy::Append,
350            None => HeaderAggregationStrategy::Last,
351        }
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use hive_router_config::headers as config;
358    use http::HeaderName;
359
360    use crate::headers::{
361        compile::{build_header_value, HeaderRuleCompiler, HeaderRuleCompilerContext},
362        errors::HeaderRuleCompileError,
363        plan::{HeaderAggregationStrategy, RequestHeaderRule, ResponseHeaderRule},
364    };
365
366    fn header_name_owned(s: &str) -> HeaderName {
367        HeaderName::from_bytes(s.as_bytes()).unwrap()
368    }
369
370    #[test]
371    fn test_propagate_named_request() {
372        let rule = config::RequestHeaderRule::Propagate(config::RequestPropagateRule {
373            spec: config::MatchSpec {
374                named: Some(config::OneOrMany::One("x-test".to_string())),
375                matching: None,
376                exclude: None,
377            },
378            rename: None,
379            default: None,
380        });
381        let ctx = HeaderRuleCompilerContext::new();
382        let mut actions = Vec::new();
383        rule.compile(&ctx, &mut actions).unwrap();
384        assert_eq!(actions.len(), 1);
385        match &actions[0] {
386            RequestHeaderRule::PropagateNamed(data) => {
387                assert_eq!(data.names, vec![header_name_owned("x-test")]);
388                assert!(data.default.is_none());
389                assert!(data.rename.is_none());
390            }
391            _ => panic!("Expected PropagateNamed"),
392        }
393    }
394
395    #[test]
396    fn test_set_request() {
397        let rule = config::RequestHeaderRule::Insert(config::RequestInsertRule {
398            name: "x-set".to_string(),
399            source: config::InsertSource::Value {
400                value: "abc".to_string(),
401            },
402        });
403        let mut actions = Vec::new();
404        let ctx = HeaderRuleCompilerContext::new();
405        rule.compile(&ctx, &mut actions).unwrap();
406        assert_eq!(actions.len(), 1);
407        match &actions[0] {
408            RequestHeaderRule::InsertStatic(data) => {
409                assert_eq!(data.name, header_name_owned("x-set"));
410                assert_eq!(data.value, build_header_value("x-set", "abc").unwrap());
411            }
412            _ => panic!("Expected SetStatic"),
413        }
414    }
415
416    #[test]
417    fn test_remove_named_request() {
418        let rule = config::RequestHeaderRule::Remove(config::RemoveRule {
419            spec: config::MatchSpec {
420                named: Some(config::OneOrMany::One("x-remove".to_string())),
421                matching: None,
422                exclude: None,
423            },
424        });
425        let mut actions = Vec::new();
426        let ctx = HeaderRuleCompilerContext::new();
427        rule.compile(&ctx, &mut actions).unwrap();
428        assert_eq!(actions.len(), 1);
429        match &actions[0] {
430            RequestHeaderRule::RemoveNamed(data) => {
431                assert_eq!(data.names, vec![header_name_owned("x-remove")]);
432            }
433            _ => panic!("Expected RemoveNamed"),
434        }
435    }
436
437    #[test]
438    fn test_invalid_default_request() {
439        let rule = config::RequestHeaderRule::Propagate(config::RequestPropagateRule {
440            spec: config::MatchSpec {
441                named: Some(config::OneOrMany::Many(vec![
442                    "x1".to_string(),
443                    "x2".to_string(),
444                ])),
445                matching: None,
446                exclude: None,
447            },
448            rename: None,
449            default: Some("def".to_string()),
450        });
451        let mut actions = Vec::new();
452        let ctx = HeaderRuleCompilerContext::new();
453        let err = rule.compile(&ctx, &mut actions).unwrap_err();
454        match err {
455            HeaderRuleCompileError::InvalidDefault => {}
456            _ => panic!("Expected InvalidDefault error"),
457        }
458    }
459
460    #[test]
461    fn test_propagate_named_response() {
462        let rule = config::ResponseHeaderRule::Propagate(config::ResponsePropagateRule {
463            spec: config::MatchSpec {
464                named: Some(config::OneOrMany::One("x-resp".to_string())),
465                matching: None,
466                exclude: None,
467            },
468            rename: None,
469            default: None,
470            algorithm: config::AggregationAlgo::First,
471        });
472        let mut actions = Vec::new();
473        let ctx = HeaderRuleCompilerContext::new();
474        rule.compile(&ctx, &mut actions).unwrap();
475        assert_eq!(actions.len(), 1);
476        match &actions[0] {
477            ResponseHeaderRule::PropagateNamed(data) => {
478                assert_eq!(data.names, vec![header_name_owned("x-resp")]);
479                assert!(matches!(data.strategy, HeaderAggregationStrategy::First));
480            }
481            _ => panic!("Expected PropagateNamed"),
482        }
483    }
484}