1use async_trait::async_trait;
2use base64::Engine;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use chrono::{DateTime, Utc};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8use std::collections::HashMap;
9use std::sync::Arc;
10use url::Url;
11
12use crate::McpSession;
13
14fn default_send_redirect_uri() -> bool {
15 true
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
20#[serde(tag = "type", content = "config")]
21pub enum AuthType {
22 #[serde(rename = "none")]
24 None,
25 #[serde(rename = "oauth2")]
27 OAuth2 {
28 flow_type: OAuth2FlowType,
30 authorization_url: String,
32 token_url: String,
34 refresh_url: Option<String>,
36 scopes: Vec<String>,
38 #[serde(default = "default_send_redirect_uri")]
40 send_redirect_uri: bool,
41 },
42 #[serde(rename = "secret")]
44 Secret {
45 provider: String,
46 #[serde(default)]
47 fields: Vec<SecretFieldSpec>,
48 },
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
53#[serde(rename_all = "snake_case")]
54pub enum OAuth2FlowType {
55 AuthorizationCode,
56 ClientCredentials,
57 Implicit,
58 Password,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct AuthSession {
64 pub access_token: String,
66 pub refresh_token: Option<String>,
68 pub expires_at: Option<DateTime<Utc>>,
70 pub token_type: String,
72 pub scopes: Vec<String>,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
78pub struct TokenLimits {
79 #[serde(skip_serializing_if = "Option::is_none")]
81 pub daily_tokens: Option<u64>,
82
83 #[serde(skip_serializing_if = "Option::is_none")]
85 pub monthly_tokens: Option<u64>,
86
87 #[serde(skip_serializing_if = "Option::is_none")]
89 pub daily_calls: Option<u64>,
90
91 #[serde(skip_serializing_if = "Option::is_none")]
93 pub weekly_calls: Option<u64>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
98pub struct TokenResponse {
99 pub access_token: String,
100 pub refresh_token: String,
101 pub expires_at: i64,
102 #[serde(skip_serializing_if = "Option::is_none")]
104 pub identifier_id: Option<String>,
105 #[serde(skip_serializing_if = "Option::is_none")]
107 pub limits: Option<TokenLimits>,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct AuthSecret {
113 pub secret: String,
115 pub key: String,
117}
118impl From<AuthSession> for McpSession {
119 fn from(val: AuthSession) -> Self {
120 McpSession {
121 token: val.access_token,
122 expiry: val.expires_at.map(|dt| dt.into()),
123 }
124 }
125}
126
127impl From<AuthSecret> for McpSession {
128 fn from(val: AuthSecret) -> Self {
129 McpSession {
130 token: val.secret,
131 expiry: None, }
133 }
134}
135
136impl AuthSession {
137 pub fn new(
139 access_token: String,
140 token_type: Option<String>,
141 expires_in: Option<i64>,
142 refresh_token: Option<String>,
143 scopes: Vec<String>,
144 ) -> Self {
145 let now = Utc::now();
146 let expires_at = expires_in.map(|secs| now + chrono::Duration::seconds(secs));
147
148 AuthSession {
149 access_token,
150 refresh_token,
151 expires_at,
152 token_type: token_type.unwrap_or_else(|| "Bearer".to_string()),
153 scopes,
154 }
155 }
156
157 pub fn is_expired(&self, buffer_seconds: i64) -> bool {
159 match &self.expires_at {
160 Some(expires_at) => {
161 let buffer = chrono::Duration::seconds(buffer_seconds);
162 Utc::now() + buffer >= *expires_at
163 }
164 None => false, }
166 }
167
168 pub fn needs_refresh(&self) -> bool {
170 self.is_expired(300) }
172
173 pub fn get_access_token(&self) -> &str {
175 &self.access_token
176 }
177
178 pub fn update_tokens(
180 &mut self,
181 access_token: String,
182 expires_in: Option<i64>,
183 refresh_token: Option<String>,
184 ) {
185 self.access_token = access_token;
186
187 if let Some(secs) = expires_in {
188 self.expires_at = Some(Utc::now() + chrono::Duration::seconds(secs));
189 }
190
191 if let Some(token) = refresh_token {
192 self.refresh_token = Some(token);
193 }
194 }
195}
196
197impl AuthSecret {
198 pub fn new(key: String, secret: String) -> Self {
200 AuthSecret { secret, key }
201 }
202
203 pub fn get_secret(&self) -> &str {
205 &self.secret
206 }
207
208 pub fn get_provider(&self) -> &str {
210 &self.key
211 }
212}
213
214pub trait AuthMetadata: Send + Sync {
216 fn get_auth_entity(&self) -> String;
218
219 fn get_auth_type(&self) -> AuthType;
221
222 fn requires_auth(&self) -> bool {
224 !matches!(self.get_auth_type(), AuthType::None)
225 }
226
227 fn get_auth_config(&self) -> HashMap<String, serde_json::Value> {
229 HashMap::new()
230 }
231}
232
233#[derive(Debug, thiserror::Error)]
235pub enum AuthError {
236 #[error("OAuth2 flow error: {0}")]
237 OAuth2Flow(String),
238
239 #[error("Token expired and refresh failed: {0}")]
240 TokenRefreshFailed(String),
241
242 #[error("Invalid authentication configuration: {0}")]
243 InvalidConfig(String),
244
245 #[error("Authentication required but not configured for entity: {0}")]
246 AuthRequired(String),
247
248 #[error("API key not found for entity: {0}")]
249 ApiKeyNotFound(String),
250
251 #[error("Storage error: {0}")]
252 Storage(#[from] anyhow::Error),
253
254 #[error("Store error: {0}")]
255 StoreError(String),
256
257 #[error("Provider not found: {0}")]
258 ProviderNotFound(String),
259
260 #[error("Server error: {0}")]
261 ServerError(String),
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
266#[serde(tag = "type")]
267pub enum AuthRequirement {
268 #[serde(rename = "oauth2")]
269 OAuth2 {
270 provider: String,
271 #[serde(default)]
272 scopes: Vec<String>,
273 #[serde(
274 rename = "authorizationUrl",
275 default,
276 skip_serializing_if = "Option::is_none"
277 )]
278 authorization_url: Option<String>,
279 #[serde(rename = "tokenUrl", default, skip_serializing_if = "Option::is_none")]
280 token_url: Option<String>,
281 #[serde(
282 rename = "refreshUrl",
283 default,
284 skip_serializing_if = "Option::is_none"
285 )]
286 refresh_url: Option<String>,
287 #[serde(
288 rename = "sendRedirectUri",
289 default,
290 skip_serializing_if = "Option::is_none"
291 )]
292 send_redirect_uri: Option<bool>,
293 },
294 #[serde(rename = "secret")]
295 Secret {
296 provider: String,
297 #[serde(default)]
298 fields: Vec<SecretFieldSpec>,
299 },
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
303pub struct SecretFieldSpec {
304 pub key: String,
305 #[serde(default)]
306 pub label: Option<String>,
307 #[serde(default)]
308 pub description: Option<String>,
309 #[serde(default)]
310 pub optional: bool,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct OAuth2State {
316 pub state: String,
318 pub provider_name: String,
320 #[serde(default, skip_serializing_if = "Option::is_none")]
322 pub redirect_uri: Option<String>,
323 pub user_id: String,
325 pub scopes: Vec<String>,
327 pub metadata: HashMap<String, serde_json::Value>,
329 pub created_at: DateTime<Utc>,
331}
332
333pub const PKCE_CODE_VERIFIER_KEY: &str = "pkce_code_verifier";
334pub const PKCE_CODE_CHALLENGE_METHOD: &str = "S256";
335const PKCE_RANDOM_BYTES: usize = 32;
336
337pub fn generate_pkce_pair() -> (String, String) {
338 let mut random = vec![0u8; PKCE_RANDOM_BYTES];
339 rand::fill(&mut random);
340 let verifier = URL_SAFE_NO_PAD.encode(&random);
341 let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
342 (verifier, challenge)
343}
344
345pub fn append_pkce_challenge(auth_url: &str, challenge: &str) -> Result<String, AuthError> {
346 let mut url = Url::parse(auth_url)
347 .map_err(|e| AuthError::InvalidConfig(format!("Invalid authorization URL: {}", e)))?;
348 {
349 let mut pairs = url.query_pairs_mut();
350 pairs.append_pair("code_challenge", challenge);
351 pairs.append_pair("code_challenge_method", PKCE_CODE_CHALLENGE_METHOD);
352 }
353 Ok(url.to_string())
354}
355
356impl OAuth2State {
357 pub fn new_with_state(
359 state: String,
360 provider_name: String,
361 redirect_uri: Option<String>,
362 user_id: String,
363 scopes: Vec<String>,
364 ) -> Self {
365 Self {
366 state,
367 provider_name,
368 redirect_uri,
369 user_id,
370 scopes,
371 metadata: HashMap::new(),
372 created_at: Utc::now(),
373 }
374 }
375
376 pub fn new(
378 provider_name: String,
379 redirect_uri: Option<String>,
380 user_id: String,
381 scopes: Vec<String>,
382 ) -> Self {
383 Self::new_with_state(
384 uuid::Uuid::new_v4().to_string(),
385 provider_name,
386 redirect_uri,
387 user_id,
388 scopes,
389 )
390 }
391
392 pub fn is_expired(&self, max_age_seconds: i64) -> bool {
394 let max_age = chrono::Duration::seconds(max_age_seconds);
395 Utc::now() - self.created_at > max_age
396 }
397}
398
399#[async_trait]
402pub trait ToolAuthStore: Send + Sync {
403 async fn get_session(
406 &self,
407 auth_entity: &str,
408 user_id: &str,
409 ) -> Result<Option<AuthSession>, AuthError>;
410
411 async fn store_session(
413 &self,
414 auth_entity: &str,
415 user_id: &str,
416 session: AuthSession,
417 ) -> Result<(), AuthError>;
418
419 async fn remove_session(&self, auth_entity: &str, user_id: &str) -> Result<bool, AuthError>;
421
422 async fn store_secret(
425 &self,
426 user_id: &str,
427 auth_entity: Option<&str>, secret: AuthSecret,
429 ) -> Result<(), AuthError>;
430
431 async fn get_secret(
433 &self,
434 user_id: &str,
435 auth_entity: Option<&str>, key: &str,
437 ) -> Result<Option<AuthSecret>, AuthError>;
438
439 async fn remove_secret(
441 &self,
442 user_id: &str,
443 auth_entity: Option<&str>, key: &str,
445 ) -> Result<bool, AuthError>;
446
447 async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError>;
450
451 async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError>;
453
454 async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError>;
456
457 async fn list_secrets(&self, user_id: &str) -> Result<HashMap<String, AuthSecret>, AuthError>;
458
459 async fn list_sessions(
460 &self,
461 _user_id: &str,
462 ) -> Result<HashMap<String, AuthSession>, AuthError>;
463}
464
465#[derive(Clone)]
467pub struct OAuthHandler {
468 store: Arc<dyn ToolAuthStore>,
469 provider_registry: Option<Arc<dyn ProviderRegistry>>,
470 redirect_uri: String,
471}
472
473#[async_trait]
475pub trait ProviderRegistry: Send + Sync {
476 async fn get_provider(&self, provider_name: &str) -> Option<Arc<dyn AuthProvider>>;
477 async fn get_auth_type(&self, provider_name: &str) -> Option<AuthType>;
478 async fn is_provider_available(&self, provider_name: &str) -> bool;
479 async fn list_providers(&self) -> Vec<String>;
480 async fn requires_pkce(&self, _provider_name: &str) -> bool {
481 false
482 }
483}
484
485impl OAuthHandler {
486 pub fn new(store: Arc<dyn ToolAuthStore>, redirect_uri: String) -> Self {
487 Self {
488 store,
489 provider_registry: None,
490 redirect_uri,
491 }
492 }
493
494 pub fn with_provider_registry(
495 store: Arc<dyn ToolAuthStore>,
496 provider_registry: Arc<dyn ProviderRegistry>,
497 redirect_uri: String,
498 ) -> Self {
499 Self {
500 store,
501 provider_registry: Some(provider_registry),
502 redirect_uri,
503 }
504 }
505
506 pub async fn get_auth_url(
508 &self,
509 auth_entity: &str,
510 user_id: &str,
511 auth_config: &AuthType,
512 scopes: &[String],
513 ) -> Result<String, AuthError> {
514 tracing::debug!(
515 "Getting auth URL for entity: {} user: {:?}",
516 auth_entity,
517 user_id
518 );
519
520 match auth_config {
521 AuthType::OAuth2 {
522 flow_type: OAuth2FlowType::ClientCredentials,
523 ..
524 } => Err(AuthError::InvalidConfig(
525 "Client credentials flow doesn't require authorization URL".to_string(),
526 )),
527 auth_config @ AuthType::OAuth2 {
528 send_redirect_uri, ..
529 } => {
530 let redirect_uri = if *send_redirect_uri {
532 Some(self.redirect_uri.clone())
533 } else {
534 None
535 };
536 let mut state = OAuth2State::new(
537 auth_entity.to_string(),
538 redirect_uri.clone(),
539 user_id.to_string(),
540 scopes.to_vec(),
541 );
542
543 let mut pkce_challenge = None;
544 if let Some(registry) = &self.provider_registry
545 && registry.requires_pkce(auth_entity).await
546 {
547 let (verifier, challenge) = generate_pkce_pair();
548 state.metadata.insert(
549 PKCE_CODE_VERIFIER_KEY.to_string(),
550 serde_json::Value::String(verifier.clone()),
551 );
552 pkce_challenge = Some(challenge);
553 }
554
555 self.store.store_oauth2_state(state.clone()).await?;
557
558 let provider = self.get_provider(auth_entity).await?;
560
561 let mut auth_url = provider.build_auth_url(
563 auth_config,
564 &state.state,
565 scopes,
566 redirect_uri.as_deref(),
567 )?;
568
569 if let Some(challenge) = pkce_challenge {
570 auth_url = append_pkce_challenge(&auth_url, &challenge)?;
571 }
572
573 tracing::debug!("Generated auth URL: {}", auth_url);
574 Ok(auth_url)
575 }
576 AuthType::Secret { .. } => Err(AuthError::InvalidConfig(
577 "Secret authentication doesn't require authorization URL".to_string(),
578 )),
579 AuthType::None => Err(AuthError::InvalidConfig(
580 "No authentication doesn't require authorization URL".to_string(),
581 )),
582 }
583 }
584
585 pub async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthSession, AuthError> {
587 tracing::debug!("Handling OAuth2 callback with state: {}", state);
588
589 let oauth2_state = self.store.get_oauth2_state(state).await?.ok_or_else(|| {
591 AuthError::OAuth2Flow("Invalid or expired state parameter".to_string())
592 })?;
593
594 self.store.remove_oauth2_state(state).await?;
596
597 if oauth2_state.is_expired(600) {
599 return Err(AuthError::OAuth2Flow(
600 "OAuth2 state has expired".to_string(),
601 ));
602 }
603
604 let auth_config = if let Some(registry) = &self.provider_registry {
606 registry
607 .get_auth_type(&oauth2_state.provider_name)
608 .await
609 .ok_or_else(|| {
610 AuthError::InvalidConfig(format!(
611 "No configuration found for provider: {}",
612 oauth2_state.provider_name
613 ))
614 })?
615 } else {
616 return Err(AuthError::InvalidConfig(
617 "No provider registry configured".to_string(),
618 ));
619 };
620
621 let provider = self.get_provider(&oauth2_state.provider_name).await?;
623
624 let redirect_uri = match &auth_config {
626 AuthType::OAuth2 {
627 send_redirect_uri, ..
628 } if *send_redirect_uri => oauth2_state
629 .redirect_uri
630 .clone()
631 .or_else(|| Some(self.redirect_uri.clone())),
632 AuthType::OAuth2 { .. } => None,
633 _ => None,
634 };
635 let pkce_code_verifier = oauth2_state
636 .metadata
637 .get(PKCE_CODE_VERIFIER_KEY)
638 .and_then(|v| v.as_str());
639
640 let session = provider
641 .exchange_code(
642 code,
643 redirect_uri.as_deref(),
644 &auth_config,
645 pkce_code_verifier,
646 )
647 .await?;
648
649 self.store
651 .store_session(
652 &oauth2_state.provider_name,
653 &oauth2_state.user_id,
654 session.clone(),
655 )
656 .await?;
657
658 tracing::debug!(
659 "Successfully stored auth session for entity: {}",
660 oauth2_state.provider_name
661 );
662 Ok(session)
663 }
664
665 pub async fn refresh_session(
667 &self,
668 auth_entity: &str,
669 user_id: &str,
670 auth_config: &AuthType,
671 ) -> Result<AuthSession, AuthError> {
672 tracing::debug!(
673 "Refreshing session for entity: {} user: {:?}",
674 auth_entity,
675 user_id
676 );
677
678 let current_session = self
680 .store
681 .get_session(auth_entity, user_id)
682 .await?
683 .ok_or_else(|| {
684 AuthError::TokenRefreshFailed("No session found to refresh".to_string())
685 })?;
686
687 let refresh_token = current_session.refresh_token.ok_or_else(|| {
688 AuthError::TokenRefreshFailed("No refresh token available".to_string())
689 })?;
690
691 match auth_config {
692 AuthType::OAuth2 {
693 flow_type: OAuth2FlowType::ClientCredentials,
694 ..
695 } => {
696 let provider = self.get_provider(auth_entity).await?;
698 let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
699
700 self.store
702 .store_session(auth_entity, user_id, new_session.clone())
703 .await?;
704 Ok(new_session)
705 }
706 auth_config @ AuthType::OAuth2 { .. } => {
707 let provider = self.get_provider(auth_entity).await?;
709
710 let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
712
713 self.store
715 .store_session(auth_entity, user_id, new_session.clone())
716 .await?;
717 Ok(new_session)
718 }
719 _ => Err(AuthError::InvalidConfig(
720 "Cannot refresh non-OAuth2 session".to_string(),
721 )),
722 }
723 }
724
725 pub async fn refresh_get_session(
727 &self,
728 auth_entity: &str,
729 user_id: &str,
730 auth_config: &AuthType,
731 ) -> Result<Option<AuthSession>, AuthError> {
732 match self.store.get_session(auth_entity, user_id).await? {
733 Some(session) => {
734 if session.needs_refresh() {
735 tracing::debug!(
736 "Session expired for {}:{:?}, attempting refresh",
737 auth_entity,
738 user_id
739 );
740 match self
741 .refresh_session(auth_entity, user_id, auth_config)
742 .await
743 {
744 Ok(refreshed_session) => {
745 tracing::info!(
746 "Successfully refreshed session for {}:{:?}",
747 auth_entity,
748 user_id
749 );
750 Ok(Some(refreshed_session))
751 }
752 Err(e) => {
753 tracing::warn!(
754 "Failed to refresh session for {}:{:?}: {}",
755 auth_entity,
756 user_id,
757 e
758 );
759 Err(e)
760 }
761 }
762 } else {
763 Ok(Some(session))
764 }
765 }
766 None => Ok(None),
767 }
768 }
769
770 async fn get_provider(&self, provider_name: &str) -> Result<Arc<dyn AuthProvider>, AuthError> {
771 if let Some(registry) = &self.provider_registry {
772 registry
773 .get_provider(provider_name)
774 .await
775 .ok_or_else(|| AuthError::ProviderNotFound(provider_name.to_string()))
776 } else {
777 Err(AuthError::InvalidConfig(
778 "No provider registry configured".to_string(),
779 ))
780 }
781 }
782
783 pub async fn get_session(
785 &self,
786 auth_entity: &str,
787 user_id: &str,
788 ) -> Result<Option<AuthSession>, AuthError> {
789 self.store.get_session(auth_entity, user_id).await
790 }
791
792 pub async fn store_session(
793 &self,
794 auth_entity: &str,
795 user_id: &str,
796 session: AuthSession,
797 ) -> Result<(), AuthError> {
798 self.store
799 .store_session(auth_entity, user_id, session)
800 .await
801 }
802
803 pub async fn remove_session(
804 &self,
805 auth_entity: &str,
806 user_id: &str,
807 ) -> Result<bool, AuthError> {
808 self.store.remove_session(auth_entity, user_id).await
809 }
810
811 pub async fn store_secret(
812 &self,
813 user_id: &str,
814 auth_entity: Option<&str>,
815 secret: AuthSecret,
816 ) -> Result<(), AuthError> {
817 self.store.store_secret(user_id, auth_entity, secret).await
818 }
819
820 pub async fn get_secret(
821 &self,
822 user_id: &str,
823 auth_entity: Option<&str>,
824 key: &str,
825 ) -> Result<Option<AuthSecret>, AuthError> {
826 self.store.get_secret(user_id, auth_entity, key).await
827 }
828
829 pub async fn remove_secret(
830 &self,
831 user_id: &str,
832 auth_entity: Option<&str>,
833 key: &str,
834 ) -> Result<bool, AuthError> {
835 self.store.remove_secret(user_id, auth_entity, key).await
836 }
837
838 pub async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError> {
839 self.store.store_oauth2_state(state).await
840 }
841
842 pub async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError> {
843 self.store.get_oauth2_state(state).await
844 }
845
846 pub async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError> {
847 self.store.remove_oauth2_state(state).await
848 }
849
850 pub async fn list_secrets(
851 &self,
852 user_id: &str,
853 ) -> Result<HashMap<String, AuthSecret>, AuthError> {
854 self.store.list_secrets(user_id).await
855 }
856
857 pub async fn list_sessions(
858 &self,
859 user_id: &str,
860 ) -> Result<HashMap<String, AuthSession>, AuthError> {
861 self.store.list_sessions(user_id).await
862 }
863}
864
865#[async_trait]
867pub trait AuthProvider: Send + Sync {
868 fn provider_name(&self) -> &str;
870
871 async fn exchange_code(
873 &self,
874 code: &str,
875 redirect_uri: Option<&str>,
876 auth_config: &AuthType,
877 pkce_code_verifier: Option<&str>,
878 ) -> Result<AuthSession, AuthError>;
879
880 async fn refresh_token(
882 &self,
883 refresh_token: &str,
884 auth_config: &AuthType,
885 ) -> Result<AuthSession, AuthError>;
886
887 fn build_auth_url(
889 &self,
890 auth_config: &AuthType,
891 state: &str,
892 scopes: &[String],
893 redirect_uri: Option<&str>,
894 ) -> Result<String, AuthError>;
895}