1use async_trait::async_trait;
2use base64::Engine;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use chrono::{DateTime, Utc};
5use rand::RngCore;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use std::collections::HashMap;
10use std::sync::Arc;
11use url::Url;
12
13use crate::McpSession;
14
15fn default_send_redirect_uri() -> bool {
16 true
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
21#[serde(tag = "type", content = "config")]
22pub enum AuthType {
23 #[serde(rename = "none")]
25 None,
26 #[serde(rename = "oauth2")]
28 OAuth2 {
29 flow_type: OAuth2FlowType,
31 authorization_url: String,
33 token_url: String,
35 refresh_url: Option<String>,
37 scopes: Vec<String>,
39 #[serde(default = "default_send_redirect_uri")]
41 send_redirect_uri: bool,
42 },
43 #[serde(rename = "secret")]
45 Secret {
46 provider: String,
47 #[serde(default)]
48 fields: Vec<SecretFieldSpec>,
49 },
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
54#[serde(rename_all = "snake_case")]
55pub enum OAuth2FlowType {
56 AuthorizationCode,
57 ClientCredentials,
58 Implicit,
59 Password,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct AuthSession {
65 pub access_token: String,
67 pub refresh_token: Option<String>,
69 pub expires_at: Option<DateTime<Utc>>,
71 pub token_type: String,
73 pub scopes: Vec<String>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
79pub struct TokenLimits {
80 #[serde(skip_serializing_if = "Option::is_none")]
82 pub daily_tokens: Option<u64>,
83
84 #[serde(skip_serializing_if = "Option::is_none")]
86 pub monthly_tokens: Option<u64>,
87
88 #[serde(skip_serializing_if = "Option::is_none")]
90 pub daily_calls: Option<u64>,
91
92 #[serde(skip_serializing_if = "Option::is_none")]
94 pub weekly_calls: Option<u64>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
99pub struct TokenResponse {
100 pub access_token: String,
101 pub refresh_token: String,
102 pub expires_at: i64,
103 #[serde(skip_serializing_if = "Option::is_none")]
105 pub identifier_id: Option<String>,
106 #[serde(skip_serializing_if = "Option::is_none")]
108 pub limits: Option<TokenLimits>,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct AuthSecret {
114 pub secret: String,
116 pub key: String,
118}
119impl Into<McpSession> for AuthSession {
120 fn into(self) -> McpSession {
121 McpSession {
122 token: self.access_token,
123 expiry: self.expires_at.map(|dt| dt.into()),
124 }
125 }
126}
127
128impl Into<McpSession> for AuthSecret {
129 fn into(self) -> McpSession {
130 McpSession {
131 token: self.secret,
132 expiry: None, }
134 }
135}
136
137impl AuthSession {
138 pub fn new(
140 access_token: String,
141 token_type: Option<String>,
142 expires_in: Option<i64>,
143 refresh_token: Option<String>,
144 scopes: Vec<String>,
145 ) -> Self {
146 let now = Utc::now();
147 let expires_at = expires_in.map(|secs| now + chrono::Duration::seconds(secs));
148
149 AuthSession {
150 access_token,
151 refresh_token,
152 expires_at,
153 token_type: token_type.unwrap_or_else(|| "Bearer".to_string()),
154 scopes,
155 }
156 }
157
158 pub fn is_expired(&self, buffer_seconds: i64) -> bool {
160 match &self.expires_at {
161 Some(expires_at) => {
162 let buffer = chrono::Duration::seconds(buffer_seconds);
163 Utc::now() + buffer >= *expires_at
164 }
165 None => false, }
167 }
168
169 pub fn needs_refresh(&self) -> bool {
171 self.is_expired(300) }
173
174 pub fn get_access_token(&self) -> &str {
176 &self.access_token
177 }
178
179 pub fn update_tokens(
181 &mut self,
182 access_token: String,
183 expires_in: Option<i64>,
184 refresh_token: Option<String>,
185 ) {
186 self.access_token = access_token;
187
188 if let Some(secs) = expires_in {
189 self.expires_at = Some(Utc::now() + chrono::Duration::seconds(secs));
190 }
191
192 if let Some(token) = refresh_token {
193 self.refresh_token = Some(token);
194 }
195 }
196}
197
198impl AuthSecret {
199 pub fn new(key: String, secret: String) -> Self {
201 AuthSecret { secret, key }
202 }
203
204 pub fn get_secret(&self) -> &str {
206 &self.secret
207 }
208
209 pub fn get_provider(&self) -> &str {
211 &self.key
212 }
213}
214
215pub trait AuthMetadata: Send + Sync {
217 fn get_auth_entity(&self) -> String;
219
220 fn get_auth_type(&self) -> AuthType;
222
223 fn requires_auth(&self) -> bool {
225 !matches!(self.get_auth_type(), AuthType::None)
226 }
227
228 fn get_auth_config(&self) -> HashMap<String, serde_json::Value> {
230 HashMap::new()
231 }
232}
233
234#[derive(Debug, thiserror::Error)]
236pub enum AuthError {
237 #[error("OAuth2 flow error: {0}")]
238 OAuth2Flow(String),
239
240 #[error("Token expired and refresh failed: {0}")]
241 TokenRefreshFailed(String),
242
243 #[error("Invalid authentication configuration: {0}")]
244 InvalidConfig(String),
245
246 #[error("Authentication required but not configured for entity: {0}")]
247 AuthRequired(String),
248
249 #[error("API key not found for entity: {0}")]
250 ApiKeyNotFound(String),
251
252 #[error("Storage error: {0}")]
253 Storage(#[from] anyhow::Error),
254
255 #[error("Store error: {0}")]
256 StoreError(String),
257
258 #[error("Provider not found: {0}")]
259 ProviderNotFound(String),
260
261 #[error("Server error: {0}")]
262 ServerError(String),
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
267#[serde(tag = "type")]
268pub enum AuthRequirement {
269 #[serde(rename = "oauth2")]
270 OAuth2 {
271 provider: String,
272 #[serde(default)]
273 scopes: Vec<String>,
274 #[serde(
275 rename = "authorizationUrl",
276 default,
277 skip_serializing_if = "Option::is_none"
278 )]
279 authorization_url: Option<String>,
280 #[serde(rename = "tokenUrl", default, skip_serializing_if = "Option::is_none")]
281 token_url: Option<String>,
282 #[serde(
283 rename = "refreshUrl",
284 default,
285 skip_serializing_if = "Option::is_none"
286 )]
287 refresh_url: Option<String>,
288 #[serde(
289 rename = "sendRedirectUri",
290 default,
291 skip_serializing_if = "Option::is_none"
292 )]
293 send_redirect_uri: Option<bool>,
294 },
295 #[serde(rename = "secret")]
296 Secret {
297 provider: String,
298 #[serde(default)]
299 fields: Vec<SecretFieldSpec>,
300 },
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
304pub struct SecretFieldSpec {
305 pub key: String,
306 #[serde(default)]
307 pub label: Option<String>,
308 #[serde(default)]
309 pub description: Option<String>,
310 #[serde(default)]
311 pub optional: bool,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
316pub struct OAuth2State {
317 pub state: String,
319 pub provider_name: String,
321 #[serde(default, skip_serializing_if = "Option::is_none")]
323 pub redirect_uri: Option<String>,
324 pub user_id: String,
326 pub scopes: Vec<String>,
328 pub metadata: HashMap<String, serde_json::Value>,
330 pub created_at: DateTime<Utc>,
332}
333
334pub const PKCE_CODE_VERIFIER_KEY: &str = "pkce_code_verifier";
335pub const PKCE_CODE_CHALLENGE_METHOD: &str = "S256";
336const PKCE_RANDOM_BYTES: usize = 32;
337
338pub fn generate_pkce_pair() -> (String, String) {
339 let mut random = vec![0u8; PKCE_RANDOM_BYTES];
340 rand::thread_rng().fill_bytes(&mut random);
341 let verifier = URL_SAFE_NO_PAD.encode(&random);
342 let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
343 (verifier, challenge)
344}
345
346pub fn append_pkce_challenge(auth_url: &str, challenge: &str) -> Result<String, AuthError> {
347 let mut url = Url::parse(auth_url)
348 .map_err(|e| AuthError::InvalidConfig(format!("Invalid authorization URL: {}", e)))?;
349 {
350 let mut pairs = url.query_pairs_mut();
351 pairs.append_pair("code_challenge", challenge);
352 pairs.append_pair("code_challenge_method", PKCE_CODE_CHALLENGE_METHOD);
353 }
354 Ok(url.to_string())
355}
356
357impl OAuth2State {
358 pub fn new_with_state(
360 state: String,
361 provider_name: String,
362 redirect_uri: Option<String>,
363 user_id: String,
364 scopes: Vec<String>,
365 ) -> Self {
366 Self {
367 state,
368 provider_name,
369 redirect_uri,
370 user_id,
371 scopes,
372 metadata: HashMap::new(),
373 created_at: Utc::now(),
374 }
375 }
376
377 pub fn new(
379 provider_name: String,
380 redirect_uri: Option<String>,
381 user_id: String,
382 scopes: Vec<String>,
383 ) -> Self {
384 Self::new_with_state(
385 uuid::Uuid::new_v4().to_string(),
386 provider_name,
387 redirect_uri,
388 user_id,
389 scopes,
390 )
391 }
392
393 pub fn is_expired(&self, max_age_seconds: i64) -> bool {
395 let max_age = chrono::Duration::seconds(max_age_seconds);
396 Utc::now() - self.created_at > max_age
397 }
398}
399
400#[async_trait]
403pub trait ToolAuthStore: Send + Sync {
404 async fn get_session(
408 &self,
409 auth_entity: &str,
410 user_id: &str,
411 ) -> Result<Option<AuthSession>, AuthError>;
412
413 async fn store_session(
415 &self,
416 auth_entity: &str,
417 user_id: &str,
418 session: AuthSession,
419 ) -> Result<(), AuthError>;
420
421 async fn remove_session(&self, auth_entity: &str, user_id: &str) -> Result<bool, AuthError>;
423
424 async fn store_secret(
428 &self,
429 user_id: &str,
430 auth_entity: Option<&str>, secret: AuthSecret,
432 ) -> Result<(), AuthError>;
433
434 async fn get_secret(
436 &self,
437 user_id: &str,
438 auth_entity: Option<&str>, key: &str,
440 ) -> Result<Option<AuthSecret>, AuthError>;
441
442 async fn remove_secret(
444 &self,
445 user_id: &str,
446 auth_entity: Option<&str>, key: &str,
448 ) -> Result<bool, AuthError>;
449
450 async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError>;
454
455 async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError>;
457
458 async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError>;
460
461 async fn list_secrets(&self, user_id: &str) -> Result<HashMap<String, AuthSecret>, AuthError>;
462
463 async fn list_sessions(
464 &self,
465 _user_id: &str,
466 ) -> Result<HashMap<String, AuthSession>, AuthError>;
467}
468
469#[derive(Clone)]
471pub struct OAuthHandler {
472 store: Arc<dyn ToolAuthStore>,
473 provider_registry: Option<Arc<dyn ProviderRegistry>>,
474 redirect_uri: String,
475}
476
477#[async_trait]
479pub trait ProviderRegistry: Send + Sync {
480 async fn get_provider(&self, provider_name: &str) -> Option<Arc<dyn AuthProvider>>;
481 async fn get_auth_type(&self, provider_name: &str) -> Option<AuthType>;
482 async fn is_provider_available(&self, provider_name: &str) -> bool;
483 async fn list_providers(&self) -> Vec<String>;
484 async fn requires_pkce(&self, _provider_name: &str) -> bool {
485 false
486 }
487}
488
489impl OAuthHandler {
490 pub fn new(store: Arc<dyn ToolAuthStore>, redirect_uri: String) -> Self {
491 Self {
492 store,
493 provider_registry: None,
494 redirect_uri,
495 }
496 }
497
498 pub fn with_provider_registry(
499 store: Arc<dyn ToolAuthStore>,
500 provider_registry: Arc<dyn ProviderRegistry>,
501 redirect_uri: String,
502 ) -> Self {
503 Self {
504 store,
505 provider_registry: Some(provider_registry),
506 redirect_uri,
507 }
508 }
509
510 pub async fn get_auth_url(
512 &self,
513 auth_entity: &str,
514 user_id: &str,
515 auth_config: &AuthType,
516 scopes: &[String],
517 ) -> Result<String, AuthError> {
518 tracing::debug!(
519 "Getting auth URL for entity: {} user: {:?}",
520 auth_entity,
521 user_id
522 );
523
524 match auth_config {
525 AuthType::OAuth2 {
526 flow_type: OAuth2FlowType::ClientCredentials,
527 ..
528 } => Err(AuthError::InvalidConfig(
529 "Client credentials flow doesn't require authorization URL".to_string(),
530 )),
531 auth_config @ AuthType::OAuth2 {
532 send_redirect_uri, ..
533 } => {
534 let redirect_uri = if *send_redirect_uri {
536 Some(self.redirect_uri.clone())
537 } else {
538 None
539 };
540 let mut state = OAuth2State::new(
541 auth_entity.to_string(),
542 redirect_uri.clone(),
543 user_id.to_string(),
544 scopes.to_vec(),
545 );
546
547 let mut pkce_challenge = None;
548 if let Some(registry) = &self.provider_registry {
549 if registry.requires_pkce(auth_entity).await {
550 let (verifier, challenge) = generate_pkce_pair();
551 state.metadata.insert(
552 PKCE_CODE_VERIFIER_KEY.to_string(),
553 serde_json::Value::String(verifier.clone()),
554 );
555 pkce_challenge = Some(challenge);
556 }
557 }
558
559 self.store.store_oauth2_state(state.clone()).await?;
561
562 let provider = self.get_provider(auth_entity).await?;
564
565 let mut auth_url = provider.build_auth_url(
567 auth_config,
568 &state.state,
569 scopes,
570 redirect_uri.as_deref(),
571 )?;
572
573 if let Some(challenge) = pkce_challenge {
574 auth_url = append_pkce_challenge(&auth_url, &challenge)?;
575 }
576
577 tracing::debug!("Generated auth URL: {}", auth_url);
578 Ok(auth_url)
579 }
580 AuthType::Secret { .. } => Err(AuthError::InvalidConfig(
581 "Secret authentication doesn't require authorization URL".to_string(),
582 )),
583 AuthType::None => Err(AuthError::InvalidConfig(
584 "No authentication doesn't require authorization URL".to_string(),
585 )),
586 }
587 }
588
589 pub async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthSession, AuthError> {
591 tracing::debug!("Handling OAuth2 callback with state: {}", state);
592
593 let oauth2_state = self.store.get_oauth2_state(state).await?.ok_or_else(|| {
595 AuthError::OAuth2Flow("Invalid or expired state parameter".to_string())
596 })?;
597
598 self.store.remove_oauth2_state(state).await?;
600
601 if oauth2_state.is_expired(600) {
603 return Err(AuthError::OAuth2Flow(
604 "OAuth2 state has expired".to_string(),
605 ));
606 }
607
608 let auth_config = if let Some(registry) = &self.provider_registry {
610 registry
611 .get_auth_type(&oauth2_state.provider_name)
612 .await
613 .ok_or_else(|| {
614 AuthError::InvalidConfig(format!(
615 "No configuration found for provider: {}",
616 oauth2_state.provider_name
617 ))
618 })?
619 } else {
620 return Err(AuthError::InvalidConfig(
621 "No provider registry configured".to_string(),
622 ));
623 };
624
625 let provider = self.get_provider(&oauth2_state.provider_name).await?;
627
628 let redirect_uri = match &auth_config {
630 AuthType::OAuth2 {
631 send_redirect_uri, ..
632 } if *send_redirect_uri => oauth2_state
633 .redirect_uri
634 .clone()
635 .or_else(|| Some(self.redirect_uri.clone())),
636 AuthType::OAuth2 { .. } => None,
637 _ => None,
638 };
639 let pkce_code_verifier = oauth2_state
640 .metadata
641 .get(PKCE_CODE_VERIFIER_KEY)
642 .and_then(|v| v.as_str());
643
644 let session = provider
645 .exchange_code(
646 code,
647 redirect_uri.as_deref(),
648 &auth_config,
649 pkce_code_verifier,
650 )
651 .await?;
652
653 self.store
655 .store_session(
656 &oauth2_state.provider_name,
657 &oauth2_state.user_id,
658 session.clone(),
659 )
660 .await?;
661
662 tracing::debug!(
663 "Successfully stored auth session for entity: {}",
664 oauth2_state.provider_name
665 );
666 Ok(session)
667 }
668
669 pub async fn refresh_session(
671 &self,
672 auth_entity: &str,
673 user_id: &str,
674 auth_config: &AuthType,
675 ) -> Result<AuthSession, AuthError> {
676 tracing::debug!(
677 "Refreshing session for entity: {} user: {:?}",
678 auth_entity,
679 user_id
680 );
681
682 let current_session = self
684 .store
685 .get_session(auth_entity, &user_id)
686 .await?
687 .ok_or_else(|| {
688 AuthError::TokenRefreshFailed("No session found to refresh".to_string())
689 })?;
690
691 let refresh_token = current_session.refresh_token.ok_or_else(|| {
692 AuthError::TokenRefreshFailed("No refresh token available".to_string())
693 })?;
694
695 match auth_config {
696 AuthType::OAuth2 {
697 flow_type: OAuth2FlowType::ClientCredentials,
698 ..
699 } => {
700 let provider = self.get_provider(auth_entity).await?;
702 let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
703
704 self.store
706 .store_session(auth_entity, &user_id, new_session.clone())
707 .await?;
708 Ok(new_session)
709 }
710 auth_config @ AuthType::OAuth2 { .. } => {
711 let provider = self.get_provider(auth_entity).await?;
713
714 let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
716
717 self.store
719 .store_session(auth_entity, &user_id, new_session.clone())
720 .await?;
721 Ok(new_session)
722 }
723 _ => Err(AuthError::InvalidConfig(
724 "Cannot refresh non-OAuth2 session".to_string(),
725 )),
726 }
727 }
728
729 pub async fn refresh_get_session(
731 &self,
732 auth_entity: &str,
733 user_id: &str,
734 auth_config: &AuthType,
735 ) -> Result<Option<AuthSession>, AuthError> {
736 match self.store.get_session(auth_entity, user_id).await? {
737 Some(session) => {
738 if session.needs_refresh() {
739 tracing::debug!(
740 "Session expired for {}:{:?}, attempting refresh",
741 auth_entity,
742 user_id
743 );
744 match self
745 .refresh_session(auth_entity, user_id, auth_config)
746 .await
747 {
748 Ok(refreshed_session) => {
749 tracing::info!(
750 "Successfully refreshed session for {}:{:?}",
751 auth_entity,
752 user_id
753 );
754 Ok(Some(refreshed_session))
755 }
756 Err(e) => {
757 tracing::warn!(
758 "Failed to refresh session for {}:{:?}: {}",
759 auth_entity,
760 user_id,
761 e
762 );
763 Err(e)
764 }
765 }
766 } else {
767 Ok(Some(session))
768 }
769 }
770 None => Ok(None),
771 }
772 }
773
774 async fn get_provider(&self, provider_name: &str) -> Result<Arc<dyn AuthProvider>, AuthError> {
775 if let Some(registry) = &self.provider_registry {
776 registry
777 .get_provider(provider_name)
778 .await
779 .ok_or_else(|| AuthError::ProviderNotFound(provider_name.to_string()))
780 } else {
781 Err(AuthError::InvalidConfig(
782 "No provider registry configured".to_string(),
783 ))
784 }
785 }
786
787 pub async fn get_session(
789 &self,
790 auth_entity: &str,
791 user_id: &str,
792 ) -> Result<Option<AuthSession>, AuthError> {
793 self.store.get_session(auth_entity, user_id).await
794 }
795
796 pub async fn store_session(
797 &self,
798 auth_entity: &str,
799 user_id: &str,
800 session: AuthSession,
801 ) -> Result<(), AuthError> {
802 self.store
803 .store_session(auth_entity, user_id, session)
804 .await
805 }
806
807 pub async fn remove_session(
808 &self,
809 auth_entity: &str,
810 user_id: &str,
811 ) -> Result<bool, AuthError> {
812 self.store.remove_session(auth_entity, user_id).await
813 }
814
815 pub async fn store_secret(
816 &self,
817 user_id: &str,
818 auth_entity: Option<&str>,
819 secret: AuthSecret,
820 ) -> Result<(), AuthError> {
821 self.store.store_secret(user_id, auth_entity, secret).await
822 }
823
824 pub async fn get_secret(
825 &self,
826 user_id: &str,
827 auth_entity: Option<&str>,
828 key: &str,
829 ) -> Result<Option<AuthSecret>, AuthError> {
830 self.store.get_secret(user_id, auth_entity, key).await
831 }
832
833 pub async fn remove_secret(
834 &self,
835 user_id: &str,
836 auth_entity: Option<&str>,
837 key: &str,
838 ) -> Result<bool, AuthError> {
839 self.store.remove_secret(user_id, auth_entity, key).await
840 }
841
842 pub async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError> {
843 self.store.store_oauth2_state(state).await
844 }
845
846 pub async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError> {
847 self.store.get_oauth2_state(state).await
848 }
849
850 pub async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError> {
851 self.store.remove_oauth2_state(state).await
852 }
853
854 pub async fn list_secrets(
855 &self,
856 user_id: &str,
857 ) -> Result<HashMap<String, AuthSecret>, AuthError> {
858 self.store.list_secrets(user_id).await
859 }
860
861 pub async fn list_sessions(
862 &self,
863 user_id: &str,
864 ) -> Result<HashMap<String, AuthSession>, AuthError> {
865 self.store.list_sessions(user_id).await
866 }
867}
868
869#[async_trait]
871pub trait AuthProvider: Send + Sync {
872 fn provider_name(&self) -> &str;
874
875 async fn exchange_code(
877 &self,
878 code: &str,
879 redirect_uri: Option<&str>,
880 auth_config: &AuthType,
881 pkce_code_verifier: Option<&str>,
882 ) -> Result<AuthSession, AuthError>;
883
884 async fn refresh_token(
886 &self,
887 refresh_token: &str,
888 auth_config: &AuthType,
889 ) -> Result<AuthSession, AuthError>;
890
891 fn build_auth_url(
893 &self,
894 auth_config: &AuthType,
895 state: &str,
896 scopes: &[String],
897 redirect_uri: Option<&str>,
898 ) -> Result<String, AuthError>;
899}