use chrono::{Duration, Utc};
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
use rsa::RsaPrivateKey;
use serde::Serialize;
use crate::error::{SalesforceAuthError, SalesforceAuthResult};
#[derive(Debug, Serialize)]
struct JwtClaims {
iss: String,
sub: String,
aud: String,
iat: i64,
exp: i64,
}
pub(crate) fn build_jwt_assertion(
client_id: &str,
username: &str,
login_url: &url::Url,
private_key: &RsaPrivateKey,
) -> SalesforceAuthResult<String> {
let now = Utc::now();
let audience = format!(
"{}://{}",
login_url.scheme(),
login_url.host_str().unwrap_or("login.salesforce.com")
);
let claims = JwtClaims {
iss: client_id.to_string(),
sub: username.to_string(),
aud: audience,
iat: now.timestamp(),
exp: (now + Duration::minutes(2)).timestamp(),
};
let private_key_pem = private_key_to_pem(private_key)?;
let encoding_key = EncodingKey::from_rsa_pem(private_key_pem.as_bytes())
.map_err(|e| SalesforceAuthError::Jwt(format!("failed to create encoding key: {e}")))?;
let header = Header::new(Algorithm::RS256);
let token = encode(&header, &claims, &encoding_key)?;
tracing::debug!(
iss = %client_id,
sub = %username,
aud = %claims.aud,
"JWT assertion created"
);
Ok(token)
}
fn private_key_to_pem(key: &RsaPrivateKey) -> SalesforceAuthResult<String> {
use rsa::pkcs8::EncodePrivateKey;
key.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
.map(|pem| pem.to_string())
.map_err(|e| SalesforceAuthError::PrivateKey(format!("failed to encode private key: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_test_key() -> RsaPrivateKey {
let mut rng = rsa::rand_core::OsRng;
RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate RSA key")
}
#[test]
fn test_build_jwt_assertion() {
let private_key = generate_test_key();
let login_url = url::Url::parse("https://login.salesforce.com").unwrap();
let jwt = build_jwt_assertion(
"test-client-id",
"user@example.com",
&login_url,
&private_key,
)
.unwrap();
let parts: Vec<&str> = jwt.split('.').collect();
assert_eq!(parts.len(), 3);
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
let header_json = URL_SAFE_NO_PAD.decode(parts[0]).unwrap();
let header: serde_json::Value = serde_json::from_slice(&header_json).unwrap();
assert_eq!(header["alg"], "RS256");
assert_eq!(header["typ"], "JWT");
}
}