hyperdb_api_salesforce/provider.rs
1// Copyright (c) 2026, Salesforce, Inc. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! DC JWT provider for Salesforce Data Cloud authentication.
5//!
6//! This module implements the two-stage token flow:
7//! 1. Authenticate with Salesforce to get an **OAuth Access Token**
8//! (via `/services/oauth2/token`)
9//! 2. Exchange the OAuth Access Token for a **DC JWT**
10//! (via `/services/a360/token`)
11//!
12//! The provider caches both the OAuth Access Token and the DC JWT
13//! independently. The OAuth Access Token is only refreshed when it
14//! has genuinely expired, avoiding unnecessary **OAuth Refresh Token**
15//! rotation that would invalidate tokens held by other connections.
16
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::Duration;
20
21use reqwest::Client as HttpClient;
22use tokio::sync::Mutex;
23use tracing::{debug, info, warn};
24
25use crate::config::{AuthMode, SalesforceAuthConfig};
26use crate::error::{SalesforceAuthError, SalesforceAuthResult};
27use crate::jwt::build_jwt_assertion;
28use crate::token::{DataCloudToken, DataCloudTokenResponse, OAuthToken, OAuthTokenResponse};
29
30/// OAuth Access Token endpoint path.
31const OAUTH_TOKEN_PATH: &str = "services/oauth2/token";
32
33/// DC JWT exchange endpoint path.
34const DATA_CLOUD_TOKEN_PATH: &str = "services/a360/token";
35
36/// DC JWT provider.
37///
38/// Handles the full token flow for Salesforce Data Cloud:
39/// 1. Authenticates with Salesforce using the configured auth mode to
40/// obtain an **OAuth Access Token**
41/// 2. Exchanges the OAuth Access Token for a **DC JWT**
42/// 3. Caches both tokens and refreshes them independently:
43/// - The OAuth Access Token is refreshed only when genuinely expired
44/// (to avoid unnecessary OAuth Refresh Token rotation)
45/// - The DC JWT is refreshed whenever it is expired or requested
46///
47/// On DC JWT exchange failure, the provider retries once with a
48/// force-refreshed OAuth Access Token (Step 2a), matching the behavior
49/// described in the `GenieOAuthManagement` documentation.
50///
51/// # Example
52///
53/// ```no_run
54/// use hyperdb_api_salesforce::{SalesforceAuthConfig, AuthMode, DataCloudTokenProvider};
55///
56/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
57/// # let private_key_pem = "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----";
58/// let config = SalesforceAuthConfig::new(
59/// "https://login.salesforce.com",
60/// "your-client-id",
61/// )?
62/// .auth_mode(AuthMode::private_key("user@example.com", &private_key_pem)?);
63///
64/// let mut provider = DataCloudTokenProvider::new(config)?;
65///
66/// // Get a valid DC JWT (automatically handles the full token flow)
67/// let token = provider.get_token().await?;
68/// println!("Authorization: {}", token.bearer_token());
69/// # Ok(())
70/// # }
71/// ```
72pub struct DataCloudTokenProvider {
73 /// Configuration
74 config: SalesforceAuthConfig,
75 /// HTTP client for token requests
76 http_client: HttpClient,
77 /// Cached OAuth Access Token (refreshed only when genuinely expired)
78 cached_oauth_token: Option<OAuthToken>,
79 /// Cached DC JWT
80 cached_dc_jwt: Option<DataCloudToken>,
81}
82
83impl DataCloudTokenProvider {
84 /// Creates a new DC JWT provider with the given configuration.
85 ///
86 /// # Errors
87 ///
88 /// Returns an error if the configuration is invalid.
89 pub fn new(config: SalesforceAuthConfig) -> SalesforceAuthResult<Self> {
90 config.validate()?;
91
92 let http_client = HttpClient::builder()
93 .timeout(Duration::from_secs(config.timeout_secs))
94 .build()
95 .map_err(|e| SalesforceAuthError::http(format!("failed to create HTTP client: {e}")))?;
96
97 Ok(DataCloudTokenProvider {
98 config,
99 http_client,
100 cached_oauth_token: None,
101 cached_dc_jwt: None,
102 })
103 }
104
105 /// Returns the configuration.
106 #[must_use]
107 pub fn config(&self) -> &SalesforceAuthConfig {
108 &self.config
109 }
110
111 /// Gets a valid DC JWT.
112 ///
113 /// If a cached DC JWT exists and is still valid, it is returned.
114 /// Otherwise, a new DC JWT is obtained through the full token flow.
115 ///
116 /// # Errors
117 ///
118 /// Propagates any error from `Self::fetch_dc_jwt` — typically
119 /// [`SalesforceAuthError::Http`], [`SalesforceAuthError::Authorization`],
120 /// [`SalesforceAuthError::Jwt`], [`SalesforceAuthError::TokenExchange`],
121 /// or [`SalesforceAuthError::TokenParse`] depending on where the
122 /// three-step refresh cycle (OAuth Access Token → DC JWT) fails.
123 ///
124 /// # Panics
125 ///
126 /// Does not panic in practice. The trailing `unwrap()` on
127 /// `self.cached_dc_jwt` is guarded by the preceding cache-population
128 /// logic: either the cache was already populated with a valid token,
129 /// or `Self::fetch_dc_jwt` just filled it.
130 pub async fn get_token(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
131 let needs_refresh = match &self.cached_dc_jwt {
132 Some(token) if token.is_valid() => {
133 debug!("Using cached DC JWT");
134 false
135 }
136 Some(_) => {
137 debug!("Cached DC JWT expired, refreshing");
138 true
139 }
140 None => true,
141 };
142
143 if needs_refresh {
144 let token = self.fetch_dc_jwt().await?;
145 self.cached_dc_jwt = Some(token);
146 }
147
148 Ok(self.cached_dc_jwt.as_ref().unwrap())
149 }
150
151 /// Forces a full token refresh (both OAuth Access Token and DC JWT),
152 /// even if the cached tokens are still valid.
153 ///
154 /// # Errors
155 ///
156 /// Propagates any error from [`Self::get_token`] (same failure modes
157 /// as the full token-flow refresh).
158 pub async fn force_refresh(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
159 self.cached_oauth_token = None;
160 self.cached_dc_jwt = None;
161 self.get_token().await
162 }
163
164 /// Forces a DC JWT refresh while allowing the OAuth Access Token to
165 /// be reused if still valid.
166 ///
167 /// This is the preferred refresh method during normal operation: it
168 /// re-exchanges the (possibly cached) OAuth Access Token for a fresh
169 /// DC JWT without unnecessarily rotating the OAuth Refresh Token.
170 ///
171 /// # Errors
172 ///
173 /// Propagates any error from [`Self::get_token`] (HTTP, authorization,
174 /// JWT signing, or token-parse failures during the DC JWT exchange).
175 pub async fn refresh_token(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
176 self.cached_dc_jwt = None;
177 self.get_token().await
178 }
179
180 /// Clears all cached tokens (both OAuth Access Token and DC JWT).
181 pub fn clear_cache(&mut self) {
182 self.cached_oauth_token = None;
183 self.cached_dc_jwt = None;
184 }
185
186 /// Returns the DC JWT bearer token string if a valid DC JWT is cached.
187 ///
188 /// Convenience method for getting the `Authorization` header value
189 /// without an async call. Returns `None` if no valid DC JWT is cached.
190 #[must_use]
191 pub fn bearer_token(&self) -> Option<String> {
192 self.cached_dc_jwt
193 .as_ref()
194 .filter(|t| t.is_valid())
195 .map(super::token::DataCloudToken::bearer_token)
196 }
197
198 /// Returns the tenant URL if a valid DC JWT is cached.
199 #[must_use]
200 pub fn tenant_url(&self) -> Option<&str> {
201 self.cached_dc_jwt
202 .as_ref()
203 .filter(|t| t.is_valid())
204 .map(super::token::DataCloudToken::tenant_url_str)
205 }
206
207 /// Returns the lakehouse name for Hyper connection.
208 ///
209 /// # Errors
210 ///
211 /// Propagates [`SalesforceAuthError::TokenParse`] from
212 /// [`DataCloudToken::lakehouse_name`] if the cached DC JWT's tenant
213 /// URL cannot be parsed into a valid lakehouse identifier.
214 pub fn lakehouse_name(&self) -> SalesforceAuthResult<Option<String>> {
215 if let Some(ref token) = self.cached_dc_jwt {
216 if token.is_valid() {
217 return Ok(Some(token.lakehouse_name(self.config.dataspace_value())?));
218 }
219 }
220 Ok(None)
221 }
222
223 /// Fetches a new DC JWT through the full token flow.
224 ///
225 /// Implements the three-step refresh cycle from the
226 /// `GenieOAuthManagement` documentation:
227 ///
228 /// - **Step 1**: Validate / refresh the OAuth Access Token
229 /// (only refreshes when genuinely expired — avoids unnecessary
230 /// OAuth Refresh Token rotation)
231 /// - **Step 2**: Exchange the OAuth Access Token for a DC JWT
232 /// - **Step 2a** (retry): If Step 2 fails, force-refresh the
233 /// OAuth Access Token and retry the DC JWT exchange once
234 async fn fetch_dc_jwt(&mut self) -> SalesforceAuthResult<DataCloudToken> {
235 // Step 1: Validate / refresh OAuth Access Token
236 let oauth_token = self.get_valid_oauth_access_token().await?;
237
238 // Step 2: Exchange OAuth Access Token → DC JWT
239 match self
240 .exchange_oauth_access_token_for_dc_jwt(&oauth_token)
241 .await
242 {
243 Ok(dc_jwt) => Ok(dc_jwt),
244 Err(step2_err) => {
245 // Step 2a: Force-refresh the OAuth Access Token and retry once.
246 // This handles the case where the OAuth Access Token appeared
247 // valid locally but was invalidated server-side (e.g., by
248 // Salesforce's inactivity timeout).
249 warn!(
250 error = %step2_err,
251 "DC JWT exchange failed; force-refreshing OAuth Access Token and retrying (Step 2a)"
252 );
253
254 self.cached_oauth_token = None;
255 let fresh_oauth_token = self.fetch_oauth_access_token().await?;
256 self.cached_oauth_token = Some(fresh_oauth_token.clone());
257
258 self.exchange_oauth_access_token_for_dc_jwt(&fresh_oauth_token)
259 .await
260 .map_err(|retry_err| {
261 warn!(
262 original_error = %step2_err,
263 retry_error = %retry_err,
264 "DC JWT exchange failed again after OAuth Access Token refresh (Step 2a retry)"
265 );
266 retry_err
267 })
268 }
269 }
270 }
271
272 /// Returns a valid OAuth Access Token, using the cache when possible.
273 ///
274 /// Only contacts Salesforce when the cached OAuth Access Token has
275 /// genuinely expired. This avoids unnecessary OAuth Refresh Token
276 /// rotation that would invalidate tokens held by other connections.
277 async fn get_valid_oauth_access_token(&mut self) -> SalesforceAuthResult<OAuthToken> {
278 if let Some(ref token) = self.cached_oauth_token {
279 if token.is_likely_valid() {
280 debug!(
281 "OAuth Access Token still valid (obtained at {}), reusing",
282 token.obtained_at
283 );
284 return Ok(token.clone());
285 }
286 debug!("Cached OAuth Access Token expired, refreshing");
287 }
288
289 let token = self.fetch_oauth_access_token().await?;
290 self.cached_oauth_token = Some(token.clone());
291 Ok(token)
292 }
293
294 /// Fetches a fresh OAuth Access Token from Salesforce.
295 async fn fetch_oauth_access_token(&self) -> SalesforceAuthResult<OAuthToken> {
296 let auth_mode = self
297 .config
298 .auth_mode
299 .as_ref()
300 .ok_or_else(|| SalesforceAuthError::config("auth_mode not configured"))?;
301
302 let mut form_data = HashMap::new();
303 form_data.insert("client_id", self.config.client_id.clone());
304
305 match auth_mode {
306 AuthMode::Password { username, password } => {
307 info!(username = %username, "Fetching OAuth Access Token via password grant");
308 form_data.insert("grant_type", "password".to_string());
309 form_data.insert("username", username.clone());
310 form_data.insert("password", password.as_str().to_string());
311
312 if let Some(ref secret) = self.config.client_secret {
313 form_data.insert("client_secret", secret.as_str().to_string());
314 }
315 }
316
317 AuthMode::PrivateKey {
318 username,
319 private_key,
320 } => {
321 info!(username = %username, "Fetching OAuth Access Token via JWT Bearer Token Flow");
322
323 let assertion = build_jwt_assertion(
324 &self.config.client_id,
325 username,
326 &self.config.login_url,
327 private_key,
328 )?;
329
330 form_data.insert(
331 "grant_type",
332 "urn:ietf:params:oauth:grant-type:jwt-bearer".to_string(),
333 );
334 form_data.insert("assertion", assertion);
335 }
336
337 AuthMode::RefreshToken { refresh_token } => {
338 info!("Fetching OAuth Access Token via OAuth Refresh Token");
339 form_data.insert("grant_type", "refresh_token".to_string());
340 form_data.insert("refresh_token", refresh_token.as_str().to_string());
341
342 if let Some(ref secret) = self.config.client_secret {
343 form_data.insert("client_secret", secret.as_str().to_string());
344 }
345 }
346 }
347
348 let token_url = self.config.login_url.join(OAUTH_TOKEN_PATH).map_err(|e| {
349 SalesforceAuthError::config(format!("failed to build OAuth Access Token URL: {e}"))
350 })?;
351
352 debug!(url = %token_url, "Requesting OAuth Access Token");
353
354 let response = self.post_with_retry(&token_url, &form_data).await?;
355 let response_text = response.text().await?;
356
357 debug!(response = %response_text, "OAuth Access Token response received");
358
359 let oauth_response: OAuthTokenResponse =
360 serde_json::from_str(&response_text).map_err(|e| {
361 SalesforceAuthError::token_parse(format!(
362 "failed to parse OAuth Access Token response: {e}"
363 ))
364 })?;
365
366 let token_changed = self
367 .cached_oauth_token
368 .as_ref()
369 .map_or(true, |old| old.token != oauth_response.access_token);
370
371 debug!(
372 instance_url = %oauth_response.instance_url,
373 token_type = ?oauth_response.token_type,
374 scope = ?oauth_response.scope,
375 token_changed = token_changed,
376 "OAuth Access Token response parsed"
377 );
378
379 OAuthToken::from_response(oauth_response)
380 }
381
382 /// Exchanges an OAuth Access Token for a DC JWT.
383 ///
384 /// Calls `POST /services/a360/token` with the OAuth Access Token as
385 /// the `subject_token`.
386 async fn exchange_oauth_access_token_for_dc_jwt(
387 &self,
388 oauth_token: &OAuthToken,
389 ) -> SalesforceAuthResult<DataCloudToken> {
390 let mut form_data = HashMap::new();
391 form_data.insert(
392 "grant_type",
393 "urn:salesforce:grant-type:external:cdp".to_string(),
394 );
395 form_data.insert(
396 "subject_token_type",
397 "urn:ietf:params:oauth:token-type:access_token".to_string(),
398 );
399 form_data.insert("subject_token", oauth_token.token.clone());
400
401 if let Some(ref dataspace) = self.config.dataspace {
402 form_data.insert("dataspace", dataspace.clone());
403 }
404
405 let exchange_url = oauth_token
406 .instance_url
407 .join(DATA_CLOUD_TOKEN_PATH)
408 .map_err(|e| {
409 SalesforceAuthError::config(format!("failed to build DC JWT exchange URL: {e}"))
410 })?;
411
412 debug!(url = %exchange_url, "Exchanging OAuth Access Token for DC JWT");
413
414 let response = self.post_with_retry(&exchange_url, &form_data).await?;
415 let response_text = response.text().await?;
416
417 debug!(response = %response_text, "DC JWT response received");
418
419 let dc_response: DataCloudTokenResponse =
420 serde_json::from_str(&response_text).map_err(|e| {
421 SalesforceAuthError::token_parse(format!("failed to parse DC JWT response: {e}"))
422 })?;
423
424 debug!(
425 instance_url = %dc_response.instance_url,
426 token_type = ?dc_response.token_type,
427 expires_in = ?dc_response.expires_in,
428 "DC JWT response parsed"
429 );
430
431 let token = DataCloudToken::from_response(dc_response)?;
432
433 info!(
434 tenant_url = %token.tenant_url(),
435 expires_at = %token.expires_at(),
436 "DC JWT obtained"
437 );
438
439 Ok(token)
440 }
441
442 /// Makes a POST request with retry logic for transient failures.
443 async fn post_with_retry(
444 &self,
445 url: &url::Url,
446 form_data: &HashMap<&str, String>,
447 ) -> SalesforceAuthResult<reqwest::Response> {
448 let mut last_error = None;
449
450 for attempt in 0..=self.config.max_retries {
451 if attempt > 0 {
452 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
453 warn!(
454 attempt = attempt,
455 delay_secs = delay.as_secs(),
456 "Retrying after transient failure"
457 );
458 tokio::time::sleep(delay).await;
459 }
460
461 match self
462 .http_client
463 .post(url.as_str())
464 .header("Accept", "application/json")
465 .header("Content-Type", "application/x-www-form-urlencoded")
466 .form(form_data)
467 .send()
468 .await
469 {
470 Ok(response) => {
471 if response.status().is_client_error() {
472 let status = response.status();
473 let body = response.text().await.unwrap_or_default();
474
475 if let Ok(error_json) = serde_json::from_str::<serde_json::Value>(&body) {
476 let error_code = error_json
477 .get("error")
478 .and_then(|v| v.as_str())
479 .unwrap_or("unknown");
480 let error_desc = error_json
481 .get("error_description")
482 .and_then(|v| v.as_str())
483 .unwrap_or(&body);
484
485 return Err(SalesforceAuthError::authorization(
486 error_code.to_string(),
487 error_desc.to_string(),
488 ));
489 }
490
491 return Err(SalesforceAuthError::http(format!(
492 "HTTP {status} error: {body}"
493 )));
494 }
495
496 if response.status().is_server_error() {
497 last_error = Some(SalesforceAuthError::http(format!(
498 "HTTP {} error",
499 response.status()
500 )));
501 continue;
502 }
503
504 return Ok(response);
505 }
506 Err(e) => {
507 last_error = Some(SalesforceAuthError::Http(e.to_string()));
508 }
509 }
510 }
511
512 Err(last_error.unwrap_or_else(|| SalesforceAuthError::http("request failed after retries")))
513 }
514}
515
516/// Thread-safe wrapper around [`DataCloudTokenProvider`].
517///
518/// Allows sharing the DC JWT provider between multiple tasks/threads
519/// while ensuring exclusive access during token operations. All access
520/// is protected by a [`tokio::sync::Mutex`].
521///
522/// # Example
523///
524/// ```no_run
525/// use hyperdb_api_salesforce::{SalesforceAuthConfig, AuthMode, SharedTokenProvider};
526///
527/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
528/// # let config = SalesforceAuthConfig::new("https://login.salesforce.com", "client_id")?
529/// # .auth_mode(AuthMode::password("user", "pass"));
530/// let provider = SharedTokenProvider::new(config)?;
531///
532/// // Can be cloned and shared between tasks
533/// let provider_clone = provider.clone();
534///
535/// tokio::spawn(async move {
536/// let dc_jwt = provider_clone.get_token().await.unwrap();
537/// // use dc_jwt.bearer_token() as the Authorization header
538/// });
539/// # Ok(())
540/// # }
541/// ```
542#[derive(Clone)]
543pub struct SharedTokenProvider {
544 inner: Arc<Mutex<DataCloudTokenProvider>>,
545}
546
547impl SharedTokenProvider {
548 /// Creates a new shared DC JWT provider.
549 ///
550 /// # Errors
551 ///
552 /// Propagates any error from [`DataCloudTokenProvider::new`]:
553 /// configuration validation failures or HTTP client construction
554 /// failures (surfaced as [`SalesforceAuthError::Http`]).
555 pub fn new(config: SalesforceAuthConfig) -> SalesforceAuthResult<Self> {
556 let provider = DataCloudTokenProvider::new(config)?;
557 Ok(SharedTokenProvider {
558 inner: Arc::new(Mutex::new(provider)),
559 })
560 }
561
562 /// Gets a valid DC JWT.
563 ///
564 /// # Errors
565 ///
566 /// Propagates any error from [`DataCloudTokenProvider::get_token`]
567 /// (HTTP failure, authorization rejection, JWT signing error, or
568 /// token-parse failure during the refresh cycle).
569 pub async fn get_token(&self) -> SalesforceAuthResult<DataCloudToken> {
570 let mut provider = self.inner.lock().await;
571 provider.get_token().await.cloned()
572 }
573
574 /// Forces a DC JWT refresh (reuses OAuth Access Token if still valid).
575 ///
576 /// # Errors
577 ///
578 /// Propagates any error from [`DataCloudTokenProvider::refresh_token`].
579 pub async fn refresh_token(&self) -> SalesforceAuthResult<DataCloudToken> {
580 let mut provider = self.inner.lock().await;
581 provider.refresh_token().await.cloned()
582 }
583
584 /// Forces a full refresh (both OAuth Access Token and DC JWT).
585 ///
586 /// # Errors
587 ///
588 /// Propagates any error from [`DataCloudTokenProvider::force_refresh`].
589 pub async fn force_refresh(&self) -> SalesforceAuthResult<DataCloudToken> {
590 let mut provider = self.inner.lock().await;
591 provider.force_refresh().await.cloned()
592 }
593
594 /// Returns the DC JWT bearer token string if a valid DC JWT is cached.
595 pub async fn bearer_token(&self) -> Option<String> {
596 let provider = self.inner.lock().await;
597 provider.bearer_token()
598 }
599
600 /// Returns the tenant URL if a valid DC JWT is cached.
601 pub async fn tenant_url(&self) -> Option<String> {
602 let provider = self.inner.lock().await;
603 provider.tenant_url().map(std::string::ToString::to_string)
604 }
605
606 /// Returns the lakehouse name for Hyper connection.
607 ///
608 /// # Errors
609 ///
610 /// Propagates [`SalesforceAuthError::TokenParse`] from
611 /// [`DataCloudTokenProvider::lakehouse_name`] if the cached DC JWT's
612 /// tenant URL cannot be parsed into a valid lakehouse identifier.
613 pub async fn lakehouse_name(&self) -> SalesforceAuthResult<Option<String>> {
614 let provider = self.inner.lock().await;
615 provider.lakehouse_name()
616 }
617}
618
619impl std::fmt::Debug for DataCloudTokenProvider {
620 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
621 f.debug_struct("DataCloudTokenProvider")
622 .field("config", &self.config)
623 .field("has_cached_oauth_token", &self.cached_oauth_token.is_some())
624 .field("has_cached_dc_jwt", &self.cached_dc_jwt.is_some())
625 .finish_non_exhaustive()
626 }
627}
628
629impl std::fmt::Debug for SharedTokenProvider {
630 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
631 f.debug_struct("SharedTokenProvider")
632 .finish_non_exhaustive()
633 }
634}