1use serde::{Deserialize, Serialize};
12use tracing::instrument;
13
14use crate::credentials::SalesforceCredentials;
15use crate::error::{Error, ErrorKind, Result};
16
17#[derive(Clone)]
22pub struct OAuthConfig {
23 pub consumer_key: String,
25 consumer_secret: Option<String>,
27 pub redirect_uri: Option<String>,
29 pub scopes: Vec<String>,
31}
32
33impl std::fmt::Debug for OAuthConfig {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("OAuthConfig")
36 .field("consumer_key", &self.consumer_key)
37 .field("consumer_secret", &"[REDACTED]")
38 .field("redirect_uri", &self.redirect_uri)
39 .field("scopes", &self.scopes)
40 .finish()
41 }
42}
43
44impl OAuthConfig {
45 pub fn new(consumer_key: impl Into<String>) -> Self {
47 Self {
48 consumer_key: consumer_key.into(),
49 consumer_secret: None,
50 redirect_uri: None,
51 scopes: vec!["api".to_string(), "refresh_token".to_string()],
52 }
53 }
54
55 pub fn with_secret(mut self, secret: impl Into<String>) -> Self {
57 self.consumer_secret = Some(secret.into());
58 self
59 }
60
61 #[allow(dead_code)]
63 pub(crate) fn consumer_secret(&self) -> Option<&str> {
64 self.consumer_secret.as_deref()
65 }
66
67 pub fn with_redirect_uri(mut self, uri: impl Into<String>) -> Self {
69 self.redirect_uri = Some(uri.into());
70 self
71 }
72
73 pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
75 self.scopes = scopes;
76 self
77 }
78}
79
80#[derive(Clone)]
82pub struct OAuthClient {
83 config: OAuthConfig,
84 http_client: reqwest::Client,
85}
86
87impl std::fmt::Debug for OAuthClient {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("OAuthClient")
90 .field("config", &self.config)
91 .finish_non_exhaustive()
92 }
93}
94
95impl OAuthClient {
96 pub fn new(config: OAuthConfig) -> Self {
98 Self {
99 config,
100 http_client: reqwest::Client::new(),
101 }
102 }
103
104 pub fn config(&self) -> &OAuthConfig {
106 &self.config
107 }
108
109 #[instrument(skip(self, refresh_token))]
113 pub async fn refresh_token(
114 &self,
115 refresh_token: &str,
116 login_url: &str,
117 ) -> Result<TokenResponse> {
118 let mut params = vec![
119 ("grant_type", "refresh_token"),
120 ("refresh_token", refresh_token),
121 ("client_id", &self.config.consumer_key),
122 ];
123
124 if let Some(ref secret) = self.config.consumer_secret {
125 params.push(("client_secret", secret));
126 }
127
128 let body = serde_urlencoded::to_string(params)?;
129
130 let response = self
131 .http_client
132 .post(format!("{}/services/oauth2/token", login_url))
133 .header("Content-Type", "application/x-www-form-urlencoded")
134 .body(body)
135 .send()
136 .await?;
137
138 self.handle_token_response(response).await
139 }
140
141 #[instrument(skip(self, token))]
146 pub async fn validate_token(&self, token: &str, login_url: &str) -> Result<TokenInfo> {
147 let form_data = [("access_token", token)];
150 let body = serde_urlencoded::to_string(form_data)?;
151
152 let response = self
153 .http_client
154 .post(format!("{}/services/oauth2/tokeninfo", login_url))
155 .header("Content-Type", "application/x-www-form-urlencoded")
156 .body(body)
157 .send()
158 .await?;
159
160 if !response.status().is_success() {
161 return Err(Error::new(ErrorKind::TokenInvalid(
162 "Token validation failed".to_string(),
163 )));
164 }
165
166 let info: TokenInfo = response.json().await?;
167 Ok(info)
168 }
169
170 #[instrument(skip(self, token))]
205 pub async fn revoke_token(&self, token: &str, login_url: &str) -> Result<()> {
206 let form_data = [("token", token)];
207 let body = serde_urlencoded::to_string(form_data)?;
208
209 let response = self
210 .http_client
211 .post(format!("{}/services/oauth2/revoke", login_url))
212 .header("Content-Type", "application/x-www-form-urlencoded")
213 .body(body)
214 .send()
215 .await?;
216
217 if !response.status().is_success() {
218 let status = response.status();
220 let body = response.text().await.unwrap_or_default();
221 if let Ok(error) = serde_json::from_str::<OAuthErrorResponse>(&body) {
222 return Err(Error::new(ErrorKind::OAuth {
223 error: error.error,
224 description: error.error_description,
225 }));
226 }
227 return Err(Error::new(ErrorKind::Http(format!(
228 "Token revocation failed with status {status}"
229 ))));
230 }
231
232 Ok(())
233 }
234
235 async fn handle_token_response(&self, response: reqwest::Response) -> Result<TokenResponse> {
237 if !response.status().is_success() {
238 let error: OAuthErrorResponse = response.json().await?;
239 return Err(Error::new(ErrorKind::OAuth {
240 error: error.error,
241 description: error.error_description,
242 }));
243 }
244
245 let token: TokenResponse = response.json().await?;
246 Ok(token)
247 }
248}
249
250#[derive(Clone)]
252pub struct WebFlowAuth {
253 config: OAuthConfig,
254 http_client: reqwest::Client,
255}
256
257impl std::fmt::Debug for WebFlowAuth {
258 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259 f.debug_struct("WebFlowAuth")
260 .field("config", &self.config)
261 .finish_non_exhaustive()
262 }
263}
264
265impl WebFlowAuth {
266 pub fn new(config: OAuthConfig) -> Result<Self> {
268 if config.redirect_uri.is_none() {
269 return Err(Error::new(ErrorKind::Config(
270 "redirect_uri is required for web flow".to_string(),
271 )));
272 }
273
274 Ok(Self {
275 config,
276 http_client: reqwest::Client::new(),
277 })
278 }
279
280 pub fn authorization_url(&self, login_url: &str, state: Option<&str>) -> String {
282 let redirect_uri = self.config.redirect_uri.as_ref().unwrap();
283 let scopes = self.config.scopes.join(" ");
284
285 let mut url = format!(
286 "{}/services/oauth2/authorize?response_type=code&client_id={}&redirect_uri={}",
287 login_url,
288 urlencoding::encode(&self.config.consumer_key),
289 urlencoding::encode(redirect_uri),
290 );
291
292 if !scopes.is_empty() {
293 url.push_str(&format!("&scope={}", urlencoding::encode(&scopes)));
294 }
295
296 if let Some(state) = state {
297 url.push_str(&format!("&state={}", urlencoding::encode(state)));
298 }
299
300 url
301 }
302
303 #[instrument(skip(self, code))]
307 pub async fn exchange_code(&self, code: &str, login_url: &str) -> Result<TokenResponse> {
308 let redirect_uri = self.config.redirect_uri.as_ref().unwrap();
309
310 let mut params = vec![
311 ("grant_type", "authorization_code"),
312 ("code", code),
313 ("client_id", &self.config.consumer_key),
314 ("redirect_uri", redirect_uri),
315 ];
316
317 if let Some(ref secret) = self.config.consumer_secret {
318 params.push(("client_secret", secret));
319 }
320
321 let body = serde_urlencoded::to_string(params)?;
322
323 let response = self
324 .http_client
325 .post(format!("{}/services/oauth2/token", login_url))
326 .header("Content-Type", "application/x-www-form-urlencoded")
327 .body(body)
328 .send()
329 .await?;
330
331 if !response.status().is_success() {
332 let error: OAuthErrorResponse = response.json().await?;
333 return Err(Error::new(ErrorKind::OAuth {
334 error: error.error,
335 description: error.error_description,
336 }));
337 }
338
339 let token: TokenResponse = response.json().await?;
340 Ok(token)
341 }
342}
343
344#[derive(Clone, Deserialize, Serialize)]
349pub struct TokenResponse {
350 pub access_token: String,
352 #[serde(default)]
354 pub refresh_token: Option<String>,
355 pub instance_url: String,
357 #[serde(default)]
359 pub id: Option<String>,
360 #[serde(default)]
362 pub token_type: Option<String>,
363 #[serde(default)]
365 pub scope: Option<String>,
366 #[serde(default)]
368 pub signature: Option<String>,
369 #[serde(default)]
371 pub issued_at: Option<String>,
372}
373
374impl std::fmt::Debug for TokenResponse {
375 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376 f.debug_struct("TokenResponse")
377 .field("access_token", &"[REDACTED]")
378 .field(
379 "refresh_token",
380 &self.refresh_token.as_ref().map(|_| "[REDACTED]"),
381 )
382 .field("instance_url", &self.instance_url)
383 .field("id", &self.id)
384 .field("token_type", &self.token_type)
385 .field("scope", &self.scope)
386 .field("signature", &self.signature.as_ref().map(|_| "[REDACTED]"))
387 .field("issued_at", &self.issued_at)
388 .finish()
389 }
390}
391
392impl TokenResponse {
393 pub fn to_credentials(&self, api_version: &str) -> SalesforceCredentials {
395 let mut creds =
396 SalesforceCredentials::new(&self.instance_url, &self.access_token, api_version);
397
398 if let Some(ref rt) = self.refresh_token {
399 creds = creds.with_refresh_token(rt);
400 }
401
402 creds
403 }
404}
405
406#[derive(Debug, Clone, Deserialize)]
408pub struct TokenInfo {
409 pub active: bool,
411 #[serde(default)]
413 pub scope: Option<String>,
414 #[serde(default)]
416 pub client_id: Option<String>,
417 #[serde(default)]
419 pub username: Option<String>,
420 #[serde(default)]
422 pub token_type: Option<String>,
423 #[serde(default)]
425 pub exp: Option<u64>,
426 #[serde(default)]
428 pub iat: Option<u64>,
429 #[serde(default)]
431 pub sub: Option<String>,
432}
433
434#[derive(Debug, Deserialize)]
436struct OAuthErrorResponse {
437 error: String,
438 error_description: String,
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use crate::credentials::Credentials;
445
446 #[test]
447 fn test_oauth_config() {
448 let config = OAuthConfig::new("consumer_key")
449 .with_secret("secret")
450 .with_redirect_uri("https://example.com/callback")
451 .with_scopes(vec!["api".to_string(), "web".to_string()]);
452
453 assert_eq!(config.consumer_key, "consumer_key");
454 assert_eq!(config.consumer_secret(), Some("secret"));
455 assert_eq!(
456 config.redirect_uri,
457 Some("https://example.com/callback".to_string())
458 );
459 assert_eq!(config.scopes, vec!["api", "web"]);
460 }
461
462 #[test]
463 fn test_oauth_config_debug_redacts_secret() {
464 let config = OAuthConfig::new("consumer_key").with_secret("super_secret_value");
465
466 let debug_output = format!("{:?}", config);
467 assert!(debug_output.contains("[REDACTED]"));
468 assert!(!debug_output.contains("super_secret_value"));
469 }
470
471 #[test]
472 fn test_web_flow_auth_url() {
473 let config = OAuthConfig::new("my_client_id")
474 .with_redirect_uri("https://localhost:8080/callback")
475 .with_scopes(vec!["api".to_string()]);
476
477 let auth = WebFlowAuth::new(config).unwrap();
478 let url = auth.authorization_url("https://login.salesforce.com", Some("state123"));
479
480 assert!(url.contains("response_type=code"));
481 assert!(url.contains("client_id=my_client_id"));
482 assert!(url.contains("redirect_uri="));
483 assert!(url.contains("state=state123"));
484 }
485
486 #[test]
487 fn test_token_response_to_credentials() {
488 let token = TokenResponse {
489 access_token: "access123".to_string(),
490 refresh_token: Some("refresh456".to_string()),
491 instance_url: "https://na1.salesforce.com".to_string(),
492 id: None,
493 token_type: Some("Bearer".to_string()),
494 scope: None,
495 signature: None,
496 issued_at: None,
497 };
498
499 let creds = token.to_credentials("62.0");
500 assert_eq!(creds.instance_url(), "https://na1.salesforce.com");
501 assert_eq!(creds.access_token(), "access123");
502 assert_eq!(creds.refresh_token(), Some("refresh456"));
503 }
504
505 #[test]
506 fn test_token_response_debug_redacts_tokens() {
507 let token = TokenResponse {
508 access_token: "super_secret_access_token".to_string(),
509 refresh_token: Some("super_secret_refresh_token".to_string()),
510 instance_url: "https://na1.salesforce.com".to_string(),
511 id: None,
512 token_type: Some("Bearer".to_string()),
513 scope: None,
514 signature: Some("signature_value".to_string()),
515 issued_at: None,
516 };
517
518 let debug_output = format!("{:?}", token);
519 assert!(debug_output.contains("[REDACTED]"));
520 assert!(!debug_output.contains("super_secret_access_token"));
521 assert!(!debug_output.contains("super_secret_refresh_token"));
522 assert!(!debug_output.contains("signature_value"));
523 }
524
525 #[tokio::test]
526 async fn test_revoke_token_success() {
527 use wiremock::matchers::{body_string_contains, header, method, path};
528 use wiremock::{Mock, MockServer, ResponseTemplate};
529
530 let mock_server = MockServer::start().await;
531
532 Mock::given(method("POST"))
534 .and(path("/services/oauth2/revoke"))
535 .and(header("Content-Type", "application/x-www-form-urlencoded"))
536 .and(body_string_contains("token=test_token_to_revoke"))
537 .respond_with(ResponseTemplate::new(200))
538 .mount(&mock_server)
539 .await;
540
541 let config = OAuthConfig::new("test_client_id");
542 let client = OAuthClient::new(config);
543
544 let result = client
545 .revoke_token("test_token_to_revoke", &mock_server.uri())
546 .await;
547
548 assert!(result.is_ok(), "Token revocation should succeed");
549 }
550
551 #[tokio::test]
552 async fn test_revoke_token_idempotency() {
553 use wiremock::matchers::{method, path};
554 use wiremock::{Mock, MockServer, ResponseTemplate};
555
556 let mock_server = MockServer::start().await;
557
558 Mock::given(method("POST"))
560 .and(path("/services/oauth2/revoke"))
561 .respond_with(ResponseTemplate::new(200))
562 .mount(&mock_server)
563 .await;
564
565 let config = OAuthConfig::new("test_client_id");
566 let client = OAuthClient::new(config);
567
568 let result1 = client
570 .revoke_token("already_invalid_token", &mock_server.uri())
571 .await;
572 assert!(result1.is_ok(), "First revocation should succeed");
573
574 let result2 = client
576 .revoke_token("already_invalid_token", &mock_server.uri())
577 .await;
578 assert!(
579 result2.is_ok(),
580 "Second revocation should also succeed (idempotent)"
581 );
582 }
583
584 #[tokio::test]
585 async fn test_revoke_token_failure() {
586 use wiremock::matchers::{method, path};
587 use wiremock::{Mock, MockServer, ResponseTemplate};
588
589 let mock_server = MockServer::start().await;
590
591 Mock::given(method("POST"))
593 .and(path("/services/oauth2/revoke"))
594 .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
595 "error": "invalid_request",
596 "error_description": "Token parameter is missing"
597 })))
598 .mount(&mock_server)
599 .await;
600
601 let config = OAuthConfig::new("test_client_id");
602 let client = OAuthClient::new(config);
603
604 let result = client
605 .revoke_token("malformed_token", &mock_server.uri())
606 .await;
607
608 assert!(result.is_err(), "Token revocation should fail");
609 let err = result.unwrap_err();
610 assert!(
611 matches!(err.kind, ErrorKind::OAuth { .. }),
612 "Should return OAuth error"
613 );
614 }
615
616 #[tokio::test]
617 async fn test_revoke_token_non_json_error() {
618 use wiremock::matchers::{method, path};
619 use wiremock::{Mock, MockServer, ResponseTemplate};
620
621 let mock_server = MockServer::start().await;
622
623 Mock::given(method("POST"))
625 .and(path("/services/oauth2/revoke"))
626 .respond_with(ResponseTemplate::new(400).set_body_string("<html>Bad Request</html>"))
627 .mount(&mock_server)
628 .await;
629
630 let config = OAuthConfig::new("test_client_id");
631 let client = OAuthClient::new(config);
632
633 let result = client.revoke_token("some_token", &mock_server.uri()).await;
634
635 assert!(result.is_err(), "Should fail with non-JSON error body");
636 let err = result.unwrap_err();
637 assert!(
638 matches!(err.kind, ErrorKind::Http(_)),
639 "Should return Http error, got: {:?}",
640 err.kind
641 );
642 assert!(
643 err.to_string().contains("revocation failed"),
644 "Error should mention revocation failed"
645 );
646 }
647}