Skip to main content

hyperdb_api_salesforce/
provider.rs

1// Copyright (c) 2026, Salesforce, Inc. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! DC JWT provider for Salesforce Data Cloud authentication.
5//!
6//! This module implements the two-stage token flow:
7//! 1. Authenticate with Salesforce to get an **OAuth Access Token**
8//!    (via `/services/oauth2/token`)
9//! 2. Exchange the OAuth Access Token for a **DC JWT**
10//!    (via `/services/a360/token`)
11//!
12//! The provider caches both the OAuth Access Token and the DC JWT
13//! independently.  The OAuth Access Token is only refreshed when it
14//! has genuinely expired, avoiding unnecessary **OAuth Refresh Token**
15//! rotation that would invalidate tokens held by other connections.
16
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::Duration;
20
21use reqwest::Client as HttpClient;
22use tokio::sync::Mutex;
23use tracing::{debug, info, warn};
24
25use crate::config::{AuthMode, SalesforceAuthConfig};
26use crate::error::{SalesforceAuthError, SalesforceAuthResult};
27use crate::jwt::build_jwt_assertion;
28use crate::token::{DataCloudToken, DataCloudTokenResponse, OAuthToken, OAuthTokenResponse};
29
30/// OAuth Access Token endpoint path.
31const OAUTH_TOKEN_PATH: &str = "services/oauth2/token";
32
33/// DC JWT exchange endpoint path.
34const DATA_CLOUD_TOKEN_PATH: &str = "services/a360/token";
35
36/// DC JWT provider.
37///
38/// Handles the full token flow for Salesforce Data Cloud:
39/// 1. Authenticates with Salesforce using the configured auth mode to
40///    obtain an **OAuth Access Token**
41/// 2. Exchanges the OAuth Access Token for a **DC JWT**
42/// 3. Caches both tokens and refreshes them independently:
43///    - The OAuth Access Token is refreshed only when genuinely expired
44///      (to avoid unnecessary OAuth Refresh Token rotation)
45///    - The DC JWT is refreshed whenever it is expired or requested
46///
47/// On DC JWT exchange failure, the provider retries once with a
48/// force-refreshed OAuth Access Token (Step 2a), matching the behavior
49/// described in the `GenieOAuthManagement` documentation.
50///
51/// # Example
52///
53/// ```no_run
54/// use hyperdb_api_salesforce::{SalesforceAuthConfig, AuthMode, DataCloudTokenProvider};
55///
56/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
57/// # let private_key_pem = "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----";
58/// let config = SalesforceAuthConfig::new(
59///     "https://login.salesforce.com",
60///     "your-client-id",
61/// )?
62/// .auth_mode(AuthMode::private_key("user@example.com", &private_key_pem)?);
63///
64/// let mut provider = DataCloudTokenProvider::new(config)?;
65///
66/// // Get a valid DC JWT (automatically handles the full token flow)
67/// let token = provider.get_token().await?;
68/// println!("Authorization: {}", token.bearer_token());
69/// # Ok(())
70/// # }
71/// ```
72pub struct DataCloudTokenProvider {
73    /// Configuration
74    config: SalesforceAuthConfig,
75    /// HTTP client for token requests
76    http_client: HttpClient,
77    /// Cached OAuth Access Token (refreshed only when genuinely expired)
78    cached_oauth_token: Option<OAuthToken>,
79    /// Cached DC JWT
80    cached_dc_jwt: Option<DataCloudToken>,
81}
82
83impl DataCloudTokenProvider {
84    /// Creates a new DC JWT provider with the given configuration.
85    ///
86    /// # Errors
87    ///
88    /// Returns an error if the configuration is invalid.
89    pub fn new(config: SalesforceAuthConfig) -> SalesforceAuthResult<Self> {
90        config.validate()?;
91
92        let http_client = HttpClient::builder()
93            .timeout(Duration::from_secs(config.timeout_secs))
94            .build()
95            .map_err(|e| SalesforceAuthError::Http(format!("failed to create HTTP client: {e}")))?;
96
97        Ok(DataCloudTokenProvider {
98            config,
99            http_client,
100            cached_oauth_token: None,
101            cached_dc_jwt: None,
102        })
103    }
104
105    /// Returns the configuration.
106    #[must_use]
107    pub fn config(&self) -> &SalesforceAuthConfig {
108        &self.config
109    }
110
111    /// Gets a valid DC JWT.
112    ///
113    /// If a cached DC JWT exists and is still valid, it is returned.
114    /// Otherwise, a new DC JWT is obtained through the full token flow.
115    ///
116    /// # Errors
117    ///
118    /// Propagates any error from `Self::fetch_dc_jwt` — typically
119    /// [`SalesforceAuthError::Http`], [`SalesforceAuthError::Authorization`],
120    /// [`SalesforceAuthError::Jwt`], [`SalesforceAuthError::TokenExchange`],
121    /// or [`SalesforceAuthError::TokenParse`] depending on where the
122    /// three-step refresh cycle (OAuth Access Token → DC JWT) fails.
123    ///
124    /// # Panics
125    ///
126    /// Does not panic in practice. The trailing `unwrap()` on
127    /// `self.cached_dc_jwt` is guarded by the preceding cache-population
128    /// logic: either the cache was already populated with a valid token,
129    /// or `Self::fetch_dc_jwt` just filled it.
130    pub async fn get_token(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
131        let needs_refresh = match &self.cached_dc_jwt {
132            Some(token) if token.is_valid() => {
133                debug!("Using cached DC JWT");
134                false
135            }
136            Some(_) => {
137                debug!("Cached DC JWT expired, refreshing");
138                true
139            }
140            None => true,
141        };
142
143        if needs_refresh {
144            let token = self.fetch_dc_jwt().await?;
145            self.cached_dc_jwt = Some(token);
146        }
147
148        Ok(self.cached_dc_jwt.as_ref().unwrap())
149    }
150
151    /// Forces a full token refresh (both OAuth Access Token and DC JWT),
152    /// even if the cached tokens are still valid.
153    ///
154    /// # Errors
155    ///
156    /// Propagates any error from [`Self::get_token`] (same failure modes
157    /// as the full token-flow refresh).
158    pub async fn force_refresh(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
159        self.cached_oauth_token = None;
160        self.cached_dc_jwt = None;
161        self.get_token().await
162    }
163
164    /// Forces a DC JWT refresh while allowing the OAuth Access Token to
165    /// be reused if still valid.
166    ///
167    /// This is the preferred refresh method during normal operation: it
168    /// re-exchanges the (possibly cached) OAuth Access Token for a fresh
169    /// DC JWT without unnecessarily rotating the OAuth Refresh Token.
170    ///
171    /// # Errors
172    ///
173    /// Propagates any error from [`Self::get_token`] (HTTP, authorization,
174    /// JWT signing, or token-parse failures during the DC JWT exchange).
175    pub async fn refresh_token(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
176        self.cached_dc_jwt = None;
177        self.get_token().await
178    }
179
180    /// Clears all cached tokens (both OAuth Access Token and DC JWT).
181    pub fn clear_cache(&mut self) {
182        self.cached_oauth_token = None;
183        self.cached_dc_jwt = None;
184    }
185
186    /// Returns the DC JWT bearer token string if a valid DC JWT is cached.
187    ///
188    /// Convenience method for getting the `Authorization` header value
189    /// without an async call.  Returns `None` if no valid DC JWT is cached.
190    #[must_use]
191    pub fn bearer_token(&self) -> Option<String> {
192        self.cached_dc_jwt
193            .as_ref()
194            .filter(|t| t.is_valid())
195            .map(super::token::DataCloudToken::bearer_token)
196    }
197
198    /// Returns the tenant URL if a valid DC JWT is cached.
199    #[must_use]
200    pub fn tenant_url(&self) -> Option<&str> {
201        self.cached_dc_jwt
202            .as_ref()
203            .filter(|t| t.is_valid())
204            .map(super::token::DataCloudToken::tenant_url_str)
205    }
206
207    /// Returns the lakehouse name for Hyper connection.
208    ///
209    /// # Errors
210    ///
211    /// Propagates [`SalesforceAuthError::TokenParse`] from
212    /// [`DataCloudToken::lakehouse_name`] if the cached DC JWT's tenant
213    /// URL cannot be parsed into a valid lakehouse identifier.
214    pub fn lakehouse_name(&self) -> SalesforceAuthResult<Option<String>> {
215        if let Some(ref token) = self.cached_dc_jwt {
216            if token.is_valid() {
217                return Ok(Some(token.lakehouse_name(self.config.dataspace_value())?));
218            }
219        }
220        Ok(None)
221    }
222
223    /// Fetches a new DC JWT through the full token flow.
224    ///
225    /// Implements the three-step refresh cycle from the
226    /// `GenieOAuthManagement` documentation:
227    ///
228    /// - **Step 1**: Validate / refresh the OAuth Access Token
229    ///   (only refreshes when genuinely expired — avoids unnecessary
230    ///   OAuth Refresh Token rotation)
231    /// - **Step 2**: Exchange the OAuth Access Token for a DC JWT
232    /// - **Step 2a** (retry): If Step 2 fails, force-refresh the
233    ///   OAuth Access Token and retry the DC JWT exchange once
234    async fn fetch_dc_jwt(&mut self) -> SalesforceAuthResult<DataCloudToken> {
235        // Step 1: Validate / refresh OAuth Access Token
236        let oauth_token = self.get_valid_oauth_access_token().await?;
237
238        // Step 2: Exchange OAuth Access Token → DC JWT
239        match self
240            .exchange_oauth_access_token_for_dc_jwt(&oauth_token)
241            .await
242        {
243            Ok(dc_jwt) => Ok(dc_jwt),
244            Err(step2_err) => {
245                // Step 2a: Force-refresh the OAuth Access Token and retry once.
246                // This handles the case where the OAuth Access Token appeared
247                // valid locally but was invalidated server-side (e.g., by
248                // Salesforce's inactivity timeout).
249                warn!(
250                    error = %step2_err,
251                    "DC JWT exchange failed; force-refreshing OAuth Access Token and retrying (Step 2a)"
252                );
253
254                self.cached_oauth_token = None;
255                let fresh_oauth_token = self.fetch_oauth_access_token().await?;
256                self.cached_oauth_token = Some(fresh_oauth_token.clone());
257
258                self.exchange_oauth_access_token_for_dc_jwt(&fresh_oauth_token)
259                    .await
260                    .map_err(|retry_err| {
261                        warn!(
262                            original_error = %step2_err,
263                            retry_error = %retry_err,
264                            "DC JWT exchange failed again after OAuth Access Token refresh (Step 2a retry)"
265                        );
266                        retry_err
267                    })
268            }
269        }
270    }
271
272    /// Returns a valid OAuth Access Token, using the cache when possible.
273    ///
274    /// Only contacts Salesforce when the cached OAuth Access Token has
275    /// genuinely expired.  This avoids unnecessary OAuth Refresh Token
276    /// rotation that would invalidate tokens held by other connections.
277    async fn get_valid_oauth_access_token(&mut self) -> SalesforceAuthResult<OAuthToken> {
278        if let Some(ref token) = self.cached_oauth_token {
279            if token.is_likely_valid() {
280                debug!(
281                    "OAuth Access Token still valid (obtained at {}), reusing",
282                    token.obtained_at
283                );
284                return Ok(token.clone());
285            }
286            debug!("Cached OAuth Access Token expired, refreshing");
287        }
288
289        let token = self.fetch_oauth_access_token().await?;
290        self.cached_oauth_token = Some(token.clone());
291        Ok(token)
292    }
293
294    /// Fetches a fresh OAuth Access Token from Salesforce.
295    async fn fetch_oauth_access_token(&self) -> SalesforceAuthResult<OAuthToken> {
296        let auth_mode =
297            self.config.auth_mode.as_ref().ok_or_else(|| {
298                SalesforceAuthError::Config("auth_mode not configured".to_string())
299            })?;
300
301        let mut form_data = HashMap::new();
302        form_data.insert("client_id", self.config.client_id.clone());
303
304        match auth_mode {
305            AuthMode::Password { username, password } => {
306                info!(username = %username, "Fetching OAuth Access Token via password grant");
307                form_data.insert("grant_type", "password".to_string());
308                form_data.insert("username", username.clone());
309                form_data.insert("password", password.as_str().to_string());
310
311                if let Some(ref secret) = self.config.client_secret {
312                    form_data.insert("client_secret", secret.as_str().to_string());
313                }
314            }
315
316            AuthMode::PrivateKey {
317                username,
318                private_key,
319            } => {
320                info!(username = %username, "Fetching OAuth Access Token via JWT Bearer Token Flow");
321
322                let assertion = build_jwt_assertion(
323                    &self.config.client_id,
324                    username,
325                    &self.config.login_url,
326                    private_key,
327                )?;
328
329                form_data.insert(
330                    "grant_type",
331                    "urn:ietf:params:oauth:grant-type:jwt-bearer".to_string(),
332                );
333                form_data.insert("assertion", assertion);
334            }
335
336            AuthMode::RefreshToken { refresh_token } => {
337                info!("Fetching OAuth Access Token via OAuth Refresh Token");
338                form_data.insert("grant_type", "refresh_token".to_string());
339                form_data.insert("refresh_token", refresh_token.as_str().to_string());
340
341                if let Some(ref secret) = self.config.client_secret {
342                    form_data.insert("client_secret", secret.as_str().to_string());
343                }
344            }
345        }
346
347        let token_url = self.config.login_url.join(OAUTH_TOKEN_PATH).map_err(|e| {
348            SalesforceAuthError::Config(format!("failed to build OAuth Access Token URL: {e}"))
349        })?;
350
351        debug!(url = %token_url, "Requesting OAuth Access Token");
352
353        let response = self.post_with_retry(&token_url, &form_data).await?;
354        let response_text = response.text().await?;
355
356        debug!(response = %response_text, "OAuth Access Token response received");
357
358        let oauth_response: OAuthTokenResponse =
359            serde_json::from_str(&response_text).map_err(|e| {
360                SalesforceAuthError::TokenParse(format!(
361                    "failed to parse OAuth Access Token response: {e}"
362                ))
363            })?;
364
365        let token_changed = self
366            .cached_oauth_token
367            .as_ref()
368            .map_or(true, |old| old.token != oauth_response.access_token);
369
370        debug!(
371            instance_url = %oauth_response.instance_url,
372            token_type = ?oauth_response.token_type,
373            scope = ?oauth_response.scope,
374            token_changed = token_changed,
375            "OAuth Access Token response parsed"
376        );
377
378        OAuthToken::from_response(oauth_response)
379    }
380
381    /// Exchanges an OAuth Access Token for a DC JWT.
382    ///
383    /// Calls `POST /services/a360/token` with the OAuth Access Token as
384    /// the `subject_token`.
385    async fn exchange_oauth_access_token_for_dc_jwt(
386        &self,
387        oauth_token: &OAuthToken,
388    ) -> SalesforceAuthResult<DataCloudToken> {
389        let mut form_data = HashMap::new();
390        form_data.insert(
391            "grant_type",
392            "urn:salesforce:grant-type:external:cdp".to_string(),
393        );
394        form_data.insert(
395            "subject_token_type",
396            "urn:ietf:params:oauth:token-type:access_token".to_string(),
397        );
398        form_data.insert("subject_token", oauth_token.token.clone());
399
400        if let Some(ref dataspace) = self.config.dataspace {
401            form_data.insert("dataspace", dataspace.clone());
402        }
403
404        let exchange_url = oauth_token
405            .instance_url
406            .join(DATA_CLOUD_TOKEN_PATH)
407            .map_err(|e| {
408                SalesforceAuthError::Config(format!("failed to build DC JWT exchange URL: {e}"))
409            })?;
410
411        debug!(url = %exchange_url, "Exchanging OAuth Access Token for DC JWT");
412
413        let response = self.post_with_retry(&exchange_url, &form_data).await?;
414        let response_text = response.text().await?;
415
416        debug!(response = %response_text, "DC JWT response received");
417
418        let dc_response: DataCloudTokenResponse =
419            serde_json::from_str(&response_text).map_err(|e| {
420                SalesforceAuthError::TokenParse(format!("failed to parse DC JWT response: {e}"))
421            })?;
422
423        debug!(
424            instance_url = %dc_response.instance_url,
425            token_type = ?dc_response.token_type,
426            expires_in = ?dc_response.expires_in,
427            "DC JWT response parsed"
428        );
429
430        let token = DataCloudToken::from_response(dc_response)?;
431
432        info!(
433            tenant_url = %token.tenant_url(),
434            expires_at = %token.expires_at(),
435            "DC JWT obtained"
436        );
437
438        Ok(token)
439    }
440
441    /// Makes a POST request with retry logic for transient failures.
442    async fn post_with_retry(
443        &self,
444        url: &url::Url,
445        form_data: &HashMap<&str, String>,
446    ) -> SalesforceAuthResult<reqwest::Response> {
447        let mut last_error = None;
448
449        for attempt in 0..=self.config.max_retries {
450            if attempt > 0 {
451                let delay = Duration::from_secs(1 << (attempt - 1).min(4));
452                warn!(
453                    attempt = attempt,
454                    delay_secs = delay.as_secs(),
455                    "Retrying after transient failure"
456                );
457                tokio::time::sleep(delay).await;
458            }
459
460            match self
461                .http_client
462                .post(url.as_str())
463                .header("Accept", "application/json")
464                .header("Content-Type", "application/x-www-form-urlencoded")
465                .form(form_data)
466                .send()
467                .await
468            {
469                Ok(response) => {
470                    if response.status().is_client_error() {
471                        let status = response.status();
472                        let body = response.text().await.unwrap_or_default();
473
474                        if let Ok(error_json) = serde_json::from_str::<serde_json::Value>(&body) {
475                            let error_code = error_json
476                                .get("error")
477                                .and_then(|v| v.as_str())
478                                .unwrap_or("unknown");
479                            let error_desc = error_json
480                                .get("error_description")
481                                .and_then(|v| v.as_str())
482                                .unwrap_or(&body);
483
484                            return Err(SalesforceAuthError::Authorization {
485                                error_code: error_code.to_string(),
486                                error_description: error_desc.to_string(),
487                            });
488                        }
489
490                        return Err(SalesforceAuthError::Http(format!(
491                            "HTTP {status} error: {body}"
492                        )));
493                    }
494
495                    if response.status().is_server_error() {
496                        last_error = Some(SalesforceAuthError::Http(format!(
497                            "HTTP {} error",
498                            response.status()
499                        )));
500                        continue;
501                    }
502
503                    return Ok(response);
504                }
505                Err(e) => {
506                    last_error = Some(SalesforceAuthError::Http(e.to_string()));
507                }
508            }
509        }
510
511        Err(last_error.unwrap_or_else(|| {
512            SalesforceAuthError::Http("request failed after retries".to_string())
513        }))
514    }
515}
516
517/// Thread-safe wrapper around [`DataCloudTokenProvider`].
518///
519/// Allows sharing the DC JWT provider between multiple tasks/threads
520/// while ensuring exclusive access during token operations.  All access
521/// is protected by a [`tokio::sync::Mutex`].
522///
523/// # Example
524///
525/// ```no_run
526/// use hyperdb_api_salesforce::{SalesforceAuthConfig, AuthMode, SharedTokenProvider};
527///
528/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
529/// # let config = SalesforceAuthConfig::new("https://login.salesforce.com", "client_id")?
530/// #     .auth_mode(AuthMode::password("user", "pass"));
531/// let provider = SharedTokenProvider::new(config)?;
532///
533/// // Can be cloned and shared between tasks
534/// let provider_clone = provider.clone();
535///
536/// tokio::spawn(async move {
537///     let dc_jwt = provider_clone.get_token().await.unwrap();
538///     // use dc_jwt.bearer_token() as the Authorization header
539/// });
540/// # Ok(())
541/// # }
542/// ```
543#[derive(Clone)]
544pub struct SharedTokenProvider {
545    inner: Arc<Mutex<DataCloudTokenProvider>>,
546}
547
548impl SharedTokenProvider {
549    /// Creates a new shared DC JWT provider.
550    ///
551    /// # Errors
552    ///
553    /// Propagates any error from [`DataCloudTokenProvider::new`]:
554    /// configuration validation failures or HTTP client construction
555    /// failures (surfaced as [`SalesforceAuthError::Http`]).
556    pub fn new(config: SalesforceAuthConfig) -> SalesforceAuthResult<Self> {
557        let provider = DataCloudTokenProvider::new(config)?;
558        Ok(SharedTokenProvider {
559            inner: Arc::new(Mutex::new(provider)),
560        })
561    }
562
563    /// Gets a valid DC JWT.
564    ///
565    /// # Errors
566    ///
567    /// Propagates any error from [`DataCloudTokenProvider::get_token`]
568    /// (HTTP failure, authorization rejection, JWT signing error, or
569    /// token-parse failure during the refresh cycle).
570    pub async fn get_token(&self) -> SalesforceAuthResult<DataCloudToken> {
571        let mut provider = self.inner.lock().await;
572        provider.get_token().await.cloned()
573    }
574
575    /// Forces a DC JWT refresh (reuses OAuth Access Token if still valid).
576    ///
577    /// # Errors
578    ///
579    /// Propagates any error from [`DataCloudTokenProvider::refresh_token`].
580    pub async fn refresh_token(&self) -> SalesforceAuthResult<DataCloudToken> {
581        let mut provider = self.inner.lock().await;
582        provider.refresh_token().await.cloned()
583    }
584
585    /// Forces a full refresh (both OAuth Access Token and DC JWT).
586    ///
587    /// # Errors
588    ///
589    /// Propagates any error from [`DataCloudTokenProvider::force_refresh`].
590    pub async fn force_refresh(&self) -> SalesforceAuthResult<DataCloudToken> {
591        let mut provider = self.inner.lock().await;
592        provider.force_refresh().await.cloned()
593    }
594
595    /// Returns the DC JWT bearer token string if a valid DC JWT is cached.
596    pub async fn bearer_token(&self) -> Option<String> {
597        let provider = self.inner.lock().await;
598        provider.bearer_token()
599    }
600
601    /// Returns the tenant URL if a valid DC JWT is cached.
602    pub async fn tenant_url(&self) -> Option<String> {
603        let provider = self.inner.lock().await;
604        provider.tenant_url().map(std::string::ToString::to_string)
605    }
606
607    /// Returns the lakehouse name for Hyper connection.
608    ///
609    /// # Errors
610    ///
611    /// Propagates [`SalesforceAuthError::TokenParse`] from
612    /// [`DataCloudTokenProvider::lakehouse_name`] if the cached DC JWT's
613    /// tenant URL cannot be parsed into a valid lakehouse identifier.
614    pub async fn lakehouse_name(&self) -> SalesforceAuthResult<Option<String>> {
615        let provider = self.inner.lock().await;
616        provider.lakehouse_name()
617    }
618}
619
620impl std::fmt::Debug for DataCloudTokenProvider {
621    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
622        f.debug_struct("DataCloudTokenProvider")
623            .field("config", &self.config)
624            .field("has_cached_oauth_token", &self.cached_oauth_token.is_some())
625            .field("has_cached_dc_jwt", &self.cached_dc_jwt.is_some())
626            .finish_non_exhaustive()
627    }
628}
629
630impl std::fmt::Debug for SharedTokenProvider {
631    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
632        f.debug_struct("SharedTokenProvider")
633            .finish_non_exhaustive()
634    }
635}