reqwest_auth/
lib.rs

1//! # reqwest-auth
2//!
3//! A reqwest middleware to fill-in the authorization header using a token source.
4//!
5//! Uses the `token-source` crate to provide a common interface for token sources.
6
7#![warn(missing_docs)]
8
9use anyhow::anyhow;
10use http::Extensions;
11use reqwest_middleware::reqwest::header::HeaderValue;
12use reqwest_middleware::reqwest::header::AUTHORIZATION;
13use reqwest_middleware::reqwest::Request;
14use reqwest_middleware::reqwest::Response;
15use reqwest_middleware::Error;
16use reqwest_middleware::Middleware;
17use reqwest_middleware::Next;
18use std::sync::Arc;
19use token_source::TokenSource;
20
21/// AuthorizationHeaderMiddleware
22///
23/// Provided a [TokenSource](token_source::TokenSource) implementation, this middleware
24/// will set the Authorization header of the request with the token value obtained from this
25/// token source.
26///
27/// The token source is expected to provide a valid token (e.g including renewal), or an error if the token
28/// could not be obtained.
29///
30/// # How to use
31///
32/// ```rust
33///  use reqwest_middleware::ClientBuilder;
34///  use token_source::{TokenSource, TokenSourceProvider};
35///  use std::sync::Arc;
36///  use reqwest_auth::AuthorizationHeaderMiddleware;
37///  
38///  // In real cases you should have a token source provider
39///  // that provides a token source implementation.
40///  // Here we are using a simple example with a hardcoded token value.
41///
42///  // For demonstration purposes.
43///  #[derive(Debug)]
44///  struct MyTokenSource {
45///    pub token: String,
46///  }
47///
48///  // For demonstration purposes.
49///  #[async_trait::async_trait]
50///  impl TokenSource for MyTokenSource {
51///    async fn token(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
52///       Ok(self.token.clone())
53///    }
54///  }
55///
56///  // For demonstration purposes.
57///  #[derive(Debug)]
58///  struct MyTokenProvider {
59///    pub ts: Arc<MyTokenSource>,
60///  }
61///
62///  // For demonstration purposes.
63///  impl TokenSourceProvider for MyTokenProvider {
64///    fn token_source(&self) -> Arc<dyn TokenSource> {
65///      self.ts.clone()
66///    }
67///  }
68///
69///  // For demonstration purposes.
70///  let ts_provider = MyTokenProvider {
71///    ts: Arc::new(MyTokenSource {
72///      token: "Bearer my-token".to_string(),
73///    }),
74///  };
75///
76///  // Create the middleware from the token source
77///  let auth_middleware = AuthorizationHeaderMiddleware::from(ts_provider.token_source());
78///
79///  // Create your reqwest client with the middleware
80///  let client = ClientBuilder::new(reqwest::Client::default())
81///    // Ideally, the authorization middleware should come last,
82///    // especially if you are using a retry middleware as well.
83///    // This way, your retry requests will benefit from the renewals of the token,
84///    // as long as your token source implementation is able to renew the token.
85///    .with(auth_middleware)
86///    .build();
87/// ```
88pub struct AuthorizationHeaderMiddleware {
89    ts: Arc<dyn TokenSource>,
90}
91
92impl From<Arc<dyn TokenSource>> for AuthorizationHeaderMiddleware {
93    fn from(ts: Arc<dyn TokenSource>) -> Self {
94        Self { ts }
95    }
96}
97
98impl From<Box<dyn TokenSource>> for AuthorizationHeaderMiddleware {
99    fn from(ts: Box<dyn TokenSource>) -> Self {
100        Self { ts: ts.into() }
101    }
102}
103
104#[async_trait::async_trait]
105impl Middleware for AuthorizationHeaderMiddleware {
106    async fn handle(
107        &self,
108        mut req: Request,
109        extensions: &mut Extensions,
110        next: Next<'_>,
111    ) -> reqwest_middleware::Result<Response> {
112        // Obtain (or regenerate) an auth token from the token source
113        let auth_token = self
114            .ts
115            .token()
116            .await
117            .map_err(|e| Error::Middleware(anyhow!(e.to_string())))?;
118
119        // Set the Authorization header with the auth token
120        // Note: any previous value of the Authorization header will be overwritten
121        req.headers_mut().insert(
122            AUTHORIZATION,
123            HeaderValue::from_str(auth_token.as_str())
124                .map_err(|e| Error::Middleware(anyhow!(format!("Invalid auth token value: {e}"))))?,
125        );
126
127        // Chain to next middleware in the stack
128        next.run(req, extensions).await
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use std::sync::Arc;
135
136    use http::Extensions;
137    use reqwest_middleware::reqwest;
138    use reqwest_middleware::ClientBuilder;
139    use reqwest_middleware::Middleware;
140    use token_source::{TokenSource, TokenSourceProvider};
141
142    use super::AuthorizationHeaderMiddleware;
143    use reqwest_middleware::reqwest::header::HeaderValue;
144    use reqwest_middleware::reqwest::header::AUTHORIZATION;
145    use reqwest_middleware::reqwest::Request;
146    use reqwest_middleware::reqwest::Response;
147    use reqwest_middleware::Next;
148
149    #[derive(Debug)]
150    struct MyTokenSource {
151        pub token: String,
152    }
153
154    #[async_trait::async_trait]
155    impl TokenSource for MyTokenSource {
156        async fn token(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
157            Ok(self.token.clone())
158        }
159    }
160
161    #[derive(Debug)]
162    struct MyTokenProvider {
163        pub ts: Arc<MyTokenSource>,
164    }
165
166    impl TokenSourceProvider for MyTokenProvider {
167        fn token_source(&self) -> Arc<dyn TokenSource> {
168            self.ts.clone()
169        }
170    }
171
172    /// A simple middleware to verify the Authorization header
173    /// is set correctly.
174    ///
175    /// For testing purposes only.
176    struct VerificationMiddleware {
177        expected: &'static str,
178    }
179
180    #[async_trait::async_trait]
181    impl Middleware for VerificationMiddleware {
182        async fn handle(
183            &self,
184            req: Request,
185            extensions: &mut Extensions,
186            next: Next<'_>,
187        ) -> reqwest_middleware::Result<Response> {
188            // Verify the Authorization header is set correctly
189            let token_value = req
190                .headers()
191                .get(AUTHORIZATION)
192                .expect("Authorization header should be set");
193            assert_eq!(token_value, &HeaderValue::from_static(self.expected));
194
195            // Chain to next middleware in the stack
196            next.run(req, extensions).await
197        }
198    }
199
200    #[async_std::test]
201    async fn test_middleware() {
202        // Given - the Authorization middleware & test verification one
203        let token_value = "Bearer my-token";
204        let ts_provider = MyTokenProvider {
205            ts: Arc::new(MyTokenSource {
206                token: token_value.to_string(),
207            }),
208        };
209        let auth_middleware = AuthorizationHeaderMiddleware::from(ts_provider.token_source());
210        let verification_middleware = VerificationMiddleware { expected: token_value };
211
212        let client = ClientBuilder::new(reqwest::Client::default())
213            // Authorization should come first
214            .with(auth_middleware)
215            // Verification should come next
216            .with(verification_middleware)
217            .build();
218
219        // When - making a request
220        // Then - the Authorization header has been set correctly
221        let _ = client
222            .get("https://github.com/nicolas-vivot/reqwest-auth/CODE_OF_CONDUCT.md")
223            .send()
224            .await;
225    }
226}