use oauth2::basic::BasicTokenType;
use oauth2::{EmptyExtraTokenFields, RefreshToken, TokenResponse};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
use std::time::{SystemTime, UNIX_EPOCH};
const DEFAULT_TOKEN_PATH: &str = "/services/oauth2/token";
const DEFAULT_TOKEN_REFRESH_BUFFER_SECONDS: u64 = 300;
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum Error {
#[error("Failed to read credentials file at {path}: {source}")]
ReadCredentials {
path: std::path::PathBuf,
#[source]
source: std::io::Error,
},
#[error("Failed to parse credentials JSON: {source}")]
ParseCredentials {
#[source]
source: serde_json::Error,
},
#[error("Invalid URL format: {source}")]
ParseUrl {
#[source]
source: url::ParseError,
},
#[error("OAuth2 token exchange failed: {source}")]
TokenExchange {
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("OAuth2 request failed with status {status}: {body}")]
OAuth2RequestFailed { status: u16, body: String },
#[error("Credentials not provided: call credentials() or credentials_path() before build()")]
MissingCredentials,
#[error("Client secret is required for authentication")]
MissingClientSecret,
#[error("Username is required for username-password authentication")]
MissingUsername,
#[error("Password is required for username-password authentication")]
MissingPassword,
#[error("Failed to get current system time: {source}")]
SystemTimeError {
#[source]
source: std::time::SystemTimeError,
},
#[error("Token expiry time calculation overflow")]
TokenExpiryOverflow,
#[error("Time threshold calculation overflow")]
TimeThresholdOverflow,
#[error("Token refresh not available: no refresh token in response")]
NoRefreshToken,
#[error("Failed to acquire lock on token state")]
LockError,
#[error(
"Client is not connected. Call connect() first to authenticate and retrieve instance URL."
)]
NotConnected,
#[error("Failed to build HTTP client: {source}")]
HttpClientBuild {
#[source]
source: reqwest::Error,
},
}
pub type SalesforceTokenResponse =
oauth2::StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
#[derive(Debug, Clone)]
pub(crate) struct TokenState {
token_response: SalesforceTokenResponse,
expires_at: u64,
}
impl TokenState {
pub(crate) fn new(token_response: SalesforceTokenResponse) -> Result<Self, Error> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|source| Error::SystemTimeError { source })?
.as_secs();
let expires_at = if let Some(expires_in) = token_response.expires_in() {
now.checked_add(expires_in.as_secs())
.ok_or(Error::TokenExpiryOverflow)?
} else {
now.checked_add(7200).ok_or(Error::TokenExpiryOverflow)?
};
Ok(Self {
token_response,
expires_at,
})
}
fn is_expired(&self, buffer_seconds: u64) -> Result<bool, Error> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|source| Error::SystemTimeError { source })?
.as_secs();
let threshold = now
.checked_add(buffer_seconds)
.ok_or(Error::TimeThresholdOverflow)?;
Ok(threshold >= self.expires_at)
}
fn access_token(&self) -> &str {
self.token_response.access_token().secret()
}
fn refresh_token(&self) -> Option<&RefreshToken> {
self.token_response.refresh_token()
}
}
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum AuthFlow {
#[default]
ClientCredentials,
UsernamePassword,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Credentials {
pub client_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub username: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub password: Option<String>,
pub instance_url: String,
pub tenant_id: String,
}
#[derive(Debug, Clone)]
pub enum CredentialsFrom {
Path(PathBuf),
Value(Credentials),
}
#[derive(Debug, Clone)]
#[allow(clippy::type_complexity)]
pub struct Client {
credentials_from: CredentialsFrom,
auth_flow: AuthFlow,
pub(crate) token_state: Option<Arc<RwLock<TokenState>>>,
pub instance_url: Option<String>,
pub tenant_id: Option<String>,
}
impl Client {
fn validate_credentials(&self, credentials: &Credentials) -> Result<(), Error> {
match self.auth_flow {
AuthFlow::ClientCredentials => {
credentials
.client_secret
.as_ref()
.ok_or(Error::MissingClientSecret)?;
}
AuthFlow::UsernamePassword => {
credentials
.client_secret
.as_ref()
.ok_or(Error::MissingClientSecret)?;
credentials
.username
.as_ref()
.ok_or(Error::MissingUsername)?;
credentials
.password
.as_ref()
.ok_or(Error::MissingPassword)?;
}
}
Ok(())
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub async fn connect(mut self) -> Result<Self, Error> {
let credentials = self.load_credentials().await?;
self.validate_credentials(&credentials)?;
let http_client = Self::build_auth_http_client()?;
let token_response = match self.auth_flow {
AuthFlow::ClientCredentials => {
self.exchange_client_credentials(&credentials, &http_client)
.await?
}
AuthFlow::UsernamePassword => {
self.exchange_password(&credentials, &http_client).await?
}
};
let token_state = TokenState::new(token_response)?;
self.token_state = Some(Arc::new(RwLock::new(token_state)));
self.instance_url = Some(credentials.instance_url);
self.tenant_id = Some(credentials.tenant_id);
Ok(self)
}
async fn exchange_client_credentials(
&self,
credentials: &Credentials,
http_client: &reqwest::Client,
) -> Result<SalesforceTokenResponse, Error> {
let client_secret = credentials
.client_secret
.as_ref()
.ok_or(Error::MissingClientSecret)?;
let token_url = format!("{}{}", credentials.instance_url, DEFAULT_TOKEN_PATH);
let response = http_client
.post(&token_url)
.form(&[
("grant_type", "client_credentials"),
("client_id", &credentials.client_id),
("client_secret", client_secret),
])
.send()
.await
.map_err(|e| Error::TokenExchange {
source: Box::new(e),
})?;
let status = response.status();
let body = response.text().await.map_err(|e| Error::TokenExchange {
source: Box::new(e),
})?;
if !status.is_success() {
return Err(Error::OAuth2RequestFailed {
status: status.as_u16(),
body,
});
}
serde_json::from_str(&body).map_err(|e| Error::TokenExchange {
source: Box::new(e),
})
}
async fn exchange_password(
&self,
credentials: &Credentials,
http_client: &reqwest::Client,
) -> Result<SalesforceTokenResponse, Error> {
let client_secret = credentials
.client_secret
.as_ref()
.ok_or(Error::MissingClientSecret)?;
let username = credentials
.username
.as_ref()
.ok_or(Error::MissingUsername)?;
let password = credentials
.password
.as_ref()
.ok_or(Error::MissingPassword)?;
let token_url = format!("{}{}", credentials.instance_url, DEFAULT_TOKEN_PATH);
let response = http_client
.post(&token_url)
.form(&[
("grant_type", "password"),
("client_id", &credentials.client_id),
("client_secret", client_secret),
("username", username),
("password", password),
])
.send()
.await
.map_err(|e| Error::TokenExchange {
source: Box::new(e),
})?;
let status = response.status();
let body = response.text().await.map_err(|e| Error::TokenExchange {
source: Box::new(e),
})?;
if !status.is_success() {
return Err(Error::OAuth2RequestFailed {
status: status.as_u16(),
body,
});
}
serde_json::from_str(&body).map_err(|e| Error::TokenExchange {
source: Box::new(e),
})
}
async fn load_credentials(&self) -> Result<Credentials, Error> {
match &self.credentials_from {
CredentialsFrom::Value(creds) => Ok(creds.clone()),
CredentialsFrom::Path(path) => {
let credentials_string =
fs::read_to_string(path).map_err(|e| Error::ReadCredentials {
path: path.clone(),
source: e,
})?;
serde_json::from_str(&credentials_string)
.map_err(|e| Error::ParseCredentials { source: e })
}
}
}
fn build_auth_http_client() -> Result<reqwest::Client, Error> {
reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.connect_timeout(std::time::Duration::from_secs(
crate::DEFAULT_AUTH_CONNECT_TIMEOUT_SECS,
))
.timeout(std::time::Duration::from_secs(
crate::DEFAULT_AUTH_REQUEST_TIMEOUT_SECS,
))
.build()
.map_err(|source| Error::HttpClientBuild { source })
}
async fn reauthenticate(&self) -> Result<(), Error> {
let token_state_arc = self.token_state.as_ref().ok_or(Error::NotConnected)?;
let credentials = self.load_credentials().await?;
self.validate_credentials(&credentials)?;
let http_client = Self::build_auth_http_client()?;
let token_response = match self.auth_flow {
AuthFlow::ClientCredentials => {
self.exchange_client_credentials(&credentials, &http_client)
.await?
}
AuthFlow::UsernamePassword => {
self.exchange_password(&credentials, &http_client).await?
}
};
let new_state = TokenState::new(token_response)?;
let mut state = token_state_arc.write().map_err(|_| Error::LockError)?;
*state = new_state;
Ok(())
}
async fn refresh_token(&self) -> Result<(), Error> {
let token_state_arc = self.token_state.as_ref().ok_or(Error::NoRefreshToken)?;
let refresh_token = {
let state = token_state_arc.read().map_err(|_| Error::LockError)?;
state.refresh_token().ok_or(Error::NoRefreshToken)?.clone()
};
let credentials = self.load_credentials().await?;
let client_secret = credentials
.client_secret
.as_ref()
.ok_or(Error::MissingClientSecret)?;
let token_url = format!("{}{}", credentials.instance_url, DEFAULT_TOKEN_PATH);
let http_client = Self::build_auth_http_client()?;
let response = http_client
.post(&token_url)
.form(&[
("grant_type", "refresh_token"),
("client_id", &credentials.client_id),
("client_secret", client_secret),
("refresh_token", refresh_token.secret()),
])
.send()
.await
.map_err(|e| Error::TokenExchange {
source: Box::new(e),
})?;
let status = response.status();
let body = response.text().await.map_err(|e| Error::TokenExchange {
source: Box::new(e),
})?;
if !status.is_success() {
return Err(Error::OAuth2RequestFailed {
status: status.as_u16(),
body,
});
}
let new_token_response: SalesforceTokenResponse =
serde_json::from_str(&body).map_err(|e| Error::TokenExchange {
source: Box::new(e),
})?;
let new_state = TokenState::new(new_token_response)?;
let mut state = token_state_arc.write().map_err(|_| Error::LockError)?;
*state = new_state;
Ok(())
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub fn current_access_token(&self) -> Result<String, Error> {
let token_state_arc = self.token_state.as_ref().ok_or(Error::NotConnected)?;
let state = token_state_arc.read().map_err(|_| Error::LockError)?;
Ok(state.access_token().to_string())
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub async fn reconnect(&mut self) -> Result<(), Error> {
let credentials = self.load_credentials().await?;
self.validate_credentials(&credentials)?;
let http_client = Self::build_auth_http_client()?;
let token_response = match self.auth_flow {
AuthFlow::ClientCredentials => {
self.exchange_client_credentials(&credentials, &http_client)
.await?
}
AuthFlow::UsernamePassword => {
self.exchange_password(&credentials, &http_client).await?
}
};
let token_state = TokenState::new(token_response)?;
self.token_state = Some(Arc::new(RwLock::new(token_state)));
Ok(())
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub async fn access_token(&self) -> Result<String, Error> {
let token_state_arc = self.token_state.as_ref().ok_or(Error::NotConnected)?;
let needs_refresh = {
let state = token_state_arc.read().map_err(|_| Error::LockError)?;
state.is_expired(DEFAULT_TOKEN_REFRESH_BUFFER_SECONDS)?
};
if needs_refresh {
match self.refresh_token().await {
Ok(()) => {}
Err(Error::NoRefreshToken) => {
self.reauthenticate().await?;
}
Err(e) => return Err(e),
}
}
let state = token_state_arc.read().map_err(|_| Error::LockError)?;
Ok(state.access_token().to_string())
}
}
#[derive(Default)]
pub struct Builder {
credentials_from: Option<CredentialsFrom>,
auth_flow: Option<AuthFlow>,
}
impl Builder {
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub fn new() -> Self {
Self::default()
}
pub fn credentials_path(mut self, path: PathBuf) -> Self {
self.credentials_from = Some(CredentialsFrom::Path(path));
self
}
pub fn credentials(mut self, credentials: Credentials) -> Self {
self.credentials_from = Some(CredentialsFrom::Value(credentials));
self
}
pub fn auth_flow(mut self, auth_flow: AuthFlow) -> Self {
self.auth_flow = Some(auth_flow);
self
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub fn build(self) -> Result<Client, Error> {
Ok(Client {
credentials_from: self.credentials_from.ok_or(Error::MissingCredentials)?,
auth_flow: self.auth_flow.unwrap_or_default(),
token_state: None,
instance_url: None,
tenant_id: None,
})
}
}
#[cfg(test)]
mod tests {
use std::env;
use super::*;
#[test]
fn test_build_without_credentials() {
let client = Builder::new().build();
assert!(matches!(client, Err(Error::MissingCredentials)));
}
#[test]
fn test_build_with_credentials() {
let mut path = env::temp_dir();
path.push(format!("credentials_{}.json", std::process::id()));
let client = Builder::new().credentials_path(path).build();
assert!(client.is_ok());
}
#[tokio::test]
async fn test_connect_with_invalid_credentials() {
let creds: &str = r#"{"client_id":"client_id"}"#;
let mut path = env::temp_dir();
path.push(format!("invalid_credentials_{}.json", std::process::id()));
let _ = fs::write(path.clone(), creds);
let client = Builder::new()
.credentials_path(path.clone())
.build()
.unwrap();
let result = client.connect().await;
let _ = fs::remove_file(path);
assert!(matches!(result, Err(Error::ParseCredentials { .. })));
}
#[tokio::test]
async fn test_connect_with_invalid_url() {
let creds: &str = r#"
{
"client_id": "some_client_id",
"client_secret": "some_client_secret",
"instance_url": "mydomain.salesforce.com",
"tenant_id": "some_tenant_id"
}"#;
let mut path = env::temp_dir();
path.push(format!(
"invalid_url_credentials_{}.json",
std::process::id()
));
let _ = fs::write(path.clone(), creds);
let client = Builder::new()
.credentials_path(path.clone())
.build()
.unwrap();
let result = client.connect().await;
let _ = fs::remove_file(path);
assert!(matches!(result, Err(Error::TokenExchange { .. })));
}
#[tokio::test]
async fn test_connect_with_missing_file() {
let mut path = env::temp_dir();
path.push(format!("nonexistent_{}.json", std::process::id()));
let client = Builder::new().credentials_path(path).build().unwrap();
let result = client.connect().await;
assert!(matches!(result, Err(Error::ReadCredentials { .. })));
}
#[tokio::test]
async fn test_connect_with_valid_json_but_invalid_credentials() {
let creds: &str = r#"
{
"client_id": "test_client_id",
"client_secret": "test_client_secret",
"instance_url": "https://test.salesforce.com",
"tenant_id": "test_tenant_id"
}"#;
let mut path = env::temp_dir();
path.push(format!(
"valid_json_invalid_creds_{}.json",
std::process::id()
));
let _ = fs::write(path.clone(), creds);
let client = Builder::new()
.credentials_path(path.clone())
.build()
.unwrap();
let result = client.connect().await;
let _ = fs::remove_file(path);
assert!(matches!(result, Err(Error::OAuth2RequestFailed { .. })));
}
#[tokio::test]
async fn test_connect_with_direct_credentials() {
let creds = Credentials {
client_id: "test_client_id".to_string(),
client_secret: Some("test_client_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant_id".to_string(),
};
let client = Builder::new().credentials(creds).build().unwrap();
let result = client.connect().await;
assert!(matches!(result, Err(Error::OAuth2RequestFailed { .. })));
}
#[tokio::test]
async fn test_client_credentials_flow_missing_secret() {
let creds = Credentials {
client_id: "test_client_id".to_string(),
client_secret: None,
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant_id".to_string(),
};
let client = Builder::new()
.credentials(creds)
.auth_flow(AuthFlow::ClientCredentials)
.build()
.unwrap();
let result = client.connect().await;
assert!(matches!(result, Err(Error::MissingClientSecret)));
}
#[tokio::test]
async fn test_username_password_flow_missing_username() {
let creds = Credentials {
client_id: "test_client_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: Some("test_password".to_string()),
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant_id".to_string(),
};
let client = Builder::new()
.credentials(creds)
.auth_flow(AuthFlow::UsernamePassword)
.build()
.unwrap();
let result = client.connect().await;
assert!(matches!(result, Err(Error::MissingUsername)));
}
#[tokio::test]
async fn test_username_password_flow_missing_password() {
let creds = Credentials {
client_id: "test_client_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: Some("test_user".to_string()),
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant_id".to_string(),
};
let client = Builder::new()
.credentials(creds)
.auth_flow(AuthFlow::UsernamePassword)
.build()
.unwrap();
let result = client.connect().await;
assert!(matches!(result, Err(Error::MissingPassword)));
}
#[tokio::test]
async fn test_username_password_flow_with_valid_fields() {
let creds = Credentials {
client_id: "test_client_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: Some("test_user".to_string()),
password: Some("test_password".to_string()),
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant_id".to_string(),
};
let client = Builder::new()
.credentials(creds)
.auth_flow(AuthFlow::UsernamePassword)
.build()
.unwrap();
let result = client.connect().await;
assert!(matches!(result, Err(Error::OAuth2RequestFailed { .. })));
}
#[tokio::test]
async fn test_username_password_flow_missing_client_secret() {
let creds = Credentials {
client_id: "test_client_id".to_string(),
client_secret: None,
username: Some("test_user".to_string()),
password: Some("test_password".to_string()),
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant_id".to_string(),
};
let client = Builder::new()
.credentials(creds)
.auth_flow(AuthFlow::UsernamePassword)
.build()
.unwrap();
let result = client.connect().await;
assert!(matches!(result, Err(Error::MissingClientSecret)));
}
#[test]
fn test_token_state_expiry_check_with_buffer() {
use oauth2::basic::BasicTokenResponse;
use oauth2::AccessToken;
use std::time::Duration;
let mut token_response = BasicTokenResponse::new(
AccessToken::new("test_token".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
token_response.set_expires_in(Some(&Duration::from_secs(1)));
let token_state = TokenState::new(token_response).unwrap();
let is_expired = token_state.is_expired(300);
assert!(is_expired.is_ok());
assert!(is_expired.unwrap());
let is_expired = token_state.is_expired(0);
assert!(is_expired.is_ok());
assert!(!is_expired.unwrap());
}
#[test]
fn test_current_access_token_without_connection() {
let client = Builder::new()
.credentials(Credentials {
client_id: "test_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
let result = client.current_access_token();
assert!(matches!(result, Err(Error::NotConnected)));
}
#[tokio::test]
async fn test_access_token_without_connection() {
let client = Builder::new()
.credentials(Credentials {
client_id: "test_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
let result = client.access_token().await;
assert!(matches!(result, Err(Error::NotConnected)));
}
#[tokio::test]
async fn test_access_token_reauthenticates_when_expired_and_no_refresh_token() {
use oauth2::basic::BasicTokenResponse;
use oauth2::AccessToken;
use std::time::Duration;
let mut client = Builder::new()
.credentials(Credentials {
client_id: "test_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
let mut token_response = BasicTokenResponse::new(
AccessToken::new("expired_token".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
token_response.set_expires_in(Some(&Duration::from_secs(0)));
let token_state = TokenState::new(token_response).unwrap();
client.token_state = Some(Arc::new(RwLock::new(token_state)));
client.instance_url = Some("https://test.salesforce.com".to_string());
client.tenant_id = Some("test_tenant".to_string());
let result = client.access_token().await;
assert!(
matches!(result, Err(Error::OAuth2RequestFailed { .. })),
"expected OAuth2RequestFailed from reauthenticate fallback, got: {result:?}"
);
}
#[tokio::test]
async fn test_reconnect_without_connection() {
let mut client = Builder::new()
.credentials(Credentials {
client_id: "test_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
let result = client.reconnect().await;
assert!(matches!(result, Err(Error::OAuth2RequestFailed { .. })));
}
}