1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
use crate::{
    headers::Headers,
    policies::{Policy, PolicyResult},
    Context, Request,
};
use std::sync::Arc;
use tracing::trace;

#[derive(Debug, Clone)]
pub struct CustomHeaders(Headers);

impl From<Headers> for CustomHeaders {
    fn from(h: Headers) -> Self {
        Self(h)
    }
}

#[derive(Clone, Debug, Default)]
pub struct CustomHeadersPolicy {}

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl Policy for CustomHeadersPolicy {
    async fn send(
        &self,
        ctx: &Context,
        request: &mut Request,
        next: &[Arc<dyn Policy>],
    ) -> PolicyResult {
        if let Some(CustomHeaders(custom_headers)) = ctx.get::<CustomHeaders>() {
            custom_headers
                .iter()
                .for_each(|(header_name, header_value)| {
                    trace!(
                        "injecting custom context header {:?} with value {:?}",
                        header_name,
                        header_value
                    );
                    request.insert_header(header_name.clone(), header_value.clone());
                });
        }

        next[0].send(ctx, request, &next[1..]).await
    }
}