reqwest_auth/lib.rs
1//! # reqwest-auth
2//!
3//! A reqwest middleware to fill-in the authorization header using a token source.
4//!
5//! Uses the `token-source` crate to provide a common interface for token sources.
6
7#![warn(missing_docs)]
8
9use anyhow::anyhow;
10use http::Extensions;
11use reqwest_middleware::reqwest::header::HeaderValue;
12use reqwest_middleware::reqwest::header::AUTHORIZATION;
13use reqwest_middleware::reqwest::Request;
14use reqwest_middleware::reqwest::Response;
15use reqwest_middleware::Error;
16use reqwest_middleware::Middleware;
17use reqwest_middleware::Next;
18use std::sync::Arc;
19use token_source::TokenSource;
20
21/// AuthorizationHeaderMiddleware
22///
23/// Provided a [TokenSource](token_source::TokenSource) implementation, this middleware
24/// will set the Authorization header of the request with the token value obtained from this
25/// token source.
26///
27/// The token source is expected to provide a valid token (e.g including renewal), or an error if the token
28/// could not be obtained.
29///
30/// # How to use
31///
32/// ```rust
33/// use reqwest_middleware::ClientBuilder;
34/// use token_source::{TokenSource, TokenSourceProvider};
35/// use std::sync::Arc;
36/// use reqwest_auth::AuthorizationHeaderMiddleware;
37///
38/// // In real cases you should have a token source provider
39/// // that provides a token source implementation.
40/// // Here we are using a simple example with a hardcoded token value.
41///
42/// // For demonstration purposes.
43/// #[derive(Debug)]
44/// struct MyTokenSource {
45/// pub token: String,
46/// }
47///
48/// // For demonstration purposes.
49/// #[async_trait::async_trait]
50/// impl TokenSource for MyTokenSource {
51/// async fn token(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
52/// Ok(self.token.clone())
53/// }
54/// }
55///
56/// // For demonstration purposes.
57/// #[derive(Debug)]
58/// struct MyTokenProvider {
59/// pub ts: Arc<MyTokenSource>,
60/// }
61///
62/// // For demonstration purposes.
63/// impl TokenSourceProvider for MyTokenProvider {
64/// fn token_source(&self) -> Arc<dyn TokenSource> {
65/// self.ts.clone()
66/// }
67/// }
68///
69/// // For demonstration purposes.
70/// let ts_provider = MyTokenProvider {
71/// ts: Arc::new(MyTokenSource {
72/// token: "Bearer my-token".to_string(),
73/// }),
74/// };
75///
76/// // Create the middleware from the token source
77/// let auth_middleware = AuthorizationHeaderMiddleware::from(ts_provider.token_source());
78///
79/// // Create your reqwest client with the middleware
80/// let client = ClientBuilder::new(reqwest::Client::default())
81/// // Ideally, the authorization middleware should come last,
82/// // especially if you are using a retry middleware as well.
83/// // This way, your retry requests will benefit from the renewals of the token,
84/// // as long as your token source implementation is able to renew the token.
85/// .with(auth_middleware)
86/// .build();
87/// ```
88pub struct AuthorizationHeaderMiddleware {
89 ts: Arc<dyn TokenSource>,
90}
91
92impl From<Arc<dyn TokenSource>> for AuthorizationHeaderMiddleware {
93 fn from(ts: Arc<dyn TokenSource>) -> Self {
94 Self { ts }
95 }
96}
97
98impl From<Box<dyn TokenSource>> for AuthorizationHeaderMiddleware {
99 fn from(ts: Box<dyn TokenSource>) -> Self {
100 Self { ts: ts.into() }
101 }
102}
103
104#[async_trait::async_trait]
105impl Middleware for AuthorizationHeaderMiddleware {
106 async fn handle(
107 &self,
108 mut req: Request,
109 extensions: &mut Extensions,
110 next: Next<'_>,
111 ) -> reqwest_middleware::Result<Response> {
112 // Obtain (or regenerate) an auth token from the token source
113 let auth_token = self
114 .ts
115 .token()
116 .await
117 .map_err(|e| Error::Middleware(anyhow!(e.to_string())))?;
118
119 // Set the Authorization header with the auth token
120 // Note: any previous value of the Authorization header will be overwritten
121 req.headers_mut().insert(
122 AUTHORIZATION,
123 HeaderValue::from_str(auth_token.as_str())
124 .map_err(|e| Error::Middleware(anyhow!(format!("Invalid auth token value: {e}"))))?,
125 );
126
127 // Chain to next middleware in the stack
128 next.run(req, extensions).await
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use std::sync::Arc;
135
136 use http::Extensions;
137 use reqwest_middleware::reqwest;
138 use reqwest_middleware::ClientBuilder;
139 use reqwest_middleware::Middleware;
140 use token_source::{TokenSource, TokenSourceProvider};
141
142 use super::AuthorizationHeaderMiddleware;
143 use reqwest_middleware::reqwest::header::HeaderValue;
144 use reqwest_middleware::reqwest::header::AUTHORIZATION;
145 use reqwest_middleware::reqwest::Request;
146 use reqwest_middleware::reqwest::Response;
147 use reqwest_middleware::Next;
148
149 #[derive(Debug)]
150 struct MyTokenSource {
151 pub token: String,
152 }
153
154 #[async_trait::async_trait]
155 impl TokenSource for MyTokenSource {
156 async fn token(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
157 Ok(self.token.clone())
158 }
159 }
160
161 #[derive(Debug)]
162 struct MyTokenProvider {
163 pub ts: Arc<MyTokenSource>,
164 }
165
166 impl TokenSourceProvider for MyTokenProvider {
167 fn token_source(&self) -> Arc<dyn TokenSource> {
168 self.ts.clone()
169 }
170 }
171
172 /// A simple middleware to verify the Authorization header
173 /// is set correctly.
174 ///
175 /// For testing purposes only.
176 struct VerificationMiddleware {
177 expected: &'static str,
178 }
179
180 #[async_trait::async_trait]
181 impl Middleware for VerificationMiddleware {
182 async fn handle(
183 &self,
184 req: Request,
185 extensions: &mut Extensions,
186 next: Next<'_>,
187 ) -> reqwest_middleware::Result<Response> {
188 // Verify the Authorization header is set correctly
189 let token_value = req
190 .headers()
191 .get(AUTHORIZATION)
192 .expect("Authorization header should be set");
193 assert_eq!(token_value, &HeaderValue::from_static(self.expected));
194
195 // Chain to next middleware in the stack
196 next.run(req, extensions).await
197 }
198 }
199
200 #[async_std::test]
201 async fn test_middleware() {
202 // Given - the Authorization middleware & test verification one
203 let token_value = "Bearer my-token";
204 let ts_provider = MyTokenProvider {
205 ts: Arc::new(MyTokenSource {
206 token: token_value.to_string(),
207 }),
208 };
209 let auth_middleware = AuthorizationHeaderMiddleware::from(ts_provider.token_source());
210 let verification_middleware = VerificationMiddleware { expected: token_value };
211
212 let client = ClientBuilder::new(reqwest::Client::default())
213 // Authorization should come first
214 .with(auth_middleware)
215 // Verification should come next
216 .with(verification_middleware)
217 .build();
218
219 // When - making a request
220 // Then - the Authorization header has been set correctly
221 let _ = client
222 .get("https://github.com/nicolas-vivot/reqwest-auth/CODE_OF_CONDUCT.md")
223 .send()
224 .await;
225 }
226}