1use crate::AuthSession;
30use crate::error::{AuthError, AuthResult};
31use crate::token_endpoint::{check_instance_url, exchange};
32use async_trait::async_trait;
33use std::borrow::Cow;
34use std::time::{Duration, Instant};
35use tokio::sync::RwLock;
36
37pub const PRODUCTION_LOGIN_URL: &str = "https://login.salesforce.com";
39
40pub const SANDBOX_LOGIN_URL: &str = "https://test.salesforce.com";
42
43const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(30 * 60);
45
46#[derive(Clone)]
47struct CachedToken {
48 access_token: String,
49 expires_at: Instant,
50}
51
52impl std::fmt::Debug for CachedToken {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("CachedToken")
55 .field("access_token", &"[redacted]")
56 .field("expires_at", &self.expires_at)
57 .finish()
58 }
59}
60
61pub struct RefreshTokenAuth {
65 consumer_key: String,
66 consumer_secret: Option<String>,
67 refresh_token: String,
68 login_url: String,
69 instance_url: String,
70 token_ttl: Duration,
71 http: reqwest::Client,
72 cached: RwLock<Option<CachedToken>>,
73}
74
75impl std::fmt::Debug for RefreshTokenAuth {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("RefreshTokenAuth")
79 .field("login_url", &self.login_url)
80 .field("instance_url", &self.instance_url)
81 .field("token_ttl", &self.token_ttl)
82 .field("confidential", &self.consumer_secret.is_some())
83 .finish_non_exhaustive()
84 }
85}
86
87impl RefreshTokenAuth {
88 pub fn builder() -> RefreshTokenAuthBuilder {
114 RefreshTokenAuthBuilder::default()
115 }
116
117 async fn mint_token(&self) -> AuthResult<CachedToken> {
118 tracing::info!(
119 target: "cirrus::auth",
120 flow = "refresh-token",
121 login_url = %self.login_url,
122 "minting fresh access token",
123 );
124 let mut body: Vec<(&str, &str)> = vec![
127 ("grant_type", "refresh_token"),
128 ("client_id", self.consumer_key.as_str()),
129 ("refresh_token", self.refresh_token.as_str()),
130 ];
131 if let Some(secret) = self.consumer_secret.as_deref() {
132 body.push(("client_secret", secret));
133 }
134
135 let token = exchange(&self.http, &self.login_url, &body).await?;
136 check_instance_url(&self.instance_url, &token)?;
137
138 Ok(CachedToken {
139 access_token: token.access_token,
140 expires_at: Instant::now() + self.token_ttl,
141 })
142 }
143}
144
145#[async_trait]
146impl AuthSession for RefreshTokenAuth {
147 async fn access_token(&self) -> AuthResult<Cow<'_, str>> {
148 {
150 let guard = self.cached.read().await;
151 if let Some(cached) = guard.as_ref()
152 && cached.expires_at > Instant::now()
153 {
154 return Ok(Cow::Owned(cached.access_token.clone()));
155 }
156 }
157
158 let mut guard = self.cached.write().await;
160 if let Some(cached) = guard.as_ref()
161 && cached.expires_at > Instant::now()
162 {
163 return Ok(Cow::Owned(cached.access_token.clone()));
164 }
165 let new_token = self.mint_token().await?;
166 let token_str = new_token.access_token.clone();
167 *guard = Some(new_token);
168 Ok(Cow::Owned(token_str))
169 }
170
171 fn instance_url(&self) -> &str {
172 &self.instance_url
173 }
174
175 async fn invalidate(&self, stale_token: &str) {
176 let mut guard = self.cached.write().await;
181 if let Some(cached) = guard.as_ref()
182 && cached.access_token == stale_token
183 {
184 tracing::debug!(
185 target: "cirrus::auth",
186 flow = "refresh-token",
187 "invalidating cached token (CAS matched)",
188 );
189 *guard = None;
190 } else {
191 tracing::trace!(
192 target: "cirrus::auth",
193 flow = "refresh-token",
194 "invalidate called but cached token differs (concurrent refresh?); no-op",
195 );
196 }
197 }
198}
199
200#[derive(Default)]
202pub struct RefreshTokenAuthBuilder {
203 consumer_key: Option<String>,
204 consumer_secret: Option<String>,
205 refresh_token: Option<String>,
206 login_url: Option<String>,
207 instance_url: Option<String>,
208 token_ttl: Option<Duration>,
209 http_client: Option<reqwest::Client>,
210}
211
212impl std::fmt::Debug for RefreshTokenAuthBuilder {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 f.debug_struct("RefreshTokenAuthBuilder")
215 .field("consumer_key", &self.consumer_key.is_some())
216 .field("consumer_secret", &self.consumer_secret.is_some())
217 .field("refresh_token", &self.refresh_token.is_some())
218 .field("login_url", &self.login_url)
219 .field("instance_url", &self.instance_url)
220 .field("token_ttl", &self.token_ttl)
221 .finish_non_exhaustive()
222 }
223}
224
225impl RefreshTokenAuthBuilder {
226 pub fn consumer_key(mut self, key: impl Into<String>) -> Self {
228 self.consumer_key = Some(key.into());
229 self
230 }
231
232 pub fn consumer_secret(mut self, secret: impl Into<String>) -> Self {
235 self.consumer_secret = Some(secret.into());
236 self
237 }
238
239 pub fn refresh_token(mut self, token: impl Into<String>) -> Self {
242 self.refresh_token = Some(token.into());
243 self
244 }
245
246 pub fn login_url(mut self, url: impl Into<String>) -> Self {
250 self.login_url = Some(url.into());
251 self
252 }
253
254 pub fn instance_url(mut self, url: impl Into<String>) -> Self {
257 self.instance_url = Some(url.into());
258 self
259 }
260
261 pub fn token_ttl(mut self, ttl: Duration) -> Self {
264 self.token_ttl = Some(ttl);
265 self
266 }
267
268 pub fn http_client(mut self, client: reqwest::Client) -> Self {
271 self.http_client = Some(client);
272 self
273 }
274
275 pub fn build(self) -> AuthResult<RefreshTokenAuth> {
277 let consumer_key = self
278 .consumer_key
279 .ok_or(AuthError::MissingField("consumer_key"))?;
280 let refresh_token = self
281 .refresh_token
282 .ok_or(AuthError::MissingField("refresh_token"))?;
283 let mut instance_url = self
284 .instance_url
285 .ok_or(AuthError::MissingField("instance_url"))?;
286 if instance_url.ends_with('/') {
287 instance_url.pop();
288 }
289 let mut login_url = self
290 .login_url
291 .unwrap_or_else(|| PRODUCTION_LOGIN_URL.to_string());
292 if login_url.ends_with('/') {
293 login_url.pop();
294 }
295 let token_ttl = self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL);
296 let http = self.http_client.unwrap_or_default();
297
298 Ok(RefreshTokenAuth {
299 consumer_key,
300 consumer_secret: self.consumer_secret,
301 refresh_token,
302 login_url,
303 instance_url,
304 token_ttl,
305 http,
306 cached: RwLock::new(None),
307 })
308 }
309}
310
311#[cfg(test)]
312#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
313mod tests {
314 use super::*;
315 use std::sync::Arc;
316 use std::sync::atomic::{AtomicUsize, Ordering};
317 use wiremock::matchers::{body_string_contains, method, path};
318 use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
319
320 fn builder_with_required_fields() -> RefreshTokenAuthBuilder {
321 RefreshTokenAuth::builder()
322 .consumer_key("consumer-key-123")
323 .refresh_token("5Aep861KIwKdekr...refresh")
324 .instance_url("https://my-org.my.salesforce.com")
325 }
326
327 #[test]
328 fn builder_requires_consumer_key() {
329 let err = RefreshTokenAuth::builder()
330 .refresh_token("r")
331 .instance_url("https://x")
332 .build()
333 .unwrap_err();
334 assert!(matches!(err, AuthError::MissingField("consumer_key")));
335 }
336
337 #[test]
338 fn builder_requires_refresh_token() {
339 let err = RefreshTokenAuth::builder()
340 .consumer_key("k")
341 .instance_url("https://x")
342 .build()
343 .unwrap_err();
344 assert!(matches!(err, AuthError::MissingField("refresh_token")));
345 }
346
347 #[test]
348 fn builder_requires_instance_url() {
349 let err = RefreshTokenAuth::builder()
350 .consumer_key("k")
351 .refresh_token("r")
352 .build()
353 .unwrap_err();
354 assert!(matches!(err, AuthError::MissingField("instance_url")));
355 }
356
357 #[test]
358 fn builder_strips_trailing_slashes_and_defaults_login_url() {
359 let auth = builder_with_required_fields()
360 .instance_url("https://my-org.my.salesforce.com/")
361 .build()
362 .unwrap();
363 assert_eq!(auth.instance_url(), "https://my-org.my.salesforce.com");
364 assert_eq!(auth.login_url, PRODUCTION_LOGIN_URL);
365 }
366
367 #[tokio::test]
368 async fn refresh_succeeds_and_caches() {
369 let server = MockServer::start().await;
370 let hits = Arc::new(AtomicUsize::new(0));
371
372 Mock::given(method("POST"))
373 .and(path("/services/oauth2/token"))
374 .and(body_string_contains("grant_type=refresh_token"))
375 .and(body_string_contains("client_id=consumer-key-123"))
376 .and(body_string_contains("refresh_token=5Aep861KIwKdekr"))
377 .respond_with(CountingResponder {
378 hits: hits.clone(),
379 response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
380 "access_token": "00DXX!ACCESS",
381 "instance_url": "https://my-org.my.salesforce.com",
382 "token_type": "Bearer",
383 "id": "https://login.salesforce.com/id/00DXX/005XX",
384 })),
385 })
386 .mount(&server)
387 .await;
388
389 let auth = builder_with_required_fields()
390 .login_url(server.uri())
391 .build()
392 .unwrap();
393
394 let t1 = auth.access_token().await.unwrap();
395 assert_eq!(&*t1, "00DXX!ACCESS");
396 let t2 = auth.access_token().await.unwrap();
397 assert_eq!(&*t2, "00DXX!ACCESS");
398 assert_eq!(hits.load(Ordering::SeqCst), 1);
399 }
400
401 #[tokio::test]
402 async fn confidential_client_includes_consumer_secret() {
403 let server = MockServer::start().await;
404 Mock::given(method("POST"))
405 .and(path("/services/oauth2/token"))
406 .and(body_string_contains("client_secret=top-secret"))
407 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
408 "access_token": "tok",
409 "instance_url": "https://my-org.my.salesforce.com"
410 })))
411 .mount(&server)
412 .await;
413
414 let auth = builder_with_required_fields()
415 .consumer_secret("top-secret")
416 .login_url(server.uri())
417 .build()
418 .unwrap();
419
420 auth.access_token().await.unwrap();
423 }
424
425 #[tokio::test]
426 async fn public_client_omits_consumer_secret() {
427 let server = MockServer::start().await;
428 let received_body = Arc::new(tokio::sync::Mutex::new(String::new()));
433 let captured = received_body.clone();
434
435 Mock::given(method("POST"))
436 .and(path("/services/oauth2/token"))
437 .respond_with(BodyCapturingResponder {
438 captured,
439 response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
440 "access_token": "tok",
441 "instance_url": "https://my-org.my.salesforce.com"
442 })),
443 })
444 .mount(&server)
445 .await;
446
447 let auth = builder_with_required_fields()
448 .login_url(server.uri())
449 .build()
450 .unwrap();
451 auth.access_token().await.unwrap();
452
453 let body = received_body.lock().await;
454 assert!(
455 !body.contains("client_secret"),
456 "public client should not send client_secret, got: {body}"
457 );
458 }
459
460 #[tokio::test]
461 async fn expired_cache_remints_token() {
462 let server = MockServer::start().await;
463 let hits = Arc::new(AtomicUsize::new(0));
464
465 Mock::given(method("POST"))
466 .and(path("/services/oauth2/token"))
467 .respond_with(CountingResponder {
468 hits: hits.clone(),
469 response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
470 "access_token": "tok",
471 "instance_url": "https://my-org.my.salesforce.com"
472 })),
473 })
474 .mount(&server)
475 .await;
476
477 let auth = builder_with_required_fields()
478 .login_url(server.uri())
479 .token_ttl(Duration::ZERO)
480 .build()
481 .unwrap();
482
483 let _ = auth.access_token().await.unwrap();
484 let _ = auth.access_token().await.unwrap();
485 let _ = auth.access_token().await.unwrap();
486 assert_eq!(hits.load(Ordering::SeqCst), 3);
487 }
488
489 #[tokio::test]
490 async fn revoked_refresh_token_surfaces_oauth_error() {
491 let server = MockServer::start().await;
492 Mock::given(method("POST"))
493 .and(path("/services/oauth2/token"))
494 .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
495 "error": "invalid_grant",
496 "error_description": "expired authorization code"
497 })))
498 .mount(&server)
499 .await;
500
501 let auth = builder_with_required_fields()
502 .login_url(server.uri())
503 .build()
504 .unwrap();
505
506 let err = auth.access_token().await.unwrap_err();
507 match err {
508 AuthError::OAuth {
509 error,
510 error_description,
511 } => {
512 assert_eq!(error, "invalid_grant");
513 assert!(error_description.is_some());
514 }
515 other => panic!("expected OAuth error, got {other:?}"),
516 }
517 }
518
519 #[tokio::test]
520 async fn instance_url_mismatch_is_an_auth_error() {
521 let server = MockServer::start().await;
522 Mock::given(method("POST"))
523 .and(path("/services/oauth2/token"))
524 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
525 "access_token": "tok",
526 "instance_url": "https://wrong-org.my.salesforce.com"
527 })))
528 .mount(&server)
529 .await;
530
531 let auth = builder_with_required_fields()
532 .login_url(server.uri())
533 .build()
534 .unwrap();
535
536 let err = auth.access_token().await.unwrap_err();
537 assert!(matches!(err, AuthError::Other(_)));
538 }
539
540 struct CountingResponder {
544 hits: Arc<AtomicUsize>,
545 response: ResponseTemplate,
546 }
547
548 impl Respond for CountingResponder {
549 fn respond(&self, _: &Request) -> ResponseTemplate {
550 self.hits.fetch_add(1, Ordering::SeqCst);
551 self.response.clone()
552 }
553 }
554
555 struct BodyCapturingResponder {
558 captured: Arc<tokio::sync::Mutex<String>>,
559 response: ResponseTemplate,
560 }
561
562 impl Respond for BodyCapturingResponder {
563 fn respond(&self, request: &Request) -> ResponseTemplate {
564 let body = String::from_utf8_lossy(&request.body).into_owned();
565 if let Ok(mut guard) = self.captured.try_lock() {
568 *guard = body;
569 }
570 self.response.clone()
571 }
572 }
573}