1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
use crate::RequestBuilder;

/// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run
/// whenever the client starts building a request, in the order it was attached.
///
/// # Example
///
/// ```
/// use reqwest_middleware::{RequestInitialiser, RequestBuilder};
///
/// struct AuthInit;
///
/// impl RequestInitialiser for AuthInit {
///     fn init(&self, req: RequestBuilder) -> RequestBuilder {
///         req.bearer_auth("my_auth_token")
///     }
/// }
/// ```
///
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
/// [`with_init`]: crate::ClientBuilder::with_init
pub trait RequestInitialiser: 'static + Send + Sync {
    fn init(&self, req: RequestBuilder) -> RequestBuilder;
}

impl<F> RequestInitialiser for F
where
    F: Send + Sync + 'static + Fn(RequestBuilder) -> RequestBuilder,
{
    fn init(&self, req: RequestBuilder) -> RequestBuilder {
        (self)(req)
    }
}

/// A middleware that inserts the value into the [`Extensions`](task_local_extensions::Extensions) during the call.
///
/// This is a good way to inject extensions to middleware deeper in the stack
///
/// ```
/// use reqwest::{Client, Request, Response};
/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result, Extension};
/// use task_local_extensions::Extensions;
///
/// #[derive(Clone)]
/// struct LogName(&'static str);
/// struct LoggingMiddleware;
///
/// #[async_trait::async_trait]
/// impl Middleware for LoggingMiddleware {
///     async fn handle(
///         &self,
///         req: Request,
///         extensions: &mut Extensions,
///         next: Next<'_>,
///     ) -> Result<Response> {
///         // get the log name or default to "unknown"
///         let name = extensions
///             .get()
///             .map(|&LogName(name)| name)
///             .unwrap_or("unknown");
///         println!("[{name}] Request started {req:?}");
///         let res = next.run(req, extensions).await;
///         println!("[{name}] Result: {res:?}");
///         res
///     }
/// }
///
/// async fn run() {
///     let reqwest_client = Client::builder().build().unwrap();
///     let client = ClientBuilder::new(reqwest_client)
///         .with_init(Extension(LogName("my-client")))
///         .with(LoggingMiddleware)
///         .build();
///     let resp = client.get("https://truelayer.com").send().await.unwrap();
///     println!("TrueLayer page HTML: {}", resp.text().await.unwrap());
/// }
/// ```
pub struct Extension<T>(pub T);

impl<T: Send + Sync + Clone + 'static> RequestInitialiser for Extension<T> {
    fn init(&self, req: RequestBuilder) -> RequestBuilder {
        req.with_extension(self.0.clone())
    }
}