clawspec_core/client/oauth2/
provider.rs

1//! OAuth2 token provider for acquiring and refreshing tokens.
2
3use std::time::Duration;
4
5use oauth2::{AccessToken, TokenResponse};
6
7use super::config::{OAuth2Config, OAuth2GrantType};
8use super::error::OAuth2Error;
9use super::token::OAuth2Token;
10
11impl OAuth2Config {
12    /// Acquires a new access token using the configured grant type.
13    ///
14    /// This method handles:
15    /// - Client Credentials grant: fetches a new token from the token endpoint
16    /// - Pre-Acquired token: returns the cached token if available
17    ///
18    /// # Errors
19    ///
20    /// Returns an error if:
21    /// - Network request fails
22    /// - Token endpoint returns an error
23    /// - Response cannot be parsed
24    pub async fn acquire_token(&self) -> Result<OAuth2Token, OAuth2Error> {
25        match self.grant_type {
26            OAuth2GrantType::ClientCredentials => self.acquire_client_credentials_token().await,
27            OAuth2GrantType::PreAcquired => self.get_pre_acquired_token().await,
28        }
29    }
30
31    /// Acquires a token using the Client Credentials grant.
32    async fn acquire_client_credentials_token(&self) -> Result<OAuth2Token, OAuth2Error> {
33        // Create HTTP client with redirect disabled for SSRF prevention
34        // Use oauth2::reqwest to ensure version compatibility
35        let http_client = oauth2::reqwest::ClientBuilder::new()
36            .redirect(oauth2::reqwest::redirect::Policy::none())
37            .build()
38            .map_err(|e| OAuth2Error::TokenAcquisitionFailed {
39                reason: format!("Failed to create HTTP client: {e}"),
40            })?;
41
42        self.acquire_client_credentials_token_with_client(&http_client)
43            .await
44    }
45
46    /// Internal method for acquiring tokens with a custom HTTP client.
47    ///
48    /// This enables testing without making real network requests by injecting
49    /// mock HTTP clients that return predefined responses.
50    pub(crate) async fn acquire_client_credentials_token_with_client(
51        &self,
52        http_client: &oauth2::reqwest::Client,
53    ) -> Result<OAuth2Token, OAuth2Error> {
54        use oauth2::basic::BasicClient;
55        use oauth2::{AuthUrl, ClientId, ClientSecret, Scope, TokenUrl};
56
57        let client_id = ClientId::new(self.client_id.clone());
58
59        // Use a dummy auth URL if not specified (client_credentials doesn't need it)
60        let auth_url_str = self
61            .auth_url
62            .as_ref()
63            .map(|u| u.to_string())
64            .unwrap_or_else(|| format!("{}/../authorize", self.token_url));
65
66        let auth_url = AuthUrl::new(auth_url_str).map_err(|e| OAuth2Error::ConfigurationError {
67            reason: format!("Invalid authorization URL: {e}"),
68        })?;
69
70        let token_url = TokenUrl::new(self.token_url.to_string()).map_err(|e| {
71            OAuth2Error::ConfigurationError {
72                reason: format!("Invalid token URL: {e}"),
73            }
74        })?;
75
76        // Build client using the new builder pattern (oauth2 5.x)
77        // The type-state pattern ensures exchange_client_credentials() is available
78        // only after set_token_uri() is called
79        let mut client = BasicClient::new(client_id)
80            .set_auth_uri(auth_url)
81            .set_token_uri(token_url);
82
83        // Set client secret if provided
84        if let Some(ref secret) = self.client_secret {
85            client = client.set_client_secret(ClientSecret::new(secret.as_str().to_string()));
86        }
87
88        let mut request = client.exchange_client_credentials();
89
90        // Add scopes
91        for scope in self.scopes.iter().map(|s| Scope::new(s.clone())) {
92            request = request.add_scope(scope);
93        }
94
95        // Execute the request
96        let token_result = request.request_async(http_client).await.map_err(|e| {
97            OAuth2Error::TokenAcquisitionFailed {
98                reason: format!("{e}"),
99            }
100        })?;
101
102        // Convert to our token type
103        let token =
104            Self::convert_token_response(token_result.access_token(), token_result.expires_in());
105
106        // Cache the token
107        self.set_token(token.clone()).await;
108
109        Ok(token)
110    }
111
112    /// Returns the pre-acquired token if available.
113    async fn get_pre_acquired_token(&self) -> Result<OAuth2Token, OAuth2Error> {
114        self.get_token().await.ok_or(OAuth2Error::TokenExpired)
115    }
116
117    /// Converts an oauth2 token response to our token type.
118    fn convert_token_response(
119        access_token: &AccessToken,
120        expires_in: Option<Duration>,
121    ) -> OAuth2Token {
122        if let Some(duration) = expires_in {
123            OAuth2Token::with_expiry(access_token.secret().clone(), duration)
124        } else {
125            OAuth2Token::new(access_token.secret().clone())
126        }
127    }
128
129    /// Gets a valid token, acquiring a new one if necessary.
130    ///
131    /// This is the main entry point for getting an access token.
132    /// It checks the cache first and only acquires a new token if needed.
133    pub async fn get_valid_token(&self) -> Result<OAuth2Token, OAuth2Error> {
134        // Check if we have a valid cached token
135        if !self.needs_token().await
136            && let Some(token) = self.get_token().await
137        {
138            return Ok(token);
139        }
140
141        // Need to acquire a new token
142        self.acquire_token().await
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    // =========================================
151    // Pre-acquired token tests
152    // =========================================
153
154    #[tokio::test]
155    async fn should_return_pre_acquired_token() {
156        let config = OAuth2Config::pre_acquired(
157            "client-id",
158            "https://auth.example.com/token",
159            "pre-acquired-access-token",
160        )
161        .expect("Should create builder")
162        .build()
163        .expect("Should build config");
164
165        let token = config.get_valid_token().await.expect("Should get token");
166        assert_eq!(token.access_token(), "pre-acquired-access-token");
167    }
168
169    #[tokio::test]
170    async fn should_fail_when_no_pre_acquired_token() {
171        let config =
172            OAuth2Config::pre_acquired("client-id", "https://auth.example.com/token", "token")
173                .expect("Should create builder")
174                .build()
175                .expect("Should build config");
176
177        config.token_cache.clear().await;
178
179        let result = config.get_pre_acquired_token().await;
180        assert!(result.is_err());
181        match result.expect_err("Should fail") {
182            OAuth2Error::TokenExpired => {}
183            _ => panic!("Expected TokenExpired error"),
184        }
185    }
186
187    // =========================================
188    // Cache behavior tests
189    // =========================================
190
191    #[tokio::test]
192    async fn should_return_cached_token_without_network_call() {
193        let config = OAuth2Config::client_credentials(
194            "test-client",
195            "test-secret",
196            "https://auth.example.com/token",
197        )
198        .expect("Should create builder")
199        .build()
200        .expect("Should build config");
201
202        // Pre-populate cache with a valid token (long expiry)
203        let token = OAuth2Token::with_expiry("cached-valid-token", Duration::from_secs(3600));
204        config.set_token(token).await;
205
206        // get_valid_token should return cached token
207        let result = config.get_valid_token().await;
208        let token = result.expect("Should return cached token");
209        assert_eq!(token.access_token(), "cached-valid-token");
210    }
211
212    // =========================================
213    // convert_token_response tests
214    // =========================================
215
216    #[test]
217    fn should_convert_token_with_expiry() {
218        let access_token = oauth2::AccessToken::new("test-token".to_string());
219        let expires_in = Some(Duration::from_secs(3600));
220
221        let token = OAuth2Config::convert_token_response(&access_token, expires_in);
222
223        assert_eq!(token.access_token(), "test-token");
224        assert!(token.time_until_expiry().is_some());
225    }
226
227    #[test]
228    fn should_convert_token_without_expiry() {
229        let access_token = oauth2::AccessToken::new("no-expiry".to_string());
230        let expires_in = None;
231
232        let token = OAuth2Config::convert_token_response(&access_token, expires_in);
233
234        assert_eq!(token.access_token(), "no-expiry");
235        assert!(token.time_until_expiry().is_none());
236    }
237
238    // =========================================
239    // Scope configuration tests
240    // =========================================
241
242    #[tokio::test]
243    async fn should_configure_scopes() {
244        let config = OAuth2Config::client_credentials(
245            "test-client",
246            "test-secret",
247            "https://auth.example.com/token",
248        )
249        .expect("Should create builder")
250        .add_scope("read:users")
251        .add_scope("write:users")
252        .build()
253        .expect("Should build config");
254
255        assert_eq!(config.scopes, vec!["read:users", "write:users"]);
256    }
257}