1use crate::error::{AuthError, AuthResult};
49use crate::token_endpoint::exchange;
50
51pub const GRANT_TYPE_TOKEN_EXCHANGE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
53
54pub const GRANT_TYPE_HYBRID_TOKEN_EXCHANGE: &str =
56 "urn:ietf:params:oauth:grant-type:hybrid-token-exchange";
57
58#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
60pub enum TokenExchangeGrantType {
61 #[default]
63 TokenExchange,
64 HybridTokenExchange,
67}
68
69impl TokenExchangeGrantType {
70 pub fn as_urn(self) -> &'static str {
72 match self {
73 Self::TokenExchange => GRANT_TYPE_TOKEN_EXCHANGE,
74 Self::HybridTokenExchange => GRANT_TYPE_HYBRID_TOKEN_EXCHANGE,
75 }
76 }
77}
78
79#[derive(Debug, Clone, PartialEq, Eq)]
87pub enum SubjectTokenType {
88 AccessToken,
90 RefreshToken,
92 IdToken,
94 Saml2,
96 Jwt,
98 Custom(String),
100}
101
102impl SubjectTokenType {
103 pub fn as_urn(&self) -> &str {
105 match self {
106 Self::AccessToken => "urn:ietf:params:oauth:token-type:access_token",
107 Self::RefreshToken => "urn:ietf:params:oauth:token-type:refresh_token",
108 Self::IdToken => "urn:ietf:params:oauth:token-type:id_token",
109 Self::Saml2 => "urn:ietf:params:oauth:token-type:saml2",
110 Self::Jwt => "urn:ietf:params:oauth:token-type:jwt",
111 Self::Custom(s) => s,
112 }
113 }
114}
115
116#[derive(Debug)]
120pub struct TokenExchangeFlow {
121 consumer_key: String,
122 consumer_secret: Option<String>,
123 login_url: String,
124 subject_token: String,
125 subject_token_type: SubjectTokenType,
126 grant_type: TokenExchangeGrantType,
127 scopes: Vec<String>,
128 token_handler: Option<String>,
129 http: reqwest::Client,
130}
131
132impl TokenExchangeFlow {
133 pub fn builder() -> TokenExchangeFlowBuilder {
135 TokenExchangeFlowBuilder::default()
136 }
137
138 pub async fn exchange(self) -> AuthResult<TokenExchangeSession> {
144 let scope_joined;
145 let mut body: Vec<(&str, &str)> = vec![
146 ("grant_type", self.grant_type.as_urn()),
147 ("subject_token", self.subject_token.as_str()),
148 ("subject_token_type", self.subject_token_type.as_urn()),
149 ("client_id", self.consumer_key.as_str()),
150 ];
151 if let Some(secret) = self.consumer_secret.as_deref() {
152 body.push(("client_secret", secret));
153 }
154 if !self.scopes.is_empty() {
155 scope_joined = self.scopes.join(" ");
156 body.push(("scope", scope_joined.as_str()));
157 }
158 if let Some(handler) = self.token_handler.as_deref() {
159 body.push(("token_handler", handler));
160 }
161
162 let token = exchange(&self.http, &self.login_url, &body).await?;
163 Ok(TokenExchangeSession {
164 access_token: token.access_token,
165 refresh_token: token.refresh_token,
166 id_token: token.id_token,
167 instance_url: token.instance_url,
168 issued_at: token.issued_at,
169 scope: token.scope,
170 id: token.id,
171 signature: token.signature,
172 })
173 }
174}
175
176#[derive(Debug, Clone)]
178pub struct TokenExchangeSession {
179 pub access_token: String,
181 pub refresh_token: Option<String>,
184 pub id_token: Option<String>,
186 pub instance_url: String,
188 pub issued_at: Option<String>,
191 pub scope: Option<String>,
193 pub id: Option<String>,
197 pub signature: Option<String>,
201}
202
203#[derive(Default)]
205pub struct TokenExchangeFlowBuilder {
206 consumer_key: Option<String>,
207 consumer_secret: Option<String>,
208 login_url: Option<String>,
209 subject_token: Option<String>,
210 subject_token_type: Option<SubjectTokenType>,
211 grant_type: Option<TokenExchangeGrantType>,
212 scopes: Vec<String>,
213 token_handler: Option<String>,
214 http_client: Option<reqwest::Client>,
215}
216
217impl std::fmt::Debug for TokenExchangeFlowBuilder {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 f.debug_struct("TokenExchangeFlowBuilder")
220 .field("consumer_key", &self.consumer_key.is_some())
221 .field("consumer_secret", &self.consumer_secret.is_some())
222 .field("login_url", &self.login_url)
223 .field("subject_token", &self.subject_token.is_some())
224 .field("subject_token_type", &self.subject_token_type)
225 .field("grant_type", &self.grant_type)
226 .field("scopes", &self.scopes)
227 .field("token_handler", &self.token_handler)
228 .finish_non_exhaustive()
229 }
230}
231
232impl TokenExchangeFlowBuilder {
233 pub fn consumer_key(mut self, key: impl Into<String>) -> Self {
235 self.consumer_key = Some(key.into());
236 self
237 }
238
239 pub fn consumer_secret(mut self, secret: impl Into<String>) -> Self {
244 self.consumer_secret = Some(secret.into());
245 self
246 }
247
248 pub fn login_url(mut self, url: impl Into<String>) -> Self {
252 self.login_url = Some(url.into());
253 self
254 }
255
256 pub fn subject_token(mut self, token: impl Into<String>) -> Self {
260 self.subject_token = Some(token.into());
261 self
262 }
263
264 pub fn subject_token_type(mut self, ty: SubjectTokenType) -> Self {
266 self.subject_token_type = Some(ty);
267 self
268 }
269
270 pub fn grant_type(mut self, gt: TokenExchangeGrantType) -> Self {
275 self.grant_type = Some(gt);
276 self
277 }
278
279 pub fn scope(mut self, scope: impl Into<String>) -> Self {
282 self.scopes.push(scope.into());
283 self
284 }
285
286 pub fn scopes<I, S>(mut self, scopes: I) -> Self
288 where
289 I: IntoIterator<Item = S>,
290 S: Into<String>,
291 {
292 self.scopes = scopes.into_iter().map(Into::into).collect();
293 self
294 }
295
296 pub fn token_handler(mut self, name: impl Into<String>) -> Self {
300 self.token_handler = Some(name.into());
301 self
302 }
303
304 pub fn http_client(mut self, client: reqwest::Client) -> Self {
306 self.http_client = Some(client);
307 self
308 }
309
310 pub fn build(self) -> AuthResult<TokenExchangeFlow> {
312 let consumer_key = self
313 .consumer_key
314 .ok_or(AuthError::MissingField("consumer_key"))?;
315 let subject_token = self
316 .subject_token
317 .ok_or(AuthError::MissingField("subject_token"))?;
318 let subject_token_type = self
319 .subject_token_type
320 .ok_or(AuthError::MissingField("subject_token_type"))?;
321 let mut login_url = self.login_url.ok_or(AuthError::MissingField("login_url"))?;
322 if login_url.ends_with('/') {
323 login_url.pop();
324 }
325 let http = self.http_client.unwrap_or_default();
326 Ok(TokenExchangeFlow {
327 consumer_key,
328 consumer_secret: self.consumer_secret,
329 login_url,
330 subject_token,
331 subject_token_type,
332 grant_type: self.grant_type.unwrap_or_default(),
333 scopes: self.scopes,
334 token_handler: self.token_handler,
335 http,
336 })
337 }
338}
339
340#[cfg(test)]
341#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
342mod tests {
343 use super::*;
344 use std::sync::Arc;
345 use wiremock::matchers::{body_string_contains, method, path};
346 use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
347
348 fn builder_with_required_fields() -> TokenExchangeFlowBuilder {
349 TokenExchangeFlow::builder()
350 .consumer_key("consumer-key-123")
351 .login_url("https://my-org.my.salesforce.com")
352 .subject_token("idp-issued-token-xyz")
353 .subject_token_type(SubjectTokenType::AccessToken)
354 }
355
356 #[test]
357 fn subject_token_type_urns_match_rfc_8693() {
358 assert_eq!(
359 SubjectTokenType::AccessToken.as_urn(),
360 "urn:ietf:params:oauth:token-type:access_token"
361 );
362 assert_eq!(
363 SubjectTokenType::RefreshToken.as_urn(),
364 "urn:ietf:params:oauth:token-type:refresh_token"
365 );
366 assert_eq!(
367 SubjectTokenType::IdToken.as_urn(),
368 "urn:ietf:params:oauth:token-type:id_token"
369 );
370 assert_eq!(
371 SubjectTokenType::Saml2.as_urn(),
372 "urn:ietf:params:oauth:token-type:saml2"
373 );
374 assert_eq!(
375 SubjectTokenType::Jwt.as_urn(),
376 "urn:ietf:params:oauth:token-type:jwt"
377 );
378 assert_eq!(
379 SubjectTokenType::Custom("urn:custom:foo".into()).as_urn(),
380 "urn:custom:foo"
381 );
382 }
383
384 #[test]
385 fn grant_type_urns_match_spec() {
386 assert_eq!(
387 TokenExchangeGrantType::TokenExchange.as_urn(),
388 "urn:ietf:params:oauth:grant-type:token-exchange"
389 );
390 assert_eq!(
391 TokenExchangeGrantType::HybridTokenExchange.as_urn(),
392 "urn:ietf:params:oauth:grant-type:hybrid-token-exchange"
393 );
394 }
395
396 #[test]
397 fn builder_requires_consumer_key() {
398 let err = TokenExchangeFlow::builder()
399 .login_url("https://x")
400 .subject_token("t")
401 .subject_token_type(SubjectTokenType::Jwt)
402 .build()
403 .unwrap_err();
404 assert!(matches!(err, AuthError::MissingField("consumer_key")));
405 }
406
407 #[test]
408 fn builder_requires_login_url() {
409 let err = TokenExchangeFlow::builder()
410 .consumer_key("k")
411 .subject_token("t")
412 .subject_token_type(SubjectTokenType::Jwt)
413 .build()
414 .unwrap_err();
415 assert!(matches!(err, AuthError::MissingField("login_url")));
416 }
417
418 #[test]
419 fn builder_requires_subject_token() {
420 let err = TokenExchangeFlow::builder()
421 .consumer_key("k")
422 .login_url("https://x")
423 .subject_token_type(SubjectTokenType::Jwt)
424 .build()
425 .unwrap_err();
426 assert!(matches!(err, AuthError::MissingField("subject_token")));
427 }
428
429 #[test]
430 fn builder_requires_subject_token_type() {
431 let err = TokenExchangeFlow::builder()
432 .consumer_key("k")
433 .login_url("https://x")
434 .subject_token("t")
435 .build()
436 .unwrap_err();
437 assert!(matches!(err, AuthError::MissingField("subject_token_type")));
438 }
439
440 #[test]
441 fn builder_strips_trailing_slash_on_login_url() {
442 let flow = builder_with_required_fields()
443 .login_url("https://my-org.my.salesforce.com/")
444 .build()
445 .unwrap();
446 assert_eq!(flow.login_url, "https://my-org.my.salesforce.com");
447 }
448
449 #[tokio::test]
450 async fn exchange_sends_required_params() {
451 let server = MockServer::start().await;
452 Mock::given(method("POST"))
453 .and(path("/services/oauth2/token"))
454 .and(body_string_contains(
455 "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange",
456 ))
457 .and(body_string_contains("subject_token=idp-issued-token-xyz"))
458 .and(body_string_contains(
459 "subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token",
460 ))
461 .and(body_string_contains("client_id=consumer-key-123"))
462 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
463 "access_token": "00DXX!ACCESS",
464 "instance_url": "https://my-org.my.salesforce.com",
465 "token_type": "Bearer",
466 "scope": "api refresh_token",
467 "issued_at": "1700000000000",
468 "id": "https://login.salesforce.com/id/00DXX/005XX",
469 "signature": "abcdef==",
470 })))
471 .mount(&server)
472 .await;
473
474 let session = builder_with_required_fields()
475 .login_url(server.uri())
476 .build()
477 .unwrap()
478 .exchange()
479 .await
480 .unwrap();
481 assert_eq!(session.access_token, "00DXX!ACCESS");
482 assert_eq!(session.instance_url, "https://my-org.my.salesforce.com");
483 assert_eq!(session.scope.as_deref(), Some("api refresh_token"));
484 assert_eq!(session.issued_at.as_deref(), Some("1700000000000"));
485 assert_eq!(
489 session.id.as_deref(),
490 Some("https://login.salesforce.com/id/00DXX/005XX")
491 );
492 assert_eq!(session.signature.as_deref(), Some("abcdef=="));
493 }
494
495 #[tokio::test]
496 async fn exchange_includes_optional_params_when_set() {
497 let server = MockServer::start().await;
498 Mock::given(method("POST"))
499 .and(path("/services/oauth2/token"))
500 .and(body_string_contains("client_secret=hunter2"))
501 .and(body_string_contains("scope=api+refresh_token"))
502 .and(body_string_contains("token_handler=MyHandler"))
503 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
504 "access_token": "tok",
505 "instance_url": "https://my-org.my.salesforce.com",
506 "id_token": "eyJ...",
507 "refresh_token": "5Aep861...",
508 })))
509 .mount(&server)
510 .await;
511
512 let session = builder_with_required_fields()
513 .login_url(server.uri())
514 .consumer_secret("hunter2")
515 .scope("api")
516 .scope("refresh_token")
517 .token_handler("MyHandler")
518 .build()
519 .unwrap()
520 .exchange()
521 .await
522 .unwrap();
523 assert_eq!(session.id_token.as_deref(), Some("eyJ..."));
524 assert_eq!(session.refresh_token.as_deref(), Some("5Aep861..."));
525 }
526
527 #[tokio::test]
528 async fn public_client_omits_client_secret() {
529 let server = MockServer::start().await;
530 let captured = Arc::new(tokio::sync::Mutex::new(String::new()));
531 let captured_clone = captured.clone();
532
533 Mock::given(method("POST"))
534 .and(path("/services/oauth2/token"))
535 .respond_with(BodyCapturingResponder {
536 captured: captured_clone,
537 response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
538 "access_token": "tok",
539 "instance_url": "https://my-org.my.salesforce.com"
540 })),
541 })
542 .mount(&server)
543 .await;
544
545 builder_with_required_fields()
546 .login_url(server.uri())
547 .build()
548 .unwrap()
549 .exchange()
550 .await
551 .unwrap();
552
553 let body = captured.lock().await;
554 assert!(
555 !body.contains("client_secret"),
556 "public client should not send client_secret, got: {body}"
557 );
558 }
559
560 #[tokio::test]
561 async fn hybrid_grant_type_sets_correct_urn() {
562 let server = MockServer::start().await;
563 Mock::given(method("POST"))
564 .and(path("/services/oauth2/token"))
565 .and(body_string_contains(
566 "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ahybrid-token-exchange",
567 ))
568 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
569 "access_token": "tok",
570 "instance_url": "https://my-org.my.salesforce.com"
571 })))
572 .mount(&server)
573 .await;
574
575 builder_with_required_fields()
576 .login_url(server.uri())
577 .grant_type(TokenExchangeGrantType::HybridTokenExchange)
578 .build()
579 .unwrap()
580 .exchange()
581 .await
582 .unwrap();
583 }
584
585 #[tokio::test]
586 async fn rejected_subject_token_surfaces_oauth_error() {
587 let server = MockServer::start().await;
588 Mock::given(method("POST"))
589 .and(path("/services/oauth2/token"))
590 .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
591 "error": "invalid_grant",
592 "error_description": "subject_token validation failed"
593 })))
594 .mount(&server)
595 .await;
596
597 let err = builder_with_required_fields()
598 .login_url(server.uri())
599 .build()
600 .unwrap()
601 .exchange()
602 .await
603 .unwrap_err();
604 match err {
605 AuthError::OAuth {
606 error,
607 error_description,
608 } => {
609 assert_eq!(error, "invalid_grant");
610 assert!(error_description.is_some());
611 }
612 other => panic!("expected OAuth error, got {other:?}"),
613 }
614 }
615
616 struct BodyCapturingResponder {
617 captured: Arc<tokio::sync::Mutex<String>>,
618 response: ResponseTemplate,
619 }
620
621 impl Respond for BodyCapturingResponder {
622 fn respond(&self, request: &Request) -> ResponseTemplate {
623 let body = String::from_utf8_lossy(&request.body).into_owned();
624 if let Ok(mut guard) = self.captured.try_lock() {
625 *guard = body;
626 }
627 self.response.clone()
628 }
629 }
630}