use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use reqwest::Client as HttpClient;
use tokio::sync::Mutex;
use tracing::{debug, info, warn};
use crate::config::{AuthMode, SalesforceAuthConfig};
use crate::error::{SalesforceAuthError, SalesforceAuthResult};
use crate::jwt::build_jwt_assertion;
use crate::token::{DataCloudToken, DataCloudTokenResponse, OAuthToken, OAuthTokenResponse};
const OAUTH_TOKEN_PATH: &str = "services/oauth2/token";
const DATA_CLOUD_TOKEN_PATH: &str = "services/a360/token";
pub struct DataCloudTokenProvider {
config: SalesforceAuthConfig,
http_client: HttpClient,
cached_oauth_token: Option<OAuthToken>,
cached_dc_jwt: Option<DataCloudToken>,
}
impl DataCloudTokenProvider {
pub fn new(config: SalesforceAuthConfig) -> SalesforceAuthResult<Self> {
config.validate()?;
let http_client = HttpClient::builder()
.timeout(Duration::from_secs(config.timeout_secs))
.build()
.map_err(|e| SalesforceAuthError::Http(format!("failed to create HTTP client: {e}")))?;
Ok(DataCloudTokenProvider {
config,
http_client,
cached_oauth_token: None,
cached_dc_jwt: None,
})
}
#[must_use]
pub fn config(&self) -> &SalesforceAuthConfig {
&self.config
}
pub async fn get_token(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
let needs_refresh = match &self.cached_dc_jwt {
Some(token) if token.is_valid() => {
debug!("Using cached DC JWT");
false
}
Some(_) => {
debug!("Cached DC JWT expired, refreshing");
true
}
None => true,
};
if needs_refresh {
let token = self.fetch_dc_jwt().await?;
self.cached_dc_jwt = Some(token);
}
Ok(self.cached_dc_jwt.as_ref().unwrap())
}
pub async fn force_refresh(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
self.cached_oauth_token = None;
self.cached_dc_jwt = None;
self.get_token().await
}
pub async fn refresh_token(&mut self) -> SalesforceAuthResult<&DataCloudToken> {
self.cached_dc_jwt = None;
self.get_token().await
}
pub fn clear_cache(&mut self) {
self.cached_oauth_token = None;
self.cached_dc_jwt = None;
}
#[must_use]
pub fn bearer_token(&self) -> Option<String> {
self.cached_dc_jwt
.as_ref()
.filter(|t| t.is_valid())
.map(super::token::DataCloudToken::bearer_token)
}
#[must_use]
pub fn tenant_url(&self) -> Option<&str> {
self.cached_dc_jwt
.as_ref()
.filter(|t| t.is_valid())
.map(super::token::DataCloudToken::tenant_url_str)
}
pub fn lakehouse_name(&self) -> SalesforceAuthResult<Option<String>> {
if let Some(ref token) = self.cached_dc_jwt {
if token.is_valid() {
return Ok(Some(token.lakehouse_name(self.config.dataspace_value())?));
}
}
Ok(None)
}
async fn fetch_dc_jwt(&mut self) -> SalesforceAuthResult<DataCloudToken> {
let oauth_token = self.get_valid_oauth_access_token().await?;
match self
.exchange_oauth_access_token_for_dc_jwt(&oauth_token)
.await
{
Ok(dc_jwt) => Ok(dc_jwt),
Err(step2_err) => {
warn!(
error = %step2_err,
"DC JWT exchange failed; force-refreshing OAuth Access Token and retrying (Step 2a)"
);
self.cached_oauth_token = None;
let fresh_oauth_token = self.fetch_oauth_access_token().await?;
self.cached_oauth_token = Some(fresh_oauth_token.clone());
self.exchange_oauth_access_token_for_dc_jwt(&fresh_oauth_token)
.await
.map_err(|retry_err| {
warn!(
original_error = %step2_err,
retry_error = %retry_err,
"DC JWT exchange failed again after OAuth Access Token refresh (Step 2a retry)"
);
retry_err
})
}
}
}
async fn get_valid_oauth_access_token(&mut self) -> SalesforceAuthResult<OAuthToken> {
if let Some(ref token) = self.cached_oauth_token {
if token.is_likely_valid() {
debug!(
"OAuth Access Token still valid (obtained at {}), reusing",
token.obtained_at
);
return Ok(token.clone());
}
debug!("Cached OAuth Access Token expired, refreshing");
}
let token = self.fetch_oauth_access_token().await?;
self.cached_oauth_token = Some(token.clone());
Ok(token)
}
async fn fetch_oauth_access_token(&self) -> SalesforceAuthResult<OAuthToken> {
let auth_mode =
self.config.auth_mode.as_ref().ok_or_else(|| {
SalesforceAuthError::Config("auth_mode not configured".to_string())
})?;
let mut form_data = HashMap::new();
form_data.insert("client_id", self.config.client_id.clone());
match auth_mode {
AuthMode::Password { username, password } => {
info!(username = %username, "Fetching OAuth Access Token via password grant");
form_data.insert("grant_type", "password".to_string());
form_data.insert("username", username.clone());
form_data.insert("password", password.as_str().to_string());
if let Some(ref secret) = self.config.client_secret {
form_data.insert("client_secret", secret.as_str().to_string());
}
}
AuthMode::PrivateKey {
username,
private_key,
} => {
info!(username = %username, "Fetching OAuth Access Token via JWT Bearer Token Flow");
let assertion = build_jwt_assertion(
&self.config.client_id,
username,
&self.config.login_url,
private_key,
)?;
form_data.insert(
"grant_type",
"urn:ietf:params:oauth:grant-type:jwt-bearer".to_string(),
);
form_data.insert("assertion", assertion);
}
AuthMode::RefreshToken { refresh_token } => {
info!("Fetching OAuth Access Token via OAuth Refresh Token");
form_data.insert("grant_type", "refresh_token".to_string());
form_data.insert("refresh_token", refresh_token.as_str().to_string());
if let Some(ref secret) = self.config.client_secret {
form_data.insert("client_secret", secret.as_str().to_string());
}
}
}
let token_url = self.config.login_url.join(OAUTH_TOKEN_PATH).map_err(|e| {
SalesforceAuthError::Config(format!("failed to build OAuth Access Token URL: {e}"))
})?;
debug!(url = %token_url, "Requesting OAuth Access Token");
let response = self.post_with_retry(&token_url, &form_data).await?;
let response_text = response.text().await?;
debug!(response = %response_text, "OAuth Access Token response received");
let oauth_response: OAuthTokenResponse =
serde_json::from_str(&response_text).map_err(|e| {
SalesforceAuthError::TokenParse(format!(
"failed to parse OAuth Access Token response: {e}"
))
})?;
let token_changed = self
.cached_oauth_token
.as_ref()
.map_or(true, |old| old.token != oauth_response.access_token);
debug!(
instance_url = %oauth_response.instance_url,
token_type = ?oauth_response.token_type,
scope = ?oauth_response.scope,
token_changed = token_changed,
"OAuth Access Token response parsed"
);
OAuthToken::from_response(oauth_response)
}
async fn exchange_oauth_access_token_for_dc_jwt(
&self,
oauth_token: &OAuthToken,
) -> SalesforceAuthResult<DataCloudToken> {
let mut form_data = HashMap::new();
form_data.insert(
"grant_type",
"urn:salesforce:grant-type:external:cdp".to_string(),
);
form_data.insert(
"subject_token_type",
"urn:ietf:params:oauth:token-type:access_token".to_string(),
);
form_data.insert("subject_token", oauth_token.token.clone());
if let Some(ref dataspace) = self.config.dataspace {
form_data.insert("dataspace", dataspace.clone());
}
let exchange_url = oauth_token
.instance_url
.join(DATA_CLOUD_TOKEN_PATH)
.map_err(|e| {
SalesforceAuthError::Config(format!("failed to build DC JWT exchange URL: {e}"))
})?;
debug!(url = %exchange_url, "Exchanging OAuth Access Token for DC JWT");
let response = self.post_with_retry(&exchange_url, &form_data).await?;
let response_text = response.text().await?;
debug!(response = %response_text, "DC JWT response received");
let dc_response: DataCloudTokenResponse =
serde_json::from_str(&response_text).map_err(|e| {
SalesforceAuthError::TokenParse(format!("failed to parse DC JWT response: {e}"))
})?;
debug!(
instance_url = %dc_response.instance_url,
token_type = ?dc_response.token_type,
expires_in = ?dc_response.expires_in,
"DC JWT response parsed"
);
let token = DataCloudToken::from_response(dc_response)?;
info!(
tenant_url = %token.tenant_url(),
expires_at = %token.expires_at(),
"DC JWT obtained"
);
Ok(token)
}
async fn post_with_retry(
&self,
url: &url::Url,
form_data: &HashMap<&str, String>,
) -> SalesforceAuthResult<reqwest::Response> {
let mut last_error = None;
for attempt in 0..=self.config.max_retries {
if attempt > 0 {
let delay = Duration::from_secs(1 << (attempt - 1).min(4));
warn!(
attempt = attempt,
delay_secs = delay.as_secs(),
"Retrying after transient failure"
);
tokio::time::sleep(delay).await;
}
match self
.http_client
.post(url.as_str())
.header("Accept", "application/json")
.header("Content-Type", "application/x-www-form-urlencoded")
.form(form_data)
.send()
.await
{
Ok(response) => {
if response.status().is_client_error() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
if let Ok(error_json) = serde_json::from_str::<serde_json::Value>(&body) {
let error_code = error_json
.get("error")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let error_desc = error_json
.get("error_description")
.and_then(|v| v.as_str())
.unwrap_or(&body);
return Err(SalesforceAuthError::Authorization {
error_code: error_code.to_string(),
error_description: error_desc.to_string(),
});
}
return Err(SalesforceAuthError::Http(format!(
"HTTP {status} error: {body}"
)));
}
if response.status().is_server_error() {
last_error = Some(SalesforceAuthError::Http(format!(
"HTTP {} error",
response.status()
)));
continue;
}
return Ok(response);
}
Err(e) => {
last_error = Some(SalesforceAuthError::Http(e.to_string()));
}
}
}
Err(last_error.unwrap_or_else(|| {
SalesforceAuthError::Http("request failed after retries".to_string())
}))
}
}
#[derive(Clone)]
pub struct SharedTokenProvider {
inner: Arc<Mutex<DataCloudTokenProvider>>,
}
impl SharedTokenProvider {
pub fn new(config: SalesforceAuthConfig) -> SalesforceAuthResult<Self> {
let provider = DataCloudTokenProvider::new(config)?;
Ok(SharedTokenProvider {
inner: Arc::new(Mutex::new(provider)),
})
}
pub async fn get_token(&self) -> SalesforceAuthResult<DataCloudToken> {
let mut provider = self.inner.lock().await;
provider.get_token().await.cloned()
}
pub async fn refresh_token(&self) -> SalesforceAuthResult<DataCloudToken> {
let mut provider = self.inner.lock().await;
provider.refresh_token().await.cloned()
}
pub async fn force_refresh(&self) -> SalesforceAuthResult<DataCloudToken> {
let mut provider = self.inner.lock().await;
provider.force_refresh().await.cloned()
}
pub async fn bearer_token(&self) -> Option<String> {
let provider = self.inner.lock().await;
provider.bearer_token()
}
pub async fn tenant_url(&self) -> Option<String> {
let provider = self.inner.lock().await;
provider.tenant_url().map(std::string::ToString::to_string)
}
pub async fn lakehouse_name(&self) -> SalesforceAuthResult<Option<String>> {
let provider = self.inner.lock().await;
provider.lakehouse_name()
}
}
impl std::fmt::Debug for DataCloudTokenProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DataCloudTokenProvider")
.field("config", &self.config)
.field("has_cached_oauth_token", &self.cached_oauth_token.is_some())
.field("has_cached_dc_jwt", &self.cached_dc_jwt.is_some())
.finish_non_exhaustive()
}
}
impl std::fmt::Debug for SharedTokenProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SharedTokenProvider")
.finish_non_exhaustive()
}
}