Skip to main content

hyperdb_api_salesforce/
provider.rs

1// Copyright (c) 2026, Salesforce, Inc. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! DC JWT provider for Salesforce Data Cloud authentication.
5//!
6//! This module implements the two-stage token flow:
7//! 1. Authenticate with Salesforce to get an **OAuth Access Token**
8//!    (via `/services/oauth2/token`)
9//! 2. Exchange the OAuth Access Token for a **DC JWT**
10//!    (via `/services/a360/token`)
11//!
12//! The provider caches both the OAuth Access Token and the DC JWT
13//! independently.  The OAuth Access Token is only refreshed when it
14//! has genuinely expired, avoiding unnecessary **OAuth Refresh Token**
15//! rotation that would invalidate tokens held by other connections.
16
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::Duration;
20
21use reqwest::Client as HttpClient;
22use tokio::sync::Mutex;
23use tracing::{debug, info, warn};
24
25use crate::config::{AuthMode, SalesforceAuthConfig};
26use crate::error::{SalesforceAuthError, SalesforceAuthResult};
27use crate::jwt::build_jwt_assertion;
28use crate::token::{DataCloudToken, DataCloudTokenResponse, OAuthToken, OAuthTokenResponse};
29
30/// OAuth Access Token endpoint path.
31const OAUTH_TOKEN_PATH: &str = "services/oauth2/token";
32
33/// DC JWT exchange endpoint path.
34const DATA_CLOUD_TOKEN_PATH: &str = "services/a360/token";
35
36/// DC JWT provider.
37///
38/// Handles the full token flow for Salesforce Data Cloud:
39/// 1. Authenticates with Salesforce using the configured auth mode to
40///    obtain an **OAuth Access Token**
41/// 2. Exchanges the OAuth Access Token for a **DC JWT**
42/// 3. Caches both tokens and refreshes them independently:
43///    - The OAuth Access Token is refreshed only when genuinely expired
44///      (to avoid unnecessary OAuth Refresh Token rotation)
45///    - The DC JWT is refreshed whenever it is expired or requested
46///
47/// On DC JWT exchange failure, the provider retries once with a
48/// force-refreshed OAuth Access Token (Step 2a), matching the behavior
49/// described in the `GenieOAuthManagement` documentation.
50///
51/// # Example
52///
53/// ```no_run
54/// use hyperdb_api_salesforce::{SalesforceAuthConfig, AuthMode, DataCloudTokenProvider};
55///
56/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
57/// # let private_key_pem = "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----";
58/// let config = SalesforceAuthConfig::new(
59///     "https://login.salesforce.com",
60///     "your-client-id",
61/// )?
62/// .auth_mode(AuthMode::private_key("user@example.com", &private_key_pem)?);
63///
64/// let mut provider = DataCloudTokenProvider::new(config)?;
65///
66/// // Get a valid DC JWT (automatically handles the full token flow)
67/// let token = provider.get_token().await?;
68/// println!("Authorization: {}", token.bearer_token());
69/// # Ok(())
70/// # }
71/// ```
72pub struct DataCloudTokenProvider {
73    /// Configuration
74    config: SalesforceAuthConfig,
75    /// HTTP client for token requests
76    http_client: HttpClient,
77    /// Cached OAuth Access Token (refreshed only when genuinely expired)
78    cached_oauth_token: Option<OAuthToken>,
79    /// Cached DC JWT
80    cached_dc_jwt: Option<DataCloudToken>,
81}
82
83impl DataCloudTokenProvider {
84    /// Creates a new DC JWT provider with the given configuration.
85    ///
86    /// # Errors
87    ///
88    /// Returns an error if the configuration is invalid.
89    pub fn new(config: SalesforceAuthConfig) -> SalesforceAuthResult<Self> {
90        config.validate()?;
91
92        let http_client = HttpClient::builder()
93            .timeout(Duration::from_secs(config.timeout_secs))
94            .build()
95            .map_err(|e| SalesforceAuthError::http(format!("failed to create HTTP client: {e}")))?;
96
97        Ok(DataCloudTokenProvider {
98            config,
99            http_client,
100            cached_oauth_token: None,
101            cached_dc_jwt: None,
102        })
103    }
104
105    /// Returns the configuration.
106    #[must_use]
107    pub fn config(&self) -> &SalesforceAuthConfig {
108        &self.config
109    }
110
111    /// Gets a valid DC JWT.
112    ///
113    /// If a cached DC JWT exists and is still valid, it is returned.
114    /// Otherwise, a new DC JWT is obtained through the full token flow.
115    ///
116    /// # Errors
117    ///
118    /// Propagates any error from `Self::fetch_dc_jwt` — typically
119    /// [`SalesforceAuthError::Http`], [`SalesforceAuthError::Authorization`],
120    /// [`SalesforceAuthError::Jwt`], [`SalesforceAuthError::TokenExchange`],
121    /// or [`SalesforceAuthError::TokenParse`] depending on where the
122    /// three-step refresh cycle (OAuth Access Token → DC JWT) fails.
123    ///
124    /// # Panics
125    ///
126    /// Does not panic in practice. The trailing `unwrap()` on
127    /// `self.cached_dc_jwt` is guarded by the preceding cache-population
128    /// logic: either the cache was already populated with a valid token,
129    /// or `Self::fetch_dc_jwt` just filled it.
130    pub async fn get_token(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
131        let needs_refresh = match &self.cached_dc_jwt {
132            Some(token) if token.is_valid() => {
133                debug!("Using cached DC JWT");
134                false
135            }
136            Some(_) => {
137                debug!("Cached DC JWT expired, refreshing");
138                true
139            }
140            None => true,
141        };
142
143        if needs_refresh {
144            let token = self.fetch_dc_jwt().await?;
145            self.cached_dc_jwt = Some(token);
146        }
147
148        Ok(self.cached_dc_jwt.as_ref().unwrap())
149    }
150
151    /// Forces a full token refresh (both OAuth Access Token and DC JWT),
152    /// even if the cached tokens are still valid.
153    ///
154    /// # Errors
155    ///
156    /// Propagates any error from [`Self::get_token`] (same failure modes
157    /// as the full token-flow refresh).
158    pub async fn force_refresh(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
159        self.cached_oauth_token = None;
160        self.cached_dc_jwt = None;
161        self.get_token().await
162    }
163
164    /// Forces a DC JWT refresh while allowing the OAuth Access Token to
165    /// be reused if still valid.
166    ///
167    /// This is the preferred refresh method during normal operation: it
168    /// re-exchanges the (possibly cached) OAuth Access Token for a fresh
169    /// DC JWT without unnecessarily rotating the OAuth Refresh Token.
170    ///
171    /// # Errors
172    ///
173    /// Propagates any error from [`Self::get_token`] (HTTP, authorization,
174    /// JWT signing, or token-parse failures during the DC JWT exchange).
175    pub async fn refresh_token(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
176        self.cached_dc_jwt = None;
177        self.get_token().await
178    }
179
180    /// Clears all cached tokens (both OAuth Access Token and DC JWT).
181    pub fn clear_cache(&mut self) {
182        self.cached_oauth_token = None;
183        self.cached_dc_jwt = None;
184    }
185
186    /// Returns the DC JWT bearer token string if a valid DC JWT is cached.
187    ///
188    /// Convenience method for getting the `Authorization` header value
189    /// without an async call.  Returns `None` if no valid DC JWT is cached.
190    #[must_use]
191    pub fn bearer_token(&self) -> Option<String> {
192        self.cached_dc_jwt
193            .as_ref()
194            .filter(|t| t.is_valid())
195            .map(super::token::DataCloudToken::bearer_token)
196    }
197
198    /// Returns the tenant URL if a valid DC JWT is cached.
199    #[must_use]
200    pub fn tenant_url(&self) -> Option<&str> {
201        self.cached_dc_jwt
202            .as_ref()
203            .filter(|t| t.is_valid())
204            .map(super::token::DataCloudToken::tenant_url_str)
205    }
206
207    /// Returns the lakehouse name for Hyper connection.
208    ///
209    /// # Errors
210    ///
211    /// Propagates [`SalesforceAuthError::TokenParse`] from
212    /// [`DataCloudToken::lakehouse_name`] if the cached DC JWT's tenant
213    /// URL cannot be parsed into a valid lakehouse identifier.
214    pub fn lakehouse_name(&self) -> SalesforceAuthResult<Option<String>> {
215        if let Some(ref token) = self.cached_dc_jwt {
216            if token.is_valid() {
217                return Ok(Some(token.lakehouse_name(self.config.dataspace_value())?));
218            }
219        }
220        Ok(None)
221    }
222
223    /// Fetches a new DC JWT through the full token flow.
224    ///
225    /// Implements the three-step refresh cycle from the
226    /// `GenieOAuthManagement` documentation:
227    ///
228    /// - **Step 1**: Validate / refresh the OAuth Access Token
229    ///   (only refreshes when genuinely expired — avoids unnecessary
230    ///   OAuth Refresh Token rotation)
231    /// - **Step 2**: Exchange the OAuth Access Token for a DC JWT
232    /// - **Step 2a** (retry): If Step 2 fails, force-refresh the
233    ///   OAuth Access Token and retry the DC JWT exchange once
234    async fn fetch_dc_jwt(&mut self) -> SalesforceAuthResult<DataCloudToken> {
235        // Step 1: Validate / refresh OAuth Access Token
236        let oauth_token = self.get_valid_oauth_access_token().await?;
237
238        // Step 2: Exchange OAuth Access Token → DC JWT
239        match self
240            .exchange_oauth_access_token_for_dc_jwt(&oauth_token)
241            .await
242        {
243            Ok(dc_jwt) => Ok(dc_jwt),
244            Err(step2_err) => {
245                // Step 2a: Force-refresh the OAuth Access Token and retry once.
246                // This handles the case where the OAuth Access Token appeared
247                // valid locally but was invalidated server-side (e.g., by
248                // Salesforce's inactivity timeout).
249                warn!(
250                    error = %step2_err,
251                    "DC JWT exchange failed; force-refreshing OAuth Access Token and retrying (Step 2a)"
252                );
253
254                self.cached_oauth_token = None;
255                let fresh_oauth_token = self.fetch_oauth_access_token().await?;
256                self.cached_oauth_token = Some(fresh_oauth_token.clone());
257
258                self.exchange_oauth_access_token_for_dc_jwt(&fresh_oauth_token)
259                    .await
260                    .map_err(|retry_err| {
261                        warn!(
262                            original_error = %step2_err,
263                            retry_error = %retry_err,
264                            "DC JWT exchange failed again after OAuth Access Token refresh (Step 2a retry)"
265                        );
266                        retry_err
267                    })
268            }
269        }
270    }
271
272    /// Returns a valid OAuth Access Token, using the cache when possible.
273    ///
274    /// Only contacts Salesforce when the cached OAuth Access Token has
275    /// genuinely expired.  This avoids unnecessary OAuth Refresh Token
276    /// rotation that would invalidate tokens held by other connections.
277    async fn get_valid_oauth_access_token(&mut self) -> SalesforceAuthResult<OAuthToken> {
278        if let Some(ref token) = self.cached_oauth_token {
279            if token.is_likely_valid() {
280                debug!(
281                    "OAuth Access Token still valid (obtained at {}), reusing",
282                    token.obtained_at
283                );
284                return Ok(token.clone());
285            }
286            debug!("Cached OAuth Access Token expired, refreshing");
287        }
288
289        let token = self.fetch_oauth_access_token().await?;
290        self.cached_oauth_token = Some(token.clone());
291        Ok(token)
292    }
293
294    /// Fetches a fresh OAuth Access Token from Salesforce.
295    async fn fetch_oauth_access_token(&self) -> SalesforceAuthResult<OAuthToken> {
296        let auth_mode = self
297            .config
298            .auth_mode
299            .as_ref()
300            .ok_or_else(|| SalesforceAuthError::config("auth_mode not configured"))?;
301
302        let mut form_data = HashMap::new();
303        form_data.insert("client_id", self.config.client_id.clone());
304
305        match auth_mode {
306            AuthMode::Password { username, password } => {
307                info!(username = %username, "Fetching OAuth Access Token via password grant");
308                form_data.insert("grant_type", "password".to_string());
309                form_data.insert("username", username.clone());
310                form_data.insert("password", password.as_str().to_string());
311
312                if let Some(ref secret) = self.config.client_secret {
313                    form_data.insert("client_secret", secret.as_str().to_string());
314                }
315            }
316
317            AuthMode::PrivateKey {
318                username,
319                private_key,
320            } => {
321                info!(username = %username, "Fetching OAuth Access Token via JWT Bearer Token Flow");
322
323                let assertion = build_jwt_assertion(
324                    &self.config.client_id,
325                    username,
326                    &self.config.login_url,
327                    private_key,
328                )?;
329
330                form_data.insert(
331                    "grant_type",
332                    "urn:ietf:params:oauth:grant-type:jwt-bearer".to_string(),
333                );
334                form_data.insert("assertion", assertion);
335            }
336
337            AuthMode::RefreshToken { refresh_token } => {
338                info!("Fetching OAuth Access Token via OAuth Refresh Token");
339                form_data.insert("grant_type", "refresh_token".to_string());
340                form_data.insert("refresh_token", refresh_token.as_str().to_string());
341
342                if let Some(ref secret) = self.config.client_secret {
343                    form_data.insert("client_secret", secret.as_str().to_string());
344                }
345            }
346        }
347
348        let token_url = self.config.login_url.join(OAUTH_TOKEN_PATH).map_err(|e| {
349            SalesforceAuthError::config(format!("failed to build OAuth Access Token URL: {e}"))
350        })?;
351
352        debug!(url = %token_url, "Requesting OAuth Access Token");
353
354        let response = self.post_with_retry(&token_url, &form_data).await?;
355        let response_text = response.text().await?;
356
357        debug!(response = %response_text, "OAuth Access Token response received");
358
359        let oauth_response: OAuthTokenResponse =
360            serde_json::from_str(&response_text).map_err(|e| {
361                SalesforceAuthError::token_parse(format!(
362                    "failed to parse OAuth Access Token response: {e}"
363                ))
364            })?;
365
366        let token_changed = self
367            .cached_oauth_token
368            .as_ref()
369            .map_or(true, |old| old.token != oauth_response.access_token);
370
371        debug!(
372            instance_url = %oauth_response.instance_url,
373            token_type = ?oauth_response.token_type,
374            scope = ?oauth_response.scope,
375            token_changed = token_changed,
376            "OAuth Access Token response parsed"
377        );
378
379        OAuthToken::from_response(oauth_response)
380    }
381
382    /// Exchanges an OAuth Access Token for a DC JWT.
383    ///
384    /// Calls `POST /services/a360/token` with the OAuth Access Token as
385    /// the `subject_token`.
386    async fn exchange_oauth_access_token_for_dc_jwt(
387        &self,
388        oauth_token: &OAuthToken,
389    ) -> SalesforceAuthResult<DataCloudToken> {
390        let mut form_data = HashMap::new();
391        form_data.insert(
392            "grant_type",
393            "urn:salesforce:grant-type:external:cdp".to_string(),
394        );
395        form_data.insert(
396            "subject_token_type",
397            "urn:ietf:params:oauth:token-type:access_token".to_string(),
398        );
399        form_data.insert("subject_token", oauth_token.token.clone());
400
401        if let Some(ref dataspace) = self.config.dataspace {
402            form_data.insert("dataspace", dataspace.clone());
403        }
404
405        let exchange_url = oauth_token
406            .instance_url
407            .join(DATA_CLOUD_TOKEN_PATH)
408            .map_err(|e| {
409                SalesforceAuthError::config(format!("failed to build DC JWT exchange URL: {e}"))
410            })?;
411
412        debug!(url = %exchange_url, "Exchanging OAuth Access Token for DC JWT");
413
414        let response = self.post_with_retry(&exchange_url, &form_data).await?;
415        let response_text = response.text().await?;
416
417        debug!(response = %response_text, "DC JWT response received");
418
419        let dc_response: DataCloudTokenResponse =
420            serde_json::from_str(&response_text).map_err(|e| {
421                SalesforceAuthError::token_parse(format!("failed to parse DC JWT response: {e}"))
422            })?;
423
424        debug!(
425            instance_url = %dc_response.instance_url,
426            token_type = ?dc_response.token_type,
427            expires_in = ?dc_response.expires_in,
428            "DC JWT response parsed"
429        );
430
431        let token = DataCloudToken::from_response(dc_response)?;
432
433        info!(
434            tenant_url = %token.tenant_url(),
435            expires_at = %token.expires_at(),
436            "DC JWT obtained"
437        );
438
439        Ok(token)
440    }
441
442    /// Makes a POST request with retry logic for transient failures.
443    async fn post_with_retry(
444        &self,
445        url: &url::Url,
446        form_data: &HashMap<&str, String>,
447    ) -> SalesforceAuthResult<reqwest::Response> {
448        let mut last_error = None;
449
450        for attempt in 0..=self.config.max_retries {
451            if attempt > 0 {
452                let delay = Duration::from_secs(1 << (attempt - 1).min(4));
453                warn!(
454                    attempt = attempt,
455                    delay_secs = delay.as_secs(),
456                    "Retrying after transient failure"
457                );
458                tokio::time::sleep(delay).await;
459            }
460
461            match self
462                .http_client
463                .post(url.as_str())
464                .header("Accept", "application/json")
465                .header("Content-Type", "application/x-www-form-urlencoded")
466                .form(form_data)
467                .send()
468                .await
469            {
470                Ok(response) => {
471                    if response.status().is_client_error() {
472                        let status = response.status();
473                        let body = response.text().await.unwrap_or_default();
474
475                        if let Ok(error_json) = serde_json::from_str::<serde_json::Value>(&body) {
476                            let error_code = error_json
477                                .get("error")
478                                .and_then(|v| v.as_str())
479                                .unwrap_or("unknown");
480                            let error_desc = error_json
481                                .get("error_description")
482                                .and_then(|v| v.as_str())
483                                .unwrap_or(&body);
484
485                            return Err(SalesforceAuthError::authorization(
486                                error_code.to_string(),
487                                error_desc.to_string(),
488                            ));
489                        }
490
491                        return Err(SalesforceAuthError::http(format!(
492                            "HTTP {status} error: {body}"
493                        )));
494                    }
495
496                    if response.status().is_server_error() {
497                        last_error = Some(SalesforceAuthError::http(format!(
498                            "HTTP {} error",
499                            response.status()
500                        )));
501                        continue;
502                    }
503
504                    return Ok(response);
505                }
506                Err(e) => {
507                    last_error = Some(SalesforceAuthError::Http(e.to_string()));
508                }
509            }
510        }
511
512        Err(last_error.unwrap_or_else(|| SalesforceAuthError::http("request failed after retries")))
513    }
514}
515
516/// Thread-safe wrapper around [`DataCloudTokenProvider`].
517///
518/// Allows sharing the DC JWT provider between multiple tasks/threads
519/// while ensuring exclusive access during token operations.  All access
520/// is protected by a [`tokio::sync::Mutex`].
521///
522/// # Example
523///
524/// ```no_run
525/// use hyperdb_api_salesforce::{SalesforceAuthConfig, AuthMode, SharedTokenProvider};
526///
527/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
528/// # let config = SalesforceAuthConfig::new("https://login.salesforce.com", "client_id")?
529/// #     .auth_mode(AuthMode::password("user", "pass"));
530/// let provider = SharedTokenProvider::new(config)?;
531///
532/// // Can be cloned and shared between tasks
533/// let provider_clone = provider.clone();
534///
535/// tokio::spawn(async move {
536///     let dc_jwt = provider_clone.get_token().await.unwrap();
537///     // use dc_jwt.bearer_token() as the Authorization header
538/// });
539/// # Ok(())
540/// # }
541/// ```
542#[derive(Clone)]
543pub struct SharedTokenProvider {
544    inner: Arc<Mutex<DataCloudTokenProvider>>,
545}
546
547impl SharedTokenProvider {
548    /// Creates a new shared DC JWT provider.
549    ///
550    /// # Errors
551    ///
552    /// Propagates any error from [`DataCloudTokenProvider::new`]:
553    /// configuration validation failures or HTTP client construction
554    /// failures (surfaced as [`SalesforceAuthError::Http`]).
555    pub fn new(config: SalesforceAuthConfig) -> SalesforceAuthResult<Self> {
556        let provider = DataCloudTokenProvider::new(config)?;
557        Ok(SharedTokenProvider {
558            inner: Arc::new(Mutex::new(provider)),
559        })
560    }
561
562    /// Gets a valid DC JWT.
563    ///
564    /// # Errors
565    ///
566    /// Propagates any error from [`DataCloudTokenProvider::get_token`]
567    /// (HTTP failure, authorization rejection, JWT signing error, or
568    /// token-parse failure during the refresh cycle).
569    pub async fn get_token(&self) -> SalesforceAuthResult<DataCloudToken> {
570        let mut provider = self.inner.lock().await;
571        provider.get_token().await.cloned()
572    }
573
574    /// Forces a DC JWT refresh (reuses OAuth Access Token if still valid).
575    ///
576    /// # Errors
577    ///
578    /// Propagates any error from [`DataCloudTokenProvider::refresh_token`].
579    pub async fn refresh_token(&self) -> SalesforceAuthResult<DataCloudToken> {
580        let mut provider = self.inner.lock().await;
581        provider.refresh_token().await.cloned()
582    }
583
584    /// Forces a full refresh (both OAuth Access Token and DC JWT).
585    ///
586    /// # Errors
587    ///
588    /// Propagates any error from [`DataCloudTokenProvider::force_refresh`].
589    pub async fn force_refresh(&self) -> SalesforceAuthResult<DataCloudToken> {
590        let mut provider = self.inner.lock().await;
591        provider.force_refresh().await.cloned()
592    }
593
594    /// Returns the DC JWT bearer token string if a valid DC JWT is cached.
595    pub async fn bearer_token(&self) -> Option<String> {
596        let provider = self.inner.lock().await;
597        provider.bearer_token()
598    }
599
600    /// Returns the tenant URL if a valid DC JWT is cached.
601    pub async fn tenant_url(&self) -> Option<String> {
602        let provider = self.inner.lock().await;
603        provider.tenant_url().map(std::string::ToString::to_string)
604    }
605
606    /// Returns the lakehouse name for Hyper connection.
607    ///
608    /// # Errors
609    ///
610    /// Propagates [`SalesforceAuthError::TokenParse`] from
611    /// [`DataCloudTokenProvider::lakehouse_name`] if the cached DC JWT's
612    /// tenant URL cannot be parsed into a valid lakehouse identifier.
613    pub async fn lakehouse_name(&self) -> SalesforceAuthResult<Option<String>> {
614        let provider = self.inner.lock().await;
615        provider.lakehouse_name()
616    }
617}
618
619impl std::fmt::Debug for DataCloudTokenProvider {
620    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
621        f.debug_struct("DataCloudTokenProvider")
622            .field("config", &self.config)
623            .field("has_cached_oauth_token", &self.cached_oauth_token.is_some())
624            .field("has_cached_dc_jwt", &self.cached_dc_jwt.is_some())
625            .finish_non_exhaustive()
626    }
627}
628
629impl std::fmt::Debug for SharedTokenProvider {
630    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
631        f.debug_struct("SharedTokenProvider")
632            .finish_non_exhaustive()
633    }
634}