Skip to main content

clawspec_core/client/oauth2/
provider.rs

1//! OAuth2 token provider for acquiring and refreshing tokens.
2
3use std::time::Duration;
4
5use oauth2::{AccessToken, TokenResponse};
6
7use super::config::{OAuth2Config, OAuth2GrantType};
8use super::error::OAuth2Error;
9use super::token::OAuth2Token;
10
11impl OAuth2Config {
12    /// Acquires a new access token using the configured grant type.
13    ///
14    /// This method handles:
15    /// - Client Credentials grant: fetches a new token from the token endpoint
16    /// - Pre-Acquired token: returns the cached token if available
17    ///
18    /// # Errors
19    ///
20    /// Returns an error if:
21    /// - Network request fails
22    /// - Token endpoint returns an error
23    /// - Response cannot be parsed
24    pub async fn acquire_token(&self) -> Result<OAuth2Token, OAuth2Error> {
25        match self.grant_type {
26            OAuth2GrantType::ClientCredentials => self.acquire_client_credentials_token().await,
27            OAuth2GrantType::PreAcquired => self.get_pre_acquired_token().await,
28        }
29    }
30
31    /// Acquires a token using the Client Credentials grant.
32    async fn acquire_client_credentials_token(&self) -> Result<OAuth2Token, OAuth2Error> {
33        // Create HTTP client with redirect disabled for SSRF prevention
34        // Use oauth2::reqwest to ensure version compatibility
35        let http_client = oauth2::reqwest::ClientBuilder::new()
36            .redirect(oauth2::reqwest::redirect::Policy::none())
37            .build()
38            .map_err(|e| OAuth2Error::TokenAcquisitionFailed {
39                reason: format!("Failed to create HTTP client: {e}"),
40            })?;
41
42        self.acquire_client_credentials_token_with_client(&http_client)
43            .await
44    }
45
46    /// Internal method for acquiring tokens with a custom HTTP client.
47    ///
48    /// This enables testing without making real network requests by injecting
49    /// mock HTTP clients that return predefined responses.
50    pub(crate) async fn acquire_client_credentials_token_with_client(
51        &self,
52        http_client: &oauth2::reqwest::Client,
53    ) -> Result<OAuth2Token, OAuth2Error> {
54        use oauth2::basic::BasicClient;
55        use oauth2::{AuthUrl, ClientId, ClientSecret, Scope, TokenUrl};
56
57        let client_id = ClientId::new(self.client_id.clone());
58
59        // Use a dummy auth URL if not specified (client_credentials doesn't need it)
60        let auth_url_str = self
61            .auth_url
62            .as_ref()
63            .map(|u| u.to_string())
64            .unwrap_or_else(|| format!("{}/../authorize", self.token_url));
65
66        let auth_url = AuthUrl::new(auth_url_str).map_err(|e| OAuth2Error::ConfigurationError {
67            reason: format!("Invalid authorization URL: {e}"),
68        })?;
69
70        let token_url = TokenUrl::new(self.token_url.to_string()).map_err(|e| {
71            OAuth2Error::ConfigurationError {
72                reason: format!("Invalid token URL: {e}"),
73            }
74        })?;
75
76        // Build client using the new builder pattern (oauth2 5.x)
77        // The type-state pattern ensures exchange_client_credentials() is available
78        // only after set_token_uri() is called
79        let mut client = BasicClient::new(client_id)
80            .set_auth_uri(auth_url)
81            .set_token_uri(token_url);
82
83        // Set client secret if provided
84        if let Some(ref secret) = self.client_secret {
85            client = client.set_client_secret(ClientSecret::new(secret.as_str().to_string()));
86        }
87
88        let mut request = client.exchange_client_credentials();
89
90        // Add scopes
91        for scope in self.scopes.iter().map(|s| Scope::new(s.clone())) {
92            request = request.add_scope(scope);
93        }
94
95        // Execute the request
96        let token_result = request.request_async(http_client).await.map_err(|e| {
97            OAuth2Error::TokenAcquisitionFailed {
98                reason: format!("{e}"),
99            }
100        })?;
101
102        // Convert to our token type
103        let token =
104            Self::convert_token_response(token_result.access_token(), token_result.expires_in());
105
106        // Cache the token
107        self.set_token(token.clone()).await;
108
109        Ok(token)
110    }
111
112    /// Returns the pre-acquired token if available.
113    async fn get_pre_acquired_token(&self) -> Result<OAuth2Token, OAuth2Error> {
114        self.get_token().await.ok_or(OAuth2Error::TokenExpired)
115    }
116
117    /// Converts an oauth2 token response to our token type.
118    fn convert_token_response(
119        access_token: &AccessToken,
120        expires_in: Option<Duration>,
121    ) -> OAuth2Token {
122        if let Some(duration) = expires_in {
123            OAuth2Token::with_expiry(access_token.secret().clone(), duration)
124        } else {
125            OAuth2Token::new(access_token.secret().clone())
126        }
127    }
128
129    /// Gets a valid token, acquiring a new one if necessary.
130    ///
131    /// This is the main entry point for getting an access token.
132    /// It checks the cache first and only acquires a new token if needed.
133    pub async fn get_valid_token(&self) -> Result<OAuth2Token, OAuth2Error> {
134        // Check if we have a valid cached token
135        if !self.needs_token().await
136            && let Some(token) = self.get_token().await
137        {
138            return Ok(token);
139        }
140
141        // Need to acquire a new token
142        self.acquire_token().await
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    // =========================================
151    // Pre-acquired token tests
152    // =========================================
153
154    #[tokio::test]
155    async fn should_return_pre_acquired_token() {
156        let config = OAuth2Config::pre_acquired(
157            "client-id",
158            "https://auth.example.com/token",
159            "pre-acquired-access-token",
160        )
161        .expect("Should create builder")
162        .build()
163        .expect("Should build config");
164
165        let token = config.get_valid_token().await.expect("Should get token");
166        assert_eq!(token.access_token(), "pre-acquired-access-token");
167    }
168
169    #[tokio::test]
170    async fn should_fail_when_no_pre_acquired_token() {
171        let config =
172            OAuth2Config::pre_acquired("client-id", "https://auth.example.com/token", "token")
173                .expect("Should create builder")
174                .build()
175                .expect("Should build config");
176
177        config.token_cache.clear().await;
178
179        let result = config.get_pre_acquired_token().await;
180        assert!(result.is_err());
181        match result.expect_err("Should fail") {
182            OAuth2Error::TokenExpired => {}
183            _ => panic!("Expected TokenExpired error"),
184        }
185    }
186
187    // =========================================
188    // Cache behavior tests
189    // =========================================
190
191    #[tokio::test]
192    async fn should_return_cached_token_without_network_call() {
193        let config = OAuth2Config::client_credentials(
194            "test-client",
195            "test-secret",
196            "https://auth.example.com/token",
197        )
198        .expect("Should create builder")
199        .build()
200        .expect("Should build config");
201
202        // Pre-populate cache with a valid token (long expiry)
203        let token = OAuth2Token::with_expiry("cached-valid-token", Duration::from_secs(3600));
204        config.set_token(token).await;
205
206        // get_valid_token should return cached token
207        let result = config.get_valid_token().await;
208        let token = result.expect("Should return cached token");
209        assert_eq!(token.access_token(), "cached-valid-token");
210    }
211
212    // =========================================
213    // convert_token_response tests
214    // =========================================
215
216    #[test]
217    fn should_convert_token_with_expiry() {
218        let access_token = oauth2::AccessToken::new("test-token".to_string());
219        let expires_in = Some(Duration::from_secs(3600));
220
221        let token = OAuth2Config::convert_token_response(&access_token, expires_in);
222
223        assert_eq!(token.access_token(), "test-token");
224        assert!(token.time_until_expiry().is_some());
225    }
226
227    #[test]
228    fn should_convert_token_without_expiry() {
229        let access_token = oauth2::AccessToken::new("no-expiry".to_string());
230        let expires_in = None;
231
232        let token = OAuth2Config::convert_token_response(&access_token, expires_in);
233
234        assert_eq!(token.access_token(), "no-expiry");
235        assert!(token.time_until_expiry().is_none());
236    }
237
238    // =========================================
239    // Scope configuration tests
240    // =========================================
241
242    #[tokio::test]
243    async fn should_configure_scopes() {
244        let config = OAuth2Config::client_credentials(
245            "test-client",
246            "test-secret",
247            "https://auth.example.com/token",
248        )
249        .expect("Should create builder")
250        .add_scope("read:users")
251        .add_scope("write:users")
252        .build()
253        .expect("Should build config");
254
255        assert_eq!(config.scopes, vec!["read:users", "write:users"]);
256    }
257
258    // =========================================
259    // Mock server tests for token acquisition
260    // =========================================
261
262    mod mock_server_tests {
263        use super::*;
264        use wiremock::matchers::{method, path};
265        use wiremock::{Mock, MockServer, ResponseTemplate};
266
267        #[tokio::test]
268        async fn should_acquire_client_credentials_token() {
269            let mock_server = MockServer::start().await;
270
271            Mock::given(method("POST"))
272                .and(path("/oauth/token"))
273                .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
274                    "access_token": "test-access-token-12345",
275                    "token_type": "Bearer",
276                    "expires_in": 3600
277                })))
278                .expect(1)
279                .mount(&mock_server)
280                .await;
281
282            let token_url = format!("{}/oauth/token", mock_server.uri());
283            let config = OAuth2Config::client_credentials("test-client", "test-secret", &token_url)
284                .expect("Should create builder")
285                .build()
286                .expect("Should build config");
287
288            let token = config
289                .acquire_token()
290                .await
291                .expect("Should acquire token successfully");
292
293            assert_eq!(token.access_token(), "test-access-token-12345");
294            assert!(token.time_until_expiry().is_some());
295        }
296
297        #[tokio::test]
298        async fn should_include_scopes_in_token_request() {
299            let mock_server = MockServer::start().await;
300
301            Mock::given(method("POST"))
302                .and(path("/oauth/token"))
303                .and(wiremock::matchers::body_string_contains(
304                    "scope=read%3Ausers",
305                ))
306                .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
307                    "access_token": "scoped-token",
308                    "token_type": "Bearer",
309                    "expires_in": 3600
310                })))
311                .expect(1)
312                .mount(&mock_server)
313                .await;
314
315            let token_url = format!("{}/oauth/token", mock_server.uri());
316            let config = OAuth2Config::client_credentials("test-client", "test-secret", &token_url)
317                .expect("Should create builder")
318                .add_scope("read:users")
319                .build()
320                .expect("Should build config");
321
322            let token = config
323                .acquire_token()
324                .await
325                .expect("Should acquire token with scopes");
326
327            assert_eq!(token.access_token(), "scoped-token");
328        }
329
330        #[tokio::test]
331        async fn should_handle_token_request_failure() {
332            let mock_server = MockServer::start().await;
333
334            Mock::given(method("POST"))
335                .and(path("/oauth/token"))
336                .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
337                    "error": "invalid_client",
338                    "error_description": "Client authentication failed"
339                })))
340                .expect(1)
341                .mount(&mock_server)
342                .await;
343
344            let token_url = format!("{}/oauth/token", mock_server.uri());
345            let config =
346                OAuth2Config::client_credentials("invalid-client", "wrong-secret", &token_url)
347                    .expect("Should create builder")
348                    .build()
349                    .expect("Should build config");
350
351            let result = config.acquire_token().await;
352
353            assert!(result.is_err());
354            match result.expect_err("Should fail") {
355                OAuth2Error::TokenAcquisitionFailed { reason } => {
356                    assert!(
357                        reason.contains("invalid_client") || reason.contains("Client"),
358                        "Error should contain client error info: {reason}"
359                    );
360                }
361                other => panic!("Expected TokenAcquisitionFailed, got {:?}", other),
362            }
363        }
364
365        #[tokio::test]
366        async fn should_handle_invalid_token_url() {
367            // Use an invalid URL that will fail URL parsing within the oauth2 crate
368            let result =
369                OAuth2Config::client_credentials("test-client", "test-secret", "not-a-valid-url");
370
371            assert!(result.is_err());
372        }
373
374        #[tokio::test]
375        async fn should_acquire_token_with_multiple_scopes() {
376            let mock_server = MockServer::start().await;
377
378            Mock::given(method("POST"))
379                .and(path("/oauth/token"))
380                .and(wiremock::matchers::body_string_contains(
381                    "scope=read%3Ausers",
382                ))
383                .and(wiremock::matchers::body_string_contains("write%3Ausers"))
384                .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
385                    "access_token": "multi-scope-token",
386                    "token_type": "Bearer",
387                    "expires_in": 3600
388                })))
389                .expect(1)
390                .mount(&mock_server)
391                .await;
392
393            let token_url = format!("{}/oauth/token", mock_server.uri());
394            let config = OAuth2Config::client_credentials("test-client", "test-secret", &token_url)
395                .expect("Should create builder")
396                .add_scope("read:users")
397                .add_scope("write:users")
398                .build()
399                .expect("Should build config");
400
401            let token = config
402                .acquire_token()
403                .await
404                .expect("Should acquire token with multiple scopes");
405
406            assert_eq!(token.access_token(), "multi-scope-token");
407        }
408
409        #[tokio::test]
410        async fn should_handle_token_without_expiry() {
411            let mock_server = MockServer::start().await;
412
413            Mock::given(method("POST"))
414                .and(path("/oauth/token"))
415                .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
416                    "access_token": "no-expiry-token",
417                    "token_type": "Bearer"
418                })))
419                .expect(1)
420                .mount(&mock_server)
421                .await;
422
423            let token_url = format!("{}/oauth/token", mock_server.uri());
424            let config = OAuth2Config::client_credentials("test-client", "test-secret", &token_url)
425                .expect("Should create builder")
426                .build()
427                .expect("Should build config");
428
429            let token = config
430                .acquire_token()
431                .await
432                .expect("Should acquire token without expiry");
433
434            assert_eq!(token.access_token(), "no-expiry-token");
435            assert!(
436                token.time_until_expiry().is_none(),
437                "Token without expires_in should have no expiry"
438            );
439        }
440
441        #[tokio::test]
442        async fn should_cache_token_after_acquisition() {
443            let mock_server = MockServer::start().await;
444
445            Mock::given(method("POST"))
446                .and(path("/oauth/token"))
447                .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
448                    "access_token": "cached-token-value",
449                    "token_type": "Bearer",
450                    "expires_in": 3600
451                })))
452                .expect(1) // Should only be called once due to caching
453                .mount(&mock_server)
454                .await;
455
456            let token_url = format!("{}/oauth/token", mock_server.uri());
457            let config = OAuth2Config::client_credentials("test-client", "test-secret", &token_url)
458                .expect("Should create builder")
459                .build()
460                .expect("Should build config");
461
462            // First call - should hit the server
463            let token1 = config
464                .acquire_token()
465                .await
466                .expect("First token acquisition should succeed");
467
468            // Second call - should use cached token (get_valid_token checks cache first)
469            let token2 = config
470                .get_valid_token()
471                .await
472                .expect("Second call should use cached token");
473
474            assert_eq!(token1.access_token(), "cached-token-value");
475            assert_eq!(token2.access_token(), "cached-token-value");
476        }
477    }
478}