hyperdb_api_salesforce/provider.rs
1// Copyright (c) 2026, Salesforce, Inc. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! DC JWT provider for Salesforce Data Cloud authentication.
5//!
6//! This module implements the two-stage token flow:
7//! 1. Authenticate with Salesforce to get an **OAuth Access Token**
8//! (via `/services/oauth2/token`)
9//! 2. Exchange the OAuth Access Token for a **DC JWT**
10//! (via `/services/a360/token`)
11//!
12//! The provider caches both the OAuth Access Token and the DC JWT
13//! independently. The OAuth Access Token is only refreshed when it
14//! has genuinely expired, avoiding unnecessary **OAuth Refresh Token**
15//! rotation that would invalidate tokens held by other connections.
16
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::Duration;
20
21use reqwest::Client as HttpClient;
22use tokio::sync::Mutex;
23use tracing::{debug, info, warn};
24
25use crate::config::{AuthMode, SalesforceAuthConfig};
26use crate::error::{SalesforceAuthError, SalesforceAuthResult};
27use crate::jwt::build_jwt_assertion;
28use crate::token::{DataCloudToken, DataCloudTokenResponse, OAuthToken, OAuthTokenResponse};
29
30/// OAuth Access Token endpoint path.
31const OAUTH_TOKEN_PATH: &str = "services/oauth2/token";
32
33/// DC JWT exchange endpoint path.
34const DATA_CLOUD_TOKEN_PATH: &str = "services/a360/token";
35
36/// DC JWT provider.
37///
38/// Handles the full token flow for Salesforce Data Cloud:
39/// 1. Authenticates with Salesforce using the configured auth mode to
40/// obtain an **OAuth Access Token**
41/// 2. Exchanges the OAuth Access Token for a **DC JWT**
42/// 3. Caches both tokens and refreshes them independently:
43/// - The OAuth Access Token is refreshed only when genuinely expired
44/// (to avoid unnecessary OAuth Refresh Token rotation)
45/// - The DC JWT is refreshed whenever it is expired or requested
46///
47/// On DC JWT exchange failure, the provider retries once with a
48/// force-refreshed OAuth Access Token (Step 2a), matching the behavior
49/// described in the `GenieOAuthManagement` documentation.
50///
51/// # Example
52///
53/// ```no_run
54/// use hyperdb_api_salesforce::{SalesforceAuthConfig, AuthMode, DataCloudTokenProvider};
55///
56/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
57/// # let private_key_pem = "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----";
58/// let config = SalesforceAuthConfig::new(
59/// "https://login.salesforce.com",
60/// "your-client-id",
61/// )?
62/// .auth_mode(AuthMode::private_key("user@example.com", &private_key_pem)?);
63///
64/// let mut provider = DataCloudTokenProvider::new(config)?;
65///
66/// // Get a valid DC JWT (automatically handles the full token flow)
67/// let token = provider.get_token().await?;
68/// println!("Authorization: {}", token.bearer_token());
69/// # Ok(())
70/// # }
71/// ```
72pub struct DataCloudTokenProvider {
73 /// Configuration
74 config: SalesforceAuthConfig,
75 /// HTTP client for token requests
76 http_client: HttpClient,
77 /// Cached OAuth Access Token (refreshed only when genuinely expired)
78 cached_oauth_token: Option<OAuthToken>,
79 /// Cached DC JWT
80 cached_dc_jwt: Option<DataCloudToken>,
81}
82
83impl DataCloudTokenProvider {
84 /// Creates a new DC JWT provider with the given configuration.
85 ///
86 /// # Errors
87 ///
88 /// Returns an error if the configuration is invalid.
89 pub fn new(config: SalesforceAuthConfig) -> SalesforceAuthResult<Self> {
90 config.validate()?;
91
92 let http_client = HttpClient::builder()
93 .timeout(Duration::from_secs(config.timeout_secs))
94 .build()
95 .map_err(|e| SalesforceAuthError::Http(format!("failed to create HTTP client: {e}")))?;
96
97 Ok(DataCloudTokenProvider {
98 config,
99 http_client,
100 cached_oauth_token: None,
101 cached_dc_jwt: None,
102 })
103 }
104
105 /// Returns the configuration.
106 #[must_use]
107 pub fn config(&self) -> &SalesforceAuthConfig {
108 &self.config
109 }
110
111 /// Gets a valid DC JWT.
112 ///
113 /// If a cached DC JWT exists and is still valid, it is returned.
114 /// Otherwise, a new DC JWT is obtained through the full token flow.
115 ///
116 /// # Errors
117 ///
118 /// Propagates any error from `Self::fetch_dc_jwt` — typically
119 /// [`SalesforceAuthError::Http`], [`SalesforceAuthError::Authorization`],
120 /// [`SalesforceAuthError::Jwt`], [`SalesforceAuthError::TokenExchange`],
121 /// or [`SalesforceAuthError::TokenParse`] depending on where the
122 /// three-step refresh cycle (OAuth Access Token → DC JWT) fails.
123 ///
124 /// # Panics
125 ///
126 /// Does not panic in practice. The trailing `unwrap()` on
127 /// `self.cached_dc_jwt` is guarded by the preceding cache-population
128 /// logic: either the cache was already populated with a valid token,
129 /// or `Self::fetch_dc_jwt` just filled it.
130 pub async fn get_token(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
131 let needs_refresh = match &self.cached_dc_jwt {
132 Some(token) if token.is_valid() => {
133 debug!("Using cached DC JWT");
134 false
135 }
136 Some(_) => {
137 debug!("Cached DC JWT expired, refreshing");
138 true
139 }
140 None => true,
141 };
142
143 if needs_refresh {
144 let token = self.fetch_dc_jwt().await?;
145 self.cached_dc_jwt = Some(token);
146 }
147
148 Ok(self.cached_dc_jwt.as_ref().unwrap())
149 }
150
151 /// Forces a full token refresh (both OAuth Access Token and DC JWT),
152 /// even if the cached tokens are still valid.
153 ///
154 /// # Errors
155 ///
156 /// Propagates any error from [`Self::get_token`] (same failure modes
157 /// as the full token-flow refresh).
158 pub async fn force_refresh(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
159 self.cached_oauth_token = None;
160 self.cached_dc_jwt = None;
161 self.get_token().await
162 }
163
164 /// Forces a DC JWT refresh while allowing the OAuth Access Token to
165 /// be reused if still valid.
166 ///
167 /// This is the preferred refresh method during normal operation: it
168 /// re-exchanges the (possibly cached) OAuth Access Token for a fresh
169 /// DC JWT without unnecessarily rotating the OAuth Refresh Token.
170 ///
171 /// # Errors
172 ///
173 /// Propagates any error from [`Self::get_token`] (HTTP, authorization,
174 /// JWT signing, or token-parse failures during the DC JWT exchange).
175 pub async fn refresh_token(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
176 self.cached_dc_jwt = None;
177 self.get_token().await
178 }
179
180 /// Clears all cached tokens (both OAuth Access Token and DC JWT).
181 pub fn clear_cache(&mut self) {
182 self.cached_oauth_token = None;
183 self.cached_dc_jwt = None;
184 }
185
186 /// Returns the DC JWT bearer token string if a valid DC JWT is cached.
187 ///
188 /// Convenience method for getting the `Authorization` header value
189 /// without an async call. Returns `None` if no valid DC JWT is cached.
190 #[must_use]
191 pub fn bearer_token(&self) -> Option<String> {
192 self.cached_dc_jwt
193 .as_ref()
194 .filter(|t| t.is_valid())
195 .map(super::token::DataCloudToken::bearer_token)
196 }
197
198 /// Returns the tenant URL if a valid DC JWT is cached.
199 #[must_use]
200 pub fn tenant_url(&self) -> Option<&str> {
201 self.cached_dc_jwt
202 .as_ref()
203 .filter(|t| t.is_valid())
204 .map(super::token::DataCloudToken::tenant_url_str)
205 }
206
207 /// Returns the lakehouse name for Hyper connection.
208 ///
209 /// # Errors
210 ///
211 /// Propagates [`SalesforceAuthError::TokenParse`] from
212 /// [`DataCloudToken::lakehouse_name`] if the cached DC JWT's tenant
213 /// URL cannot be parsed into a valid lakehouse identifier.
214 pub fn lakehouse_name(&self) -> SalesforceAuthResult<Option<String>> {
215 if let Some(ref token) = self.cached_dc_jwt {
216 if token.is_valid() {
217 return Ok(Some(token.lakehouse_name(self.config.dataspace_value())?));
218 }
219 }
220 Ok(None)
221 }
222
223 /// Fetches a new DC JWT through the full token flow.
224 ///
225 /// Implements the three-step refresh cycle from the
226 /// `GenieOAuthManagement` documentation:
227 ///
228 /// - **Step 1**: Validate / refresh the OAuth Access Token
229 /// (only refreshes when genuinely expired — avoids unnecessary
230 /// OAuth Refresh Token rotation)
231 /// - **Step 2**: Exchange the OAuth Access Token for a DC JWT
232 /// - **Step 2a** (retry): If Step 2 fails, force-refresh the
233 /// OAuth Access Token and retry the DC JWT exchange once
234 async fn fetch_dc_jwt(&mut self) -> SalesforceAuthResult<DataCloudToken> {
235 // Step 1: Validate / refresh OAuth Access Token
236 let oauth_token = self.get_valid_oauth_access_token().await?;
237
238 // Step 2: Exchange OAuth Access Token → DC JWT
239 match self
240 .exchange_oauth_access_token_for_dc_jwt(&oauth_token)
241 .await
242 {
243 Ok(dc_jwt) => Ok(dc_jwt),
244 Err(step2_err) => {
245 // Step 2a: Force-refresh the OAuth Access Token and retry once.
246 // This handles the case where the OAuth Access Token appeared
247 // valid locally but was invalidated server-side (e.g., by
248 // Salesforce's inactivity timeout).
249 warn!(
250 error = %step2_err,
251 "DC JWT exchange failed; force-refreshing OAuth Access Token and retrying (Step 2a)"
252 );
253
254 self.cached_oauth_token = None;
255 let fresh_oauth_token = self.fetch_oauth_access_token().await?;
256 self.cached_oauth_token = Some(fresh_oauth_token.clone());
257
258 self.exchange_oauth_access_token_for_dc_jwt(&fresh_oauth_token)
259 .await
260 .map_err(|retry_err| {
261 warn!(
262 original_error = %step2_err,
263 retry_error = %retry_err,
264 "DC JWT exchange failed again after OAuth Access Token refresh (Step 2a retry)"
265 );
266 retry_err
267 })
268 }
269 }
270 }
271
272 /// Returns a valid OAuth Access Token, using the cache when possible.
273 ///
274 /// Only contacts Salesforce when the cached OAuth Access Token has
275 /// genuinely expired. This avoids unnecessary OAuth Refresh Token
276 /// rotation that would invalidate tokens held by other connections.
277 async fn get_valid_oauth_access_token(&mut self) -> SalesforceAuthResult<OAuthToken> {
278 if let Some(ref token) = self.cached_oauth_token {
279 if token.is_likely_valid() {
280 debug!(
281 "OAuth Access Token still valid (obtained at {}), reusing",
282 token.obtained_at
283 );
284 return Ok(token.clone());
285 }
286 debug!("Cached OAuth Access Token expired, refreshing");
287 }
288
289 let token = self.fetch_oauth_access_token().await?;
290 self.cached_oauth_token = Some(token.clone());
291 Ok(token)
292 }
293
294 /// Fetches a fresh OAuth Access Token from Salesforce.
295 async fn fetch_oauth_access_token(&self) -> SalesforceAuthResult<OAuthToken> {
296 let auth_mode =
297 self.config.auth_mode.as_ref().ok_or_else(|| {
298 SalesforceAuthError::Config("auth_mode not configured".to_string())
299 })?;
300
301 let mut form_data = HashMap::new();
302 form_data.insert("client_id", self.config.client_id.clone());
303
304 match auth_mode {
305 AuthMode::Password { username, password } => {
306 info!(username = %username, "Fetching OAuth Access Token via password grant");
307 form_data.insert("grant_type", "password".to_string());
308 form_data.insert("username", username.clone());
309 form_data.insert("password", password.as_str().to_string());
310
311 if let Some(ref secret) = self.config.client_secret {
312 form_data.insert("client_secret", secret.as_str().to_string());
313 }
314 }
315
316 AuthMode::PrivateKey {
317 username,
318 private_key,
319 } => {
320 info!(username = %username, "Fetching OAuth Access Token via JWT Bearer Token Flow");
321
322 let assertion = build_jwt_assertion(
323 &self.config.client_id,
324 username,
325 &self.config.login_url,
326 private_key,
327 )?;
328
329 form_data.insert(
330 "grant_type",
331 "urn:ietf:params:oauth:grant-type:jwt-bearer".to_string(),
332 );
333 form_data.insert("assertion", assertion);
334 }
335
336 AuthMode::RefreshToken { refresh_token } => {
337 info!("Fetching OAuth Access Token via OAuth Refresh Token");
338 form_data.insert("grant_type", "refresh_token".to_string());
339 form_data.insert("refresh_token", refresh_token.as_str().to_string());
340
341 if let Some(ref secret) = self.config.client_secret {
342 form_data.insert("client_secret", secret.as_str().to_string());
343 }
344 }
345 }
346
347 let token_url = self.config.login_url.join(OAUTH_TOKEN_PATH).map_err(|e| {
348 SalesforceAuthError::Config(format!("failed to build OAuth Access Token URL: {e}"))
349 })?;
350
351 debug!(url = %token_url, "Requesting OAuth Access Token");
352
353 let response = self.post_with_retry(&token_url, &form_data).await?;
354 let response_text = response.text().await?;
355
356 debug!(response = %response_text, "OAuth Access Token response received");
357
358 let oauth_response: OAuthTokenResponse =
359 serde_json::from_str(&response_text).map_err(|e| {
360 SalesforceAuthError::TokenParse(format!(
361 "failed to parse OAuth Access Token response: {e}"
362 ))
363 })?;
364
365 let token_changed = self
366 .cached_oauth_token
367 .as_ref()
368 .map_or(true, |old| old.token != oauth_response.access_token);
369
370 debug!(
371 instance_url = %oauth_response.instance_url,
372 token_type = ?oauth_response.token_type,
373 scope = ?oauth_response.scope,
374 token_changed = token_changed,
375 "OAuth Access Token response parsed"
376 );
377
378 OAuthToken::from_response(oauth_response)
379 }
380
381 /// Exchanges an OAuth Access Token for a DC JWT.
382 ///
383 /// Calls `POST /services/a360/token` with the OAuth Access Token as
384 /// the `subject_token`.
385 async fn exchange_oauth_access_token_for_dc_jwt(
386 &self,
387 oauth_token: &OAuthToken,
388 ) -> SalesforceAuthResult<DataCloudToken> {
389 let mut form_data = HashMap::new();
390 form_data.insert(
391 "grant_type",
392 "urn:salesforce:grant-type:external:cdp".to_string(),
393 );
394 form_data.insert(
395 "subject_token_type",
396 "urn:ietf:params:oauth:token-type:access_token".to_string(),
397 );
398 form_data.insert("subject_token", oauth_token.token.clone());
399
400 if let Some(ref dataspace) = self.config.dataspace {
401 form_data.insert("dataspace", dataspace.clone());
402 }
403
404 let exchange_url = oauth_token
405 .instance_url
406 .join(DATA_CLOUD_TOKEN_PATH)
407 .map_err(|e| {
408 SalesforceAuthError::Config(format!("failed to build DC JWT exchange URL: {e}"))
409 })?;
410
411 debug!(url = %exchange_url, "Exchanging OAuth Access Token for DC JWT");
412
413 let response = self.post_with_retry(&exchange_url, &form_data).await?;
414 let response_text = response.text().await?;
415
416 debug!(response = %response_text, "DC JWT response received");
417
418 let dc_response: DataCloudTokenResponse =
419 serde_json::from_str(&response_text).map_err(|e| {
420 SalesforceAuthError::TokenParse(format!("failed to parse DC JWT response: {e}"))
421 })?;
422
423 debug!(
424 instance_url = %dc_response.instance_url,
425 token_type = ?dc_response.token_type,
426 expires_in = ?dc_response.expires_in,
427 "DC JWT response parsed"
428 );
429
430 let token = DataCloudToken::from_response(dc_response)?;
431
432 info!(
433 tenant_url = %token.tenant_url(),
434 expires_at = %token.expires_at(),
435 "DC JWT obtained"
436 );
437
438 Ok(token)
439 }
440
441 /// Makes a POST request with retry logic for transient failures.
442 async fn post_with_retry(
443 &self,
444 url: &url::Url,
445 form_data: &HashMap<&str, String>,
446 ) -> SalesforceAuthResult<reqwest::Response> {
447 let mut last_error = None;
448
449 for attempt in 0..=self.config.max_retries {
450 if attempt > 0 {
451 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
452 warn!(
453 attempt = attempt,
454 delay_secs = delay.as_secs(),
455 "Retrying after transient failure"
456 );
457 tokio::time::sleep(delay).await;
458 }
459
460 match self
461 .http_client
462 .post(url.as_str())
463 .header("Accept", "application/json")
464 .header("Content-Type", "application/x-www-form-urlencoded")
465 .form(form_data)
466 .send()
467 .await
468 {
469 Ok(response) => {
470 if response.status().is_client_error() {
471 let status = response.status();
472 let body = response.text().await.unwrap_or_default();
473
474 if let Ok(error_json) = serde_json::from_str::<serde_json::Value>(&body) {
475 let error_code = error_json
476 .get("error")
477 .and_then(|v| v.as_str())
478 .unwrap_or("unknown");
479 let error_desc = error_json
480 .get("error_description")
481 .and_then(|v| v.as_str())
482 .unwrap_or(&body);
483
484 return Err(SalesforceAuthError::Authorization {
485 error_code: error_code.to_string(),
486 error_description: error_desc.to_string(),
487 });
488 }
489
490 return Err(SalesforceAuthError::Http(format!(
491 "HTTP {status} error: {body}"
492 )));
493 }
494
495 if response.status().is_server_error() {
496 last_error = Some(SalesforceAuthError::Http(format!(
497 "HTTP {} error",
498 response.status()
499 )));
500 continue;
501 }
502
503 return Ok(response);
504 }
505 Err(e) => {
506 last_error = Some(SalesforceAuthError::Http(e.to_string()));
507 }
508 }
509 }
510
511 Err(last_error.unwrap_or_else(|| {
512 SalesforceAuthError::Http("request failed after retries".to_string())
513 }))
514 }
515}
516
517/// Thread-safe wrapper around [`DataCloudTokenProvider`].
518///
519/// Allows sharing the DC JWT provider between multiple tasks/threads
520/// while ensuring exclusive access during token operations. All access
521/// is protected by a [`tokio::sync::Mutex`].
522///
523/// # Example
524///
525/// ```no_run
526/// use hyperdb_api_salesforce::{SalesforceAuthConfig, AuthMode, SharedTokenProvider};
527///
528/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
529/// # let config = SalesforceAuthConfig::new("https://login.salesforce.com", "client_id")?
530/// # .auth_mode(AuthMode::password("user", "pass"));
531/// let provider = SharedTokenProvider::new(config)?;
532///
533/// // Can be cloned and shared between tasks
534/// let provider_clone = provider.clone();
535///
536/// tokio::spawn(async move {
537/// let dc_jwt = provider_clone.get_token().await.unwrap();
538/// // use dc_jwt.bearer_token() as the Authorization header
539/// });
540/// # Ok(())
541/// # }
542/// ```
543#[derive(Clone)]
544pub struct SharedTokenProvider {
545 inner: Arc<Mutex<DataCloudTokenProvider>>,
546}
547
548impl SharedTokenProvider {
549 /// Creates a new shared DC JWT provider.
550 ///
551 /// # Errors
552 ///
553 /// Propagates any error from [`DataCloudTokenProvider::new`]:
554 /// configuration validation failures or HTTP client construction
555 /// failures (surfaced as [`SalesforceAuthError::Http`]).
556 pub fn new(config: SalesforceAuthConfig) -> SalesforceAuthResult<Self> {
557 let provider = DataCloudTokenProvider::new(config)?;
558 Ok(SharedTokenProvider {
559 inner: Arc::new(Mutex::new(provider)),
560 })
561 }
562
563 /// Gets a valid DC JWT.
564 ///
565 /// # Errors
566 ///
567 /// Propagates any error from [`DataCloudTokenProvider::get_token`]
568 /// (HTTP failure, authorization rejection, JWT signing error, or
569 /// token-parse failure during the refresh cycle).
570 pub async fn get_token(&self) -> SalesforceAuthResult<DataCloudToken> {
571 let mut provider = self.inner.lock().await;
572 provider.get_token().await.cloned()
573 }
574
575 /// Forces a DC JWT refresh (reuses OAuth Access Token if still valid).
576 ///
577 /// # Errors
578 ///
579 /// Propagates any error from [`DataCloudTokenProvider::refresh_token`].
580 pub async fn refresh_token(&self) -> SalesforceAuthResult<DataCloudToken> {
581 let mut provider = self.inner.lock().await;
582 provider.refresh_token().await.cloned()
583 }
584
585 /// Forces a full refresh (both OAuth Access Token and DC JWT).
586 ///
587 /// # Errors
588 ///
589 /// Propagates any error from [`DataCloudTokenProvider::force_refresh`].
590 pub async fn force_refresh(&self) -> SalesforceAuthResult<DataCloudToken> {
591 let mut provider = self.inner.lock().await;
592 provider.force_refresh().await.cloned()
593 }
594
595 /// Returns the DC JWT bearer token string if a valid DC JWT is cached.
596 pub async fn bearer_token(&self) -> Option<String> {
597 let provider = self.inner.lock().await;
598 provider.bearer_token()
599 }
600
601 /// Returns the tenant URL if a valid DC JWT is cached.
602 pub async fn tenant_url(&self) -> Option<String> {
603 let provider = self.inner.lock().await;
604 provider.tenant_url().map(std::string::ToString::to_string)
605 }
606
607 /// Returns the lakehouse name for Hyper connection.
608 ///
609 /// # Errors
610 ///
611 /// Propagates [`SalesforceAuthError::TokenParse`] from
612 /// [`DataCloudTokenProvider::lakehouse_name`] if the cached DC JWT's
613 /// tenant URL cannot be parsed into a valid lakehouse identifier.
614 pub async fn lakehouse_name(&self) -> SalesforceAuthResult<Option<String>> {
615 let provider = self.inner.lock().await;
616 provider.lakehouse_name()
617 }
618}
619
620impl std::fmt::Debug for DataCloudTokenProvider {
621 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
622 f.debug_struct("DataCloudTokenProvider")
623 .field("config", &self.config)
624 .field("has_cached_oauth_token", &self.cached_oauth_token.is_some())
625 .field("has_cached_dc_jwt", &self.cached_dc_jwt.is_some())
626 .finish_non_exhaustive()
627 }
628}
629
630impl std::fmt::Debug for SharedTokenProvider {
631 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
632 f.debug_struct("SharedTokenProvider")
633 .finish_non_exhaustive()
634 }
635}