1use chrono::{DateTime, Duration, Utc};
16use serde::Deserialize;
17use url::Url;
18
19use crate::error::{SalesforceAuthError, SalesforceAuthResult};
20
21const DC_JWT_VALIDITY_BUFFER_SECS: i64 = 300;
27
28#[derive(Debug, Deserialize)]
32pub struct OAuthTokenResponse {
33 pub access_token: String,
35
36 pub instance_url: String,
38
39 #[serde(default)]
41 pub token_type: Option<String>,
42
43 #[serde(default)]
45 pub scope: Option<String>,
46
47 #[serde(default)]
49 pub issued_at: Option<String>,
50
51 #[serde(default)]
53 pub error: Option<String>,
54
55 #[serde(default)]
57 pub error_description: Option<String>,
58}
59
60impl OAuthTokenResponse {
61 pub fn check_error(&self) -> SalesforceAuthResult<()> {
63 if let (Some(code), Some(desc)) = (&self.error, &self.error_description) {
64 return Err(SalesforceAuthError::authorization(
65 code.clone(),
66 desc.clone(),
67 ));
68 }
69 if self.access_token.is_empty() {
70 return Err(SalesforceAuthError::token_parse(
71 "missing access_token in OAuth Access Token response",
72 ));
73 }
74 Ok(())
75 }
76}
77
78#[derive(Debug, Clone)]
83pub struct OAuthToken {
84 pub token: String,
86 pub instance_url: Url,
88 pub obtained_at: DateTime<Utc>,
90 pub expires_at: DateTime<Utc>,
93}
94
95const OAUTH_ACCESS_TOKEN_DEFAULT_LIFETIME_SECS: i64 = 7199;
101
102impl OAuthToken {
103 pub fn from_response(response: OAuthTokenResponse) -> SalesforceAuthResult<Self> {
113 response.check_error()?;
114
115 let instance_url = Url::parse(&response.instance_url)
116 .map_err(|e| SalesforceAuthError::token_parse(format!("invalid instance_url: {e}")))?;
117
118 let now = Utc::now();
119 let expires_at = now + Duration::seconds(OAUTH_ACCESS_TOKEN_DEFAULT_LIFETIME_SECS);
120
121 Ok(OAuthToken {
122 token: response.access_token,
123 instance_url,
124 obtained_at: now,
125 expires_at,
126 })
127 }
128
129 #[must_use]
131 pub fn bearer_token(&self) -> String {
132 format!("Bearer {}", self.token)
133 }
134
135 #[must_use]
138 pub fn is_likely_valid(&self) -> bool {
139 Utc::now() < self.expires_at
140 }
141}
142
143#[derive(Debug, Deserialize)]
147pub struct DataCloudTokenResponse {
148 pub access_token: String,
150
151 pub instance_url: String,
153
154 #[serde(default)]
156 pub token_type: Option<String>,
157
158 #[serde(default)]
160 pub expires_in: Option<i64>,
161
162 #[serde(default)]
164 pub error: Option<String>,
165
166 #[serde(default)]
168 pub error_description: Option<String>,
169}
170
171impl DataCloudTokenResponse {
172 pub fn check_error(&self) -> SalesforceAuthResult<()> {
174 if let (Some(code), Some(desc)) = (&self.error, &self.error_description) {
175 return Err(SalesforceAuthError::authorization(
176 code.clone(),
177 desc.clone(),
178 ));
179 }
180 if self.access_token.is_empty() {
181 return Err(SalesforceAuthError::token_parse(
182 "missing access_token in DC JWT response",
183 ));
184 }
185 Ok(())
186 }
187}
188
189#[derive(Debug, Clone)]
200pub struct DataCloudToken {
201 token_type: String,
203 token: String,
205 tenant_url: Url,
207 created_at: DateTime<Utc>,
209 expires_at: DateTime<Utc>,
211}
212
213impl DataCloudToken {
214 pub fn from_response(response: DataCloudTokenResponse) -> SalesforceAuthResult<Self> {
224 response.check_error()?;
225
226 let instance_url_with_scheme = if response.instance_url.starts_with("http://")
227 || response.instance_url.starts_with("https://")
228 {
229 response.instance_url.clone()
230 } else {
231 format!("https://{}", response.instance_url)
232 };
233
234 let tenant_url = Url::parse(&instance_url_with_scheme)
235 .map_err(|e| SalesforceAuthError::token_parse(format!("invalid instance_url: {e}")))?;
236
237 let token_type = response.token_type.unwrap_or_else(|| "Bearer".to_string());
238
239 let now = Utc::now();
240 let expires_in_secs = response.expires_in.unwrap_or(1800);
242 let expires_at = now + Duration::seconds(expires_in_secs);
243
244 Ok(DataCloudToken {
245 token_type,
246 token: response.access_token,
247 tenant_url,
248 created_at: now,
249 expires_at,
250 })
251 }
252
253 #[must_use]
257 pub fn bearer_token(&self) -> String {
258 format!("{} {}", self.token_type, self.token)
259 }
260
261 #[must_use]
263 pub fn access_token(&self) -> &str {
264 &self.token
265 }
266
267 #[must_use]
269 pub fn token_type(&self) -> &str {
270 &self.token_type
271 }
272
273 #[must_use]
275 pub fn tenant_url(&self) -> &Url {
276 &self.tenant_url
277 }
278
279 #[must_use]
281 pub fn tenant_url_str(&self) -> &str {
282 self.tenant_url.as_str()
283 }
284
285 #[must_use]
287 pub fn created_at(&self) -> DateTime<Utc> {
288 self.created_at
289 }
290
291 #[must_use]
293 pub fn expires_at(&self) -> DateTime<Utc> {
294 self.expires_at
295 }
296
297 #[must_use]
299 pub fn age(&self) -> Duration {
300 Utc::now().signed_duration_since(self.created_at)
301 }
302
303 #[must_use]
305 pub fn remaining_lifetime(&self) -> Duration {
306 self.expires_at.signed_duration_since(Utc::now())
307 }
308
309 #[must_use]
315 pub fn is_valid(&self) -> bool {
316 self.expires_at > Utc::now() + Duration::seconds(DC_JWT_VALIDITY_BUFFER_SECS)
317 }
318
319 #[must_use]
321 pub fn is_expired(&self) -> bool {
322 self.expires_at <= Utc::now()
323 }
324
325 #[must_use]
340 pub fn needs_refresh(&self, threshold_secs: i64, max_age_secs: i64) -> bool {
341 let now = Utc::now();
342 let expiring = (self.expires_at - now).num_seconds() <= threshold_secs;
343 let too_old = (now - self.created_at).num_seconds() > max_age_secs;
344 expiring || too_old
345 }
346
347 pub fn tenant_id(&self) -> SalesforceAuthResult<String> {
362 let parts: Vec<&str> = self.token.split('.').collect();
363 if parts.len() != 3 {
364 return Err(SalesforceAuthError::token_parse(
365 "invalid DC JWT format: expected 3 parts",
366 ));
367 }
368
369 let payload_b64 = parts[1];
370 let payload_bytes = base64_url_decode(payload_b64)?;
371 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)?;
372
373 payload
374 .get("audienceTenantId")
375 .and_then(|v| v.as_str())
376 .map(std::string::ToString::to_string)
377 .ok_or_else(|| {
378 SalesforceAuthError::token_parse("missing audienceTenantId in DC JWT payload")
379 })
380 }
381
382 pub fn lakehouse_name(&self, dataspace: Option<&str>) -> SalesforceAuthResult<String> {
392 let tenant_id = self.tenant_id()?;
393 let dataspace_str = dataspace.unwrap_or("");
394 Ok(format!("lakehouse:{tenant_id};{dataspace_str}"))
395 }
396}
397
398fn base64_url_decode(input: &str) -> SalesforceAuthResult<Vec<u8>> {
400 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
401
402 let padded = match input.len() % 4 {
404 2 => format!("{input}=="),
405 3 => format!("{input}="),
406 _ => input.to_string(),
407 };
408
409 URL_SAFE_NO_PAD
410 .decode(padded.trim_end_matches('='))
411 .map_err(|e| SalesforceAuthError::token_parse(format!("base64 decode error: {e}")))
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417
418 #[test]
419 fn test_oauth_access_token_response_error() {
420 let response = OAuthTokenResponse {
421 access_token: String::new(),
422 instance_url: String::new(),
423 token_type: None,
424 scope: None,
425 issued_at: None,
426 error: Some("invalid_grant".to_string()),
427 error_description: Some("authentication failure".to_string()),
428 };
429
430 let result = response.check_error();
431 assert!(result.is_err());
432 if let Err(SalesforceAuthError::Authorization { error_code, .. }) = result {
433 assert_eq!(error_code, "invalid_grant");
434 } else {
435 panic!("expected Authorization error");
436 }
437 }
438
439 #[test]
440 fn test_oauth_access_token_from_response() {
441 let response = OAuthTokenResponse {
442 access_token: "oauth_access_tok_123".to_string(),
443 instance_url: "https://na1.salesforce.com".to_string(),
444 token_type: Some("Bearer".to_string()),
445 scope: None,
446 issued_at: None,
447 error: None,
448 error_description: None,
449 };
450
451 let token = OAuthToken::from_response(response).unwrap();
452 assert_eq!(token.token, "oauth_access_tok_123");
453 assert_eq!(token.instance_url.as_str(), "https://na1.salesforce.com/");
454 assert!(token.is_likely_valid());
455 assert_eq!(token.bearer_token(), "Bearer oauth_access_tok_123");
456 }
457
458 #[test]
459 fn test_dc_jwt_validity() {
460 let response = DataCloudTokenResponse {
461 access_token: "test.token.here".to_string(),
462 instance_url: "https://tenant.salesforce.com".to_string(),
463 token_type: Some("Bearer".to_string()),
464 expires_in: Some(3600), error: None,
466 error_description: None,
467 };
468
469 let token = DataCloudToken::from_response(response).unwrap();
470 assert!(token.is_valid());
471 assert!(!token.is_expired());
472 assert_eq!(token.bearer_token(), "Bearer test.token.here");
473 assert!(token.age().num_seconds() < 2);
474 assert!(token.remaining_lifetime().num_seconds() > 3500);
475 }
476
477 #[test]
478 fn test_dc_jwt_needs_refresh_when_fresh() {
479 let response = DataCloudTokenResponse {
480 access_token: "fresh.dc.jwt".to_string(),
481 instance_url: "https://tenant.salesforce.com".to_string(),
482 token_type: Some("Bearer".to_string()),
483 expires_in: Some(7200),
484 error: None,
485 error_description: None,
486 };
487
488 let token = DataCloudToken::from_response(response).unwrap();
489 assert!(!token.needs_refresh(300, 900));
492 }
493
494 #[test]
495 fn test_dc_jwt_needs_refresh_near_expiry() {
496 let response = DataCloudTokenResponse {
497 access_token: "expiring.dc.jwt".to_string(),
498 instance_url: "https://tenant.salesforce.com".to_string(),
499 token_type: Some("Bearer".to_string()),
500 expires_in: Some(200), error: None,
502 error_description: None,
503 };
504
505 let token = DataCloudToken::from_response(response).unwrap();
506 assert!(token.needs_refresh(300, 900));
508 }
509
510 #[test]
511 fn test_dc_jwt_needs_refresh_too_old() {
512 let mut token = DataCloudToken::from_response(DataCloudTokenResponse {
514 access_token: "old.dc.jwt".to_string(),
515 instance_url: "https://tenant.salesforce.com".to_string(),
516 token_type: Some("Bearer".to_string()),
517 expires_in: Some(7200),
518 error: None,
519 error_description: None,
520 })
521 .unwrap();
522
523 token.created_at = Utc::now() - Duration::minutes(20);
525
526 assert!(token.needs_refresh(300, 900));
528 }
529
530 #[test]
531 fn test_dc_jwt_created_at_tracked() {
532 let before = Utc::now();
533 let response = DataCloudTokenResponse {
534 access_token: "dc.jwt.value".to_string(),
535 instance_url: "https://tenant.salesforce.com".to_string(),
536 token_type: Some("Bearer".to_string()),
537 expires_in: Some(3600),
538 error: None,
539 error_description: None,
540 };
541 let token = DataCloudToken::from_response(response).unwrap();
542 let after = Utc::now();
543
544 assert!(token.created_at() >= before);
545 assert!(token.created_at() <= after);
546 }
547
548 #[test]
549 fn test_dc_jwt_is_valid_uses_5min_buffer() {
550 let response = DataCloudTokenResponse {
553 access_token: "almost.expired.jwt".to_string(),
554 instance_url: "https://tenant.salesforce.com".to_string(),
555 token_type: Some("Bearer".to_string()),
556 expires_in: Some(240), error: None,
558 error_description: None,
559 };
560
561 let token = DataCloudToken::from_response(response).unwrap();
562 assert!(!token.is_valid());
563 assert!(!token.is_expired()); let response2 = DataCloudTokenResponse {
567 access_token: "still.valid.jwt".to_string(),
568 instance_url: "https://tenant.salesforce.com".to_string(),
569 token_type: Some("Bearer".to_string()),
570 expires_in: Some(360), error: None,
572 error_description: None,
573 };
574
575 let token2 = DataCloudToken::from_response(response2).unwrap();
576 assert!(token2.is_valid());
577 }
578}