1use async_trait::async_trait;
2use base64::Engine;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use chrono::{DateTime, Utc};
5use rand::RngCore;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use std::collections::HashMap;
10use std::sync::Arc;
11use url::Url;
12
13use crate::McpSession;
14
15fn default_send_redirect_uri() -> bool {
16 true
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
21#[serde(tag = "type", content = "config")]
22pub enum AuthType {
23 #[serde(rename = "none")]
25 None,
26 #[serde(rename = "oauth2")]
28 OAuth2 {
29 flow_type: OAuth2FlowType,
31 authorization_url: String,
33 token_url: String,
35 refresh_url: Option<String>,
37 scopes: Vec<String>,
39 #[serde(default = "default_send_redirect_uri")]
41 send_redirect_uri: bool,
42 },
43 #[serde(rename = "secret")]
45 Secret {
46 provider: String,
47 #[serde(default)]
48 fields: Vec<SecretFieldSpec>,
49 },
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
54#[serde(rename_all = "snake_case")]
55pub enum OAuth2FlowType {
56 AuthorizationCode,
57 ClientCredentials,
58 Implicit,
59 Password,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct AuthSession {
65 pub access_token: String,
67 pub refresh_token: Option<String>,
69 pub expires_at: Option<DateTime<Utc>>,
71 pub token_type: String,
73 pub scopes: Vec<String>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
79pub struct TokenResponse {
80 pub access_token: String,
81 pub refresh_token: String,
82 pub expires_at: i64,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct AuthSecret {
88 pub secret: String,
90 pub key: String,
92}
93impl Into<McpSession> for AuthSession {
94 fn into(self) -> McpSession {
95 McpSession {
96 token: self.access_token,
97 expiry: self.expires_at.map(|dt| dt.into()),
98 }
99 }
100}
101
102impl Into<McpSession> for AuthSecret {
103 fn into(self) -> McpSession {
104 McpSession {
105 token: self.secret,
106 expiry: None, }
108 }
109}
110
111impl AuthSession {
112 pub fn new(
114 access_token: String,
115 token_type: Option<String>,
116 expires_in: Option<i64>,
117 refresh_token: Option<String>,
118 scopes: Vec<String>,
119 ) -> Self {
120 let now = Utc::now();
121 let expires_at = expires_in.map(|secs| now + chrono::Duration::seconds(secs));
122
123 AuthSession {
124 access_token,
125 refresh_token,
126 expires_at,
127 token_type: token_type.unwrap_or_else(|| "Bearer".to_string()),
128 scopes,
129 }
130 }
131
132 pub fn is_expired(&self, buffer_seconds: i64) -> bool {
134 match &self.expires_at {
135 Some(expires_at) => {
136 let buffer = chrono::Duration::seconds(buffer_seconds);
137 Utc::now() + buffer >= *expires_at
138 }
139 None => false, }
141 }
142
143 pub fn needs_refresh(&self) -> bool {
145 self.is_expired(300) }
147
148 pub fn get_access_token(&self) -> &str {
150 &self.access_token
151 }
152
153 pub fn update_tokens(
155 &mut self,
156 access_token: String,
157 expires_in: Option<i64>,
158 refresh_token: Option<String>,
159 ) {
160 self.access_token = access_token;
161
162 if let Some(secs) = expires_in {
163 self.expires_at = Some(Utc::now() + chrono::Duration::seconds(secs));
164 }
165
166 if let Some(token) = refresh_token {
167 self.refresh_token = Some(token);
168 }
169 }
170}
171
172impl AuthSecret {
173 pub fn new(key: String, secret: String) -> Self {
175 AuthSecret { secret, key }
176 }
177
178 pub fn get_secret(&self) -> &str {
180 &self.secret
181 }
182
183 pub fn get_provider(&self) -> &str {
185 &self.key
186 }
187}
188
189pub trait AuthMetadata: Send + Sync {
191 fn get_auth_entity(&self) -> String;
193
194 fn get_auth_type(&self) -> AuthType;
196
197 fn requires_auth(&self) -> bool {
199 !matches!(self.get_auth_type(), AuthType::None)
200 }
201
202 fn get_auth_config(&self) -> HashMap<String, serde_json::Value> {
204 HashMap::new()
205 }
206}
207
208#[derive(Debug, thiserror::Error)]
210pub enum AuthError {
211 #[error("OAuth2 flow error: {0}")]
212 OAuth2Flow(String),
213
214 #[error("Token expired and refresh failed: {0}")]
215 TokenRefreshFailed(String),
216
217 #[error("Invalid authentication configuration: {0}")]
218 InvalidConfig(String),
219
220 #[error("Authentication required but not configured for entity: {0}")]
221 AuthRequired(String),
222
223 #[error("API key not found for entity: {0}")]
224 ApiKeyNotFound(String),
225
226 #[error("Storage error: {0}")]
227 Storage(#[from] anyhow::Error),
228
229 #[error("Store error: {0}")]
230 StoreError(String),
231
232 #[error("Provider not found: {0}")]
233 ProviderNotFound(String),
234
235 #[error("Server error: {0}")]
236 ServerError(String),
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
241#[serde(tag = "type")]
242pub enum AuthRequirement {
243 #[serde(rename = "oauth2")]
244 OAuth2 {
245 provider: String,
246 #[serde(default)]
247 scopes: Vec<String>,
248 #[serde(
249 rename = "authorizationUrl",
250 default,
251 skip_serializing_if = "Option::is_none"
252 )]
253 authorization_url: Option<String>,
254 #[serde(rename = "tokenUrl", default, skip_serializing_if = "Option::is_none")]
255 token_url: Option<String>,
256 #[serde(
257 rename = "refreshUrl",
258 default,
259 skip_serializing_if = "Option::is_none"
260 )]
261 refresh_url: Option<String>,
262 #[serde(
263 rename = "sendRedirectUri",
264 default,
265 skip_serializing_if = "Option::is_none"
266 )]
267 send_redirect_uri: Option<bool>,
268 },
269 #[serde(rename = "secret")]
270 Secret {
271 provider: String,
272 #[serde(default)]
273 fields: Vec<SecretFieldSpec>,
274 },
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
278pub struct SecretFieldSpec {
279 pub key: String,
280 #[serde(default)]
281 pub label: Option<String>,
282 #[serde(default)]
283 pub description: Option<String>,
284 #[serde(default)]
285 pub optional: bool,
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct OAuth2State {
291 pub state: String,
293 pub provider_name: String,
295 #[serde(default, skip_serializing_if = "Option::is_none")]
297 pub redirect_uri: Option<String>,
298 pub user_id: String,
300 pub scopes: Vec<String>,
302 pub metadata: HashMap<String, serde_json::Value>,
304 pub created_at: DateTime<Utc>,
306}
307
308pub const PKCE_CODE_VERIFIER_KEY: &str = "pkce_code_verifier";
309pub const PKCE_CODE_CHALLENGE_METHOD: &str = "S256";
310const PKCE_RANDOM_BYTES: usize = 32;
311
312pub fn generate_pkce_pair() -> (String, String) {
313 let mut random = vec![0u8; PKCE_RANDOM_BYTES];
314 rand::thread_rng().fill_bytes(&mut random);
315 let verifier = URL_SAFE_NO_PAD.encode(&random);
316 let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
317 (verifier, challenge)
318}
319
320pub fn append_pkce_challenge(auth_url: &str, challenge: &str) -> Result<String, AuthError> {
321 let mut url = Url::parse(auth_url)
322 .map_err(|e| AuthError::InvalidConfig(format!("Invalid authorization URL: {}", e)))?;
323 {
324 let mut pairs = url.query_pairs_mut();
325 pairs.append_pair("code_challenge", challenge);
326 pairs.append_pair("code_challenge_method", PKCE_CODE_CHALLENGE_METHOD);
327 }
328 Ok(url.to_string())
329}
330
331impl OAuth2State {
332 pub fn new_with_state(
334 state: String,
335 provider_name: String,
336 redirect_uri: Option<String>,
337 user_id: String,
338 scopes: Vec<String>,
339 ) -> Self {
340 Self {
341 state,
342 provider_name,
343 redirect_uri,
344 user_id,
345 scopes,
346 metadata: HashMap::new(),
347 created_at: Utc::now(),
348 }
349 }
350
351 pub fn new(
353 provider_name: String,
354 redirect_uri: Option<String>,
355 user_id: String,
356 scopes: Vec<String>,
357 ) -> Self {
358 Self::new_with_state(
359 uuid::Uuid::new_v4().to_string(),
360 provider_name,
361 redirect_uri,
362 user_id,
363 scopes,
364 )
365 }
366
367 pub fn is_expired(&self, max_age_seconds: i64) -> bool {
369 let max_age = chrono::Duration::seconds(max_age_seconds);
370 Utc::now() - self.created_at > max_age
371 }
372}
373
374#[async_trait]
377pub trait ToolAuthStore: Send + Sync {
378 async fn get_session(
382 &self,
383 auth_entity: &str,
384 user_id: &str,
385 ) -> Result<Option<AuthSession>, AuthError>;
386
387 async fn store_session(
389 &self,
390 auth_entity: &str,
391 user_id: &str,
392 session: AuthSession,
393 ) -> Result<(), AuthError>;
394
395 async fn remove_session(&self, auth_entity: &str, user_id: &str) -> Result<bool, AuthError>;
397
398 async fn store_secret(
402 &self,
403 user_id: &str,
404 auth_entity: Option<&str>, secret: AuthSecret,
406 ) -> Result<(), AuthError>;
407
408 async fn get_secret(
410 &self,
411 user_id: &str,
412 auth_entity: Option<&str>, key: &str,
414 ) -> Result<Option<AuthSecret>, AuthError>;
415
416 async fn remove_secret(
418 &self,
419 user_id: &str,
420 auth_entity: Option<&str>, key: &str,
422 ) -> Result<bool, AuthError>;
423
424 async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError>;
428
429 async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError>;
431
432 async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError>;
434
435 async fn list_secrets(&self, user_id: &str) -> Result<HashMap<String, AuthSecret>, AuthError>;
436
437 async fn list_sessions(
438 &self,
439 _user_id: &str,
440 ) -> Result<HashMap<String, AuthSession>, AuthError>;
441}
442
443#[derive(Clone)]
445pub struct OAuthHandler {
446 store: Arc<dyn ToolAuthStore>,
447 provider_registry: Option<Arc<dyn ProviderRegistry>>,
448 redirect_uri: String,
449}
450
451#[async_trait]
453pub trait ProviderRegistry: Send + Sync {
454 async fn get_provider(&self, provider_name: &str) -> Option<Arc<dyn AuthProvider>>;
455 async fn get_auth_type(&self, provider_name: &str) -> Option<AuthType>;
456 async fn is_provider_available(&self, provider_name: &str) -> bool;
457 async fn list_providers(&self) -> Vec<String>;
458 async fn requires_pkce(&self, _provider_name: &str) -> bool {
459 false
460 }
461}
462
463impl OAuthHandler {
464 pub fn new(store: Arc<dyn ToolAuthStore>, redirect_uri: String) -> Self {
465 Self {
466 store,
467 provider_registry: None,
468 redirect_uri,
469 }
470 }
471
472 pub fn with_provider_registry(
473 store: Arc<dyn ToolAuthStore>,
474 provider_registry: Arc<dyn ProviderRegistry>,
475 redirect_uri: String,
476 ) -> Self {
477 Self {
478 store,
479 provider_registry: Some(provider_registry),
480 redirect_uri,
481 }
482 }
483
484 pub async fn get_auth_url(
486 &self,
487 auth_entity: &str,
488 user_id: &str,
489 auth_config: &AuthType,
490 scopes: &[String],
491 ) -> Result<String, AuthError> {
492 tracing::debug!(
493 "Getting auth URL for entity: {} user: {:?}",
494 auth_entity,
495 user_id
496 );
497
498 match auth_config {
499 AuthType::OAuth2 {
500 flow_type: OAuth2FlowType::ClientCredentials,
501 ..
502 } => Err(AuthError::InvalidConfig(
503 "Client credentials flow doesn't require authorization URL".to_string(),
504 )),
505 auth_config @ AuthType::OAuth2 {
506 send_redirect_uri, ..
507 } => {
508 let redirect_uri = if *send_redirect_uri {
510 Some(self.redirect_uri.clone())
511 } else {
512 None
513 };
514 let mut state = OAuth2State::new(
515 auth_entity.to_string(),
516 redirect_uri.clone(),
517 user_id.to_string(),
518 scopes.to_vec(),
519 );
520
521 let mut pkce_challenge = None;
522 if let Some(registry) = &self.provider_registry {
523 if registry.requires_pkce(auth_entity).await {
524 let (verifier, challenge) = generate_pkce_pair();
525 state.metadata.insert(
526 PKCE_CODE_VERIFIER_KEY.to_string(),
527 serde_json::Value::String(verifier.clone()),
528 );
529 pkce_challenge = Some(challenge);
530 }
531 }
532
533 self.store.store_oauth2_state(state.clone()).await?;
535
536 let provider = self.get_provider(auth_entity).await?;
538
539 let mut auth_url = provider.build_auth_url(
541 auth_config,
542 &state.state,
543 scopes,
544 redirect_uri.as_deref(),
545 )?;
546
547 if let Some(challenge) = pkce_challenge {
548 auth_url = append_pkce_challenge(&auth_url, &challenge)?;
549 }
550
551 tracing::debug!("Generated auth URL: {}", auth_url);
552 Ok(auth_url)
553 }
554 AuthType::Secret { .. } => Err(AuthError::InvalidConfig(
555 "Secret authentication doesn't require authorization URL".to_string(),
556 )),
557 AuthType::None => Err(AuthError::InvalidConfig(
558 "No authentication doesn't require authorization URL".to_string(),
559 )),
560 }
561 }
562
563 pub async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthSession, AuthError> {
565 tracing::debug!("Handling OAuth2 callback with state: {}", state);
566
567 let oauth2_state = self.store.get_oauth2_state(state).await?.ok_or_else(|| {
569 AuthError::OAuth2Flow("Invalid or expired state parameter".to_string())
570 })?;
571
572 self.store.remove_oauth2_state(state).await?;
574
575 if oauth2_state.is_expired(600) {
577 return Err(AuthError::OAuth2Flow(
578 "OAuth2 state has expired".to_string(),
579 ));
580 }
581
582 let auth_config = if let Some(registry) = &self.provider_registry {
584 registry
585 .get_auth_type(&oauth2_state.provider_name)
586 .await
587 .ok_or_else(|| {
588 AuthError::InvalidConfig(format!(
589 "No configuration found for provider: {}",
590 oauth2_state.provider_name
591 ))
592 })?
593 } else {
594 return Err(AuthError::InvalidConfig(
595 "No provider registry configured".to_string(),
596 ));
597 };
598
599 let provider = self.get_provider(&oauth2_state.provider_name).await?;
601
602 let redirect_uri = match &auth_config {
604 AuthType::OAuth2 {
605 send_redirect_uri, ..
606 } if *send_redirect_uri => oauth2_state
607 .redirect_uri
608 .clone()
609 .or_else(|| Some(self.redirect_uri.clone())),
610 AuthType::OAuth2 { .. } => None,
611 _ => None,
612 };
613 let pkce_code_verifier = oauth2_state
614 .metadata
615 .get(PKCE_CODE_VERIFIER_KEY)
616 .and_then(|v| v.as_str());
617
618 let session = provider
619 .exchange_code(
620 code,
621 redirect_uri.as_deref(),
622 &auth_config,
623 pkce_code_verifier,
624 )
625 .await?;
626
627 self.store
629 .store_session(
630 &oauth2_state.provider_name,
631 &oauth2_state.user_id,
632 session.clone(),
633 )
634 .await?;
635
636 tracing::debug!(
637 "Successfully stored auth session for entity: {}",
638 oauth2_state.provider_name
639 );
640 Ok(session)
641 }
642
643 pub async fn refresh_session(
645 &self,
646 auth_entity: &str,
647 user_id: &str,
648 auth_config: &AuthType,
649 ) -> Result<AuthSession, AuthError> {
650 tracing::debug!(
651 "Refreshing session for entity: {} user: {:?}",
652 auth_entity,
653 user_id
654 );
655
656 let current_session = self
658 .store
659 .get_session(auth_entity, &user_id)
660 .await?
661 .ok_or_else(|| {
662 AuthError::TokenRefreshFailed("No session found to refresh".to_string())
663 })?;
664
665 let refresh_token = current_session.refresh_token.ok_or_else(|| {
666 AuthError::TokenRefreshFailed("No refresh token available".to_string())
667 })?;
668
669 match auth_config {
670 AuthType::OAuth2 {
671 flow_type: OAuth2FlowType::ClientCredentials,
672 ..
673 } => {
674 let provider = self.get_provider(auth_entity).await?;
676 let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
677
678 self.store
680 .store_session(auth_entity, &user_id, new_session.clone())
681 .await?;
682 Ok(new_session)
683 }
684 auth_config @ AuthType::OAuth2 { .. } => {
685 let provider = self.get_provider(auth_entity).await?;
687
688 let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
690
691 self.store
693 .store_session(auth_entity, &user_id, new_session.clone())
694 .await?;
695 Ok(new_session)
696 }
697 _ => Err(AuthError::InvalidConfig(
698 "Cannot refresh non-OAuth2 session".to_string(),
699 )),
700 }
701 }
702
703 pub async fn refresh_get_session(
705 &self,
706 auth_entity: &str,
707 user_id: &str,
708 auth_config: &AuthType,
709 ) -> Result<Option<AuthSession>, AuthError> {
710 match self.store.get_session(auth_entity, user_id).await? {
711 Some(session) => {
712 if session.needs_refresh() {
713 tracing::debug!(
714 "Session expired for {}:{:?}, attempting refresh",
715 auth_entity,
716 user_id
717 );
718 match self
719 .refresh_session(auth_entity, user_id, auth_config)
720 .await
721 {
722 Ok(refreshed_session) => {
723 tracing::info!(
724 "Successfully refreshed session for {}:{:?}",
725 auth_entity,
726 user_id
727 );
728 Ok(Some(refreshed_session))
729 }
730 Err(e) => {
731 tracing::warn!(
732 "Failed to refresh session for {}:{:?}: {}",
733 auth_entity,
734 user_id,
735 e
736 );
737 Err(e)
738 }
739 }
740 } else {
741 Ok(Some(session))
742 }
743 }
744 None => Ok(None),
745 }
746 }
747
748 async fn get_provider(&self, provider_name: &str) -> Result<Arc<dyn AuthProvider>, AuthError> {
749 if let Some(registry) = &self.provider_registry {
750 registry
751 .get_provider(provider_name)
752 .await
753 .ok_or_else(|| AuthError::ProviderNotFound(provider_name.to_string()))
754 } else {
755 Err(AuthError::InvalidConfig(
756 "No provider registry configured".to_string(),
757 ))
758 }
759 }
760
761 pub async fn get_session(
763 &self,
764 auth_entity: &str,
765 user_id: &str,
766 ) -> Result<Option<AuthSession>, AuthError> {
767 self.store.get_session(auth_entity, user_id).await
768 }
769
770 pub async fn store_session(
771 &self,
772 auth_entity: &str,
773 user_id: &str,
774 session: AuthSession,
775 ) -> Result<(), AuthError> {
776 self.store
777 .store_session(auth_entity, user_id, session)
778 .await
779 }
780
781 pub async fn remove_session(
782 &self,
783 auth_entity: &str,
784 user_id: &str,
785 ) -> Result<bool, AuthError> {
786 self.store.remove_session(auth_entity, user_id).await
787 }
788
789 pub async fn store_secret(
790 &self,
791 user_id: &str,
792 auth_entity: Option<&str>,
793 secret: AuthSecret,
794 ) -> Result<(), AuthError> {
795 self.store.store_secret(user_id, auth_entity, secret).await
796 }
797
798 pub async fn get_secret(
799 &self,
800 user_id: &str,
801 auth_entity: Option<&str>,
802 key: &str,
803 ) -> Result<Option<AuthSecret>, AuthError> {
804 self.store.get_secret(user_id, auth_entity, key).await
805 }
806
807 pub async fn remove_secret(
808 &self,
809 user_id: &str,
810 auth_entity: Option<&str>,
811 key: &str,
812 ) -> Result<bool, AuthError> {
813 self.store.remove_secret(user_id, auth_entity, key).await
814 }
815
816 pub async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError> {
817 self.store.store_oauth2_state(state).await
818 }
819
820 pub async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError> {
821 self.store.get_oauth2_state(state).await
822 }
823
824 pub async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError> {
825 self.store.remove_oauth2_state(state).await
826 }
827
828 pub async fn list_secrets(
829 &self,
830 user_id: &str,
831 ) -> Result<HashMap<String, AuthSecret>, AuthError> {
832 self.store.list_secrets(user_id).await
833 }
834
835 pub async fn list_sessions(
836 &self,
837 user_id: &str,
838 ) -> Result<HashMap<String, AuthSession>, AuthError> {
839 self.store.list_sessions(user_id).await
840 }
841}
842
843#[async_trait]
845pub trait AuthProvider: Send + Sync {
846 fn provider_name(&self) -> &str;
848
849 async fn exchange_code(
851 &self,
852 code: &str,
853 redirect_uri: Option<&str>,
854 auth_config: &AuthType,
855 pkce_code_verifier: Option<&str>,
856 ) -> Result<AuthSession, AuthError>;
857
858 async fn refresh_token(
860 &self,
861 refresh_token: &str,
862 auth_config: &AuthType,
863 ) -> Result<AuthSession, AuthError>;
864
865 fn build_auth_url(
867 &self,
868 auth_config: &AuthType,
869 state: &str,
870 scopes: &[String],
871 redirect_uri: Option<&str>,
872 ) -> Result<String, AuthError>;
873}