Skip to main content

busbar_sf_auth/
oauth.rs

1//! OAuth 2.0 authentication flows.
2//!
3//! This module provides secure OAuth 2.0 flows for Salesforce authentication:
4//! - **Web Server Flow** - For web applications with user interaction
5//! - **JWT Bearer Flow** - For server-to-server integration (see jwt.rs)
6//! - **Refresh Token** - For refreshing expired access tokens
7//!
8//! Note: Device Code Flow has been intentionally excluded as it is being
9//! deprecated due to security concerns.
10
11use serde::{Deserialize, Serialize};
12use tracing::instrument;
13
14use crate::credentials::SalesforceCredentials;
15use crate::error::{Error, ErrorKind, Result};
16
17/// OAuth 2.0 configuration for a connected app.
18///
19/// Sensitive fields like `consumer_secret` are redacted in Debug output
20/// to prevent accidental exposure in logs.
21#[derive(Clone)]
22pub struct OAuthConfig {
23    /// Consumer key (client_id).
24    pub consumer_key: String,
25    /// Consumer secret (client_secret). Optional for some flows.
26    consumer_secret: Option<String>,
27    /// Redirect URI for web flow.
28    pub redirect_uri: Option<String>,
29    /// Scopes to request.
30    pub scopes: Vec<String>,
31}
32
33impl std::fmt::Debug for OAuthConfig {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("OAuthConfig")
36            .field("consumer_key", &self.consumer_key)
37            .field("consumer_secret", &"[REDACTED]")
38            .field("redirect_uri", &self.redirect_uri)
39            .field("scopes", &self.scopes)
40            .finish()
41    }
42}
43
44impl OAuthConfig {
45    /// Create a new OAuth config.
46    pub fn new(consumer_key: impl Into<String>) -> Self {
47        Self {
48            consumer_key: consumer_key.into(),
49            consumer_secret: None,
50            redirect_uri: None,
51            scopes: vec!["api".to_string(), "refresh_token".to_string()],
52        }
53    }
54
55    /// Set the consumer secret.
56    pub fn with_secret(mut self, secret: impl Into<String>) -> Self {
57        self.consumer_secret = Some(secret.into());
58        self
59    }
60
61    /// Get the consumer secret (for internal use).
62    #[allow(dead_code)]
63    pub(crate) fn consumer_secret(&self) -> Option<&str> {
64        self.consumer_secret.as_deref()
65    }
66
67    /// Set the redirect URI.
68    pub fn with_redirect_uri(mut self, uri: impl Into<String>) -> Self {
69        self.redirect_uri = Some(uri.into());
70        self
71    }
72
73    /// Set the scopes.
74    pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
75        self.scopes = scopes;
76        self
77    }
78}
79
80/// OAuth client for authenticating with Salesforce.
81#[derive(Clone)]
82pub struct OAuthClient {
83    config: OAuthConfig,
84    http_client: reqwest::Client,
85}
86
87impl std::fmt::Debug for OAuthClient {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("OAuthClient")
90            .field("config", &self.config)
91            .finish_non_exhaustive()
92    }
93}
94
95impl OAuthClient {
96    /// Create a new OAuth client.
97    pub fn new(config: OAuthConfig) -> Self {
98        Self {
99            config,
100            http_client: reqwest::Client::new(),
101        }
102    }
103
104    /// Get the OAuth config.
105    pub fn config(&self) -> &OAuthConfig {
106        &self.config
107    }
108
109    /// Refresh an access token using a refresh token.
110    ///
111    /// The refresh_token parameter is not logged to prevent credential exposure.
112    #[instrument(skip(self, refresh_token))]
113    pub async fn refresh_token(
114        &self,
115        refresh_token: &str,
116        login_url: &str,
117    ) -> Result<TokenResponse> {
118        let mut params = vec![
119            ("grant_type", "refresh_token"),
120            ("refresh_token", refresh_token),
121            ("client_id", &self.config.consumer_key),
122        ];
123
124        if let Some(ref secret) = self.config.consumer_secret {
125            params.push(("client_secret", secret));
126        }
127
128        let body = serde_urlencoded::to_string(params)?;
129
130        let response = self
131            .http_client
132            .post(format!("{}/services/oauth2/token", login_url))
133            .header("Content-Type", "application/x-www-form-urlencoded")
134            .body(body)
135            .send()
136            .await?;
137
138        self.handle_token_response(response).await
139    }
140
141    /// Validate an access token.
142    ///
143    /// The token parameter is not logged to prevent credential exposure.
144    /// Uses POST with token in body to avoid exposing token in URL/logs.
145    #[instrument(skip(self, token))]
146    pub async fn validate_token(&self, token: &str, login_url: &str) -> Result<TokenInfo> {
147        // Use POST with token in body instead of GET with query param
148        // This prevents the token from appearing in server logs
149        let form_data = [("access_token", token)];
150        let body = serde_urlencoded::to_string(form_data)?;
151
152        let response = self
153            .http_client
154            .post(format!("{}/services/oauth2/tokeninfo", login_url))
155            .header("Content-Type", "application/x-www-form-urlencoded")
156            .body(body)
157            .send()
158            .await?;
159
160        if !response.status().is_success() {
161            return Err(Error::new(ErrorKind::TokenInvalid(
162                "Token validation failed".to_string(),
163            )));
164        }
165
166        let info: TokenInfo = response.json().await?;
167        Ok(info)
168    }
169
170    /// Revoke an access token or refresh token.
171    ///
172    /// This method implements [RFC 7009](https://datatracker.ietf.org/doc/html/rfc7009)
173    /// OAuth 2.0 Token Revocation. It programmatically invalidates tokens, enabling
174    /// clean session management and security-sensitive applications.
175    ///
176    /// # Token Type Behavior
177    ///
178    /// - **Revoking a refresh token**: Invalidates the refresh token AND all associated
179    ///   access tokens that were issued from it. Use this for complete session termination.
180    /// - **Revoking an access token**: Invalidates only that specific access token. The
181    ///   refresh token and other access tokens remain valid.
182    ///
183    /// # Idempotency
184    ///
185    /// This endpoint is idempotent - revoking an already invalid or non-existent token
186    /// will still return success (HTTP 200). This prevents information leakage about
187    /// token validity.
188    ///
189    /// # Example
190    ///
191    /// ```no_run
192    /// # use busbar_sf_auth::{OAuthClient, OAuthConfig};
193    /// # async fn example() -> Result<(), busbar_sf_auth::Error> {
194    /// let config = OAuthConfig::new("consumer_key");
195    /// let client = OAuthClient::new(config);
196    ///
197    /// // Revoke a refresh token (also revokes all its access tokens)
198    /// client.revoke_token("refresh_token_here", "https://login.salesforce.com").await?;
199    /// # Ok(())
200    /// # }
201    /// ```
202    ///
203    /// The token parameter is not logged to prevent credential exposure.
204    #[instrument(skip(self, token))]
205    pub async fn revoke_token(&self, token: &str, login_url: &str) -> Result<()> {
206        let form_data = [("token", token)];
207        let body = serde_urlencoded::to_string(form_data)?;
208
209        let response = self
210            .http_client
211            .post(format!("{}/services/oauth2/revoke", login_url))
212            .header("Content-Type", "application/x-www-form-urlencoded")
213            .body(body)
214            .send()
215            .await?;
216
217        if !response.status().is_success() {
218            // Try to parse error response; Salesforce may return non-JSON (HTML, empty body)
219            let status = response.status();
220            let body = response.text().await.unwrap_or_default();
221            if let Ok(error) = serde_json::from_str::<OAuthErrorResponse>(&body) {
222                return Err(Error::new(ErrorKind::OAuth {
223                    error: error.error,
224                    description: error.error_description,
225                }));
226            }
227            return Err(Error::new(ErrorKind::Http(format!(
228                "Token revocation failed with status {status}"
229            ))));
230        }
231
232        Ok(())
233    }
234
235    /// Handle a token response, checking for errors.
236    async fn handle_token_response(&self, response: reqwest::Response) -> Result<TokenResponse> {
237        if !response.status().is_success() {
238            let error: OAuthErrorResponse = response.json().await?;
239            return Err(Error::new(ErrorKind::OAuth {
240                error: error.error,
241                description: error.error_description,
242            }));
243        }
244
245        let token: TokenResponse = response.json().await?;
246        Ok(token)
247    }
248}
249
250/// Web Server OAuth flow for web applications.
251#[derive(Clone)]
252pub struct WebFlowAuth {
253    config: OAuthConfig,
254    http_client: reqwest::Client,
255}
256
257impl std::fmt::Debug for WebFlowAuth {
258    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259        f.debug_struct("WebFlowAuth")
260            .field("config", &self.config)
261            .finish_non_exhaustive()
262    }
263}
264
265impl WebFlowAuth {
266    /// Create a new web flow authenticator.
267    pub fn new(config: OAuthConfig) -> Result<Self> {
268        if config.redirect_uri.is_none() {
269            return Err(Error::new(ErrorKind::Config(
270                "redirect_uri is required for web flow".to_string(),
271            )));
272        }
273
274        Ok(Self {
275            config,
276            http_client: reqwest::Client::new(),
277        })
278    }
279
280    /// Generate the authorization URL to redirect users to.
281    pub fn authorization_url(&self, login_url: &str, state: Option<&str>) -> String {
282        let redirect_uri = self.config.redirect_uri.as_ref().unwrap();
283        let scopes = self.config.scopes.join(" ");
284
285        let mut url = format!(
286            "{}/services/oauth2/authorize?response_type=code&client_id={}&redirect_uri={}",
287            login_url,
288            urlencoding::encode(&self.config.consumer_key),
289            urlencoding::encode(redirect_uri),
290        );
291
292        if !scopes.is_empty() {
293            url.push_str(&format!("&scope={}", urlencoding::encode(&scopes)));
294        }
295
296        if let Some(state) = state {
297            url.push_str(&format!("&state={}", urlencoding::encode(state)));
298        }
299
300        url
301    }
302
303    /// Exchange an authorization code for tokens.
304    ///
305    /// The code parameter is not logged to prevent credential exposure.
306    #[instrument(skip(self, code))]
307    pub async fn exchange_code(&self, code: &str, login_url: &str) -> Result<TokenResponse> {
308        let redirect_uri = self.config.redirect_uri.as_ref().unwrap();
309
310        let mut params = vec![
311            ("grant_type", "authorization_code"),
312            ("code", code),
313            ("client_id", &self.config.consumer_key),
314            ("redirect_uri", redirect_uri),
315        ];
316
317        if let Some(ref secret) = self.config.consumer_secret {
318            params.push(("client_secret", secret));
319        }
320
321        let body = serde_urlencoded::to_string(params)?;
322
323        let response = self
324            .http_client
325            .post(format!("{}/services/oauth2/token", login_url))
326            .header("Content-Type", "application/x-www-form-urlencoded")
327            .body(body)
328            .send()
329            .await?;
330
331        if !response.status().is_success() {
332            let error: OAuthErrorResponse = response.json().await?;
333            return Err(Error::new(ErrorKind::OAuth {
334                error: error.error,
335                description: error.error_description,
336            }));
337        }
338
339        let token: TokenResponse = response.json().await?;
340        Ok(token)
341    }
342}
343
344/// Token response from OAuth.
345///
346/// Sensitive fields like `access_token` and `refresh_token` are redacted
347/// in Debug output to prevent accidental exposure in logs.
348#[derive(Clone, Deserialize, Serialize)]
349pub struct TokenResponse {
350    /// Access token.
351    pub access_token: String,
352    /// Refresh token (if requested).
353    #[serde(default)]
354    pub refresh_token: Option<String>,
355    /// Instance URL.
356    pub instance_url: String,
357    /// User ID URL.
358    #[serde(default)]
359    pub id: Option<String>,
360    /// Token type (usually "Bearer").
361    #[serde(default)]
362    pub token_type: Option<String>,
363    /// Scopes granted.
364    #[serde(default)]
365    pub scope: Option<String>,
366    /// Signature for verification.
367    #[serde(default)]
368    pub signature: Option<String>,
369    /// Issued at timestamp.
370    #[serde(default)]
371    pub issued_at: Option<String>,
372}
373
374impl std::fmt::Debug for TokenResponse {
375    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376        f.debug_struct("TokenResponse")
377            .field("access_token", &"[REDACTED]")
378            .field(
379                "refresh_token",
380                &self.refresh_token.as_ref().map(|_| "[REDACTED]"),
381            )
382            .field("instance_url", &self.instance_url)
383            .field("id", &self.id)
384            .field("token_type", &self.token_type)
385            .field("scope", &self.scope)
386            .field("signature", &self.signature.as_ref().map(|_| "[REDACTED]"))
387            .field("issued_at", &self.issued_at)
388            .finish()
389    }
390}
391
392impl TokenResponse {
393    /// Convert to SalesforceCredentials.
394    pub fn to_credentials(&self, api_version: &str) -> SalesforceCredentials {
395        let mut creds =
396            SalesforceCredentials::new(&self.instance_url, &self.access_token, api_version);
397
398        if let Some(ref rt) = self.refresh_token {
399            creds = creds.with_refresh_token(rt);
400        }
401
402        creds
403    }
404}
405
406/// Token info from validation.
407#[derive(Debug, Clone, Deserialize)]
408pub struct TokenInfo {
409    /// Whether the token is active.
410    pub active: bool,
411    /// Scopes.
412    #[serde(default)]
413    pub scope: Option<String>,
414    /// Client ID.
415    #[serde(default)]
416    pub client_id: Option<String>,
417    /// Username.
418    #[serde(default)]
419    pub username: Option<String>,
420    /// Token type.
421    #[serde(default)]
422    pub token_type: Option<String>,
423    /// Expiration time.
424    #[serde(default)]
425    pub exp: Option<u64>,
426    /// Issued at.
427    #[serde(default)]
428    pub iat: Option<u64>,
429    /// Subject.
430    #[serde(default)]
431    pub sub: Option<String>,
432}
433
434/// OAuth error response.
435#[derive(Debug, Deserialize)]
436struct OAuthErrorResponse {
437    error: String,
438    error_description: String,
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use crate::credentials::Credentials;
445
446    #[test]
447    fn test_oauth_config() {
448        let config = OAuthConfig::new("consumer_key")
449            .with_secret("secret")
450            .with_redirect_uri("https://example.com/callback")
451            .with_scopes(vec!["api".to_string(), "web".to_string()]);
452
453        assert_eq!(config.consumer_key, "consumer_key");
454        assert_eq!(config.consumer_secret(), Some("secret"));
455        assert_eq!(
456            config.redirect_uri,
457            Some("https://example.com/callback".to_string())
458        );
459        assert_eq!(config.scopes, vec!["api", "web"]);
460    }
461
462    #[test]
463    fn test_oauth_config_debug_redacts_secret() {
464        let config = OAuthConfig::new("consumer_key").with_secret("super_secret_value");
465
466        let debug_output = format!("{:?}", config);
467        assert!(debug_output.contains("[REDACTED]"));
468        assert!(!debug_output.contains("super_secret_value"));
469    }
470
471    #[test]
472    fn test_web_flow_auth_url() {
473        let config = OAuthConfig::new("my_client_id")
474            .with_redirect_uri("https://localhost:8080/callback")
475            .with_scopes(vec!["api".to_string()]);
476
477        let auth = WebFlowAuth::new(config).unwrap();
478        let url = auth.authorization_url("https://login.salesforce.com", Some("state123"));
479
480        assert!(url.contains("response_type=code"));
481        assert!(url.contains("client_id=my_client_id"));
482        assert!(url.contains("redirect_uri="));
483        assert!(url.contains("state=state123"));
484    }
485
486    #[test]
487    fn test_token_response_to_credentials() {
488        let token = TokenResponse {
489            access_token: "access123".to_string(),
490            refresh_token: Some("refresh456".to_string()),
491            instance_url: "https://na1.salesforce.com".to_string(),
492            id: None,
493            token_type: Some("Bearer".to_string()),
494            scope: None,
495            signature: None,
496            issued_at: None,
497        };
498
499        let creds = token.to_credentials("62.0");
500        assert_eq!(creds.instance_url(), "https://na1.salesforce.com");
501        assert_eq!(creds.access_token(), "access123");
502        assert_eq!(creds.refresh_token(), Some("refresh456"));
503    }
504
505    #[test]
506    fn test_token_response_debug_redacts_tokens() {
507        let token = TokenResponse {
508            access_token: "super_secret_access_token".to_string(),
509            refresh_token: Some("super_secret_refresh_token".to_string()),
510            instance_url: "https://na1.salesforce.com".to_string(),
511            id: None,
512            token_type: Some("Bearer".to_string()),
513            scope: None,
514            signature: Some("signature_value".to_string()),
515            issued_at: None,
516        };
517
518        let debug_output = format!("{:?}", token);
519        assert!(debug_output.contains("[REDACTED]"));
520        assert!(!debug_output.contains("super_secret_access_token"));
521        assert!(!debug_output.contains("super_secret_refresh_token"));
522        assert!(!debug_output.contains("signature_value"));
523    }
524
525    #[tokio::test]
526    async fn test_revoke_token_success() {
527        use wiremock::matchers::{body_string_contains, header, method, path};
528        use wiremock::{Mock, MockServer, ResponseTemplate};
529
530        let mock_server = MockServer::start().await;
531
532        // Mock the revoke endpoint - returns 200 with empty body on success
533        Mock::given(method("POST"))
534            .and(path("/services/oauth2/revoke"))
535            .and(header("Content-Type", "application/x-www-form-urlencoded"))
536            .and(body_string_contains("token=test_token_to_revoke"))
537            .respond_with(ResponseTemplate::new(200))
538            .mount(&mock_server)
539            .await;
540
541        let config = OAuthConfig::new("test_client_id");
542        let client = OAuthClient::new(config);
543
544        let result = client
545            .revoke_token("test_token_to_revoke", &mock_server.uri())
546            .await;
547
548        assert!(result.is_ok(), "Token revocation should succeed");
549    }
550
551    #[tokio::test]
552    async fn test_revoke_token_idempotency() {
553        use wiremock::matchers::{method, path};
554        use wiremock::{Mock, MockServer, ResponseTemplate};
555
556        let mock_server = MockServer::start().await;
557
558        // Mock the revoke endpoint - returns 200 even for invalid tokens (idempotent)
559        Mock::given(method("POST"))
560            .and(path("/services/oauth2/revoke"))
561            .respond_with(ResponseTemplate::new(200))
562            .mount(&mock_server)
563            .await;
564
565        let config = OAuthConfig::new("test_client_id");
566        let client = OAuthClient::new(config);
567
568        // First revocation
569        let result1 = client
570            .revoke_token("already_invalid_token", &mock_server.uri())
571            .await;
572        assert!(result1.is_ok(), "First revocation should succeed");
573
574        // Second revocation of same token (idempotent behavior)
575        let result2 = client
576            .revoke_token("already_invalid_token", &mock_server.uri())
577            .await;
578        assert!(
579            result2.is_ok(),
580            "Second revocation should also succeed (idempotent)"
581        );
582    }
583
584    #[tokio::test]
585    async fn test_revoke_token_failure() {
586        use wiremock::matchers::{method, path};
587        use wiremock::{Mock, MockServer, ResponseTemplate};
588
589        let mock_server = MockServer::start().await;
590
591        // Mock the revoke endpoint returning an error
592        Mock::given(method("POST"))
593            .and(path("/services/oauth2/revoke"))
594            .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
595                "error": "invalid_request",
596                "error_description": "Token parameter is missing"
597            })))
598            .mount(&mock_server)
599            .await;
600
601        let config = OAuthConfig::new("test_client_id");
602        let client = OAuthClient::new(config);
603
604        let result = client
605            .revoke_token("malformed_token", &mock_server.uri())
606            .await;
607
608        assert!(result.is_err(), "Token revocation should fail");
609        let err = result.unwrap_err();
610        assert!(
611            matches!(err.kind, ErrorKind::OAuth { .. }),
612            "Should return OAuth error"
613        );
614    }
615
616    #[tokio::test]
617    async fn test_revoke_token_non_json_error() {
618        use wiremock::matchers::{method, path};
619        use wiremock::{Mock, MockServer, ResponseTemplate};
620
621        let mock_server = MockServer::start().await;
622
623        // Mock revoke endpoint returning non-JSON error body (HTML or empty)
624        Mock::given(method("POST"))
625            .and(path("/services/oauth2/revoke"))
626            .respond_with(ResponseTemplate::new(400).set_body_string("<html>Bad Request</html>"))
627            .mount(&mock_server)
628            .await;
629
630        let config = OAuthConfig::new("test_client_id");
631        let client = OAuthClient::new(config);
632
633        let result = client.revoke_token("some_token", &mock_server.uri()).await;
634
635        assert!(result.is_err(), "Should fail with non-JSON error body");
636        let err = result.unwrap_err();
637        assert!(
638            matches!(err.kind, ErrorKind::Http(_)),
639            "Should return Http error, got: {:?}",
640            err.kind
641        );
642        assert!(
643            err.to_string().contains("revocation failed"),
644            "Error should mention revocation failed"
645        );
646    }
647}