clawspec_core/client/oauth2/
provider.rs1use std::time::Duration;
4
5use oauth2::{AccessToken, TokenResponse};
6
7use super::config::{OAuth2Config, OAuth2GrantType};
8use super::error::OAuth2Error;
9use super::token::OAuth2Token;
10
11impl OAuth2Config {
12 pub async fn acquire_token(&self) -> Result<OAuth2Token, OAuth2Error> {
25 match self.grant_type {
26 OAuth2GrantType::ClientCredentials => self.acquire_client_credentials_token().await,
27 OAuth2GrantType::PreAcquired => self.get_pre_acquired_token().await,
28 }
29 }
30
31 async fn acquire_client_credentials_token(&self) -> Result<OAuth2Token, OAuth2Error> {
33 let http_client = oauth2::reqwest::ClientBuilder::new()
36 .redirect(oauth2::reqwest::redirect::Policy::none())
37 .build()
38 .map_err(|e| OAuth2Error::TokenAcquisitionFailed {
39 reason: format!("Failed to create HTTP client: {e}"),
40 })?;
41
42 self.acquire_client_credentials_token_with_client(&http_client)
43 .await
44 }
45
46 pub(crate) async fn acquire_client_credentials_token_with_client(
51 &self,
52 http_client: &oauth2::reqwest::Client,
53 ) -> Result<OAuth2Token, OAuth2Error> {
54 use oauth2::basic::BasicClient;
55 use oauth2::{AuthUrl, ClientId, ClientSecret, Scope, TokenUrl};
56
57 let client_id = ClientId::new(self.client_id.clone());
58
59 let auth_url_str = self
61 .auth_url
62 .as_ref()
63 .map(|u| u.to_string())
64 .unwrap_or_else(|| format!("{}/../authorize", self.token_url));
65
66 let auth_url = AuthUrl::new(auth_url_str).map_err(|e| OAuth2Error::ConfigurationError {
67 reason: format!("Invalid authorization URL: {e}"),
68 })?;
69
70 let token_url = TokenUrl::new(self.token_url.to_string()).map_err(|e| {
71 OAuth2Error::ConfigurationError {
72 reason: format!("Invalid token URL: {e}"),
73 }
74 })?;
75
76 let mut client = BasicClient::new(client_id)
80 .set_auth_uri(auth_url)
81 .set_token_uri(token_url);
82
83 if let Some(ref secret) = self.client_secret {
85 client = client.set_client_secret(ClientSecret::new(secret.as_str().to_string()));
86 }
87
88 let mut request = client.exchange_client_credentials();
89
90 for scope in self.scopes.iter().map(|s| Scope::new(s.clone())) {
92 request = request.add_scope(scope);
93 }
94
95 let token_result = request.request_async(http_client).await.map_err(|e| {
97 OAuth2Error::TokenAcquisitionFailed {
98 reason: format!("{e}"),
99 }
100 })?;
101
102 let token =
104 Self::convert_token_response(token_result.access_token(), token_result.expires_in());
105
106 self.set_token(token.clone()).await;
108
109 Ok(token)
110 }
111
112 async fn get_pre_acquired_token(&self) -> Result<OAuth2Token, OAuth2Error> {
114 self.get_token().await.ok_or(OAuth2Error::TokenExpired)
115 }
116
117 fn convert_token_response(
119 access_token: &AccessToken,
120 expires_in: Option<Duration>,
121 ) -> OAuth2Token {
122 if let Some(duration) = expires_in {
123 OAuth2Token::with_expiry(access_token.secret().clone(), duration)
124 } else {
125 OAuth2Token::new(access_token.secret().clone())
126 }
127 }
128
129 pub async fn get_valid_token(&self) -> Result<OAuth2Token, OAuth2Error> {
134 if !self.needs_token().await
136 && let Some(token) = self.get_token().await
137 {
138 return Ok(token);
139 }
140
141 self.acquire_token().await
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[tokio::test]
155 async fn should_return_pre_acquired_token() {
156 let config = OAuth2Config::pre_acquired(
157 "client-id",
158 "https://auth.example.com/token",
159 "pre-acquired-access-token",
160 )
161 .expect("Should create builder")
162 .build()
163 .expect("Should build config");
164
165 let token = config.get_valid_token().await.expect("Should get token");
166 assert_eq!(token.access_token(), "pre-acquired-access-token");
167 }
168
169 #[tokio::test]
170 async fn should_fail_when_no_pre_acquired_token() {
171 let config =
172 OAuth2Config::pre_acquired("client-id", "https://auth.example.com/token", "token")
173 .expect("Should create builder")
174 .build()
175 .expect("Should build config");
176
177 config.token_cache.clear().await;
178
179 let result = config.get_pre_acquired_token().await;
180 assert!(result.is_err());
181 match result.expect_err("Should fail") {
182 OAuth2Error::TokenExpired => {}
183 _ => panic!("Expected TokenExpired error"),
184 }
185 }
186
187 #[tokio::test]
192 async fn should_return_cached_token_without_network_call() {
193 let config = OAuth2Config::client_credentials(
194 "test-client",
195 "test-secret",
196 "https://auth.example.com/token",
197 )
198 .expect("Should create builder")
199 .build()
200 .expect("Should build config");
201
202 let token = OAuth2Token::with_expiry("cached-valid-token", Duration::from_secs(3600));
204 config.set_token(token).await;
205
206 let result = config.get_valid_token().await;
208 let token = result.expect("Should return cached token");
209 assert_eq!(token.access_token(), "cached-valid-token");
210 }
211
212 #[test]
217 fn should_convert_token_with_expiry() {
218 let access_token = oauth2::AccessToken::new("test-token".to_string());
219 let expires_in = Some(Duration::from_secs(3600));
220
221 let token = OAuth2Config::convert_token_response(&access_token, expires_in);
222
223 assert_eq!(token.access_token(), "test-token");
224 assert!(token.time_until_expiry().is_some());
225 }
226
227 #[test]
228 fn should_convert_token_without_expiry() {
229 let access_token = oauth2::AccessToken::new("no-expiry".to_string());
230 let expires_in = None;
231
232 let token = OAuth2Config::convert_token_response(&access_token, expires_in);
233
234 assert_eq!(token.access_token(), "no-expiry");
235 assert!(token.time_until_expiry().is_none());
236 }
237
238 #[tokio::test]
243 async fn should_configure_scopes() {
244 let config = OAuth2Config::client_credentials(
245 "test-client",
246 "test-secret",
247 "https://auth.example.com/token",
248 )
249 .expect("Should create builder")
250 .add_scope("read:users")
251 .add_scope("write:users")
252 .build()
253 .expect("Should build config");
254
255 assert_eq!(config.scopes, vec!["read:users", "write:users"]);
256 }
257}