Skip to main content

distri_types/
auth.rs

1use async_trait::async_trait;
2use base64::Engine;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use chrono::{DateTime, Utc};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8use std::collections::HashMap;
9use std::sync::Arc;
10use url::Url;
11
12use crate::McpSession;
13
14fn default_send_redirect_uri() -> bool {
15    true
16}
17
18/// Authentication types supported by the tool auth system
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
20#[serde(tag = "type", content = "config")]
21pub enum AuthType {
22    /// No authentication required
23    #[serde(rename = "none")]
24    None,
25    /// OAuth2 authentication flows
26    #[serde(rename = "oauth2")]
27    OAuth2 {
28        /// OAuth2 flow type
29        flow_type: OAuth2FlowType,
30        /// Authorization URL
31        authorization_url: String,
32        /// Token URL
33        token_url: String,
34        /// Optional refresh URL
35        refresh_url: Option<String>,
36        /// Required scopes
37        scopes: Vec<String>,
38        /// Whether the provider should include redirect_uri in requests
39        #[serde(default = "default_send_redirect_uri")]
40        send_redirect_uri: bool,
41    },
42    /// Secret-based authentication (API keys etc.)
43    #[serde(rename = "secret")]
44    Secret {
45        provider: String,
46        #[serde(default)]
47        fields: Vec<SecretFieldSpec>,
48    },
49}
50
51/// OAuth2 flow types
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
53#[serde(rename_all = "snake_case")]
54pub enum OAuth2FlowType {
55    AuthorizationCode,
56    ClientCredentials,
57    Implicit,
58    Password,
59}
60
61/// OAuth2 authentication session - only contains OAuth tokens
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct AuthSession {
64    /// Access token
65    pub access_token: String,
66    /// Optional refresh token
67    pub refresh_token: Option<String>,
68    /// Token expiry time
69    pub expires_at: Option<DateTime<Utc>>,
70    /// Token type (usually "Bearer")
71    pub token_type: String,
72    /// Granted scopes
73    pub scopes: Vec<String>,
74}
75
76/// Usage limits that can be embedded in tokens
77#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
78pub struct TokenLimits {
79    /// Maximum tokens per day (None = unlimited)
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub daily_tokens: Option<u64>,
82
83    /// Maximum tokens per month (None = unlimited)
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub monthly_tokens: Option<u64>,
86
87    /// Maximum API calls per day (None = unlimited)
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub daily_calls: Option<u64>,
90
91    /// Maximum API calls per week (None = unlimited)
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub weekly_calls: Option<u64>,
94}
95
96/// Response for issuing access + refresh tokens.
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
98pub struct TokenResponse {
99    pub access_token: String,
100    pub refresh_token: String,
101    pub expires_at: i64,
102    /// Identifier for usage tracking (e.g., "blinksheets", "my-app")
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub identifier_id: Option<String>,
105    /// Effective limits applied to this token
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub limits: Option<TokenLimits>,
108}
109
110/// Secret storage for API keys and other non-OAuth authentication
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct AuthSecret {
113    /// The secret value (API key, token, etc.)
114    pub secret: String,
115    /// Key name for this secret
116    pub key: String,
117}
118impl From<AuthSession> for McpSession {
119    fn from(val: AuthSession) -> Self {
120        McpSession {
121            token: val.access_token,
122            expiry: val.expires_at.map(|dt| dt.into()),
123        }
124    }
125}
126
127impl From<AuthSecret> for McpSession {
128    fn from(val: AuthSecret) -> Self {
129        McpSession {
130            token: val.secret,
131            expiry: None, // Secrets don't expire
132        }
133    }
134}
135
136impl AuthSession {
137    /// Create a new OAuth auth session
138    pub fn new(
139        access_token: String,
140        token_type: Option<String>,
141        expires_in: Option<i64>,
142        refresh_token: Option<String>,
143        scopes: Vec<String>,
144    ) -> Self {
145        let now = Utc::now();
146        let expires_at = expires_in.map(|secs| now + chrono::Duration::seconds(secs));
147
148        AuthSession {
149            access_token,
150            refresh_token,
151            expires_at,
152            token_type: token_type.unwrap_or_else(|| "Bearer".to_string()),
153            scopes,
154        }
155    }
156
157    /// Check if the OAuth token is expired or will expire within the given buffer
158    pub fn is_expired(&self, buffer_seconds: i64) -> bool {
159        match &self.expires_at {
160            Some(expires_at) => {
161                let buffer = chrono::Duration::seconds(buffer_seconds);
162                Utc::now() + buffer >= *expires_at
163            }
164            None => false, // No expiry means it doesn't expire
165        }
166    }
167
168    /// Check if the OAuth token needs refreshing (expired with 5 minute buffer)
169    pub fn needs_refresh(&self) -> bool {
170        self.is_expired(300) // 5 minutes buffer
171    }
172
173    /// Get access token for OAuth sessions
174    pub fn get_access_token(&self) -> &str {
175        &self.access_token
176    }
177
178    /// Update OAuth session with new token data
179    pub fn update_tokens(
180        &mut self,
181        access_token: String,
182        expires_in: Option<i64>,
183        refresh_token: Option<String>,
184    ) {
185        self.access_token = access_token;
186
187        if let Some(secs) = expires_in {
188            self.expires_at = Some(Utc::now() + chrono::Duration::seconds(secs));
189        }
190
191        if let Some(token) = refresh_token {
192            self.refresh_token = Some(token);
193        }
194    }
195}
196
197impl AuthSecret {
198    /// Create a new secret
199    pub fn new(key: String, secret: String) -> Self {
200        AuthSecret { secret, key }
201    }
202
203    /// Get the secret value
204    pub fn get_secret(&self) -> &str {
205        &self.secret
206    }
207
208    /// Get the provider name
209    pub fn get_provider(&self) -> &str {
210        &self.key
211    }
212}
213
214/// Authentication metadata trait for tools
215pub trait AuthMetadata: Send + Sync {
216    /// Get the auth entity identifier (e.g., "google", "twitter", "api_key_service")
217    fn get_auth_entity(&self) -> String;
218
219    /// Get the authentication type and configuration
220    fn get_auth_type(&self) -> AuthType;
221
222    /// Check if authentication is required for this tool
223    fn requires_auth(&self) -> bool {
224        !matches!(self.get_auth_type(), AuthType::None)
225    }
226
227    /// Get additional authentication configuration
228    fn get_auth_config(&self) -> HashMap<String, serde_json::Value> {
229        HashMap::new()
230    }
231}
232
233/// OAuth2 authentication error types
234#[derive(Debug, thiserror::Error)]
235pub enum AuthError {
236    #[error("OAuth2 flow error: {0}")]
237    OAuth2Flow(String),
238
239    #[error("Token expired and refresh failed: {0}")]
240    TokenRefreshFailed(String),
241
242    #[error("Invalid authentication configuration: {0}")]
243    InvalidConfig(String),
244
245    #[error("Authentication required but not configured for entity: {0}")]
246    AuthRequired(String),
247
248    #[error("API key not found for entity: {0}")]
249    ApiKeyNotFound(String),
250
251    #[error("Storage error: {0}")]
252    Storage(#[from] anyhow::Error),
253
254    #[error("Store error: {0}")]
255    StoreError(String),
256
257    #[error("Provider not found: {0}")]
258    ProviderNotFound(String),
259
260    #[error("Server error: {0}")]
261    ServerError(String),
262}
263
264/// Authentication requirement specification for tools
265#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
266#[serde(tag = "type")]
267pub enum AuthRequirement {
268    #[serde(rename = "oauth2")]
269    OAuth2 {
270        provider: String,
271        #[serde(default)]
272        scopes: Vec<String>,
273        #[serde(
274            rename = "authorizationUrl",
275            default,
276            skip_serializing_if = "Option::is_none"
277        )]
278        authorization_url: Option<String>,
279        #[serde(rename = "tokenUrl", default, skip_serializing_if = "Option::is_none")]
280        token_url: Option<String>,
281        #[serde(
282            rename = "refreshUrl",
283            default,
284            skip_serializing_if = "Option::is_none"
285        )]
286        refresh_url: Option<String>,
287        #[serde(
288            rename = "sendRedirectUri",
289            default,
290            skip_serializing_if = "Option::is_none"
291        )]
292        send_redirect_uri: Option<bool>,
293    },
294    #[serde(rename = "secret")]
295    Secret {
296        provider: String,
297        #[serde(default)]
298        fields: Vec<SecretFieldSpec>,
299    },
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
303pub struct SecretFieldSpec {
304    pub key: String,
305    #[serde(default)]
306    pub label: Option<String>,
307    #[serde(default)]
308    pub description: Option<String>,
309    #[serde(default)]
310    pub optional: bool,
311}
312
313/// OAuth2 flow state for managing authorization flows
314#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct OAuth2State {
316    /// Random state parameter for security  
317    pub state: String,
318    /// Provider name for this OAuth flow
319    pub provider_name: String,
320    /// Redirect URI for the OAuth flow (if the provider requires it)
321    #[serde(default, skip_serializing_if = "Option::is_none")]
322    pub redirect_uri: Option<String>,
323    /// User ID if available
324    pub user_id: String,
325    /// Requested scopes
326    pub scopes: Vec<String>,
327    /// Additional metadata
328    pub metadata: HashMap<String, serde_json::Value>,
329    /// State creation time
330    pub created_at: DateTime<Utc>,
331}
332
333pub const PKCE_CODE_VERIFIER_KEY: &str = "pkce_code_verifier";
334pub const PKCE_CODE_CHALLENGE_METHOD: &str = "S256";
335const PKCE_RANDOM_BYTES: usize = 32;
336
337pub fn generate_pkce_pair() -> (String, String) {
338    let mut random = vec![0u8; PKCE_RANDOM_BYTES];
339    rand::fill(&mut random);
340    let verifier = URL_SAFE_NO_PAD.encode(&random);
341    let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
342    (verifier, challenge)
343}
344
345pub fn append_pkce_challenge(auth_url: &str, challenge: &str) -> Result<String, AuthError> {
346    let mut url = Url::parse(auth_url)
347        .map_err(|e| AuthError::InvalidConfig(format!("Invalid authorization URL: {}", e)))?;
348    {
349        let mut pairs = url.query_pairs_mut();
350        pairs.append_pair("code_challenge", challenge);
351        pairs.append_pair("code_challenge_method", PKCE_CODE_CHALLENGE_METHOD);
352    }
353    Ok(url.to_string())
354}
355
356impl OAuth2State {
357    /// Create a new OAuth2 state with provided state parameter
358    pub fn new_with_state(
359        state: String,
360        provider_name: String,
361        redirect_uri: Option<String>,
362        user_id: String,
363        scopes: Vec<String>,
364    ) -> Self {
365        Self {
366            state,
367            provider_name,
368            redirect_uri,
369            user_id,
370            scopes,
371            metadata: HashMap::new(),
372            created_at: Utc::now(),
373        }
374    }
375
376    /// Create a new OAuth2 state with auto-generated state parameter (deprecated)
377    pub fn new(
378        provider_name: String,
379        redirect_uri: Option<String>,
380        user_id: String,
381        scopes: Vec<String>,
382    ) -> Self {
383        Self::new_with_state(
384            uuid::Uuid::new_v4().to_string(),
385            provider_name,
386            redirect_uri,
387            user_id,
388            scopes,
389        )
390    }
391
392    /// Check if the state has expired (default 10 minutes)
393    pub fn is_expired(&self, max_age_seconds: i64) -> bool {
394        let max_age = chrono::Duration::seconds(max_age_seconds);
395        Utc::now() - self.created_at > max_age
396    }
397}
398
399/// Storage-only trait for authentication stores
400/// Implementations only need to handle storage operations
401#[async_trait]
402pub trait ToolAuthStore: Send + Sync {
403    /// Session Management
404    /// Get current authentication session for an entity
405    async fn get_session(
406        &self,
407        auth_entity: &str,
408        user_id: &str,
409    ) -> Result<Option<AuthSession>, AuthError>;
410
411    /// Store authentication session
412    async fn store_session(
413        &self,
414        auth_entity: &str,
415        user_id: &str,
416        session: AuthSession,
417    ) -> Result<(), AuthError>;
418
419    /// Remove authentication session
420    async fn remove_session(&self, auth_entity: &str, user_id: &str) -> Result<bool, AuthError>;
421
422    /// Secret Management
423    /// Store secret for a user (optionally scoped to auth_entity)
424    async fn store_secret(
425        &self,
426        user_id: &str,
427        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
428        secret: AuthSecret,
429    ) -> Result<(), AuthError>;
430
431    /// Get stored secret by key (optionally scoped to auth_entity)  
432    async fn get_secret(
433        &self,
434        user_id: &str,
435        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
436        key: &str,
437    ) -> Result<Option<AuthSecret>, AuthError>;
438
439    /// Remove stored secret by key (optionally scoped to auth_entity)
440    async fn remove_secret(
441        &self,
442        user_id: &str,
443        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
444        key: &str,
445    ) -> Result<bool, AuthError>;
446
447    /// State Management (for OAuth2 flows)
448    /// Store OAuth2 state for security
449    async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError>;
450
451    /// Get OAuth2 state by state parameter
452    async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError>;
453
454    /// Remove OAuth2 state (after successful callback)
455    async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError>;
456
457    async fn list_secrets(&self, user_id: &str) -> Result<HashMap<String, AuthSecret>, AuthError>;
458
459    async fn list_sessions(
460        &self,
461        _user_id: &str,
462    ) -> Result<HashMap<String, AuthSession>, AuthError>;
463}
464
465/// OAuth handler that works with any AuthStore implementation
466#[derive(Clone)]
467pub struct OAuthHandler {
468    store: Arc<dyn ToolAuthStore>,
469    provider_registry: Option<Arc<dyn ProviderRegistry>>,
470    redirect_uri: String,
471}
472
473/// Provider registry trait for getting auth providers
474#[async_trait]
475pub trait ProviderRegistry: Send + Sync {
476    async fn get_provider(&self, provider_name: &str) -> Option<Arc<dyn AuthProvider>>;
477    async fn get_auth_type(&self, provider_name: &str) -> Option<AuthType>;
478    async fn is_provider_available(&self, provider_name: &str) -> bool;
479    async fn list_providers(&self) -> Vec<String>;
480    async fn requires_pkce(&self, _provider_name: &str) -> bool {
481        false
482    }
483}
484
485impl OAuthHandler {
486    pub fn new(store: Arc<dyn ToolAuthStore>, redirect_uri: String) -> Self {
487        Self {
488            store,
489            provider_registry: None,
490            redirect_uri,
491        }
492    }
493
494    pub fn with_provider_registry(
495        store: Arc<dyn ToolAuthStore>,
496        provider_registry: Arc<dyn ProviderRegistry>,
497        redirect_uri: String,
498    ) -> Self {
499        Self {
500            store,
501            provider_registry: Some(provider_registry),
502            redirect_uri,
503        }
504    }
505
506    /// Generate authorization URL for OAuth2 flow
507    pub async fn get_auth_url(
508        &self,
509        auth_entity: &str,
510        user_id: &str,
511        auth_config: &AuthType,
512        scopes: &[String],
513    ) -> Result<String, AuthError> {
514        tracing::debug!(
515            "Getting auth URL for entity: {} user: {:?}",
516            auth_entity,
517            user_id
518        );
519
520        match auth_config {
521            AuthType::OAuth2 {
522                flow_type: OAuth2FlowType::ClientCredentials,
523                ..
524            } => Err(AuthError::InvalidConfig(
525                "Client credentials flow doesn't require authorization URL".to_string(),
526            )),
527            auth_config @ AuthType::OAuth2 {
528                send_redirect_uri, ..
529            } => {
530                // Create OAuth2 state
531                let redirect_uri = if *send_redirect_uri {
532                    Some(self.redirect_uri.clone())
533                } else {
534                    None
535                };
536                let mut state = OAuth2State::new(
537                    auth_entity.to_string(),
538                    redirect_uri.clone(),
539                    user_id.to_string(),
540                    scopes.to_vec(),
541                );
542
543                let mut pkce_challenge = None;
544                if let Some(registry) = &self.provider_registry
545                    && registry.requires_pkce(auth_entity).await
546                {
547                    let (verifier, challenge) = generate_pkce_pair();
548                    state.metadata.insert(
549                        PKCE_CODE_VERIFIER_KEY.to_string(),
550                        serde_json::Value::String(verifier.clone()),
551                    );
552                    pkce_challenge = Some(challenge);
553                }
554
555                // Store the state
556                self.store.store_oauth2_state(state.clone()).await?;
557
558                // Get the appropriate provider using the auth_entity as provider name
559                let provider = self.get_provider(auth_entity).await?;
560
561                // Build the authorization URL
562                let mut auth_url = provider.build_auth_url(
563                    auth_config,
564                    &state.state,
565                    scopes,
566                    redirect_uri.as_deref(),
567                )?;
568
569                if let Some(challenge) = pkce_challenge {
570                    auth_url = append_pkce_challenge(&auth_url, &challenge)?;
571                }
572
573                tracing::debug!("Generated auth URL: {}", auth_url);
574                Ok(auth_url)
575            }
576            AuthType::Secret { .. } => Err(AuthError::InvalidConfig(
577                "Secret authentication doesn't require authorization URL".to_string(),
578            )),
579            AuthType::None => Err(AuthError::InvalidConfig(
580                "No authentication doesn't require authorization URL".to_string(),
581            )),
582        }
583    }
584
585    /// Handle OAuth2 callback and exchange code for tokens
586    pub async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthSession, AuthError> {
587        tracing::debug!("Handling OAuth2 callback with state: {}", state);
588
589        // Get and remove the state
590        let oauth2_state = self.store.get_oauth2_state(state).await?.ok_or_else(|| {
591            AuthError::OAuth2Flow("Invalid or expired state parameter".to_string())
592        })?;
593
594        // Remove the used state
595        self.store.remove_oauth2_state(state).await?;
596
597        // Check if state is expired (10 minutes max)
598        if oauth2_state.is_expired(600) {
599            return Err(AuthError::OAuth2Flow(
600                "OAuth2 state has expired".to_string(),
601            ));
602        }
603
604        // Get auth config from provider registry
605        let auth_config = if let Some(registry) = &self.provider_registry {
606            registry
607                .get_auth_type(&oauth2_state.provider_name)
608                .await
609                .ok_or_else(|| {
610                    AuthError::InvalidConfig(format!(
611                        "No configuration found for provider: {}",
612                        oauth2_state.provider_name
613                    ))
614                })?
615        } else {
616            return Err(AuthError::InvalidConfig(
617                "No provider registry configured".to_string(),
618            ));
619        };
620
621        // Get the appropriate provider
622        let provider = self.get_provider(&oauth2_state.provider_name).await?;
623
624        // Exchange the authorization code for tokens
625        let redirect_uri = match &auth_config {
626            AuthType::OAuth2 {
627                send_redirect_uri, ..
628            } if *send_redirect_uri => oauth2_state
629                .redirect_uri
630                .clone()
631                .or_else(|| Some(self.redirect_uri.clone())),
632            AuthType::OAuth2 { .. } => None,
633            _ => None,
634        };
635        let pkce_code_verifier = oauth2_state
636            .metadata
637            .get(PKCE_CODE_VERIFIER_KEY)
638            .and_then(|v| v.as_str());
639
640        let session = provider
641            .exchange_code(
642                code,
643                redirect_uri.as_deref(),
644                &auth_config,
645                pkce_code_verifier,
646            )
647            .await?;
648
649        // Store the session
650        self.store
651            .store_session(
652                &oauth2_state.provider_name,
653                &oauth2_state.user_id,
654                session.clone(),
655            )
656            .await?;
657
658        tracing::debug!(
659            "Successfully stored auth session for entity: {}",
660            oauth2_state.provider_name
661        );
662        Ok(session)
663    }
664
665    /// Refresh an expired session
666    pub async fn refresh_session(
667        &self,
668        auth_entity: &str,
669        user_id: &str,
670        auth_config: &AuthType,
671    ) -> Result<AuthSession, AuthError> {
672        tracing::debug!(
673            "Refreshing session for entity: {} user: {:?}",
674            auth_entity,
675            user_id
676        );
677
678        // Get current session
679        let current_session = self
680            .store
681            .get_session(auth_entity, user_id)
682            .await?
683            .ok_or_else(|| {
684                AuthError::TokenRefreshFailed("No session found to refresh".to_string())
685            })?;
686
687        let refresh_token = current_session.refresh_token.ok_or_else(|| {
688            AuthError::TokenRefreshFailed("No refresh token available".to_string())
689        })?;
690
691        match auth_config {
692            AuthType::OAuth2 {
693                flow_type: OAuth2FlowType::ClientCredentials,
694                ..
695            } => {
696                // For client credentials, get a new token instead of refreshing
697                let provider = self.get_provider(auth_entity).await?;
698                let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
699
700                // Store the new session
701                self.store
702                    .store_session(auth_entity, user_id, new_session.clone())
703                    .await?;
704                Ok(new_session)
705            }
706            auth_config @ AuthType::OAuth2 { .. } => {
707                // Get the appropriate provider
708                let provider = self.get_provider(auth_entity).await?;
709
710                // Refresh the token
711                let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
712
713                // Store the new session
714                self.store
715                    .store_session(auth_entity, user_id, new_session.clone())
716                    .await?;
717                Ok(new_session)
718            }
719            _ => Err(AuthError::InvalidConfig(
720                "Cannot refresh non-OAuth2 session".to_string(),
721            )),
722        }
723    }
724
725    /// Get session, automatically refreshing if expired
726    pub async fn refresh_get_session(
727        &self,
728        auth_entity: &str,
729        user_id: &str,
730        auth_config: &AuthType,
731    ) -> Result<Option<AuthSession>, AuthError> {
732        match self.store.get_session(auth_entity, user_id).await? {
733            Some(session) => {
734                if session.needs_refresh() {
735                    tracing::debug!(
736                        "Session expired for {}:{:?}, attempting refresh",
737                        auth_entity,
738                        user_id
739                    );
740                    match self
741                        .refresh_session(auth_entity, user_id, auth_config)
742                        .await
743                    {
744                        Ok(refreshed_session) => {
745                            tracing::info!(
746                                "Successfully refreshed session for {}:{:?}",
747                                auth_entity,
748                                user_id
749                            );
750                            Ok(Some(refreshed_session))
751                        }
752                        Err(e) => {
753                            tracing::warn!(
754                                "Failed to refresh session for {}:{:?}: {}",
755                                auth_entity,
756                                user_id,
757                                e
758                            );
759                            Err(e)
760                        }
761                    }
762                } else {
763                    Ok(Some(session))
764                }
765            }
766            None => Ok(None),
767        }
768    }
769
770    async fn get_provider(&self, provider_name: &str) -> Result<Arc<dyn AuthProvider>, AuthError> {
771        if let Some(registry) = &self.provider_registry {
772            registry
773                .get_provider(provider_name)
774                .await
775                .ok_or_else(|| AuthError::ProviderNotFound(provider_name.to_string()))
776        } else {
777            Err(AuthError::InvalidConfig(
778                "No provider registry configured".to_string(),
779            ))
780        }
781    }
782
783    // Storage delegation methods
784    pub async fn get_session(
785        &self,
786        auth_entity: &str,
787        user_id: &str,
788    ) -> Result<Option<AuthSession>, AuthError> {
789        self.store.get_session(auth_entity, user_id).await
790    }
791
792    pub async fn store_session(
793        &self,
794        auth_entity: &str,
795        user_id: &str,
796        session: AuthSession,
797    ) -> Result<(), AuthError> {
798        self.store
799            .store_session(auth_entity, user_id, session)
800            .await
801    }
802
803    pub async fn remove_session(
804        &self,
805        auth_entity: &str,
806        user_id: &str,
807    ) -> Result<bool, AuthError> {
808        self.store.remove_session(auth_entity, user_id).await
809    }
810
811    pub async fn store_secret(
812        &self,
813        user_id: &str,
814        auth_entity: Option<&str>,
815        secret: AuthSecret,
816    ) -> Result<(), AuthError> {
817        self.store.store_secret(user_id, auth_entity, secret).await
818    }
819
820    pub async fn get_secret(
821        &self,
822        user_id: &str,
823        auth_entity: Option<&str>,
824        key: &str,
825    ) -> Result<Option<AuthSecret>, AuthError> {
826        self.store.get_secret(user_id, auth_entity, key).await
827    }
828
829    pub async fn remove_secret(
830        &self,
831        user_id: &str,
832        auth_entity: Option<&str>,
833        key: &str,
834    ) -> Result<bool, AuthError> {
835        self.store.remove_secret(user_id, auth_entity, key).await
836    }
837
838    pub async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError> {
839        self.store.store_oauth2_state(state).await
840    }
841
842    pub async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError> {
843        self.store.get_oauth2_state(state).await
844    }
845
846    pub async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError> {
847        self.store.remove_oauth2_state(state).await
848    }
849
850    pub async fn list_secrets(
851        &self,
852        user_id: &str,
853    ) -> Result<HashMap<String, AuthSecret>, AuthError> {
854        self.store.list_secrets(user_id).await
855    }
856
857    pub async fn list_sessions(
858        &self,
859        user_id: &str,
860    ) -> Result<HashMap<String, AuthSession>, AuthError> {
861        self.store.list_sessions(user_id).await
862    }
863}
864
865/// Authentication provider trait for different OAuth2 providers
866#[async_trait]
867pub trait AuthProvider: Send + Sync {
868    /// Provider name (e.g., "google", "github", "twitter")
869    fn provider_name(&self) -> &str;
870
871    /// Exchange authorization code for access token
872    async fn exchange_code(
873        &self,
874        code: &str,
875        redirect_uri: Option<&str>,
876        auth_config: &AuthType,
877        pkce_code_verifier: Option<&str>,
878    ) -> Result<AuthSession, AuthError>;
879
880    /// Refresh an access token
881    async fn refresh_token(
882        &self,
883        refresh_token: &str,
884        auth_config: &AuthType,
885    ) -> Result<AuthSession, AuthError>;
886
887    /// Build authorization URL
888    fn build_auth_url(
889        &self,
890        auth_config: &AuthType,
891        state: &str,
892        scopes: &[String],
893        redirect_uri: Option<&str>,
894    ) -> Result<String, AuthError>;
895}