1use async_trait::async_trait;
2use base64::Engine;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use chrono::{DateTime, Utc};
5use rand::RngCore;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use std::collections::HashMap;
10use std::sync::Arc;
11use url::Url;
12
13use crate::McpSession;
14
15fn default_send_redirect_uri() -> bool {
16 true
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
21#[serde(tag = "type", content = "config")]
22pub enum AuthType {
23 #[serde(rename = "none")]
25 None,
26 #[serde(rename = "oauth2")]
28 OAuth2 {
29 flow_type: OAuth2FlowType,
31 authorization_url: String,
33 token_url: String,
35 refresh_url: Option<String>,
37 scopes: Vec<String>,
39 #[serde(default = "default_send_redirect_uri")]
41 send_redirect_uri: bool,
42 },
43 #[serde(rename = "secret")]
45 Secret {
46 provider: String,
47 #[serde(default)]
48 fields: Vec<SecretFieldSpec>,
49 },
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
54#[serde(rename_all = "snake_case")]
55pub enum OAuth2FlowType {
56 AuthorizationCode,
57 ClientCredentials,
58 Implicit,
59 Password,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct AuthSession {
65 pub access_token: String,
67 pub refresh_token: Option<String>,
69 pub expires_at: Option<DateTime<Utc>>,
71 pub token_type: String,
73 pub scopes: Vec<String>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
79pub struct TokenLimits {
80 #[serde(skip_serializing_if = "Option::is_none")]
82 pub daily_tokens: Option<u64>,
83
84 #[serde(skip_serializing_if = "Option::is_none")]
86 pub monthly_tokens: Option<u64>,
87
88 #[serde(skip_serializing_if = "Option::is_none")]
90 pub daily_calls: Option<u64>,
91
92 #[serde(skip_serializing_if = "Option::is_none")]
94 pub weekly_calls: Option<u64>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
99pub struct TokenResponse {
100 pub access_token: String,
101 pub refresh_token: String,
102 pub expires_at: i64,
103 #[serde(skip_serializing_if = "Option::is_none")]
105 pub identifier_id: Option<String>,
106 #[serde(skip_serializing_if = "Option::is_none")]
108 pub limits: Option<TokenLimits>,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct AuthSecret {
114 pub secret: String,
116 pub key: String,
118}
119impl From<AuthSession> for McpSession {
120 fn from(val: AuthSession) -> Self {
121 McpSession {
122 token: val.access_token,
123 expiry: val.expires_at.map(|dt| dt.into()),
124 }
125 }
126}
127
128impl From<AuthSecret> for McpSession {
129 fn from(val: AuthSecret) -> Self {
130 McpSession {
131 token: val.secret,
132 expiry: None, }
134 }
135}
136
137impl AuthSession {
138 pub fn new(
140 access_token: String,
141 token_type: Option<String>,
142 expires_in: Option<i64>,
143 refresh_token: Option<String>,
144 scopes: Vec<String>,
145 ) -> Self {
146 let now = Utc::now();
147 let expires_at = expires_in.map(|secs| now + chrono::Duration::seconds(secs));
148
149 AuthSession {
150 access_token,
151 refresh_token,
152 expires_at,
153 token_type: token_type.unwrap_or_else(|| "Bearer".to_string()),
154 scopes,
155 }
156 }
157
158 pub fn is_expired(&self, buffer_seconds: i64) -> bool {
160 match &self.expires_at {
161 Some(expires_at) => {
162 let buffer = chrono::Duration::seconds(buffer_seconds);
163 Utc::now() + buffer >= *expires_at
164 }
165 None => false, }
167 }
168
169 pub fn needs_refresh(&self) -> bool {
171 self.is_expired(300) }
173
174 pub fn get_access_token(&self) -> &str {
176 &self.access_token
177 }
178
179 pub fn update_tokens(
181 &mut self,
182 access_token: String,
183 expires_in: Option<i64>,
184 refresh_token: Option<String>,
185 ) {
186 self.access_token = access_token;
187
188 if let Some(secs) = expires_in {
189 self.expires_at = Some(Utc::now() + chrono::Duration::seconds(secs));
190 }
191
192 if let Some(token) = refresh_token {
193 self.refresh_token = Some(token);
194 }
195 }
196}
197
198impl AuthSecret {
199 pub fn new(key: String, secret: String) -> Self {
201 AuthSecret { secret, key }
202 }
203
204 pub fn get_secret(&self) -> &str {
206 &self.secret
207 }
208
209 pub fn get_provider(&self) -> &str {
211 &self.key
212 }
213}
214
215pub trait AuthMetadata: Send + Sync {
217 fn get_auth_entity(&self) -> String;
219
220 fn get_auth_type(&self) -> AuthType;
222
223 fn requires_auth(&self) -> bool {
225 !matches!(self.get_auth_type(), AuthType::None)
226 }
227
228 fn get_auth_config(&self) -> HashMap<String, serde_json::Value> {
230 HashMap::new()
231 }
232}
233
234#[derive(Debug, thiserror::Error)]
236pub enum AuthError {
237 #[error("OAuth2 flow error: {0}")]
238 OAuth2Flow(String),
239
240 #[error("Token expired and refresh failed: {0}")]
241 TokenRefreshFailed(String),
242
243 #[error("Invalid authentication configuration: {0}")]
244 InvalidConfig(String),
245
246 #[error("Authentication required but not configured for entity: {0}")]
247 AuthRequired(String),
248
249 #[error("API key not found for entity: {0}")]
250 ApiKeyNotFound(String),
251
252 #[error("Storage error: {0}")]
253 Storage(#[from] anyhow::Error),
254
255 #[error("Store error: {0}")]
256 StoreError(String),
257
258 #[error("Provider not found: {0}")]
259 ProviderNotFound(String),
260
261 #[error("Server error: {0}")]
262 ServerError(String),
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
267#[serde(tag = "type")]
268pub enum AuthRequirement {
269 #[serde(rename = "oauth2")]
270 OAuth2 {
271 provider: String,
272 #[serde(default)]
273 scopes: Vec<String>,
274 #[serde(
275 rename = "authorizationUrl",
276 default,
277 skip_serializing_if = "Option::is_none"
278 )]
279 authorization_url: Option<String>,
280 #[serde(rename = "tokenUrl", default, skip_serializing_if = "Option::is_none")]
281 token_url: Option<String>,
282 #[serde(
283 rename = "refreshUrl",
284 default,
285 skip_serializing_if = "Option::is_none"
286 )]
287 refresh_url: Option<String>,
288 #[serde(
289 rename = "sendRedirectUri",
290 default,
291 skip_serializing_if = "Option::is_none"
292 )]
293 send_redirect_uri: Option<bool>,
294 },
295 #[serde(rename = "secret")]
296 Secret {
297 provider: String,
298 #[serde(default)]
299 fields: Vec<SecretFieldSpec>,
300 },
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
304pub struct SecretFieldSpec {
305 pub key: String,
306 #[serde(default)]
307 pub label: Option<String>,
308 #[serde(default)]
309 pub description: Option<String>,
310 #[serde(default)]
311 pub optional: bool,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
316pub struct OAuth2State {
317 pub state: String,
319 pub provider_name: String,
321 #[serde(default, skip_serializing_if = "Option::is_none")]
323 pub redirect_uri: Option<String>,
324 pub user_id: String,
326 pub scopes: Vec<String>,
328 pub metadata: HashMap<String, serde_json::Value>,
330 pub created_at: DateTime<Utc>,
332}
333
334pub const PKCE_CODE_VERIFIER_KEY: &str = "pkce_code_verifier";
335pub const PKCE_CODE_CHALLENGE_METHOD: &str = "S256";
336const PKCE_RANDOM_BYTES: usize = 32;
337
338pub fn generate_pkce_pair() -> (String, String) {
339 let mut random = vec![0u8; PKCE_RANDOM_BYTES];
340 rand::thread_rng().fill_bytes(&mut random);
341 let verifier = URL_SAFE_NO_PAD.encode(&random);
342 let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
343 (verifier, challenge)
344}
345
346pub fn append_pkce_challenge(auth_url: &str, challenge: &str) -> Result<String, AuthError> {
347 let mut url = Url::parse(auth_url)
348 .map_err(|e| AuthError::InvalidConfig(format!("Invalid authorization URL: {}", e)))?;
349 {
350 let mut pairs = url.query_pairs_mut();
351 pairs.append_pair("code_challenge", challenge);
352 pairs.append_pair("code_challenge_method", PKCE_CODE_CHALLENGE_METHOD);
353 }
354 Ok(url.to_string())
355}
356
357impl OAuth2State {
358 pub fn new_with_state(
360 state: String,
361 provider_name: String,
362 redirect_uri: Option<String>,
363 user_id: String,
364 scopes: Vec<String>,
365 ) -> Self {
366 Self {
367 state,
368 provider_name,
369 redirect_uri,
370 user_id,
371 scopes,
372 metadata: HashMap::new(),
373 created_at: Utc::now(),
374 }
375 }
376
377 pub fn new(
379 provider_name: String,
380 redirect_uri: Option<String>,
381 user_id: String,
382 scopes: Vec<String>,
383 ) -> Self {
384 Self::new_with_state(
385 uuid::Uuid::new_v4().to_string(),
386 provider_name,
387 redirect_uri,
388 user_id,
389 scopes,
390 )
391 }
392
393 pub fn is_expired(&self, max_age_seconds: i64) -> bool {
395 let max_age = chrono::Duration::seconds(max_age_seconds);
396 Utc::now() - self.created_at > max_age
397 }
398}
399
400#[async_trait]
403pub trait ToolAuthStore: Send + Sync {
404 async fn get_session(
407 &self,
408 auth_entity: &str,
409 user_id: &str,
410 ) -> Result<Option<AuthSession>, AuthError>;
411
412 async fn store_session(
414 &self,
415 auth_entity: &str,
416 user_id: &str,
417 session: AuthSession,
418 ) -> Result<(), AuthError>;
419
420 async fn remove_session(&self, auth_entity: &str, user_id: &str) -> Result<bool, AuthError>;
422
423 async fn store_secret(
426 &self,
427 user_id: &str,
428 auth_entity: Option<&str>, secret: AuthSecret,
430 ) -> Result<(), AuthError>;
431
432 async fn get_secret(
434 &self,
435 user_id: &str,
436 auth_entity: Option<&str>, key: &str,
438 ) -> Result<Option<AuthSecret>, AuthError>;
439
440 async fn remove_secret(
442 &self,
443 user_id: &str,
444 auth_entity: Option<&str>, key: &str,
446 ) -> Result<bool, AuthError>;
447
448 async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError>;
451
452 async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError>;
454
455 async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError>;
457
458 async fn list_secrets(&self, user_id: &str) -> Result<HashMap<String, AuthSecret>, AuthError>;
459
460 async fn list_sessions(
461 &self,
462 _user_id: &str,
463 ) -> Result<HashMap<String, AuthSession>, AuthError>;
464}
465
466#[derive(Clone)]
468pub struct OAuthHandler {
469 store: Arc<dyn ToolAuthStore>,
470 provider_registry: Option<Arc<dyn ProviderRegistry>>,
471 redirect_uri: String,
472}
473
474#[async_trait]
476pub trait ProviderRegistry: Send + Sync {
477 async fn get_provider(&self, provider_name: &str) -> Option<Arc<dyn AuthProvider>>;
478 async fn get_auth_type(&self, provider_name: &str) -> Option<AuthType>;
479 async fn is_provider_available(&self, provider_name: &str) -> bool;
480 async fn list_providers(&self) -> Vec<String>;
481 async fn requires_pkce(&self, _provider_name: &str) -> bool {
482 false
483 }
484}
485
486impl OAuthHandler {
487 pub fn new(store: Arc<dyn ToolAuthStore>, redirect_uri: String) -> Self {
488 Self {
489 store,
490 provider_registry: None,
491 redirect_uri,
492 }
493 }
494
495 pub fn with_provider_registry(
496 store: Arc<dyn ToolAuthStore>,
497 provider_registry: Arc<dyn ProviderRegistry>,
498 redirect_uri: String,
499 ) -> Self {
500 Self {
501 store,
502 provider_registry: Some(provider_registry),
503 redirect_uri,
504 }
505 }
506
507 pub async fn get_auth_url(
509 &self,
510 auth_entity: &str,
511 user_id: &str,
512 auth_config: &AuthType,
513 scopes: &[String],
514 ) -> Result<String, AuthError> {
515 tracing::debug!(
516 "Getting auth URL for entity: {} user: {:?}",
517 auth_entity,
518 user_id
519 );
520
521 match auth_config {
522 AuthType::OAuth2 {
523 flow_type: OAuth2FlowType::ClientCredentials,
524 ..
525 } => Err(AuthError::InvalidConfig(
526 "Client credentials flow doesn't require authorization URL".to_string(),
527 )),
528 auth_config @ AuthType::OAuth2 {
529 send_redirect_uri, ..
530 } => {
531 let redirect_uri = if *send_redirect_uri {
533 Some(self.redirect_uri.clone())
534 } else {
535 None
536 };
537 let mut state = OAuth2State::new(
538 auth_entity.to_string(),
539 redirect_uri.clone(),
540 user_id.to_string(),
541 scopes.to_vec(),
542 );
543
544 let mut pkce_challenge = None;
545 if let Some(registry) = &self.provider_registry
546 && registry.requires_pkce(auth_entity).await {
547 let (verifier, challenge) = generate_pkce_pair();
548 state.metadata.insert(
549 PKCE_CODE_VERIFIER_KEY.to_string(),
550 serde_json::Value::String(verifier.clone()),
551 );
552 pkce_challenge = Some(challenge);
553 }
554
555 self.store.store_oauth2_state(state.clone()).await?;
557
558 let provider = self.get_provider(auth_entity).await?;
560
561 let mut auth_url = provider.build_auth_url(
563 auth_config,
564 &state.state,
565 scopes,
566 redirect_uri.as_deref(),
567 )?;
568
569 if let Some(challenge) = pkce_challenge {
570 auth_url = append_pkce_challenge(&auth_url, &challenge)?;
571 }
572
573 tracing::debug!("Generated auth URL: {}", auth_url);
574 Ok(auth_url)
575 }
576 AuthType::Secret { .. } => Err(AuthError::InvalidConfig(
577 "Secret authentication doesn't require authorization URL".to_string(),
578 )),
579 AuthType::None => Err(AuthError::InvalidConfig(
580 "No authentication doesn't require authorization URL".to_string(),
581 )),
582 }
583 }
584
585 pub async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthSession, AuthError> {
587 tracing::debug!("Handling OAuth2 callback with state: {}", state);
588
589 let oauth2_state = self.store.get_oauth2_state(state).await?.ok_or_else(|| {
591 AuthError::OAuth2Flow("Invalid or expired state parameter".to_string())
592 })?;
593
594 self.store.remove_oauth2_state(state).await?;
596
597 if oauth2_state.is_expired(600) {
599 return Err(AuthError::OAuth2Flow(
600 "OAuth2 state has expired".to_string(),
601 ));
602 }
603
604 let auth_config = if let Some(registry) = &self.provider_registry {
606 registry
607 .get_auth_type(&oauth2_state.provider_name)
608 .await
609 .ok_or_else(|| {
610 AuthError::InvalidConfig(format!(
611 "No configuration found for provider: {}",
612 oauth2_state.provider_name
613 ))
614 })?
615 } else {
616 return Err(AuthError::InvalidConfig(
617 "No provider registry configured".to_string(),
618 ));
619 };
620
621 let provider = self.get_provider(&oauth2_state.provider_name).await?;
623
624 let redirect_uri = match &auth_config {
626 AuthType::OAuth2 {
627 send_redirect_uri, ..
628 } if *send_redirect_uri => oauth2_state
629 .redirect_uri
630 .clone()
631 .or_else(|| Some(self.redirect_uri.clone())),
632 AuthType::OAuth2 { .. } => None,
633 _ => None,
634 };
635 let pkce_code_verifier = oauth2_state
636 .metadata
637 .get(PKCE_CODE_VERIFIER_KEY)
638 .and_then(|v| v.as_str());
639
640 let session = provider
641 .exchange_code(
642 code,
643 redirect_uri.as_deref(),
644 &auth_config,
645 pkce_code_verifier,
646 )
647 .await?;
648
649 self.store
651 .store_session(
652 &oauth2_state.provider_name,
653 &oauth2_state.user_id,
654 session.clone(),
655 )
656 .await?;
657
658 tracing::debug!(
659 "Successfully stored auth session for entity: {}",
660 oauth2_state.provider_name
661 );
662 Ok(session)
663 }
664
665 pub async fn refresh_session(
667 &self,
668 auth_entity: &str,
669 user_id: &str,
670 auth_config: &AuthType,
671 ) -> Result<AuthSession, AuthError> {
672 tracing::debug!(
673 "Refreshing session for entity: {} user: {:?}",
674 auth_entity,
675 user_id
676 );
677
678 let current_session = self
680 .store
681 .get_session(auth_entity, user_id)
682 .await?
683 .ok_or_else(|| {
684 AuthError::TokenRefreshFailed("No session found to refresh".to_string())
685 })?;
686
687 let refresh_token = current_session.refresh_token.ok_or_else(|| {
688 AuthError::TokenRefreshFailed("No refresh token available".to_string())
689 })?;
690
691 match auth_config {
692 AuthType::OAuth2 {
693 flow_type: OAuth2FlowType::ClientCredentials,
694 ..
695 } => {
696 let provider = self.get_provider(auth_entity).await?;
698 let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
699
700 self.store
702 .store_session(auth_entity, user_id, new_session.clone())
703 .await?;
704 Ok(new_session)
705 }
706 auth_config @ AuthType::OAuth2 { .. } => {
707 let provider = self.get_provider(auth_entity).await?;
709
710 let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
712
713 self.store
715 .store_session(auth_entity, user_id, new_session.clone())
716 .await?;
717 Ok(new_session)
718 }
719 _ => Err(AuthError::InvalidConfig(
720 "Cannot refresh non-OAuth2 session".to_string(),
721 )),
722 }
723 }
724
725 pub async fn refresh_get_session(
727 &self,
728 auth_entity: &str,
729 user_id: &str,
730 auth_config: &AuthType,
731 ) -> Result<Option<AuthSession>, AuthError> {
732 match self.store.get_session(auth_entity, user_id).await? {
733 Some(session) => {
734 if session.needs_refresh() {
735 tracing::debug!(
736 "Session expired for {}:{:?}, attempting refresh",
737 auth_entity,
738 user_id
739 );
740 match self
741 .refresh_session(auth_entity, user_id, auth_config)
742 .await
743 {
744 Ok(refreshed_session) => {
745 tracing::info!(
746 "Successfully refreshed session for {}:{:?}",
747 auth_entity,
748 user_id
749 );
750 Ok(Some(refreshed_session))
751 }
752 Err(e) => {
753 tracing::warn!(
754 "Failed to refresh session for {}:{:?}: {}",
755 auth_entity,
756 user_id,
757 e
758 );
759 Err(e)
760 }
761 }
762 } else {
763 Ok(Some(session))
764 }
765 }
766 None => Ok(None),
767 }
768 }
769
770 async fn get_provider(&self, provider_name: &str) -> Result<Arc<dyn AuthProvider>, AuthError> {
771 if let Some(registry) = &self.provider_registry {
772 registry
773 .get_provider(provider_name)
774 .await
775 .ok_or_else(|| AuthError::ProviderNotFound(provider_name.to_string()))
776 } else {
777 Err(AuthError::InvalidConfig(
778 "No provider registry configured".to_string(),
779 ))
780 }
781 }
782
783 pub async fn get_session(
785 &self,
786 auth_entity: &str,
787 user_id: &str,
788 ) -> Result<Option<AuthSession>, AuthError> {
789 self.store.get_session(auth_entity, user_id).await
790 }
791
792 pub async fn store_session(
793 &self,
794 auth_entity: &str,
795 user_id: &str,
796 session: AuthSession,
797 ) -> Result<(), AuthError> {
798 self.store
799 .store_session(auth_entity, user_id, session)
800 .await
801 }
802
803 pub async fn remove_session(
804 &self,
805 auth_entity: &str,
806 user_id: &str,
807 ) -> Result<bool, AuthError> {
808 self.store.remove_session(auth_entity, user_id).await
809 }
810
811 pub async fn store_secret(
812 &self,
813 user_id: &str,
814 auth_entity: Option<&str>,
815 secret: AuthSecret,
816 ) -> Result<(), AuthError> {
817 self.store.store_secret(user_id, auth_entity, secret).await
818 }
819
820 pub async fn get_secret(
821 &self,
822 user_id: &str,
823 auth_entity: Option<&str>,
824 key: &str,
825 ) -> Result<Option<AuthSecret>, AuthError> {
826 self.store.get_secret(user_id, auth_entity, key).await
827 }
828
829 pub async fn remove_secret(
830 &self,
831 user_id: &str,
832 auth_entity: Option<&str>,
833 key: &str,
834 ) -> Result<bool, AuthError> {
835 self.store.remove_secret(user_id, auth_entity, key).await
836 }
837
838 pub async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError> {
839 self.store.store_oauth2_state(state).await
840 }
841
842 pub async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError> {
843 self.store.get_oauth2_state(state).await
844 }
845
846 pub async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError> {
847 self.store.remove_oauth2_state(state).await
848 }
849
850 pub async fn list_secrets(
851 &self,
852 user_id: &str,
853 ) -> Result<HashMap<String, AuthSecret>, AuthError> {
854 self.store.list_secrets(user_id).await
855 }
856
857 pub async fn list_sessions(
858 &self,
859 user_id: &str,
860 ) -> Result<HashMap<String, AuthSession>, AuthError> {
861 self.store.list_sessions(user_id).await
862 }
863}
864
865#[async_trait]
867pub trait AuthProvider: Send + Sync {
868 fn provider_name(&self) -> &str;
870
871 async fn exchange_code(
873 &self,
874 code: &str,
875 redirect_uri: Option<&str>,
876 auth_config: &AuthType,
877 pkce_code_verifier: Option<&str>,
878 ) -> Result<AuthSession, AuthError>;
879
880 async fn refresh_token(
882 &self,
883 refresh_token: &str,
884 auth_config: &AuthType,
885 ) -> Result<AuthSession, AuthError>;
886
887 fn build_auth_url(
889 &self,
890 auth_config: &AuthType,
891 state: &str,
892 scopes: &[String],
893 redirect_uri: Option<&str>,
894 ) -> Result<String, AuthError>;
895}