Skip to main content

hyperdb_api_salesforce/
token.rs

1// Copyright (c) 2026, Salesforce, Inc. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Token types for Salesforce Data Cloud authentication.
5//!
6//! This module defines the three token types used in the Salesforce Data Cloud
7//! authentication flow:
8//!
9//! 1. **OAuth Refresh Token** → used to obtain an OAuth Access Token (not modeled here;
10//!    it is a configuration input via [`AuthMode::RefreshToken`](super::config::AuthMode))
11//! 2. **OAuth Access Token** ([`OAuthToken`]) → obtained from Salesforce `/services/oauth2/token`
12//! 3. **DC JWT** ([`DataCloudToken`]) → obtained by exchanging the OAuth Access Token
13//!    at `/services/a360/token`, sent as `Authorization` header with every gRPC call
14
15use chrono::{DateTime, Duration, Utc};
16use serde::Deserialize;
17use url::Url;
18
19use crate::error::{SalesforceAuthError, SalesforceAuthResult};
20
21/// Default validity buffer for [`DataCloudToken::is_valid`].
22///
23/// A DC JWT is considered invalid when it has fewer than this many seconds
24/// of remaining lifetime. This provides a safety margin so callers never
25/// use a token that is about to expire.
26const DC_JWT_VALIDITY_BUFFER_SECS: i64 = 300;
27
28/// OAuth Access Token response from Salesforce `/services/oauth2/token`.
29///
30/// See: <https://help.salesforce.com/s/articleView?id=sf.remoteaccess_oauth_jwt_flow.htm>
31#[derive(Debug, Deserialize)]
32pub struct OAuthTokenResponse {
33    /// OAuth Access Token
34    pub access_token: String,
35
36    /// Salesforce instance URL (e.g., "<https://na1.salesforce.com>")
37    pub instance_url: String,
38
39    /// Token type (usually "Bearer")
40    #[serde(default)]
41    pub token_type: Option<String>,
42
43    /// Token scope
44    #[serde(default)]
45    pub scope: Option<String>,
46
47    /// When the OAuth Access Token was issued (Unix timestamp in milliseconds)
48    #[serde(default)]
49    pub issued_at: Option<String>,
50
51    /// Error code (present on failure)
52    #[serde(default)]
53    pub error: Option<String>,
54
55    /// Error description (present on failure)
56    #[serde(default)]
57    pub error_description: Option<String>,
58}
59
60impl OAuthTokenResponse {
61    /// Checks if the response contains an error.
62    pub fn check_error(&self) -> SalesforceAuthResult<()> {
63        if let (Some(code), Some(desc)) = (&self.error, &self.error_description) {
64            return Err(SalesforceAuthError::Authorization {
65                error_code: code.clone(),
66                error_description: desc.clone(),
67            });
68        }
69        if self.access_token.is_empty() {
70            return Err(SalesforceAuthError::TokenParse(
71                "missing access_token in OAuth Access Token response".to_string(),
72            ));
73        }
74        Ok(())
75    }
76}
77
78/// Parsed OAuth Access Token with Salesforce instance URL.
79///
80/// Obtained from `/services/oauth2/token`. This token is exchanged for a
81/// DC JWT via `/services/a360/token`.
82#[derive(Debug, Clone)]
83pub struct OAuthToken {
84    /// OAuth Access Token value
85    pub token: String,
86    /// Salesforce instance URL (used as base URL for the DC JWT exchange)
87    pub instance_url: Url,
88    /// When this OAuth Access Token was obtained
89    pub obtained_at: DateTime<Utc>,
90    /// Estimated expiry (Salesforce reports ~2 hours, but server-side
91    /// inactivity timeout can invalidate it earlier)
92    pub expires_at: DateTime<Utc>,
93}
94
95/// Default OAuth Access Token lifetime in seconds.
96///
97/// Salesforce reports `access-token-expires-in: 7199` (~2 hours), but the
98/// server-side session can be invalidated earlier by the org's inactivity
99/// timeout (commonly 15 min – 2 hours).
100const OAUTH_ACCESS_TOKEN_DEFAULT_LIFETIME_SECS: i64 = 7199;
101
102impl OAuthToken {
103    /// Creates an OAuth Access Token from a response.
104    ///
105    /// # Errors
106    ///
107    /// - Returns [`SalesforceAuthError::Authorization`] if `response`
108    ///   carries both `error` and `error_description` fields (via
109    ///   `OAuthTokenResponse::check_error`).
110    /// - Returns [`SalesforceAuthError::TokenParse`] if `response.access_token`
111    ///   is empty, or if `response.instance_url` cannot be parsed as a URL.
112    pub fn from_response(response: OAuthTokenResponse) -> SalesforceAuthResult<Self> {
113        response.check_error()?;
114
115        let instance_url = Url::parse(&response.instance_url)
116            .map_err(|e| SalesforceAuthError::TokenParse(format!("invalid instance_url: {e}")))?;
117
118        let now = Utc::now();
119        let expires_at = now + Duration::seconds(OAUTH_ACCESS_TOKEN_DEFAULT_LIFETIME_SECS);
120
121        Ok(OAuthToken {
122            token: response.access_token,
123            instance_url,
124            obtained_at: now,
125            expires_at,
126        })
127    }
128
129    /// Returns the bearer token string (e.g., "Bearer abc123...").
130    #[must_use]
131    pub fn bearer_token(&self) -> String {
132        format!("Bearer {}", self.token)
133    }
134
135    /// Returns `true` if the OAuth Access Token has not yet reached its
136    /// estimated expiry time.
137    #[must_use]
138    pub fn is_likely_valid(&self) -> bool {
139        Utc::now() < self.expires_at
140    }
141}
142
143/// DC JWT response from `/services/a360/token`.
144///
145/// See: <https://developer.salesforce.com/docs/atlas.en-us.c360a_api.meta/c360a_api/c360a_getting_started_with_cdp.htm>
146#[derive(Debug, Deserialize)]
147pub struct DataCloudTokenResponse {
148    /// DC JWT value
149    pub access_token: String,
150
151    /// Data Cloud instance URL (tenant URL)
152    pub instance_url: String,
153
154    /// Token type (usually "Bearer")
155    #[serde(default)]
156    pub token_type: Option<String>,
157
158    /// DC JWT expiration time in seconds
159    #[serde(default)]
160    pub expires_in: Option<i64>,
161
162    /// Error code (present on failure)
163    #[serde(default)]
164    pub error: Option<String>,
165
166    /// Error description (present on failure)
167    #[serde(default)]
168    pub error_description: Option<String>,
169}
170
171impl DataCloudTokenResponse {
172    /// Checks if the response contains an error.
173    pub fn check_error(&self) -> SalesforceAuthResult<()> {
174        if let (Some(code), Some(desc)) = (&self.error, &self.error_description) {
175            return Err(SalesforceAuthError::Authorization {
176                error_code: code.clone(),
177                error_description: desc.clone(),
178            });
179        }
180        if self.access_token.is_empty() {
181            return Err(SalesforceAuthError::TokenParse(
182                "missing access_token in DC JWT response".to_string(),
183            ));
184        }
185        Ok(())
186    }
187}
188
189/// Data Cloud JWT (DC JWT) for Hyper gRPC authentication.
190///
191/// Obtained by exchanging an OAuth Access Token at `/services/a360/token`.
192/// Sent as the `Authorization: Bearer <jwt>` header with every gRPC call
193/// to the Hyper query engine.
194///
195/// The DC JWT has a ~2-hour lifetime (`exp` claim), but is proactively
196/// refreshed much earlier (every ~15 minutes by default) so that the
197/// underlying OAuth Access Token is revalidated before Salesforce's
198/// server-side inactivity timeout can invalidate it.
199#[derive(Debug, Clone)]
200pub struct DataCloudToken {
201    /// Token type (e.g., "Bearer")
202    token_type: String,
203    /// DC JWT value
204    token: String,
205    /// Data Cloud tenant URL
206    tenant_url: Url,
207    /// When this DC JWT was obtained (used for maxAge-based proactive refresh)
208    created_at: DateTime<Utc>,
209    /// DC JWT expiration time (from `expires_in` in the response)
210    expires_at: DateTime<Utc>,
211}
212
213impl DataCloudToken {
214    /// Creates a DC JWT from a `/services/a360/token` response.
215    ///
216    /// # Errors
217    ///
218    /// - Returns [`SalesforceAuthError::Authorization`] if `response`
219    ///   carries both `error` and `error_description` fields.
220    /// - Returns [`SalesforceAuthError::TokenParse`] if `response.access_token`
221    ///   is empty, or if `response.instance_url` cannot be parsed as a URL
222    ///   (after prepending `https://` when the scheme is missing).
223    pub fn from_response(response: DataCloudTokenResponse) -> SalesforceAuthResult<Self> {
224        response.check_error()?;
225
226        let instance_url_with_scheme = if response.instance_url.starts_with("http://")
227            || response.instance_url.starts_with("https://")
228        {
229            response.instance_url.clone()
230        } else {
231            format!("https://{}", response.instance_url)
232        };
233
234        let tenant_url = Url::parse(&instance_url_with_scheme)
235            .map_err(|e| SalesforceAuthError::TokenParse(format!("invalid instance_url: {e}")))?;
236
237        let token_type = response.token_type.unwrap_or_else(|| "Bearer".to_string());
238
239        let now = Utc::now();
240        // Default to 30 minutes if Salesforce doesn't report expires_in
241        let expires_in_secs = response.expires_in.unwrap_or(1800);
242        let expires_at = now + Duration::seconds(expires_in_secs);
243
244        Ok(DataCloudToken {
245            token_type,
246            token: response.access_token,
247            tenant_url,
248            created_at: now,
249            expires_at,
250        })
251    }
252
253    /// Returns the bearer token string for the `Authorization` header.
254    ///
255    /// Format: `"Bearer <dc_jwt>"`
256    #[must_use]
257    pub fn bearer_token(&self) -> String {
258        format!("{} {}", self.token_type, self.token)
259    }
260
261    /// Returns just the DC JWT value (without the type prefix).
262    #[must_use]
263    pub fn access_token(&self) -> &str {
264        &self.token
265    }
266
267    /// Returns the token type (e.g., "Bearer").
268    #[must_use]
269    pub fn token_type(&self) -> &str {
270        &self.token_type
271    }
272
273    /// Returns the Data Cloud tenant URL.
274    #[must_use]
275    pub fn tenant_url(&self) -> &Url {
276        &self.tenant_url
277    }
278
279    /// Returns the tenant URL as a string (for the `audience` gRPC header).
280    #[must_use]
281    pub fn tenant_url_str(&self) -> &str {
282        self.tenant_url.as_str()
283    }
284
285    /// Returns when this DC JWT was obtained.
286    #[must_use]
287    pub fn created_at(&self) -> DateTime<Utc> {
288        self.created_at
289    }
290
291    /// Returns the DC JWT expiration time.
292    #[must_use]
293    pub fn expires_at(&self) -> DateTime<Utc> {
294        self.expires_at
295    }
296
297    /// Returns the age of this DC JWT (time since it was obtained).
298    #[must_use]
299    pub fn age(&self) -> Duration {
300        Utc::now().signed_duration_since(self.created_at)
301    }
302
303    /// Returns the remaining lifetime of this DC JWT.
304    #[must_use]
305    pub fn remaining_lifetime(&self) -> Duration {
306        self.expires_at.signed_duration_since(Utc::now())
307    }
308
309    /// Checks if the DC JWT is still valid (not expired).
310    ///
311    /// Returns `true` if the DC JWT has at least 300 seconds (5 minutes) of
312    /// remaining lifetime. This buffer ensures callers never use a DC JWT
313    /// that is about to expire.
314    #[must_use]
315    pub fn is_valid(&self) -> bool {
316        self.expires_at > Utc::now() + Duration::seconds(DC_JWT_VALIDITY_BUFFER_SECS)
317    }
318
319    /// Checks if the DC JWT is expired.
320    #[must_use]
321    pub fn is_expired(&self) -> bool {
322        self.expires_at <= Utc::now()
323    }
324
325    /// Checks if the DC JWT should be proactively refreshed.
326    ///
327    /// Mirrors the C++ `IsDCJWTExpiringSoon` logic: returns `true` when
328    /// either the DC JWT is near its hard expiry OR it has exceeded the
329    /// maximum age. This ensures:
330    /// - The OAuth Access Token is revalidated regularly (catching
331    ///   server-side inactivity timeouts)
332    /// - The DC JWT is replaced well before its ~2-hour hard expiry
333    ///
334    /// # Arguments
335    /// * `threshold_secs` - Refresh when the DC JWT has fewer than this
336    ///   many seconds remaining (default: 300 = 5 minutes)
337    /// * `max_age_secs` - Refresh when the DC JWT is older than this
338    ///   many seconds (default: 900 = 15 minutes)
339    #[must_use]
340    pub fn needs_refresh(&self, threshold_secs: i64, max_age_secs: i64) -> bool {
341        let now = Utc::now();
342        let expiring = (self.expires_at - now).num_seconds() <= threshold_secs;
343        let too_old = (now - self.created_at).num_seconds() > max_age_secs;
344        expiring || too_old
345    }
346
347    /// Extracts the tenant ID from the DC JWT payload.
348    ///
349    /// The tenant ID is stored in the `audienceTenantId` claim of the JWT.
350    ///
351    /// # Errors
352    ///
353    /// Returns [`SalesforceAuthError::TokenParse`] if:
354    /// - The JWT does not have exactly three dot-separated parts.
355    /// - The payload segment is not valid base64url (via
356    ///   `base64_url_decode`).
357    /// - The decoded payload is not valid JSON (via the [`From`] conversion
358    ///   from [`serde_json::Error`]).
359    /// - The payload object is missing a string-valued `audienceTenantId`
360    ///   claim.
361    pub fn tenant_id(&self) -> SalesforceAuthResult<String> {
362        let parts: Vec<&str> = self.token.split('.').collect();
363        if parts.len() != 3 {
364            return Err(SalesforceAuthError::TokenParse(
365                "invalid DC JWT format: expected 3 parts".to_string(),
366            ));
367        }
368
369        let payload_b64 = parts[1];
370        let payload_bytes = base64_url_decode(payload_b64)?;
371        let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)?;
372
373        payload
374            .get("audienceTenantId")
375            .and_then(|v| v.as_str())
376            .map(std::string::ToString::to_string)
377            .ok_or_else(|| {
378                SalesforceAuthError::TokenParse(
379                    "missing audienceTenantId in DC JWT payload".to_string(),
380                )
381            })
382    }
383
384    /// Returns the lakehouse name for Hyper connection.
385    ///
386    /// Format: `"lakehouse:<tenant_id>;<dataspace>"`
387    ///
388    /// # Errors
389    ///
390    /// Propagates any [`SalesforceAuthError::TokenParse`] from
391    /// [`Self::tenant_id`] (malformed JWT structure, non-base64url payload,
392    /// non-JSON payload, or missing `audienceTenantId` claim).
393    pub fn lakehouse_name(&self, dataspace: Option<&str>) -> SalesforceAuthResult<String> {
394        let tenant_id = self.tenant_id()?;
395        let dataspace_str = dataspace.unwrap_or("");
396        Ok(format!("lakehouse:{tenant_id};{dataspace_str}"))
397    }
398}
399
400/// Decodes a base64url-encoded string.
401fn base64_url_decode(input: &str) -> SalesforceAuthResult<Vec<u8>> {
402    use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
403
404    // Handle both padded and unpadded base64url
405    let padded = match input.len() % 4 {
406        2 => format!("{input}=="),
407        3 => format!("{input}="),
408        _ => input.to_string(),
409    };
410
411    URL_SAFE_NO_PAD
412        .decode(padded.trim_end_matches('='))
413        .map_err(|e| SalesforceAuthError::TokenParse(format!("base64 decode error: {e}")))
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn test_oauth_access_token_response_error() {
422        let response = OAuthTokenResponse {
423            access_token: String::new(),
424            instance_url: String::new(),
425            token_type: None,
426            scope: None,
427            issued_at: None,
428            error: Some("invalid_grant".to_string()),
429            error_description: Some("authentication failure".to_string()),
430        };
431
432        let result = response.check_error();
433        assert!(result.is_err());
434        if let Err(SalesforceAuthError::Authorization { error_code, .. }) = result {
435            assert_eq!(error_code, "invalid_grant");
436        } else {
437            panic!("expected Authorization error");
438        }
439    }
440
441    #[test]
442    fn test_oauth_access_token_from_response() {
443        let response = OAuthTokenResponse {
444            access_token: "oauth_access_tok_123".to_string(),
445            instance_url: "https://na1.salesforce.com".to_string(),
446            token_type: Some("Bearer".to_string()),
447            scope: None,
448            issued_at: None,
449            error: None,
450            error_description: None,
451        };
452
453        let token = OAuthToken::from_response(response).unwrap();
454        assert_eq!(token.token, "oauth_access_tok_123");
455        assert_eq!(token.instance_url.as_str(), "https://na1.salesforce.com/");
456        assert!(token.is_likely_valid());
457        assert_eq!(token.bearer_token(), "Bearer oauth_access_tok_123");
458    }
459
460    #[test]
461    fn test_dc_jwt_validity() {
462        let response = DataCloudTokenResponse {
463            access_token: "test.token.here".to_string(),
464            instance_url: "https://tenant.salesforce.com".to_string(),
465            token_type: Some("Bearer".to_string()),
466            expires_in: Some(3600), // 1 hour
467            error: None,
468            error_description: None,
469        };
470
471        let token = DataCloudToken::from_response(response).unwrap();
472        assert!(token.is_valid());
473        assert!(!token.is_expired());
474        assert_eq!(token.bearer_token(), "Bearer test.token.here");
475        assert!(token.age().num_seconds() < 2);
476        assert!(token.remaining_lifetime().num_seconds() > 3500);
477    }
478
479    #[test]
480    fn test_dc_jwt_needs_refresh_when_fresh() {
481        let response = DataCloudTokenResponse {
482            access_token: "fresh.dc.jwt".to_string(),
483            instance_url: "https://tenant.salesforce.com".to_string(),
484            token_type: Some("Bearer".to_string()),
485            expires_in: Some(7200),
486            error: None,
487            error_description: None,
488        };
489
490        let token = DataCloudToken::from_response(response).unwrap();
491        // Fresh DC JWT (age ~0s): should NOT need refresh
492        // threshold=300s (5min), maxAge=900s (15min)
493        assert!(!token.needs_refresh(300, 900));
494    }
495
496    #[test]
497    fn test_dc_jwt_needs_refresh_near_expiry() {
498        let response = DataCloudTokenResponse {
499            access_token: "expiring.dc.jwt".to_string(),
500            instance_url: "https://tenant.salesforce.com".to_string(),
501            token_type: Some("Bearer".to_string()),
502            expires_in: Some(200), // expires in 200s (< 300s threshold)
503            error: None,
504            error_description: None,
505        };
506
507        let token = DataCloudToken::from_response(response).unwrap();
508        // DC JWT with <300s remaining: SHOULD need refresh (expiring check)
509        assert!(token.needs_refresh(300, 900));
510    }
511
512    #[test]
513    fn test_dc_jwt_needs_refresh_too_old() {
514        // Simulate an old DC JWT by backdating created_at
515        let mut token = DataCloudToken::from_response(DataCloudTokenResponse {
516            access_token: "old.dc.jwt".to_string(),
517            instance_url: "https://tenant.salesforce.com".to_string(),
518            token_type: Some("Bearer".to_string()),
519            expires_in: Some(7200),
520            error: None,
521            error_description: None,
522        })
523        .unwrap();
524
525        // Backdate created_at by 20 minutes (> 900s maxAge)
526        token.created_at = Utc::now() - Duration::minutes(20);
527
528        // DC JWT still has plenty of lifetime but is too old: SHOULD need refresh
529        assert!(token.needs_refresh(300, 900));
530    }
531
532    #[test]
533    fn test_dc_jwt_created_at_tracked() {
534        let before = Utc::now();
535        let response = DataCloudTokenResponse {
536            access_token: "dc.jwt.value".to_string(),
537            instance_url: "https://tenant.salesforce.com".to_string(),
538            token_type: Some("Bearer".to_string()),
539            expires_in: Some(3600),
540            error: None,
541            error_description: None,
542        };
543        let token = DataCloudToken::from_response(response).unwrap();
544        let after = Utc::now();
545
546        assert!(token.created_at() >= before);
547        assert!(token.created_at() <= after);
548    }
549
550    #[test]
551    fn test_dc_jwt_is_valid_uses_5min_buffer() {
552        // A DC JWT with exactly 4 minutes remaining should NOT be considered valid
553        // (below the 300s / 5-minute buffer)
554        let response = DataCloudTokenResponse {
555            access_token: "almost.expired.jwt".to_string(),
556            instance_url: "https://tenant.salesforce.com".to_string(),
557            token_type: Some("Bearer".to_string()),
558            expires_in: Some(240), // 4 minutes
559            error: None,
560            error_description: None,
561        };
562
563        let token = DataCloudToken::from_response(response).unwrap();
564        assert!(!token.is_valid());
565        assert!(!token.is_expired()); // not yet hard-expired
566
567        // A DC JWT with 6 minutes remaining SHOULD be valid
568        let response2 = DataCloudTokenResponse {
569            access_token: "still.valid.jwt".to_string(),
570            instance_url: "https://tenant.salesforce.com".to_string(),
571            token_type: Some("Bearer".to_string()),
572            expires_in: Some(360), // 6 minutes
573            error: None,
574            error_description: None,
575        };
576
577        let token2 = DataCloudToken::from_response(response2).unwrap();
578        assert!(token2.is_valid());
579    }
580}