Skip to main content

cirrus_auth/
refresh.rs

1//! OAuth 2.0 Refresh Token grant for long-lived Salesforce sessions.
2//!
3//! Several Salesforce OAuth flows hand back a `refresh_token` alongside
4//! the initial access token. Refresh tokens are long-lived and can be
5//! exchanged for fresh access tokens indefinitely (until revoked). This
6//! module wraps that grant in an [`AuthSession`] so the rest of the SDK
7//! doesn't care which flow originally produced the refresh token.
8//!
9//! ## Usage
10//!
11//! Perform the initial OAuth exchange to obtain a `refresh_token` and
12//! `instance_url`, build a [`RefreshTokenAuth`] with those values, and
13//! hand it (wrapped in `Arc<dyn AuthSession>`) to a Cirrus client. New
14//! access tokens are minted on demand by hitting
15//! `/services/oauth2/token` with `grant_type=refresh_token`.
16//!
17//! ## Confidential vs public clients
18//!
19//! Connected Apps configured as **confidential clients** require a
20//! `client_secret` on every refresh; **public clients** (PKCE-based) do
21//! not. The builder treats `consumer_secret` as optional — set it for
22//! confidential clients, omit it for public.
23//!
24//! ## Token rotation
25//!
26//! Refresh tokens are not rotated; the same token is reused across
27//! refreshes.
28
29use crate::AuthSession;
30use crate::error::{AuthError, AuthResult};
31use crate::token_endpoint::{check_instance_url, exchange};
32use async_trait::async_trait;
33use std::borrow::Cow;
34use std::time::{Duration, Instant};
35use tokio::sync::RwLock;
36
37/// Salesforce production login URL — also the default token-exchange host.
38pub const PRODUCTION_LOGIN_URL: &str = "https://login.salesforce.com";
39
40/// Salesforce sandbox login URL.
41pub const SANDBOX_LOGIN_URL: &str = "https://test.salesforce.com";
42
43/// Default cache TTL for an access token after it's issued.
44const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(30 * 60);
45
46#[derive(Clone)]
47struct CachedToken {
48    access_token: String,
49    expires_at: Instant,
50}
51
52impl std::fmt::Debug for CachedToken {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("CachedToken")
55            .field("access_token", &"[redacted]")
56            .field("expires_at", &self.expires_at)
57            .finish()
58    }
59}
60
61/// Refresh-token-grant auth session.
62///
63/// Construct via [`RefreshTokenAuth::builder`].
64pub struct RefreshTokenAuth {
65    consumer_key: String,
66    consumer_secret: Option<String>,
67    refresh_token: String,
68    login_url: String,
69    instance_url: String,
70    token_ttl: Duration,
71    http: reqwest::Client,
72    cached: RwLock<Option<CachedToken>>,
73}
74
75impl std::fmt::Debug for RefreshTokenAuth {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        // Omit consumer_key, consumer_secret, and refresh_token — all secrets.
78        f.debug_struct("RefreshTokenAuth")
79            .field("login_url", &self.login_url)
80            .field("instance_url", &self.instance_url)
81            .field("token_ttl", &self.token_ttl)
82            .field("confidential", &self.consumer_secret.is_some())
83            .finish_non_exhaustive()
84    }
85}
86
87impl RefreshTokenAuth {
88    /// Begins constructing a [`RefreshTokenAuth`].
89    ///
90    /// Refresh-token grant (RFC 6749 §6): once an access token is
91    /// obtained through any flow that issues a refresh token (typically
92    /// Web Server with PKCE), use that refresh token to mint new access
93    /// tokens at will. The refresh token itself is long-lived.
94    ///
95    /// # Example
96    ///
97    /// ```no_run
98    /// use cirrus_auth::RefreshTokenAuth;
99    /// use std::sync::Arc;
100    ///
101    /// # fn example() -> Result<(), cirrus_auth::AuthError> {
102    /// let auth = RefreshTokenAuth::builder()
103    ///     .consumer_key("3MVG9...")
104    ///     .refresh_token("5Aep861...")
105    ///     .login_url("https://login.salesforce.com")
106    ///     .instance_url("https://my-org.my.salesforce.com")
107    ///     .build()?;
108    /// // Wrap as Arc<dyn AuthSession> and hand to a Cirrus client.
109    /// let _shared = Arc::new(auth);
110    /// # Ok(())
111    /// # }
112    /// ```
113    pub fn builder() -> RefreshTokenAuthBuilder {
114        RefreshTokenAuthBuilder::default()
115    }
116
117    async fn mint_token(&self) -> AuthResult<CachedToken> {
118        tracing::info!(
119            target: "cirrus::auth",
120            flow = "refresh-token",
121            login_url = %self.login_url,
122            "minting fresh access token",
123        );
124        // Compose the form body. consumer_secret is conditional on whether
125        // the connected app is confidential.
126        let mut body: Vec<(&str, &str)> = vec![
127            ("grant_type", "refresh_token"),
128            ("client_id", self.consumer_key.as_str()),
129            ("refresh_token", self.refresh_token.as_str()),
130        ];
131        if let Some(secret) = self.consumer_secret.as_deref() {
132            body.push(("client_secret", secret));
133        }
134
135        let token = exchange(&self.http, &self.login_url, &body).await?;
136        check_instance_url(&self.instance_url, &token)?;
137
138        Ok(CachedToken {
139            access_token: token.access_token,
140            expires_at: Instant::now() + self.token_ttl,
141        })
142    }
143}
144
145#[async_trait]
146impl AuthSession for RefreshTokenAuth {
147    async fn access_token(&self) -> AuthResult<Cow<'_, str>> {
148        // Fast path — read lock, return clone of cached token if still valid.
149        {
150            let guard = self.cached.read().await;
151            if let Some(cached) = guard.as_ref()
152                && cached.expires_at > Instant::now()
153            {
154                return Ok(Cow::Owned(cached.access_token.clone()));
155            }
156        }
157
158        // Slow path — write lock, double-check, mint.
159        let mut guard = self.cached.write().await;
160        if let Some(cached) = guard.as_ref()
161            && cached.expires_at > Instant::now()
162        {
163            return Ok(Cow::Owned(cached.access_token.clone()));
164        }
165        let new_token = self.mint_token().await?;
166        let token_str = new_token.access_token.clone();
167        *guard = Some(new_token);
168        Ok(Cow::Owned(token_str))
169    }
170
171    fn instance_url(&self) -> &str {
172        &self.instance_url
173    }
174
175    async fn invalidate(&self, stale_token: &str) {
176        // Compare-and-swap: only clear the cached access token if it
177        // still matches what the failing request used. The underlying
178        // refresh_token isn't affected — we only ever want the
179        // *short-lived* access token re-minted.
180        let mut guard = self.cached.write().await;
181        if let Some(cached) = guard.as_ref()
182            && cached.access_token == stale_token
183        {
184            tracing::debug!(
185                target: "cirrus::auth",
186                flow = "refresh-token",
187                "invalidating cached token (CAS matched)",
188            );
189            *guard = None;
190        } else {
191            tracing::trace!(
192                target: "cirrus::auth",
193                flow = "refresh-token",
194                "invalidate called but cached token differs (concurrent refresh?); no-op",
195            );
196        }
197    }
198}
199
200/// Builder for [`RefreshTokenAuth`].
201#[derive(Default)]
202pub struct RefreshTokenAuthBuilder {
203    consumer_key: Option<String>,
204    consumer_secret: Option<String>,
205    refresh_token: Option<String>,
206    login_url: Option<String>,
207    instance_url: Option<String>,
208    token_ttl: Option<Duration>,
209    http_client: Option<reqwest::Client>,
210}
211
212impl std::fmt::Debug for RefreshTokenAuthBuilder {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        f.debug_struct("RefreshTokenAuthBuilder")
215            .field("consumer_key", &self.consumer_key.is_some())
216            .field("consumer_secret", &self.consumer_secret.is_some())
217            .field("refresh_token", &self.refresh_token.is_some())
218            .field("login_url", &self.login_url)
219            .field("instance_url", &self.instance_url)
220            .field("token_ttl", &self.token_ttl)
221            .finish_non_exhaustive()
222    }
223}
224
225impl RefreshTokenAuthBuilder {
226    /// Connected App's Consumer Key (Client ID). Required.
227    pub fn consumer_key(mut self, key: impl Into<String>) -> Self {
228        self.consumer_key = Some(key.into());
229        self
230    }
231
232    /// Connected App's Consumer Secret (Client Secret). Required for
233    /// confidential clients; omit for public/PKCE clients.
234    pub fn consumer_secret(mut self, secret: impl Into<String>) -> Self {
235        self.consumer_secret = Some(secret.into());
236        self
237    }
238
239    /// Refresh token issued by a prior OAuth flow (Web Server, Device,
240    /// User-Agent). Required.
241    pub fn refresh_token(mut self, token: impl Into<String>) -> Self {
242        self.refresh_token = Some(token.into());
243        self
244    }
245
246    /// Login URL — the host that issued the refresh token. Defaults to
247    /// [`PRODUCTION_LOGIN_URL`]. Use [`SANDBOX_LOGIN_URL`] for sandboxes,
248    /// or your org's My Domain login URL where required.
249    pub fn login_url(mut self, url: impl Into<String>) -> Self {
250        self.login_url = Some(url.into());
251        self
252    }
253
254    /// REST instance URL — the org's My Domain. Required. Must match the
255    /// `instance_url` returned by the token-exchange response.
256    pub fn instance_url(mut self, url: impl Into<String>) -> Self {
257        self.instance_url = Some(url.into());
258        self
259    }
260
261    /// How long to cache an access token before re-minting. Defaults to 30
262    /// minutes.
263    pub fn token_ttl(mut self, ttl: Duration) -> Self {
264        self.token_ttl = Some(ttl);
265        self
266    }
267
268    /// Supplies a pre-configured `reqwest::Client`. Useful for sharing a
269    /// connection pool.
270    pub fn http_client(mut self, client: reqwest::Client) -> Self {
271        self.http_client = Some(client);
272        self
273    }
274
275    /// Finalizes the builder.
276    pub fn build(self) -> AuthResult<RefreshTokenAuth> {
277        let consumer_key = self
278            .consumer_key
279            .ok_or(AuthError::MissingField("consumer_key"))?;
280        let refresh_token = self
281            .refresh_token
282            .ok_or(AuthError::MissingField("refresh_token"))?;
283        let mut instance_url = self
284            .instance_url
285            .ok_or(AuthError::MissingField("instance_url"))?;
286        if instance_url.ends_with('/') {
287            instance_url.pop();
288        }
289        let mut login_url = self
290            .login_url
291            .unwrap_or_else(|| PRODUCTION_LOGIN_URL.to_string());
292        if login_url.ends_with('/') {
293            login_url.pop();
294        }
295        let token_ttl = self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL);
296        let http = self.http_client.unwrap_or_default();
297
298        Ok(RefreshTokenAuth {
299            consumer_key,
300            consumer_secret: self.consumer_secret,
301            refresh_token,
302            login_url,
303            instance_url,
304            token_ttl,
305            http,
306            cached: RwLock::new(None),
307        })
308    }
309}
310
311#[cfg(test)]
312#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
313mod tests {
314    use super::*;
315    use std::sync::Arc;
316    use std::sync::atomic::{AtomicUsize, Ordering};
317    use wiremock::matchers::{body_string_contains, method, path};
318    use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
319
320    fn builder_with_required_fields() -> RefreshTokenAuthBuilder {
321        RefreshTokenAuth::builder()
322            .consumer_key("consumer-key-123")
323            .refresh_token("5Aep861KIwKdekr...refresh")
324            .instance_url("https://my-org.my.salesforce.com")
325    }
326
327    #[test]
328    fn builder_requires_consumer_key() {
329        let err = RefreshTokenAuth::builder()
330            .refresh_token("r")
331            .instance_url("https://x")
332            .build()
333            .unwrap_err();
334        assert!(matches!(err, AuthError::MissingField("consumer_key")));
335    }
336
337    #[test]
338    fn builder_requires_refresh_token() {
339        let err = RefreshTokenAuth::builder()
340            .consumer_key("k")
341            .instance_url("https://x")
342            .build()
343            .unwrap_err();
344        assert!(matches!(err, AuthError::MissingField("refresh_token")));
345    }
346
347    #[test]
348    fn builder_requires_instance_url() {
349        let err = RefreshTokenAuth::builder()
350            .consumer_key("k")
351            .refresh_token("r")
352            .build()
353            .unwrap_err();
354        assert!(matches!(err, AuthError::MissingField("instance_url")));
355    }
356
357    #[test]
358    fn builder_strips_trailing_slashes_and_defaults_login_url() {
359        let auth = builder_with_required_fields()
360            .instance_url("https://my-org.my.salesforce.com/")
361            .build()
362            .unwrap();
363        assert_eq!(auth.instance_url(), "https://my-org.my.salesforce.com");
364        assert_eq!(auth.login_url, PRODUCTION_LOGIN_URL);
365    }
366
367    #[tokio::test]
368    async fn refresh_succeeds_and_caches() {
369        let server = MockServer::start().await;
370        let hits = Arc::new(AtomicUsize::new(0));
371
372        Mock::given(method("POST"))
373            .and(path("/services/oauth2/token"))
374            .and(body_string_contains("grant_type=refresh_token"))
375            .and(body_string_contains("client_id=consumer-key-123"))
376            .and(body_string_contains("refresh_token=5Aep861KIwKdekr"))
377            .respond_with(CountingResponder {
378                hits: hits.clone(),
379                response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
380                    "access_token": "00DXX!ACCESS",
381                    "instance_url": "https://my-org.my.salesforce.com",
382                    "token_type": "Bearer",
383                    "id": "https://login.salesforce.com/id/00DXX/005XX",
384                })),
385            })
386            .mount(&server)
387            .await;
388
389        let auth = builder_with_required_fields()
390            .login_url(server.uri())
391            .build()
392            .unwrap();
393
394        let t1 = auth.access_token().await.unwrap();
395        assert_eq!(&*t1, "00DXX!ACCESS");
396        let t2 = auth.access_token().await.unwrap();
397        assert_eq!(&*t2, "00DXX!ACCESS");
398        assert_eq!(hits.load(Ordering::SeqCst), 1);
399    }
400
401    #[tokio::test]
402    async fn confidential_client_includes_consumer_secret() {
403        let server = MockServer::start().await;
404        Mock::given(method("POST"))
405            .and(path("/services/oauth2/token"))
406            .and(body_string_contains("client_secret=top-secret"))
407            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
408                "access_token": "tok",
409                "instance_url": "https://my-org.my.salesforce.com"
410            })))
411            .mount(&server)
412            .await;
413
414        let auth = builder_with_required_fields()
415            .consumer_secret("top-secret")
416            .login_url(server.uri())
417            .build()
418            .unwrap();
419
420        // The body matcher above asserts client_secret is present. If it
421        // weren't, the mock would 404 and this would error.
422        auth.access_token().await.unwrap();
423    }
424
425    #[tokio::test]
426    async fn public_client_omits_consumer_secret() {
427        let server = MockServer::start().await;
428        // Match a body that does NOT include client_secret. wiremock has no
429        // direct "does not contain" matcher, so we rely on the structure:
430        // assert presence of grant_type and absence is verified by total
431        // body inspection in the responder.
432        let received_body = Arc::new(tokio::sync::Mutex::new(String::new()));
433        let captured = received_body.clone();
434
435        Mock::given(method("POST"))
436            .and(path("/services/oauth2/token"))
437            .respond_with(BodyCapturingResponder {
438                captured,
439                response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
440                    "access_token": "tok",
441                    "instance_url": "https://my-org.my.salesforce.com"
442                })),
443            })
444            .mount(&server)
445            .await;
446
447        let auth = builder_with_required_fields()
448            .login_url(server.uri())
449            .build()
450            .unwrap();
451        auth.access_token().await.unwrap();
452
453        let body = received_body.lock().await;
454        assert!(
455            !body.contains("client_secret"),
456            "public client should not send client_secret, got: {body}"
457        );
458    }
459
460    #[tokio::test]
461    async fn expired_cache_remints_token() {
462        let server = MockServer::start().await;
463        let hits = Arc::new(AtomicUsize::new(0));
464
465        Mock::given(method("POST"))
466            .and(path("/services/oauth2/token"))
467            .respond_with(CountingResponder {
468                hits: hits.clone(),
469                response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
470                    "access_token": "tok",
471                    "instance_url": "https://my-org.my.salesforce.com"
472                })),
473            })
474            .mount(&server)
475            .await;
476
477        let auth = builder_with_required_fields()
478            .login_url(server.uri())
479            .token_ttl(Duration::ZERO)
480            .build()
481            .unwrap();
482
483        let _ = auth.access_token().await.unwrap();
484        let _ = auth.access_token().await.unwrap();
485        let _ = auth.access_token().await.unwrap();
486        assert_eq!(hits.load(Ordering::SeqCst), 3);
487    }
488
489    #[tokio::test]
490    async fn revoked_refresh_token_surfaces_oauth_error() {
491        let server = MockServer::start().await;
492        Mock::given(method("POST"))
493            .and(path("/services/oauth2/token"))
494            .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
495                "error": "invalid_grant",
496                "error_description": "expired authorization code"
497            })))
498            .mount(&server)
499            .await;
500
501        let auth = builder_with_required_fields()
502            .login_url(server.uri())
503            .build()
504            .unwrap();
505
506        let err = auth.access_token().await.unwrap_err();
507        match err {
508            AuthError::OAuth {
509                error,
510                error_description,
511            } => {
512                assert_eq!(error, "invalid_grant");
513                assert!(error_description.is_some());
514            }
515            other => panic!("expected OAuth error, got {other:?}"),
516        }
517    }
518
519    #[tokio::test]
520    async fn instance_url_mismatch_is_an_auth_error() {
521        let server = MockServer::start().await;
522        Mock::given(method("POST"))
523            .and(path("/services/oauth2/token"))
524            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
525                "access_token": "tok",
526                "instance_url": "https://wrong-org.my.salesforce.com"
527            })))
528            .mount(&server)
529            .await;
530
531        let auth = builder_with_required_fields()
532            .login_url(server.uri())
533            .build()
534            .unwrap();
535
536        let err = auth.access_token().await.unwrap_err();
537        assert!(matches!(err, AuthError::Other(_)));
538    }
539
540    /// Counts invocations and returns a fixed response. Same as the JWT
541    /// tests' helper — duplicated rather than shared to keep test modules
542    /// self-contained.
543    struct CountingResponder {
544        hits: Arc<AtomicUsize>,
545        response: ResponseTemplate,
546    }
547
548    impl Respond for CountingResponder {
549        fn respond(&self, _: &Request) -> ResponseTemplate {
550            self.hits.fetch_add(1, Ordering::SeqCst);
551            self.response.clone()
552        }
553    }
554
555    /// Captures the request body for inspection. Used to assert that
556    /// `client_secret` is absent in the public-client case.
557    struct BodyCapturingResponder {
558        captured: Arc<tokio::sync::Mutex<String>>,
559        response: ResponseTemplate,
560    }
561
562    impl Respond for BodyCapturingResponder {
563        fn respond(&self, request: &Request) -> ResponseTemplate {
564            let body = String::from_utf8_lossy(&request.body).into_owned();
565            // try_lock works because this responder is invoked in the
566            // request-handling task; the test reads after access_token returns.
567            if let Ok(mut guard) = self.captured.try_lock() {
568                *guard = body;
569            }
570            self.response.clone()
571        }
572    }
573}