1use chrono::{DateTime, Duration, Utc};
16use serde::Deserialize;
17use url::Url;
18
19use crate::error::{SalesforceAuthError, SalesforceAuthResult};
20
21const DC_JWT_VALIDITY_BUFFER_SECS: i64 = 300;
27
28#[derive(Debug, Deserialize)]
32pub struct OAuthTokenResponse {
33 pub access_token: String,
35
36 pub instance_url: String,
38
39 #[serde(default)]
41 pub token_type: Option<String>,
42
43 #[serde(default)]
45 pub scope: Option<String>,
46
47 #[serde(default)]
49 pub issued_at: Option<String>,
50
51 #[serde(default)]
53 pub error: Option<String>,
54
55 #[serde(default)]
57 pub error_description: Option<String>,
58}
59
60impl OAuthTokenResponse {
61 pub fn check_error(&self) -> SalesforceAuthResult<()> {
63 if let (Some(code), Some(desc)) = (&self.error, &self.error_description) {
64 return Err(SalesforceAuthError::Authorization {
65 error_code: code.clone(),
66 error_description: desc.clone(),
67 });
68 }
69 if self.access_token.is_empty() {
70 return Err(SalesforceAuthError::TokenParse(
71 "missing access_token in OAuth Access Token response".to_string(),
72 ));
73 }
74 Ok(())
75 }
76}
77
78#[derive(Debug, Clone)]
83pub struct OAuthToken {
84 pub token: String,
86 pub instance_url: Url,
88 pub obtained_at: DateTime<Utc>,
90 pub expires_at: DateTime<Utc>,
93}
94
95const OAUTH_ACCESS_TOKEN_DEFAULT_LIFETIME_SECS: i64 = 7199;
101
102impl OAuthToken {
103 pub fn from_response(response: OAuthTokenResponse) -> SalesforceAuthResult<Self> {
113 response.check_error()?;
114
115 let instance_url = Url::parse(&response.instance_url)
116 .map_err(|e| SalesforceAuthError::TokenParse(format!("invalid instance_url: {e}")))?;
117
118 let now = Utc::now();
119 let expires_at = now + Duration::seconds(OAUTH_ACCESS_TOKEN_DEFAULT_LIFETIME_SECS);
120
121 Ok(OAuthToken {
122 token: response.access_token,
123 instance_url,
124 obtained_at: now,
125 expires_at,
126 })
127 }
128
129 #[must_use]
131 pub fn bearer_token(&self) -> String {
132 format!("Bearer {}", self.token)
133 }
134
135 #[must_use]
138 pub fn is_likely_valid(&self) -> bool {
139 Utc::now() < self.expires_at
140 }
141}
142
143#[derive(Debug, Deserialize)]
147pub struct DataCloudTokenResponse {
148 pub access_token: String,
150
151 pub instance_url: String,
153
154 #[serde(default)]
156 pub token_type: Option<String>,
157
158 #[serde(default)]
160 pub expires_in: Option<i64>,
161
162 #[serde(default)]
164 pub error: Option<String>,
165
166 #[serde(default)]
168 pub error_description: Option<String>,
169}
170
171impl DataCloudTokenResponse {
172 pub fn check_error(&self) -> SalesforceAuthResult<()> {
174 if let (Some(code), Some(desc)) = (&self.error, &self.error_description) {
175 return Err(SalesforceAuthError::Authorization {
176 error_code: code.clone(),
177 error_description: desc.clone(),
178 });
179 }
180 if self.access_token.is_empty() {
181 return Err(SalesforceAuthError::TokenParse(
182 "missing access_token in DC JWT response".to_string(),
183 ));
184 }
185 Ok(())
186 }
187}
188
189#[derive(Debug, Clone)]
200pub struct DataCloudToken {
201 token_type: String,
203 token: String,
205 tenant_url: Url,
207 created_at: DateTime<Utc>,
209 expires_at: DateTime<Utc>,
211}
212
213impl DataCloudToken {
214 pub fn from_response(response: DataCloudTokenResponse) -> SalesforceAuthResult<Self> {
224 response.check_error()?;
225
226 let instance_url_with_scheme = if response.instance_url.starts_with("http://")
227 || response.instance_url.starts_with("https://")
228 {
229 response.instance_url.clone()
230 } else {
231 format!("https://{}", response.instance_url)
232 };
233
234 let tenant_url = Url::parse(&instance_url_with_scheme)
235 .map_err(|e| SalesforceAuthError::TokenParse(format!("invalid instance_url: {e}")))?;
236
237 let token_type = response.token_type.unwrap_or_else(|| "Bearer".to_string());
238
239 let now = Utc::now();
240 let expires_in_secs = response.expires_in.unwrap_or(1800);
242 let expires_at = now + Duration::seconds(expires_in_secs);
243
244 Ok(DataCloudToken {
245 token_type,
246 token: response.access_token,
247 tenant_url,
248 created_at: now,
249 expires_at,
250 })
251 }
252
253 #[must_use]
257 pub fn bearer_token(&self) -> String {
258 format!("{} {}", self.token_type, self.token)
259 }
260
261 #[must_use]
263 pub fn access_token(&self) -> &str {
264 &self.token
265 }
266
267 #[must_use]
269 pub fn token_type(&self) -> &str {
270 &self.token_type
271 }
272
273 #[must_use]
275 pub fn tenant_url(&self) -> &Url {
276 &self.tenant_url
277 }
278
279 #[must_use]
281 pub fn tenant_url_str(&self) -> &str {
282 self.tenant_url.as_str()
283 }
284
285 #[must_use]
287 pub fn created_at(&self) -> DateTime<Utc> {
288 self.created_at
289 }
290
291 #[must_use]
293 pub fn expires_at(&self) -> DateTime<Utc> {
294 self.expires_at
295 }
296
297 #[must_use]
299 pub fn age(&self) -> Duration {
300 Utc::now().signed_duration_since(self.created_at)
301 }
302
303 #[must_use]
305 pub fn remaining_lifetime(&self) -> Duration {
306 self.expires_at.signed_duration_since(Utc::now())
307 }
308
309 #[must_use]
315 pub fn is_valid(&self) -> bool {
316 self.expires_at > Utc::now() + Duration::seconds(DC_JWT_VALIDITY_BUFFER_SECS)
317 }
318
319 #[must_use]
321 pub fn is_expired(&self) -> bool {
322 self.expires_at <= Utc::now()
323 }
324
325 #[must_use]
340 pub fn needs_refresh(&self, threshold_secs: i64, max_age_secs: i64) -> bool {
341 let now = Utc::now();
342 let expiring = (self.expires_at - now).num_seconds() <= threshold_secs;
343 let too_old = (now - self.created_at).num_seconds() > max_age_secs;
344 expiring || too_old
345 }
346
347 pub fn tenant_id(&self) -> SalesforceAuthResult<String> {
362 let parts: Vec<&str> = self.token.split('.').collect();
363 if parts.len() != 3 {
364 return Err(SalesforceAuthError::TokenParse(
365 "invalid DC JWT format: expected 3 parts".to_string(),
366 ));
367 }
368
369 let payload_b64 = parts[1];
370 let payload_bytes = base64_url_decode(payload_b64)?;
371 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)?;
372
373 payload
374 .get("audienceTenantId")
375 .and_then(|v| v.as_str())
376 .map(std::string::ToString::to_string)
377 .ok_or_else(|| {
378 SalesforceAuthError::TokenParse(
379 "missing audienceTenantId in DC JWT payload".to_string(),
380 )
381 })
382 }
383
384 pub fn lakehouse_name(&self, dataspace: Option<&str>) -> SalesforceAuthResult<String> {
394 let tenant_id = self.tenant_id()?;
395 let dataspace_str = dataspace.unwrap_or("");
396 Ok(format!("lakehouse:{tenant_id};{dataspace_str}"))
397 }
398}
399
400fn base64_url_decode(input: &str) -> SalesforceAuthResult<Vec<u8>> {
402 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
403
404 let padded = match input.len() % 4 {
406 2 => format!("{input}=="),
407 3 => format!("{input}="),
408 _ => input.to_string(),
409 };
410
411 URL_SAFE_NO_PAD
412 .decode(padded.trim_end_matches('='))
413 .map_err(|e| SalesforceAuthError::TokenParse(format!("base64 decode error: {e}")))
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419
420 #[test]
421 fn test_oauth_access_token_response_error() {
422 let response = OAuthTokenResponse {
423 access_token: String::new(),
424 instance_url: String::new(),
425 token_type: None,
426 scope: None,
427 issued_at: None,
428 error: Some("invalid_grant".to_string()),
429 error_description: Some("authentication failure".to_string()),
430 };
431
432 let result = response.check_error();
433 assert!(result.is_err());
434 if let Err(SalesforceAuthError::Authorization { error_code, .. }) = result {
435 assert_eq!(error_code, "invalid_grant");
436 } else {
437 panic!("expected Authorization error");
438 }
439 }
440
441 #[test]
442 fn test_oauth_access_token_from_response() {
443 let response = OAuthTokenResponse {
444 access_token: "oauth_access_tok_123".to_string(),
445 instance_url: "https://na1.salesforce.com".to_string(),
446 token_type: Some("Bearer".to_string()),
447 scope: None,
448 issued_at: None,
449 error: None,
450 error_description: None,
451 };
452
453 let token = OAuthToken::from_response(response).unwrap();
454 assert_eq!(token.token, "oauth_access_tok_123");
455 assert_eq!(token.instance_url.as_str(), "https://na1.salesforce.com/");
456 assert!(token.is_likely_valid());
457 assert_eq!(token.bearer_token(), "Bearer oauth_access_tok_123");
458 }
459
460 #[test]
461 fn test_dc_jwt_validity() {
462 let response = DataCloudTokenResponse {
463 access_token: "test.token.here".to_string(),
464 instance_url: "https://tenant.salesforce.com".to_string(),
465 token_type: Some("Bearer".to_string()),
466 expires_in: Some(3600), error: None,
468 error_description: None,
469 };
470
471 let token = DataCloudToken::from_response(response).unwrap();
472 assert!(token.is_valid());
473 assert!(!token.is_expired());
474 assert_eq!(token.bearer_token(), "Bearer test.token.here");
475 assert!(token.age().num_seconds() < 2);
476 assert!(token.remaining_lifetime().num_seconds() > 3500);
477 }
478
479 #[test]
480 fn test_dc_jwt_needs_refresh_when_fresh() {
481 let response = DataCloudTokenResponse {
482 access_token: "fresh.dc.jwt".to_string(),
483 instance_url: "https://tenant.salesforce.com".to_string(),
484 token_type: Some("Bearer".to_string()),
485 expires_in: Some(7200),
486 error: None,
487 error_description: None,
488 };
489
490 let token = DataCloudToken::from_response(response).unwrap();
491 assert!(!token.needs_refresh(300, 900));
494 }
495
496 #[test]
497 fn test_dc_jwt_needs_refresh_near_expiry() {
498 let response = DataCloudTokenResponse {
499 access_token: "expiring.dc.jwt".to_string(),
500 instance_url: "https://tenant.salesforce.com".to_string(),
501 token_type: Some("Bearer".to_string()),
502 expires_in: Some(200), error: None,
504 error_description: None,
505 };
506
507 let token = DataCloudToken::from_response(response).unwrap();
508 assert!(token.needs_refresh(300, 900));
510 }
511
512 #[test]
513 fn test_dc_jwt_needs_refresh_too_old() {
514 let mut token = DataCloudToken::from_response(DataCloudTokenResponse {
516 access_token: "old.dc.jwt".to_string(),
517 instance_url: "https://tenant.salesforce.com".to_string(),
518 token_type: Some("Bearer".to_string()),
519 expires_in: Some(7200),
520 error: None,
521 error_description: None,
522 })
523 .unwrap();
524
525 token.created_at = Utc::now() - Duration::minutes(20);
527
528 assert!(token.needs_refresh(300, 900));
530 }
531
532 #[test]
533 fn test_dc_jwt_created_at_tracked() {
534 let before = Utc::now();
535 let response = DataCloudTokenResponse {
536 access_token: "dc.jwt.value".to_string(),
537 instance_url: "https://tenant.salesforce.com".to_string(),
538 token_type: Some("Bearer".to_string()),
539 expires_in: Some(3600),
540 error: None,
541 error_description: None,
542 };
543 let token = DataCloudToken::from_response(response).unwrap();
544 let after = Utc::now();
545
546 assert!(token.created_at() >= before);
547 assert!(token.created_at() <= after);
548 }
549
550 #[test]
551 fn test_dc_jwt_is_valid_uses_5min_buffer() {
552 let response = DataCloudTokenResponse {
555 access_token: "almost.expired.jwt".to_string(),
556 instance_url: "https://tenant.salesforce.com".to_string(),
557 token_type: Some("Bearer".to_string()),
558 expires_in: Some(240), error: None,
560 error_description: None,
561 };
562
563 let token = DataCloudToken::from_response(response).unwrap();
564 assert!(!token.is_valid());
565 assert!(!token.is_expired()); let response2 = DataCloudTokenResponse {
569 access_token: "still.valid.jwt".to_string(),
570 instance_url: "https://tenant.salesforce.com".to_string(),
571 token_type: Some("Bearer".to_string()),
572 expires_in: Some(360), error: None,
574 error_description: None,
575 };
576
577 let token2 = DataCloudToken::from_response(response2).unwrap();
578 assert!(token2.is_valid());
579 }
580}