Skip to main content

distri_types/
auth.rs

1use async_trait::async_trait;
2use base64::Engine;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use chrono::{DateTime, Utc};
5use rand::RngCore;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use std::collections::HashMap;
10use std::sync::Arc;
11use url::Url;
12
13use crate::McpSession;
14
15fn default_send_redirect_uri() -> bool {
16    true
17}
18
19/// Authentication types supported by the tool auth system
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
21#[serde(tag = "type", content = "config")]
22pub enum AuthType {
23    /// No authentication required
24    #[serde(rename = "none")]
25    None,
26    /// OAuth2 authentication flows
27    #[serde(rename = "oauth2")]
28    OAuth2 {
29        /// OAuth2 flow type
30        flow_type: OAuth2FlowType,
31        /// Authorization URL
32        authorization_url: String,
33        /// Token URL
34        token_url: String,
35        /// Optional refresh URL
36        refresh_url: Option<String>,
37        /// Required scopes
38        scopes: Vec<String>,
39        /// Whether the provider should include redirect_uri in requests
40        #[serde(default = "default_send_redirect_uri")]
41        send_redirect_uri: bool,
42    },
43    /// Secret-based authentication (API keys etc.)
44    #[serde(rename = "secret")]
45    Secret {
46        provider: String,
47        #[serde(default)]
48        fields: Vec<SecretFieldSpec>,
49    },
50}
51
52/// OAuth2 flow types
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
54#[serde(rename_all = "snake_case")]
55pub enum OAuth2FlowType {
56    AuthorizationCode,
57    ClientCredentials,
58    Implicit,
59    Password,
60}
61
62/// OAuth2 authentication session - only contains OAuth tokens
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct AuthSession {
65    /// Access token
66    pub access_token: String,
67    /// Optional refresh token
68    pub refresh_token: Option<String>,
69    /// Token expiry time
70    pub expires_at: Option<DateTime<Utc>>,
71    /// Token type (usually "Bearer")
72    pub token_type: String,
73    /// Granted scopes
74    pub scopes: Vec<String>,
75}
76
77/// Usage limits that can be embedded in tokens
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
79pub struct TokenLimits {
80    /// Maximum tokens per day (None = unlimited)
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub daily_tokens: Option<u64>,
83
84    /// Maximum tokens per month (None = unlimited)
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub monthly_tokens: Option<u64>,
87
88    /// Maximum API calls per day (None = unlimited)
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub daily_calls: Option<u64>,
91
92    /// Maximum API calls per week (None = unlimited)
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub weekly_calls: Option<u64>,
95}
96
97/// Response for issuing access + refresh tokens.
98#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
99pub struct TokenResponse {
100    pub access_token: String,
101    pub refresh_token: String,
102    pub expires_at: i64,
103    /// Identifier for usage tracking (e.g., "blinksheets", "my-app")
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub identifier_id: Option<String>,
106    /// Effective limits applied to this token
107    #[serde(skip_serializing_if = "Option::is_none")]
108    pub limits: Option<TokenLimits>,
109}
110
111/// Secret storage for API keys and other non-OAuth authentication
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct AuthSecret {
114    /// The secret value (API key, token, etc.)
115    pub secret: String,
116    /// Key name for this secret
117    pub key: String,
118}
119impl From<AuthSession> for McpSession {
120    fn from(val: AuthSession) -> Self {
121        McpSession {
122            token: val.access_token,
123            expiry: val.expires_at.map(|dt| dt.into()),
124        }
125    }
126}
127
128impl From<AuthSecret> for McpSession {
129    fn from(val: AuthSecret) -> Self {
130        McpSession {
131            token: val.secret,
132            expiry: None, // Secrets don't expire
133        }
134    }
135}
136
137impl AuthSession {
138    /// Create a new OAuth auth session
139    pub fn new(
140        access_token: String,
141        token_type: Option<String>,
142        expires_in: Option<i64>,
143        refresh_token: Option<String>,
144        scopes: Vec<String>,
145    ) -> Self {
146        let now = Utc::now();
147        let expires_at = expires_in.map(|secs| now + chrono::Duration::seconds(secs));
148
149        AuthSession {
150            access_token,
151            refresh_token,
152            expires_at,
153            token_type: token_type.unwrap_or_else(|| "Bearer".to_string()),
154            scopes,
155        }
156    }
157
158    /// Check if the OAuth token is expired or will expire within the given buffer
159    pub fn is_expired(&self, buffer_seconds: i64) -> bool {
160        match &self.expires_at {
161            Some(expires_at) => {
162                let buffer = chrono::Duration::seconds(buffer_seconds);
163                Utc::now() + buffer >= *expires_at
164            }
165            None => false, // No expiry means it doesn't expire
166        }
167    }
168
169    /// Check if the OAuth token needs refreshing (expired with 5 minute buffer)
170    pub fn needs_refresh(&self) -> bool {
171        self.is_expired(300) // 5 minutes buffer
172    }
173
174    /// Get access token for OAuth sessions
175    pub fn get_access_token(&self) -> &str {
176        &self.access_token
177    }
178
179    /// Update OAuth session with new token data
180    pub fn update_tokens(
181        &mut self,
182        access_token: String,
183        expires_in: Option<i64>,
184        refresh_token: Option<String>,
185    ) {
186        self.access_token = access_token;
187
188        if let Some(secs) = expires_in {
189            self.expires_at = Some(Utc::now() + chrono::Duration::seconds(secs));
190        }
191
192        if let Some(token) = refresh_token {
193            self.refresh_token = Some(token);
194        }
195    }
196}
197
198impl AuthSecret {
199    /// Create a new secret
200    pub fn new(key: String, secret: String) -> Self {
201        AuthSecret { secret, key }
202    }
203
204    /// Get the secret value
205    pub fn get_secret(&self) -> &str {
206        &self.secret
207    }
208
209    /// Get the provider name
210    pub fn get_provider(&self) -> &str {
211        &self.key
212    }
213}
214
215/// Authentication metadata trait for tools
216pub trait AuthMetadata: Send + Sync {
217    /// Get the auth entity identifier (e.g., "google", "twitter", "api_key_service")
218    fn get_auth_entity(&self) -> String;
219
220    /// Get the authentication type and configuration
221    fn get_auth_type(&self) -> AuthType;
222
223    /// Check if authentication is required for this tool
224    fn requires_auth(&self) -> bool {
225        !matches!(self.get_auth_type(), AuthType::None)
226    }
227
228    /// Get additional authentication configuration
229    fn get_auth_config(&self) -> HashMap<String, serde_json::Value> {
230        HashMap::new()
231    }
232}
233
234/// OAuth2 authentication error types
235#[derive(Debug, thiserror::Error)]
236pub enum AuthError {
237    #[error("OAuth2 flow error: {0}")]
238    OAuth2Flow(String),
239
240    #[error("Token expired and refresh failed: {0}")]
241    TokenRefreshFailed(String),
242
243    #[error("Invalid authentication configuration: {0}")]
244    InvalidConfig(String),
245
246    #[error("Authentication required but not configured for entity: {0}")]
247    AuthRequired(String),
248
249    #[error("API key not found for entity: {0}")]
250    ApiKeyNotFound(String),
251
252    #[error("Storage error: {0}")]
253    Storage(#[from] anyhow::Error),
254
255    #[error("Store error: {0}")]
256    StoreError(String),
257
258    #[error("Provider not found: {0}")]
259    ProviderNotFound(String),
260
261    #[error("Server error: {0}")]
262    ServerError(String),
263}
264
265/// Authentication requirement specification for tools
266#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
267#[serde(tag = "type")]
268pub enum AuthRequirement {
269    #[serde(rename = "oauth2")]
270    OAuth2 {
271        provider: String,
272        #[serde(default)]
273        scopes: Vec<String>,
274        #[serde(
275            rename = "authorizationUrl",
276            default,
277            skip_serializing_if = "Option::is_none"
278        )]
279        authorization_url: Option<String>,
280        #[serde(rename = "tokenUrl", default, skip_serializing_if = "Option::is_none")]
281        token_url: Option<String>,
282        #[serde(
283            rename = "refreshUrl",
284            default,
285            skip_serializing_if = "Option::is_none"
286        )]
287        refresh_url: Option<String>,
288        #[serde(
289            rename = "sendRedirectUri",
290            default,
291            skip_serializing_if = "Option::is_none"
292        )]
293        send_redirect_uri: Option<bool>,
294    },
295    #[serde(rename = "secret")]
296    Secret {
297        provider: String,
298        #[serde(default)]
299        fields: Vec<SecretFieldSpec>,
300    },
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
304pub struct SecretFieldSpec {
305    pub key: String,
306    #[serde(default)]
307    pub label: Option<String>,
308    #[serde(default)]
309    pub description: Option<String>,
310    #[serde(default)]
311    pub optional: bool,
312}
313
314/// OAuth2 flow state for managing authorization flows
315#[derive(Debug, Clone, Serialize, Deserialize)]
316pub struct OAuth2State {
317    /// Random state parameter for security  
318    pub state: String,
319    /// Provider name for this OAuth flow
320    pub provider_name: String,
321    /// Redirect URI for the OAuth flow (if the provider requires it)
322    #[serde(default, skip_serializing_if = "Option::is_none")]
323    pub redirect_uri: Option<String>,
324    /// User ID if available
325    pub user_id: String,
326    /// Requested scopes
327    pub scopes: Vec<String>,
328    /// Additional metadata
329    pub metadata: HashMap<String, serde_json::Value>,
330    /// State creation time
331    pub created_at: DateTime<Utc>,
332}
333
334pub const PKCE_CODE_VERIFIER_KEY: &str = "pkce_code_verifier";
335pub const PKCE_CODE_CHALLENGE_METHOD: &str = "S256";
336const PKCE_RANDOM_BYTES: usize = 32;
337
338pub fn generate_pkce_pair() -> (String, String) {
339    let mut random = vec![0u8; PKCE_RANDOM_BYTES];
340    rand::thread_rng().fill_bytes(&mut random);
341    let verifier = URL_SAFE_NO_PAD.encode(&random);
342    let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
343    (verifier, challenge)
344}
345
346pub fn append_pkce_challenge(auth_url: &str, challenge: &str) -> Result<String, AuthError> {
347    let mut url = Url::parse(auth_url)
348        .map_err(|e| AuthError::InvalidConfig(format!("Invalid authorization URL: {}", e)))?;
349    {
350        let mut pairs = url.query_pairs_mut();
351        pairs.append_pair("code_challenge", challenge);
352        pairs.append_pair("code_challenge_method", PKCE_CODE_CHALLENGE_METHOD);
353    }
354    Ok(url.to_string())
355}
356
357impl OAuth2State {
358    /// Create a new OAuth2 state with provided state parameter
359    pub fn new_with_state(
360        state: String,
361        provider_name: String,
362        redirect_uri: Option<String>,
363        user_id: String,
364        scopes: Vec<String>,
365    ) -> Self {
366        Self {
367            state,
368            provider_name,
369            redirect_uri,
370            user_id,
371            scopes,
372            metadata: HashMap::new(),
373            created_at: Utc::now(),
374        }
375    }
376
377    /// Create a new OAuth2 state with auto-generated state parameter (deprecated)
378    pub fn new(
379        provider_name: String,
380        redirect_uri: Option<String>,
381        user_id: String,
382        scopes: Vec<String>,
383    ) -> Self {
384        Self::new_with_state(
385            uuid::Uuid::new_v4().to_string(),
386            provider_name,
387            redirect_uri,
388            user_id,
389            scopes,
390        )
391    }
392
393    /// Check if the state has expired (default 10 minutes)
394    pub fn is_expired(&self, max_age_seconds: i64) -> bool {
395        let max_age = chrono::Duration::seconds(max_age_seconds);
396        Utc::now() - self.created_at > max_age
397    }
398}
399
400/// Storage-only trait for authentication stores
401/// Implementations only need to handle storage operations
402#[async_trait]
403pub trait ToolAuthStore: Send + Sync {
404    /// Session Management
405    /// Get current authentication session for an entity
406    async fn get_session(
407        &self,
408        auth_entity: &str,
409        user_id: &str,
410    ) -> Result<Option<AuthSession>, AuthError>;
411
412    /// Store authentication session
413    async fn store_session(
414        &self,
415        auth_entity: &str,
416        user_id: &str,
417        session: AuthSession,
418    ) -> Result<(), AuthError>;
419
420    /// Remove authentication session
421    async fn remove_session(&self, auth_entity: &str, user_id: &str) -> Result<bool, AuthError>;
422
423    /// Secret Management
424    /// Store secret for a user (optionally scoped to auth_entity)
425    async fn store_secret(
426        &self,
427        user_id: &str,
428        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
429        secret: AuthSecret,
430    ) -> Result<(), AuthError>;
431
432    /// Get stored secret by key (optionally scoped to auth_entity)  
433    async fn get_secret(
434        &self,
435        user_id: &str,
436        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
437        key: &str,
438    ) -> Result<Option<AuthSecret>, AuthError>;
439
440    /// Remove stored secret by key (optionally scoped to auth_entity)
441    async fn remove_secret(
442        &self,
443        user_id: &str,
444        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
445        key: &str,
446    ) -> Result<bool, AuthError>;
447
448    /// State Management (for OAuth2 flows)
449    /// Store OAuth2 state for security
450    async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError>;
451
452    /// Get OAuth2 state by state parameter
453    async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError>;
454
455    /// Remove OAuth2 state (after successful callback)
456    async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError>;
457
458    async fn list_secrets(&self, user_id: &str) -> Result<HashMap<String, AuthSecret>, AuthError>;
459
460    async fn list_sessions(
461        &self,
462        _user_id: &str,
463    ) -> Result<HashMap<String, AuthSession>, AuthError>;
464}
465
466/// OAuth handler that works with any AuthStore implementation
467#[derive(Clone)]
468pub struct OAuthHandler {
469    store: Arc<dyn ToolAuthStore>,
470    provider_registry: Option<Arc<dyn ProviderRegistry>>,
471    redirect_uri: String,
472}
473
474/// Provider registry trait for getting auth providers
475#[async_trait]
476pub trait ProviderRegistry: Send + Sync {
477    async fn get_provider(&self, provider_name: &str) -> Option<Arc<dyn AuthProvider>>;
478    async fn get_auth_type(&self, provider_name: &str) -> Option<AuthType>;
479    async fn is_provider_available(&self, provider_name: &str) -> bool;
480    async fn list_providers(&self) -> Vec<String>;
481    async fn requires_pkce(&self, _provider_name: &str) -> bool {
482        false
483    }
484}
485
486impl OAuthHandler {
487    pub fn new(store: Arc<dyn ToolAuthStore>, redirect_uri: String) -> Self {
488        Self {
489            store,
490            provider_registry: None,
491            redirect_uri,
492        }
493    }
494
495    pub fn with_provider_registry(
496        store: Arc<dyn ToolAuthStore>,
497        provider_registry: Arc<dyn ProviderRegistry>,
498        redirect_uri: String,
499    ) -> Self {
500        Self {
501            store,
502            provider_registry: Some(provider_registry),
503            redirect_uri,
504        }
505    }
506
507    /// Generate authorization URL for OAuth2 flow
508    pub async fn get_auth_url(
509        &self,
510        auth_entity: &str,
511        user_id: &str,
512        auth_config: &AuthType,
513        scopes: &[String],
514    ) -> Result<String, AuthError> {
515        tracing::debug!(
516            "Getting auth URL for entity: {} user: {:?}",
517            auth_entity,
518            user_id
519        );
520
521        match auth_config {
522            AuthType::OAuth2 {
523                flow_type: OAuth2FlowType::ClientCredentials,
524                ..
525            } => Err(AuthError::InvalidConfig(
526                "Client credentials flow doesn't require authorization URL".to_string(),
527            )),
528            auth_config @ AuthType::OAuth2 {
529                send_redirect_uri, ..
530            } => {
531                // Create OAuth2 state
532                let redirect_uri = if *send_redirect_uri {
533                    Some(self.redirect_uri.clone())
534                } else {
535                    None
536                };
537                let mut state = OAuth2State::new(
538                    auth_entity.to_string(),
539                    redirect_uri.clone(),
540                    user_id.to_string(),
541                    scopes.to_vec(),
542                );
543
544                let mut pkce_challenge = None;
545                if let Some(registry) = &self.provider_registry
546                    && registry.requires_pkce(auth_entity).await {
547                        let (verifier, challenge) = generate_pkce_pair();
548                        state.metadata.insert(
549                            PKCE_CODE_VERIFIER_KEY.to_string(),
550                            serde_json::Value::String(verifier.clone()),
551                        );
552                        pkce_challenge = Some(challenge);
553                    }
554
555                // Store the state
556                self.store.store_oauth2_state(state.clone()).await?;
557
558                // Get the appropriate provider using the auth_entity as provider name
559                let provider = self.get_provider(auth_entity).await?;
560
561                // Build the authorization URL
562                let mut auth_url = provider.build_auth_url(
563                    auth_config,
564                    &state.state,
565                    scopes,
566                    redirect_uri.as_deref(),
567                )?;
568
569                if let Some(challenge) = pkce_challenge {
570                    auth_url = append_pkce_challenge(&auth_url, &challenge)?;
571                }
572
573                tracing::debug!("Generated auth URL: {}", auth_url);
574                Ok(auth_url)
575            }
576            AuthType::Secret { .. } => Err(AuthError::InvalidConfig(
577                "Secret authentication doesn't require authorization URL".to_string(),
578            )),
579            AuthType::None => Err(AuthError::InvalidConfig(
580                "No authentication doesn't require authorization URL".to_string(),
581            )),
582        }
583    }
584
585    /// Handle OAuth2 callback and exchange code for tokens
586    pub async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthSession, AuthError> {
587        tracing::debug!("Handling OAuth2 callback with state: {}", state);
588
589        // Get and remove the state
590        let oauth2_state = self.store.get_oauth2_state(state).await?.ok_or_else(|| {
591            AuthError::OAuth2Flow("Invalid or expired state parameter".to_string())
592        })?;
593
594        // Remove the used state
595        self.store.remove_oauth2_state(state).await?;
596
597        // Check if state is expired (10 minutes max)
598        if oauth2_state.is_expired(600) {
599            return Err(AuthError::OAuth2Flow(
600                "OAuth2 state has expired".to_string(),
601            ));
602        }
603
604        // Get auth config from provider registry
605        let auth_config = if let Some(registry) = &self.provider_registry {
606            registry
607                .get_auth_type(&oauth2_state.provider_name)
608                .await
609                .ok_or_else(|| {
610                    AuthError::InvalidConfig(format!(
611                        "No configuration found for provider: {}",
612                        oauth2_state.provider_name
613                    ))
614                })?
615        } else {
616            return Err(AuthError::InvalidConfig(
617                "No provider registry configured".to_string(),
618            ));
619        };
620
621        // Get the appropriate provider
622        let provider = self.get_provider(&oauth2_state.provider_name).await?;
623
624        // Exchange the authorization code for tokens
625        let redirect_uri = match &auth_config {
626            AuthType::OAuth2 {
627                send_redirect_uri, ..
628            } if *send_redirect_uri => oauth2_state
629                .redirect_uri
630                .clone()
631                .or_else(|| Some(self.redirect_uri.clone())),
632            AuthType::OAuth2 { .. } => None,
633            _ => None,
634        };
635        let pkce_code_verifier = oauth2_state
636            .metadata
637            .get(PKCE_CODE_VERIFIER_KEY)
638            .and_then(|v| v.as_str());
639
640        let session = provider
641            .exchange_code(
642                code,
643                redirect_uri.as_deref(),
644                &auth_config,
645                pkce_code_verifier,
646            )
647            .await?;
648
649        // Store the session
650        self.store
651            .store_session(
652                &oauth2_state.provider_name,
653                &oauth2_state.user_id,
654                session.clone(),
655            )
656            .await?;
657
658        tracing::debug!(
659            "Successfully stored auth session for entity: {}",
660            oauth2_state.provider_name
661        );
662        Ok(session)
663    }
664
665    /// Refresh an expired session
666    pub async fn refresh_session(
667        &self,
668        auth_entity: &str,
669        user_id: &str,
670        auth_config: &AuthType,
671    ) -> Result<AuthSession, AuthError> {
672        tracing::debug!(
673            "Refreshing session for entity: {} user: {:?}",
674            auth_entity,
675            user_id
676        );
677
678        // Get current session
679        let current_session = self
680            .store
681            .get_session(auth_entity, user_id)
682            .await?
683            .ok_or_else(|| {
684                AuthError::TokenRefreshFailed("No session found to refresh".to_string())
685            })?;
686
687        let refresh_token = current_session.refresh_token.ok_or_else(|| {
688            AuthError::TokenRefreshFailed("No refresh token available".to_string())
689        })?;
690
691        match auth_config {
692            AuthType::OAuth2 {
693                flow_type: OAuth2FlowType::ClientCredentials,
694                ..
695            } => {
696                // For client credentials, get a new token instead of refreshing
697                let provider = self.get_provider(auth_entity).await?;
698                let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
699
700                // Store the new session
701                self.store
702                    .store_session(auth_entity, user_id, new_session.clone())
703                    .await?;
704                Ok(new_session)
705            }
706            auth_config @ AuthType::OAuth2 { .. } => {
707                // Get the appropriate provider
708                let provider = self.get_provider(auth_entity).await?;
709
710                // Refresh the token
711                let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
712
713                // Store the new session
714                self.store
715                    .store_session(auth_entity, user_id, new_session.clone())
716                    .await?;
717                Ok(new_session)
718            }
719            _ => Err(AuthError::InvalidConfig(
720                "Cannot refresh non-OAuth2 session".to_string(),
721            )),
722        }
723    }
724
725    /// Get session, automatically refreshing if expired
726    pub async fn refresh_get_session(
727        &self,
728        auth_entity: &str,
729        user_id: &str,
730        auth_config: &AuthType,
731    ) -> Result<Option<AuthSession>, AuthError> {
732        match self.store.get_session(auth_entity, user_id).await? {
733            Some(session) => {
734                if session.needs_refresh() {
735                    tracing::debug!(
736                        "Session expired for {}:{:?}, attempting refresh",
737                        auth_entity,
738                        user_id
739                    );
740                    match self
741                        .refresh_session(auth_entity, user_id, auth_config)
742                        .await
743                    {
744                        Ok(refreshed_session) => {
745                            tracing::info!(
746                                "Successfully refreshed session for {}:{:?}",
747                                auth_entity,
748                                user_id
749                            );
750                            Ok(Some(refreshed_session))
751                        }
752                        Err(e) => {
753                            tracing::warn!(
754                                "Failed to refresh session for {}:{:?}: {}",
755                                auth_entity,
756                                user_id,
757                                e
758                            );
759                            Err(e)
760                        }
761                    }
762                } else {
763                    Ok(Some(session))
764                }
765            }
766            None => Ok(None),
767        }
768    }
769
770    async fn get_provider(&self, provider_name: &str) -> Result<Arc<dyn AuthProvider>, AuthError> {
771        if let Some(registry) = &self.provider_registry {
772            registry
773                .get_provider(provider_name)
774                .await
775                .ok_or_else(|| AuthError::ProviderNotFound(provider_name.to_string()))
776        } else {
777            Err(AuthError::InvalidConfig(
778                "No provider registry configured".to_string(),
779            ))
780        }
781    }
782
783    // Storage delegation methods
784    pub async fn get_session(
785        &self,
786        auth_entity: &str,
787        user_id: &str,
788    ) -> Result<Option<AuthSession>, AuthError> {
789        self.store.get_session(auth_entity, user_id).await
790    }
791
792    pub async fn store_session(
793        &self,
794        auth_entity: &str,
795        user_id: &str,
796        session: AuthSession,
797    ) -> Result<(), AuthError> {
798        self.store
799            .store_session(auth_entity, user_id, session)
800            .await
801    }
802
803    pub async fn remove_session(
804        &self,
805        auth_entity: &str,
806        user_id: &str,
807    ) -> Result<bool, AuthError> {
808        self.store.remove_session(auth_entity, user_id).await
809    }
810
811    pub async fn store_secret(
812        &self,
813        user_id: &str,
814        auth_entity: Option<&str>,
815        secret: AuthSecret,
816    ) -> Result<(), AuthError> {
817        self.store.store_secret(user_id, auth_entity, secret).await
818    }
819
820    pub async fn get_secret(
821        &self,
822        user_id: &str,
823        auth_entity: Option<&str>,
824        key: &str,
825    ) -> Result<Option<AuthSecret>, AuthError> {
826        self.store.get_secret(user_id, auth_entity, key).await
827    }
828
829    pub async fn remove_secret(
830        &self,
831        user_id: &str,
832        auth_entity: Option<&str>,
833        key: &str,
834    ) -> Result<bool, AuthError> {
835        self.store.remove_secret(user_id, auth_entity, key).await
836    }
837
838    pub async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError> {
839        self.store.store_oauth2_state(state).await
840    }
841
842    pub async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError> {
843        self.store.get_oauth2_state(state).await
844    }
845
846    pub async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError> {
847        self.store.remove_oauth2_state(state).await
848    }
849
850    pub async fn list_secrets(
851        &self,
852        user_id: &str,
853    ) -> Result<HashMap<String, AuthSecret>, AuthError> {
854        self.store.list_secrets(user_id).await
855    }
856
857    pub async fn list_sessions(
858        &self,
859        user_id: &str,
860    ) -> Result<HashMap<String, AuthSession>, AuthError> {
861        self.store.list_sessions(user_id).await
862    }
863}
864
865/// Authentication provider trait for different OAuth2 providers
866#[async_trait]
867pub trait AuthProvider: Send + Sync {
868    /// Provider name (e.g., "google", "github", "twitter")
869    fn provider_name(&self) -> &str;
870
871    /// Exchange authorization code for access token
872    async fn exchange_code(
873        &self,
874        code: &str,
875        redirect_uri: Option<&str>,
876        auth_config: &AuthType,
877        pkce_code_verifier: Option<&str>,
878    ) -> Result<AuthSession, AuthError>;
879
880    /// Refresh an access token
881    async fn refresh_token(
882        &self,
883        refresh_token: &str,
884        auth_config: &AuthType,
885    ) -> Result<AuthSession, AuthError>;
886
887    /// Build authorization URL
888    fn build_auth_url(
889        &self,
890        auth_config: &AuthType,
891        state: &str,
892        scopes: &[String],
893        redirect_uri: Option<&str>,
894    ) -> Result<String, AuthError>;
895}