use chrono::{DateTime, Duration, Utc};
use serde::Deserialize;
use url::Url;
use crate::error::{SalesforceAuthError, SalesforceAuthResult};
const DC_JWT_VALIDITY_BUFFER_SECS: i64 = 300;
#[derive(Debug, Deserialize)]
pub struct OAuthTokenResponse {
pub access_token: String,
pub instance_url: String,
#[serde(default)]
pub token_type: Option<String>,
#[serde(default)]
pub scope: Option<String>,
#[serde(default)]
pub issued_at: Option<String>,
#[serde(default)]
pub error: Option<String>,
#[serde(default)]
pub error_description: Option<String>,
}
impl OAuthTokenResponse {
pub fn check_error(&self) -> SalesforceAuthResult<()> {
if let (Some(code), Some(desc)) = (&self.error, &self.error_description) {
return Err(SalesforceAuthError::Authorization {
error_code: code.clone(),
error_description: desc.clone(),
});
}
if self.access_token.is_empty() {
return Err(SalesforceAuthError::TokenParse(
"missing access_token in OAuth Access Token response".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct OAuthToken {
pub token: String,
pub instance_url: Url,
pub obtained_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
const OAUTH_ACCESS_TOKEN_DEFAULT_LIFETIME_SECS: i64 = 7199;
impl OAuthToken {
pub fn from_response(response: OAuthTokenResponse) -> SalesforceAuthResult<Self> {
response.check_error()?;
let instance_url = Url::parse(&response.instance_url)
.map_err(|e| SalesforceAuthError::TokenParse(format!("invalid instance_url: {e}")))?;
let now = Utc::now();
let expires_at = now + Duration::seconds(OAUTH_ACCESS_TOKEN_DEFAULT_LIFETIME_SECS);
Ok(OAuthToken {
token: response.access_token,
instance_url,
obtained_at: now,
expires_at,
})
}
#[must_use]
pub fn bearer_token(&self) -> String {
format!("Bearer {}", self.token)
}
#[must_use]
pub fn is_likely_valid(&self) -> bool {
Utc::now() < self.expires_at
}
}
#[derive(Debug, Deserialize)]
pub struct DataCloudTokenResponse {
pub access_token: String,
pub instance_url: String,
#[serde(default)]
pub token_type: Option<String>,
#[serde(default)]
pub expires_in: Option<i64>,
#[serde(default)]
pub error: Option<String>,
#[serde(default)]
pub error_description: Option<String>,
}
impl DataCloudTokenResponse {
pub fn check_error(&self) -> SalesforceAuthResult<()> {
if let (Some(code), Some(desc)) = (&self.error, &self.error_description) {
return Err(SalesforceAuthError::Authorization {
error_code: code.clone(),
error_description: desc.clone(),
});
}
if self.access_token.is_empty() {
return Err(SalesforceAuthError::TokenParse(
"missing access_token in DC JWT response".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct DataCloudToken {
token_type: String,
token: String,
tenant_url: Url,
created_at: DateTime<Utc>,
expires_at: DateTime<Utc>,
}
impl DataCloudToken {
pub fn from_response(response: DataCloudTokenResponse) -> SalesforceAuthResult<Self> {
response.check_error()?;
let instance_url_with_scheme = if response.instance_url.starts_with("http://")
|| response.instance_url.starts_with("https://")
{
response.instance_url.clone()
} else {
format!("https://{}", response.instance_url)
};
let tenant_url = Url::parse(&instance_url_with_scheme)
.map_err(|e| SalesforceAuthError::TokenParse(format!("invalid instance_url: {e}")))?;
let token_type = response.token_type.unwrap_or_else(|| "Bearer".to_string());
let now = Utc::now();
let expires_in_secs = response.expires_in.unwrap_or(1800);
let expires_at = now + Duration::seconds(expires_in_secs);
Ok(DataCloudToken {
token_type,
token: response.access_token,
tenant_url,
created_at: now,
expires_at,
})
}
#[must_use]
pub fn bearer_token(&self) -> String {
format!("{} {}", self.token_type, self.token)
}
#[must_use]
pub fn access_token(&self) -> &str {
&self.token
}
#[must_use]
pub fn token_type(&self) -> &str {
&self.token_type
}
#[must_use]
pub fn tenant_url(&self) -> &Url {
&self.tenant_url
}
#[must_use]
pub fn tenant_url_str(&self) -> &str {
self.tenant_url.as_str()
}
#[must_use]
pub fn created_at(&self) -> DateTime<Utc> {
self.created_at
}
#[must_use]
pub fn expires_at(&self) -> DateTime<Utc> {
self.expires_at
}
#[must_use]
pub fn age(&self) -> Duration {
Utc::now().signed_duration_since(self.created_at)
}
#[must_use]
pub fn remaining_lifetime(&self) -> Duration {
self.expires_at.signed_duration_since(Utc::now())
}
#[must_use]
pub fn is_valid(&self) -> bool {
self.expires_at > Utc::now() + Duration::seconds(DC_JWT_VALIDITY_BUFFER_SECS)
}
#[must_use]
pub fn is_expired(&self) -> bool {
self.expires_at <= Utc::now()
}
#[must_use]
pub fn needs_refresh(&self, threshold_secs: i64, max_age_secs: i64) -> bool {
let now = Utc::now();
let expiring = (self.expires_at - now).num_seconds() <= threshold_secs;
let too_old = (now - self.created_at).num_seconds() > max_age_secs;
expiring || too_old
}
pub fn tenant_id(&self) -> SalesforceAuthResult<String> {
let parts: Vec<&str> = self.token.split('.').collect();
if parts.len() != 3 {
return Err(SalesforceAuthError::TokenParse(
"invalid DC JWT format: expected 3 parts".to_string(),
));
}
let payload_b64 = parts[1];
let payload_bytes = base64_url_decode(payload_b64)?;
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)?;
payload
.get("audienceTenantId")
.and_then(|v| v.as_str())
.map(std::string::ToString::to_string)
.ok_or_else(|| {
SalesforceAuthError::TokenParse(
"missing audienceTenantId in DC JWT payload".to_string(),
)
})
}
pub fn lakehouse_name(&self, dataspace: Option<&str>) -> SalesforceAuthResult<String> {
let tenant_id = self.tenant_id()?;
let dataspace_str = dataspace.unwrap_or("");
Ok(format!("lakehouse:{tenant_id};{dataspace_str}"))
}
}
fn base64_url_decode(input: &str) -> SalesforceAuthResult<Vec<u8>> {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
let padded = match input.len() % 4 {
2 => format!("{input}=="),
3 => format!("{input}="),
_ => input.to_string(),
};
URL_SAFE_NO_PAD
.decode(padded.trim_end_matches('='))
.map_err(|e| SalesforceAuthError::TokenParse(format!("base64 decode error: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_oauth_access_token_response_error() {
let response = OAuthTokenResponse {
access_token: String::new(),
instance_url: String::new(),
token_type: None,
scope: None,
issued_at: None,
error: Some("invalid_grant".to_string()),
error_description: Some("authentication failure".to_string()),
};
let result = response.check_error();
assert!(result.is_err());
if let Err(SalesforceAuthError::Authorization { error_code, .. }) = result {
assert_eq!(error_code, "invalid_grant");
} else {
panic!("expected Authorization error");
}
}
#[test]
fn test_oauth_access_token_from_response() {
let response = OAuthTokenResponse {
access_token: "oauth_access_tok_123".to_string(),
instance_url: "https://na1.salesforce.com".to_string(),
token_type: Some("Bearer".to_string()),
scope: None,
issued_at: None,
error: None,
error_description: None,
};
let token = OAuthToken::from_response(response).unwrap();
assert_eq!(token.token, "oauth_access_tok_123");
assert_eq!(token.instance_url.as_str(), "https://na1.salesforce.com/");
assert!(token.is_likely_valid());
assert_eq!(token.bearer_token(), "Bearer oauth_access_tok_123");
}
#[test]
fn test_dc_jwt_validity() {
let response = DataCloudTokenResponse {
access_token: "test.token.here".to_string(),
instance_url: "https://tenant.salesforce.com".to_string(),
token_type: Some("Bearer".to_string()),
expires_in: Some(3600), error: None,
error_description: None,
};
let token = DataCloudToken::from_response(response).unwrap();
assert!(token.is_valid());
assert!(!token.is_expired());
assert_eq!(token.bearer_token(), "Bearer test.token.here");
assert!(token.age().num_seconds() < 2);
assert!(token.remaining_lifetime().num_seconds() > 3500);
}
#[test]
fn test_dc_jwt_needs_refresh_when_fresh() {
let response = DataCloudTokenResponse {
access_token: "fresh.dc.jwt".to_string(),
instance_url: "https://tenant.salesforce.com".to_string(),
token_type: Some("Bearer".to_string()),
expires_in: Some(7200),
error: None,
error_description: None,
};
let token = DataCloudToken::from_response(response).unwrap();
assert!(!token.needs_refresh(300, 900));
}
#[test]
fn test_dc_jwt_needs_refresh_near_expiry() {
let response = DataCloudTokenResponse {
access_token: "expiring.dc.jwt".to_string(),
instance_url: "https://tenant.salesforce.com".to_string(),
token_type: Some("Bearer".to_string()),
expires_in: Some(200), error: None,
error_description: None,
};
let token = DataCloudToken::from_response(response).unwrap();
assert!(token.needs_refresh(300, 900));
}
#[test]
fn test_dc_jwt_needs_refresh_too_old() {
let mut token = DataCloudToken::from_response(DataCloudTokenResponse {
access_token: "old.dc.jwt".to_string(),
instance_url: "https://tenant.salesforce.com".to_string(),
token_type: Some("Bearer".to_string()),
expires_in: Some(7200),
error: None,
error_description: None,
})
.unwrap();
token.created_at = Utc::now() - Duration::minutes(20);
assert!(token.needs_refresh(300, 900));
}
#[test]
fn test_dc_jwt_created_at_tracked() {
let before = Utc::now();
let response = DataCloudTokenResponse {
access_token: "dc.jwt.value".to_string(),
instance_url: "https://tenant.salesforce.com".to_string(),
token_type: Some("Bearer".to_string()),
expires_in: Some(3600),
error: None,
error_description: None,
};
let token = DataCloudToken::from_response(response).unwrap();
let after = Utc::now();
assert!(token.created_at() >= before);
assert!(token.created_at() <= after);
}
#[test]
fn test_dc_jwt_is_valid_uses_5min_buffer() {
let response = DataCloudTokenResponse {
access_token: "almost.expired.jwt".to_string(),
instance_url: "https://tenant.salesforce.com".to_string(),
token_type: Some("Bearer".to_string()),
expires_in: Some(240), error: None,
error_description: None,
};
let token = DataCloudToken::from_response(response).unwrap();
assert!(!token.is_valid());
assert!(!token.is_expired());
let response2 = DataCloudTokenResponse {
access_token: "still.valid.jwt".to_string(),
instance_url: "https://tenant.salesforce.com".to_string(),
token_type: Some("Bearer".to_string()),
expires_in: Some(360), error: None,
error_description: None,
};
let token2 = DataCloudToken::from_response(response2).unwrap();
assert!(token2.is_valid());
}
}