use crate::auth::token::{AccessToken, TokenResponse};
use crate::error::{ForceError, HttpError, Result};
use async_trait::async_trait;
use secrecy::{ExposeSecret, SecretString};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Clone)]
pub struct UsernamePassword {
client_id: String,
client_secret: SecretString,
username: String,
password: SecretString,
security_token: SecretString,
token_url: String,
client: reqwest::Client,
refresh_token: Arc<RwLock<Option<String>>>,
}
impl std::fmt::Debug for UsernamePassword {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UsernamePassword")
.field("client_id", &self.client_id)
.field("client_secret", &"[REDACTED]")
.field("username", &self.username)
.field("password", &"[REDACTED]")
.field("security_token", &"[REDACTED]")
.field("token_url", &self.token_url)
.finish()
}
}
impl UsernamePassword {
pub fn new(
client_id: impl Into<String>,
client_secret: impl Into<String>,
username: impl Into<String>,
password: impl Into<String>,
security_token: impl Into<String>,
token_url: impl Into<String>,
) -> Self {
Self {
client_id: client_id.into(),
client_secret: SecretString::new(client_secret.into().into()),
username: username.into(),
password: SecretString::new(password.into().into()),
security_token: SecretString::new(security_token.into().into()),
token_url: token_url.into(),
client: crate::auth::default_auth_http_client(),
refresh_token: Arc::new(RwLock::new(None)),
}
}
#[must_use]
pub fn with_client(mut self, client: reqwest::Client) -> Self {
self.client = client;
self
}
pub fn new_production(
client_id: impl Into<String>,
client_secret: impl Into<String>,
username: impl Into<String>,
password: impl Into<String>,
security_token: impl Into<String>,
) -> Self {
Self::new(
client_id,
client_secret,
username,
password,
security_token,
crate::auth::PRODUCTION_TOKEN_URL,
)
}
pub fn new_sandbox(
client_id: impl Into<String>,
client_secret: impl Into<String>,
username: impl Into<String>,
password: impl Into<String>,
security_token: impl Into<String>,
) -> Self {
Self::new(
client_id,
client_secret,
username,
password,
security_token,
crate::auth::SANDBOX_TOKEN_URL,
)
}
fn password_with_token(&self) -> String {
format!(
"{}{}",
self.password.expose_secret(),
self.security_token.expose_secret()
)
}
async fn send_token_request(&self, params: &[(&str, &str)]) -> Result<TokenResponse> {
let response = self
.client
.post(&self.token_url)
.form(params)
.send()
.await
.map_err(|e| ForceError::Http(HttpError::RequestFailed(e)))?;
if !response.status().is_success() {
return Err(crate::auth::handle_oauth_error(response, None).await);
}
let bytes = crate::http::error::read_capped_body_bytes(response, 1024 * 1024).await?;
serde_json::from_slice::<TokenResponse>(&bytes)
.map_err(crate::error::SerializationError::from)
.map_err(Into::into)
}
async fn store_refresh_token(&self, response: &TokenResponse) {
if let Some(ref rt) = response.refresh_token {
let mut stored = self.refresh_token.write().await;
*stored = Some(rt.clone());
}
}
}
#[async_trait]
impl crate::auth::authenticator::Authenticator for UsernamePassword {
async fn authenticate(&self) -> Result<AccessToken> {
let password_with_token = self.password_with_token();
let params = [
("grant_type", "password"),
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.expose_secret()),
("username", self.username.as_str()),
("password", password_with_token.as_str()),
];
let token_response = self.send_token_request(¶ms).await?;
self.store_refresh_token(&token_response).await;
Ok(AccessToken::from_response(token_response))
}
async fn refresh(&self) -> Result<AccessToken> {
let stored_rt = self.refresh_token.read().await.clone();
if let Some(rt) = stored_rt {
let params = [
("grant_type", "refresh_token"),
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.expose_secret()),
("refresh_token", rt.as_str()),
];
if let Ok(token_response) = self.send_token_request(¶ms).await {
self.store_refresh_token(&token_response).await;
return Ok(AccessToken::from_response(token_response));
}
let mut stored = self.refresh_token.write().await;
if stored.as_deref() == Some(rt.as_str()) {
*stored = None;
}
}
self.authenticate().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::Authenticator;
use crate::error::AuthenticationError;
use crate::test_support::Must;
use wiremock::matchers::{body_string_contains, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn sample_token_response() -> serde_json::Value {
serde_json::json!({
"access_token": "00Dxx0000001gPL!test_token",
"instance_url": "https://test.my.salesforce.com",
"token_type": "Bearer",
"issued_at": "1704067200000",
"signature": "testSignature==",
"refresh_token": "fake_refresh_token_for_testing"
})
}
fn sample_token_response_no_refresh() -> serde_json::Value {
serde_json::json!({
"access_token": "00Dxx0000001gPL!test_token",
"instance_url": "https://test.my.salesforce.com",
"token_type": "Bearer",
"issued_at": "1704067200000",
"signature": "testSignature=="
})
}
#[test]
fn test_new_production() {
let auth = UsernamePassword::new_production(
"client_id",
"client_secret",
"user@example.com",
"password",
"secToken",
);
assert_eq!(auth.client_id, "client_id");
assert_eq!(auth.username, "user@example.com");
assert_eq!(
auth.token_url,
"https://login.salesforce.com/services/oauth2/token"
);
}
#[test]
fn test_new_sandbox() {
let auth = UsernamePassword::new_sandbox(
"client_id",
"client_secret",
"user@example.com",
"password",
"secToken",
);
assert_eq!(
auth.token_url,
"https://test.salesforce.com/services/oauth2/token"
);
}
#[test]
fn test_password_with_token_concatenation() {
let auth = UsernamePassword::new_production(
"client_id",
"client_secret",
"user@example.com",
"myPassword",
"myToken123",
);
assert_eq!(auth.password_with_token(), "myPasswordmyToken123");
}
#[test]
fn test_empty_security_token() {
let auth = UsernamePassword::new_production(
"client_id",
"client_secret",
"user@example.com",
"myPassword",
"",
);
assert_eq!(auth.password_with_token(), "myPassword");
}
#[test]
fn test_debug_redacts_secrets() {
let auth = UsernamePassword::new_production(
"client_id",
"super_secret",
"user@example.com",
"myPassword",
"myToken123",
);
let debug = format!("{auth:?}");
assert!(debug.contains("client_id"));
assert!(debug.contains("user@example.com"));
assert!(!debug.contains("super_secret"));
assert!(!debug.contains("myPassword"));
assert!(!debug.contains("myToken123"));
assert!(debug.contains("[REDACTED]"));
}
#[tokio::test]
async fn test_authenticate_success() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.and(body_string_contains("grant_type=password"))
.and(body_string_contains("username=user%40example.com"))
.respond_with(ResponseTemplate::new(200).set_body_json(sample_token_response()))
.mount(&server)
.await;
let auth = UsernamePassword::new(
"client_id",
"client_secret",
"user@example.com",
"myPassword",
"secToken",
format!("{}/services/oauth2/token", server.uri()),
);
let token = auth.authenticate().await.must();
assert_eq!(token.as_str(), "00Dxx0000001gPL!test_token");
assert_eq!(token.instance_url(), "https://test.my.salesforce.com");
}
#[tokio::test]
async fn test_authenticate_sends_password_with_security_token() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.and(body_string_contains("password=myPasswordsecToken"))
.respond_with(ResponseTemplate::new(200).set_body_json(sample_token_response()))
.mount(&server)
.await;
let auth = UsernamePassword::new(
"client_id",
"client_secret",
"user@example.com",
"myPassword",
"secToken",
format!("{}/services/oauth2/token", server.uri()),
);
let token = auth.authenticate().await.must();
assert_eq!(token.as_str(), "00Dxx0000001gPL!test_token");
}
#[tokio::test]
async fn test_authenticate_invalid_credentials() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
"error": "invalid_grant",
"error_description": "authentication failure"
})))
.mount(&server)
.await;
let auth = UsernamePassword::new(
"bad_id",
"bad_secret",
"bad@example.com",
"wrong",
"",
format!("{}/services/oauth2/token", server.uri()),
);
let result = auth.authenticate().await;
if let Err(ForceError::Authentication(AuthenticationError::TokenRequestFailed(msg))) =
result
{
assert!(msg.contains("invalid_grant"));
assert!(msg.contains("authentication failure"));
} else {
panic!("Expected TokenRequestFailed error");
}
}
#[tokio::test]
async fn test_authenticate_server_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
.mount(&server)
.await;
let auth = UsernamePassword::new(
"client_id",
"client_secret",
"user@example.com",
"password",
"",
format!("{}/services/oauth2/token", server.uri()),
);
let result = auth.authenticate().await;
assert!(matches!(result, Err(ForceError::Http(_))));
}
#[tokio::test]
async fn test_authenticate_stores_refresh_token() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(sample_token_response()))
.mount(&server)
.await;
let auth = UsernamePassword::new(
"client_id",
"client_secret",
"user@example.com",
"password",
"",
format!("{}/services/oauth2/token", server.uri()),
);
let _token = auth.authenticate().await.must();
let stored = auth.refresh_token.read().await;
assert_eq!(stored.as_deref(), Some("fake_refresh_token_for_testing"));
}
#[tokio::test]
async fn test_refresh_uses_refresh_token() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.and(body_string_contains("grant_type=password"))
.respond_with(ResponseTemplate::new(200).set_body_json(sample_token_response()))
.expect(1)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.and(body_string_contains("grant_type=refresh_token"))
.and(body_string_contains(
"refresh_token=fake_refresh_token_for_testing",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "refreshed_access_token",
"instance_url": "https://test.my.salesforce.com",
"token_type": "Bearer",
"issued_at": "1704070800000",
"signature": "newSig==",
"refresh_token": "new_refresh_token"
})))
.expect(1)
.mount(&server)
.await;
let auth = UsernamePassword::new(
"client_id",
"client_secret",
"user@example.com",
"password",
"",
format!("{}/services/oauth2/token", server.uri()),
);
let _token1 = auth.authenticate().await.must();
let token2 = auth.refresh().await.must();
assert_eq!(token2.as_str(), "refreshed_access_token");
}
#[tokio::test]
async fn test_refresh_token_rotation() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(body_string_contains("grant_type=password"))
.respond_with(ResponseTemplate::new(200).set_body_json(sample_token_response()))
.expect(1)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(body_string_contains("grant_type=refresh_token"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "new_access_token",
"instance_url": "https://test.my.salesforce.com",
"token_type": "Bearer",
"issued_at": "1704070800000",
"signature": "sig==",
"refresh_token": "rotated_refresh_token"
})))
.expect(1)
.mount(&server)
.await;
let auth = UsernamePassword::new(
"client_id",
"client_secret",
"user@example.com",
"password",
"",
format!("{}/services/oauth2/token", server.uri()),
);
let _token1 = auth.authenticate().await.must();
let _token2 = auth.refresh().await.must();
let stored = auth.refresh_token.read().await;
assert_eq!(stored.as_deref(), Some("rotated_refresh_token"));
}
#[tokio::test]
async fn test_refresh_fallback_on_revoked_token() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(body_string_contains("grant_type=password"))
.respond_with(ResponseTemplate::new(200).set_body_json(sample_token_response()))
.expect(2)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(body_string_contains("grant_type=refresh_token"))
.respond_with(ResponseTemplate::new(401).set_body_json(serde_json::json!({
"error": "invalid_grant",
"error_description": "expired authorization code"
})))
.expect(1)
.mount(&server)
.await;
let auth = UsernamePassword::new(
"client_id",
"client_secret",
"user@example.com",
"password",
"",
format!("{}/services/oauth2/token", server.uri()),
);
let _token1 = auth.authenticate().await.must();
let token2 = auth.refresh().await.must();
assert_eq!(token2.as_str(), "00Dxx0000001gPL!test_token");
}
#[tokio::test]
async fn test_refresh_without_stored_token_reauthenticates() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(body_string_contains("grant_type=password"))
.respond_with(
ResponseTemplate::new(200).set_body_json(sample_token_response_no_refresh()),
)
.expect(2) .mount(&server)
.await;
let auth = UsernamePassword::new(
"client_id",
"client_secret",
"user@example.com",
"password",
"",
format!("{}/services/oauth2/token", server.uri()),
);
let _token1 = auth.authenticate().await.must();
let token2 = auth.refresh().await.must();
assert_eq!(token2.as_str(), "00Dxx0000001gPL!test_token");
}
#[tokio::test]
async fn test_refresh_clears_token_on_failure() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(body_string_contains("grant_type=password"))
.respond_with(ResponseTemplate::new(200).set_body_json(sample_token_response()))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(body_string_contains("grant_type=refresh_token"))
.respond_with(ResponseTemplate::new(401).set_body_json(serde_json::json!({
"error": "invalid_grant",
"error_description": "expired"
})))
.mount(&server)
.await;
let auth = UsernamePassword::new(
"client_id",
"client_secret",
"user@example.com",
"password",
"",
format!("{}/services/oauth2/token", server.uri()),
);
let _token1 = auth.authenticate().await.must();
assert!(auth.refresh_token.read().await.is_some());
let _token2 = auth.refresh().await.must();
}
#[tokio::test]
async fn test_with_client_custom_http_client() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(sample_token_response()))
.mount(&server)
.await;
let custom_client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.must();
let auth = UsernamePassword::new(
"client_id",
"client_secret",
"user@example.com",
"password",
"",
format!("{}/services/oauth2/token", server.uri()),
)
.with_client(custom_client);
let token = auth.authenticate().await.must();
assert_eq!(token.as_str(), "00Dxx0000001gPL!test_token");
}
#[tokio::test]
async fn test_refresh_does_not_clear_newer_token_on_failure() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(body_string_contains("grant_type=password"))
.respond_with(ResponseTemplate::new(200).set_body_json(sample_token_response()))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(body_string_contains("grant_type=refresh_token"))
.respond_with(ResponseTemplate::new(401).set_body_json(serde_json::json!({
"error": "invalid_grant",
"error_description": "expired"
})))
.mount(&server)
.await;
let auth = UsernamePassword::new(
"client_id",
"client_secret",
"user@example.com",
"password",
"",
format!("{}/services/oauth2/token", server.uri()),
);
let _token1 = auth.authenticate().await.must();
assert_eq!(
auth.refresh_token.read().await.as_deref(),
Some("fake_refresh_token_for_testing")
);
{
let mut stored = auth.refresh_token.write().await;
*stored = Some("newer_token_from_another_thread".to_string());
}
let _token2 = auth.refresh().await.must();
}
#[tokio::test]
async fn test_authenticate_network_error() {
let auth = UsernamePassword::new(
"client_id",
"client_secret",
"user@example.com",
"password",
"",
"http://invalid.invalid.localhost:99999/oauth2/token",
);
let result = auth.authenticate().await;
assert!(matches!(result, Err(ForceError::Http(_))));
}
}