use chrono::{Duration, Utc};
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
use serde::{Deserialize, Serialize};
use tracing::debug;
use crate::credentials::SalesforceCredentials;
use crate::error::{Error, ErrorKind, Result};
#[derive(Debug, Clone)]
pub struct JwtAuth {
consumer_key: String,
username: String,
private_key: Vec<u8>,
expiration: Duration,
}
impl JwtAuth {
pub fn new(
consumer_key: impl Into<String>,
username: impl Into<String>,
private_key: impl Into<Vec<u8>>,
) -> Self {
Self {
consumer_key: consumer_key.into(),
username: username.into(),
private_key: private_key.into(),
expiration: Duration::minutes(3),
}
}
pub fn from_key_file(
consumer_key: impl Into<String>,
username: impl Into<String>,
key_path: impl AsRef<std::path::Path>,
) -> Result<Self> {
let private_key = std::fs::read(key_path.as_ref())?;
Ok(Self::new(consumer_key, username, private_key))
}
pub fn with_expiration(mut self, expiration: Duration) -> Self {
self.expiration = expiration;
self
}
fn generate_assertion(&self, audience: &str) -> Result<String> {
let now = Utc::now();
let exp = now + self.expiration;
let claims = JwtClaims {
iss: self.consumer_key.clone(),
sub: self.username.clone(),
aud: audience.to_string(),
exp: exp.timestamp(),
iat: now.timestamp(),
};
let header = Header::new(Algorithm::RS256);
let key = EncodingKey::from_rsa_pem(&self.private_key)?;
let token = encode(&header, &claims, &key)?;
Ok(token)
}
pub async fn authenticate(&self, login_url: &str) -> Result<SalesforceCredentials> {
let assertion = self.generate_assertion(login_url)?;
debug!(login_url, "Authenticating with JWT Bearer flow");
let client = reqwest::Client::new();
let form_data = [
("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
("assertion", &assertion),
];
let body = serde_urlencoded::to_string(form_data)?;
let response = client
.post(format!("{}/services/oauth2/token", login_url))
.header("Content-Type", "application/x-www-form-urlencoded")
.body(body)
.send()
.await?;
if !response.status().is_success() {
let error: OAuthErrorResponse = response.json().await?;
return Err(Error::new(ErrorKind::OAuth {
error: error.error,
description: error.error_description,
}));
}
let token_response: JwtTokenResponse = response.json().await?;
Ok(SalesforceCredentials::new(
token_response.instance_url,
token_response.access_token,
busbar_sf_client::DEFAULT_API_VERSION,
))
}
pub async fn authenticate_production(&self) -> Result<SalesforceCredentials> {
self.authenticate(crate::PRODUCTION_LOGIN_URL).await
}
pub async fn authenticate_sandbox(&self) -> Result<SalesforceCredentials> {
self.authenticate(crate::SANDBOX_LOGIN_URL).await
}
}
#[derive(Debug, Serialize)]
struct JwtClaims {
iss: String,
sub: String,
aud: String,
exp: i64,
iat: i64,
}
#[derive(Debug, Deserialize)]
struct JwtTokenResponse {
access_token: String,
instance_url: String,
#[serde(default)]
#[allow(dead_code)]
token_type: String,
#[serde(default)]
#[allow(dead_code)]
scope: Option<String>,
#[serde(default)]
#[allow(dead_code)]
id: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OAuthErrorResponse {
error: String,
error_description: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jwt_auth_creation() {
let auth = JwtAuth::new(
"consumer_key",
"user@example.com",
b"fake_private_key".to_vec(),
);
assert_eq!(auth.consumer_key, "consumer_key");
assert_eq!(auth.username, "user@example.com");
}
#[test]
fn test_jwt_auth_with_expiration() {
let auth =
JwtAuth::new("key", "user", b"key".to_vec()).with_expiration(Duration::minutes(5));
assert_eq!(auth.expiration, Duration::minutes(5));
}
}