#[cfg(feature = "jwt")]
use crate::auth::authenticator::Authenticator;
#[cfg(feature = "jwt")]
use crate::auth::token::{AccessToken, TokenResponse};
#[cfg(feature = "jwt")]
use crate::error::{AuthenticationError, ForceError, HttpError, Result};
#[cfg(feature = "jwt")]
use async_trait::async_trait;
#[cfg(feature = "jwt")]
use jsonwebtoken::{EncodingKey, Header, encode};
#[cfg(feature = "jwt")]
use serde::Serialize;
#[cfg(feature = "jwt")]
use std::time::{SystemTime, UNIX_EPOCH};
#[cfg(feature = "jwt")]
#[derive(Debug, Serialize)]
struct JwtClaims<'a> {
iss: &'a str,
sub: &'a str,
aud: &'a str,
exp: u64,
}
#[cfg(feature = "jwt")]
#[derive(Clone)]
pub struct JwtBearerFlow {
client_id: String,
username: String,
private_key: EncodingKey,
audience: String,
token_url: String,
http_client: reqwest::Client,
}
#[cfg(feature = "jwt")]
impl JwtBearerFlow {
pub fn new(
client_id: impl Into<String>,
username: impl Into<String>,
private_key_pem: impl Into<String>,
audience: impl Into<String>,
token_url: impl Into<String>,
) -> Result<Self> {
let private_key_pem_str = private_key_pem.into();
let private_key =
EncodingKey::from_rsa_pem(private_key_pem_str.as_bytes()).map_err(|e| {
ForceError::Authentication(AuthenticationError::InvalidJwtConfig(format!(
"Invalid RSA private key: {e}"
)))
})?;
Ok(Self {
client_id: client_id.into(),
username: username.into(),
private_key,
audience: audience.into(),
token_url: token_url.into(),
http_client: crate::auth::default_auth_http_client(),
})
}
#[must_use]
pub fn with_client(mut self, client: reqwest::Client) -> Self {
self.http_client = client;
self
}
pub fn new_production(
client_id: impl Into<String>,
username: impl Into<String>,
private_key_pem: impl Into<String>,
) -> Result<Self> {
Self::new(
client_id,
username,
private_key_pem,
crate::auth::PRODUCTION_LOGIN_URL,
crate::auth::PRODUCTION_TOKEN_URL,
)
}
pub fn new_sandbox(
client_id: impl Into<String>,
username: impl Into<String>,
private_key_pem: impl Into<String>,
) -> Result<Self> {
Self::new(
client_id,
username,
private_key_pem,
crate::auth::SANDBOX_LOGIN_URL,
crate::auth::SANDBOX_TOKEN_URL,
)
}
fn generate_jwt(&self) -> Result<String> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| {
ForceError::Authentication(AuthenticationError::JwtCreationFailed(format!(
"System time error: {e}"
)))
})?
.as_secs();
let claims = JwtClaims {
iss: &self.client_id,
sub: &self.username,
aud: &self.audience,
exp: now + 300, };
encode(
&Header::new(jsonwebtoken::Algorithm::RS256),
&claims,
&self.private_key,
)
.map_err(|e| {
ForceError::Authentication(AuthenticationError::JwtCreationFailed(format!(
"JWT encoding failed: {e}"
)))
})
}
}
#[cfg(feature = "jwt")]
impl std::fmt::Debug for JwtBearerFlow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtBearerFlow")
.field("client_id", &self.client_id)
.field("username", &self.username)
.field("private_key", &"[REDACTED]")
.field("audience", &self.audience)
.field("token_url", &self.token_url)
.finish()
}
}
#[cfg(feature = "jwt")]
#[async_trait]
impl Authenticator for JwtBearerFlow {
async fn authenticate(&self) -> Result<AccessToken> {
let assertion = self.generate_jwt()?;
let params = [
("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
("assertion", assertion.as_str()),
];
let response = self
.http_client
.post(&self.token_url)
.form(¶ms)
.send()
.await
.map_err(|e| ForceError::Http(HttpError::RequestFailed(e)))?;
if !response.status().is_success() {
return Err(crate::auth::handle_oauth_error(response, None).await);
}
let bytes = crate::http::error::read_capped_body_bytes(response, 1024 * 1024).await?;
let token_response = serde_json::from_slice::<TokenResponse>(&bytes)
.map_err(crate::error::SerializationError::from)?;
Ok(AccessToken::from_response(token_response))
}
async fn refresh(&self) -> Result<AccessToken> {
self.authenticate().await
}
}
#[cfg(all(test, feature = "jwt"))]
mod tests {
use super::*;
#[cfg(feature = "mock")]
use crate::auth::Authenticator;
use crate::test_support::Must;
const TEST_PRIVATE_KEY: &str = r"-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCk2XyObD+F8vk9
cTEBJXymtVVFKG/IlXvwY8aPDsjog4O3/iUtmEgQss0JbckPqEAso1GOXeCoC2Sw
mDN3PdTxWIN864BbKI3aR/jzcPPstTK6QpxGiwXyM0Q3Dyi7fmBOxtDRwYTPB/aP
mbsXXhW1ZJM8Fd5pDZSBN9MX41jipwm76MUHyViPyiu5tXcsiMtNDo7KpB7mbHOt
RIHDGTuloxaT7sUVyauNWCHLNlUr5OsIUyS2LrtKeSUu5sWzhE+hrpQF7cjyDYJD
HbRgjAQs2AStI0axhcbawrBdX0TkU9/RFKVzslPm2T1l1clCeip3640ONZsbgz5j
RqPSfP5vAgMBAAECggEASOJNcyzB8yencbZra63WzGAs6KxFrAIHb5O1lMd9JWwM
Hxuy9VM4PYXIKGyNMip54SJ+KvsvmiybYoaQbp58WQ6A6Ai5UdR+zyz2ES/18Mh2
Oqq7rGbIBLsM5GkD4c2wp/O4HJ06akyDgx79fInhADeM70pd8MWLzIvRfWTLhj2N
PjOkUSGUOEVP0p/2SRyzMzOcnmhOujDFMM1yXqmPDGdPQWfVkhJa2umQfqd3dMbe
j8kXtQDhcA9GWm1I0dqpDV9eJg0WZgTXMB1yb0tZjoKNHNx0KPQguwG35bFVuz+X
CYE81kdGb0xIsMFH2FIxQY0yLGAN+bMnnQ0pApLCFQKBgQDWBtGxTz7FMGhT9zWB
ZUJvFQHx9MiRVOMk+mRiZsZ2aZuvMN6kXpBsx0BWIebCJkhFpAfhgwntCRNgD9Gg
+gGNczXeIoYV+Pjo9QNROJaF4QvtZvJgLmqfAesCnNJg4ARCLe/4eFoCp1d89F8s
xok0gmkiu7l8jLDvQ8R/hYCaXQKBgQDFLbz6Hm/WdDuOJ9AyRXdLeRTRTDiw35XQ
gc8OW0tvfHo6KI+Yf2wOZSassm7xMm3iMXbvKPFkAFIzfdL0GGY+hXUtIZstTnpa
zmwSQ8BjHcUkHL7GA2JQ4Vees9JIJw5xBhtNmTVr5Nk0oNHHfgWGD5qrgtCh4l54
T21g/tZnOwKBgDyqEiW/4HrkDa4/E9tpaDs0KSj7yR3ogbmpf2qk1vwZUxeFMpZE
d4tdrs67LT06vKGArPsuuVGGkQdZdIG8W1RMo6gjAP6ZY3Qkfpw2/fNUppzT4T+B
6JbJZGOJL9hlps9bVfmHo3u9Ev9IBPIcFCfeDw7ZRuoWttAa1UeP/7PBAoGAESSy
44Q18Q1WCDwJ6/UCNDuoxbG81BP8cI54tCTX4C+QaPIR2g5qFK5SuH0jDDF4QExQ
rOaAZlNo0jVEXBiq+xCbaXschMnn9XExED13wqZZ95PQOmMc7y9IcPHtfHx40vbW
9N43ONRC1kKNOq0ISemdZwAOp6SI1ikBt4cwmPUCgYEAv2S66uf1hO832lYjPjwv
JGmrmoxGzif0L840eWGb4lJ2relNe6Z5o0Z2a15HVq1wuRh3k09sfnn6bkhPQda7
g1FZTFRZVk+gGC+cHE9oq10Gk/upIGx+4kx/vG5qIg5zBqpzRKCRh5D7+/+pp1uh
QcWLHR6ul3bFRWNhXoThNBQ=
-----END PRIVATE KEY-----";
#[test]
fn test_jwt_bearer_invalid_private_key() {
let result = JwtBearerFlow::new(
"test_client",
"user@example.com",
"invalid key",
"https://login.salesforce.com",
"https://login.salesforce.com/services/oauth2/token",
);
if let Err(ForceError::Authentication(AuthenticationError::InvalidJwtConfig(msg))) = result
{
assert!(msg.contains("Invalid RSA private key"));
} else {
panic!("Expected InvalidJwtConfig error");
}
}
#[test]
fn test_jwt_bearer_debug_redacts_private_key() {
let flow = JwtBearerFlow::new(
"test_client",
"user@example.com",
TEST_PRIVATE_KEY,
"https://login.salesforce.com",
"https://login.salesforce.com/services/oauth2/token",
)
.must();
let debug_str = format!("{flow:?}");
assert!(debug_str.contains("test_client"));
assert!(debug_str.contains("user@example.com"));
assert!(!debug_str.contains("BEGIN RSA PRIVATE KEY"));
assert!(debug_str.contains("[REDACTED]"));
}
#[test]
fn test_generate_jwt() {
let flow = JwtBearerFlow::new(
"test_client_id",
"test@example.com",
TEST_PRIVATE_KEY,
"https://test.salesforce.com",
"https://test.salesforce.com/services/oauth2/token",
)
.must();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.must()
.as_secs();
let jwt = flow.generate_jwt().must();
assert!(!jwt.is_empty());
let parts: Vec<&str> = jwt.split('.').collect();
assert_eq!(parts.len(), 3);
let payload_b64 = parts[1];
let payload_bytes = base64::Engine::decode(
&base64::engine::general_purpose::URL_SAFE_NO_PAD,
payload_b64,
)
.must();
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).must();
assert_eq!(payload["iss"], "test_client_id");
assert_eq!(payload["sub"], "test@example.com");
assert_eq!(payload["aud"], "https://test.salesforce.com");
let exp = payload["exp"].as_u64().must();
assert!(exp >= now + 300);
assert!(exp <= now + 305);
}
#[cfg(feature = "mock")]
#[tokio::test]
async fn test_jwt_bearer_authenticate_success() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
let token_response = serde_json::json!({
"access_token": "jwt_bearer_token",
"instance_url": "https://test.salesforce.com",
"id": "https://login.salesforce.com/id/00Dxx/005xx",
"token_type": "Bearer",
"issued_at": "1704067200000",
"signature": "sig=="
});
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(token_response))
.mount(&mock_server)
.await;
let flow = JwtBearerFlow::new(
"test_client",
"test@example.com",
TEST_PRIVATE_KEY,
"https://login.salesforce.com",
format!("{}/services/oauth2/token", mock_server.uri()),
)
.must();
let token = flow.authenticate().await.must();
assert_eq!(token.as_str(), "jwt_bearer_token");
assert_eq!(token.instance_url(), "https://test.salesforce.com");
}
#[cfg(feature = "mock")]
#[tokio::test]
async fn test_jwt_bearer_authenticate_oauth_error() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
let error_response = serde_json::json!({
"error": "invalid_grant",
"error_description": "user hasn't approved this consumer"
});
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(ResponseTemplate::new(400).set_body_json(error_response))
.mount(&mock_server)
.await;
let flow = JwtBearerFlow::new(
"test_client",
"test@example.com",
TEST_PRIVATE_KEY,
"https://login.salesforce.com",
format!("{}/services/oauth2/token", mock_server.uri()),
)
.must();
let result = flow.authenticate().await;
if let Err(ForceError::Authentication(AuthenticationError::TokenRequestFailed(msg))) =
result
{
assert!(msg.contains("invalid_grant"));
assert!(msg.contains("user hasn't approved"));
} else {
panic!("Expected TokenRequestFailed error");
}
}
#[cfg(feature = "mock")]
#[tokio::test]
async fn test_jwt_bearer_refresh_calls_authenticate() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
let token_response = serde_json::json!({
"access_token": "refreshed_jwt_token",
"instance_url": "https://test.salesforce.com",
"token_type": "Bearer",
"issued_at": "1704067200000",
"signature": "sig=="
});
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(token_response))
.expect(2) .mount(&mock_server)
.await;
let flow = JwtBearerFlow::new(
"test_client",
"test@example.com",
TEST_PRIVATE_KEY,
"https://login.salesforce.com",
format!("{}/services/oauth2/token", mock_server.uri()),
)
.must();
let _token1 = flow.authenticate().await.must();
let token2 = flow.refresh().await.must();
assert_eq!(token2.as_str(), "refreshed_jwt_token");
}
#[test]
fn test_jwt_bearer_new_sandbox() {
let flow =
JwtBearerFlow::new_sandbox("test_client", "user@example.com", TEST_PRIVATE_KEY).must();
assert_eq!(flow.audience, "https://test.salesforce.com");
assert_eq!(
flow.token_url,
"https://test.salesforce.com/services/oauth2/token"
);
}
#[cfg(feature = "mock")]
#[tokio::test]
async fn test_jwt_bearer_authenticate_error_truncation() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
let large_body = "A".repeat(2 * 1024 * 1024);
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(ResponseTemplate::new(400).set_body_string(large_body))
.mount(&mock_server)
.await;
let flow = JwtBearerFlow::new(
"test_client",
"test@example.com",
TEST_PRIVATE_KEY,
"https://login.salesforce.com",
format!("{}/services/oauth2/token", mock_server.uri()),
)
.must();
let result = flow.authenticate().await;
if let Err(ForceError::Http(HttpError::PayloadTooLarge { limit_bytes })) = result {
assert_eq!(limit_bytes, 1024 * 1024);
} else {
panic!("Expected HttpError::PayloadTooLarge, got {:?}", result);
}
}
#[test]
fn test_jwt_bearer_new_production() {
let flow =
JwtBearerFlow::new_production("test_client", "user@example.com", TEST_PRIVATE_KEY)
.must();
assert_eq!(flow.audience, "https://login.salesforce.com");
assert_eq!(
flow.token_url,
"https://login.salesforce.com/services/oauth2/token"
);
}
}