Skip to main content

distri_types/
auth.rs

1use async_trait::async_trait;
2use base64::Engine;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use chrono::{DateTime, Utc};
5use rand::RngCore;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use std::collections::HashMap;
10use std::sync::Arc;
11use url::Url;
12
13use crate::McpSession;
14
15fn default_send_redirect_uri() -> bool {
16    true
17}
18
19/// Authentication types supported by the tool auth system
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
21#[serde(tag = "type", content = "config")]
22pub enum AuthType {
23    /// No authentication required
24    #[serde(rename = "none")]
25    None,
26    /// OAuth2 authentication flows
27    #[serde(rename = "oauth2")]
28    OAuth2 {
29        /// OAuth2 flow type
30        flow_type: OAuth2FlowType,
31        /// Authorization URL
32        authorization_url: String,
33        /// Token URL
34        token_url: String,
35        /// Optional refresh URL
36        refresh_url: Option<String>,
37        /// Required scopes
38        scopes: Vec<String>,
39        /// Whether the provider should include redirect_uri in requests
40        #[serde(default = "default_send_redirect_uri")]
41        send_redirect_uri: bool,
42    },
43    /// Secret-based authentication (API keys etc.)
44    #[serde(rename = "secret")]
45    Secret {
46        provider: String,
47        #[serde(default)]
48        fields: Vec<SecretFieldSpec>,
49    },
50}
51
52/// OAuth2 flow types
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
54#[serde(rename_all = "snake_case")]
55pub enum OAuth2FlowType {
56    AuthorizationCode,
57    ClientCredentials,
58    Implicit,
59    Password,
60}
61
62/// OAuth2 authentication session - only contains OAuth tokens
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct AuthSession {
65    /// Access token
66    pub access_token: String,
67    /// Optional refresh token
68    pub refresh_token: Option<String>,
69    /// Token expiry time
70    pub expires_at: Option<DateTime<Utc>>,
71    /// Token type (usually "Bearer")
72    pub token_type: String,
73    /// Granted scopes
74    pub scopes: Vec<String>,
75}
76
77/// Usage limits that can be embedded in tokens
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
79pub struct TokenLimits {
80    /// Maximum tokens per day (None = unlimited)
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub daily_tokens: Option<u64>,
83
84    /// Maximum tokens per month (None = unlimited)
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub monthly_tokens: Option<u64>,
87
88    /// Maximum API calls per day (None = unlimited)
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub daily_calls: Option<u64>,
91
92    /// Maximum API calls per week (None = unlimited)
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub weekly_calls: Option<u64>,
95}
96
97/// Response for issuing access + refresh tokens.
98#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
99pub struct TokenResponse {
100    pub access_token: String,
101    pub refresh_token: String,
102    pub expires_at: i64,
103    /// Identifier for usage tracking (e.g., "blinksheets", "my-app")
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub identifier_id: Option<String>,
106    /// Effective limits applied to this token
107    #[serde(skip_serializing_if = "Option::is_none")]
108    pub limits: Option<TokenLimits>,
109}
110
111/// Secret storage for API keys and other non-OAuth authentication
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct AuthSecret {
114    /// The secret value (API key, token, etc.)
115    pub secret: String,
116    /// Key name for this secret
117    pub key: String,
118}
119impl Into<McpSession> for AuthSession {
120    fn into(self) -> McpSession {
121        McpSession {
122            token: self.access_token,
123            expiry: self.expires_at.map(|dt| dt.into()),
124        }
125    }
126}
127
128impl Into<McpSession> for AuthSecret {
129    fn into(self) -> McpSession {
130        McpSession {
131            token: self.secret,
132            expiry: None, // Secrets don't expire
133        }
134    }
135}
136
137impl AuthSession {
138    /// Create a new OAuth auth session
139    pub fn new(
140        access_token: String,
141        token_type: Option<String>,
142        expires_in: Option<i64>,
143        refresh_token: Option<String>,
144        scopes: Vec<String>,
145    ) -> Self {
146        let now = Utc::now();
147        let expires_at = expires_in.map(|secs| now + chrono::Duration::seconds(secs));
148
149        AuthSession {
150            access_token,
151            refresh_token,
152            expires_at,
153            token_type: token_type.unwrap_or_else(|| "Bearer".to_string()),
154            scopes,
155        }
156    }
157
158    /// Check if the OAuth token is expired or will expire within the given buffer
159    pub fn is_expired(&self, buffer_seconds: i64) -> bool {
160        match &self.expires_at {
161            Some(expires_at) => {
162                let buffer = chrono::Duration::seconds(buffer_seconds);
163                Utc::now() + buffer >= *expires_at
164            }
165            None => false, // No expiry means it doesn't expire
166        }
167    }
168
169    /// Check if the OAuth token needs refreshing (expired with 5 minute buffer)
170    pub fn needs_refresh(&self) -> bool {
171        self.is_expired(300) // 5 minutes buffer
172    }
173
174    /// Get access token for OAuth sessions
175    pub fn get_access_token(&self) -> &str {
176        &self.access_token
177    }
178
179    /// Update OAuth session with new token data
180    pub fn update_tokens(
181        &mut self,
182        access_token: String,
183        expires_in: Option<i64>,
184        refresh_token: Option<String>,
185    ) {
186        self.access_token = access_token;
187
188        if let Some(secs) = expires_in {
189            self.expires_at = Some(Utc::now() + chrono::Duration::seconds(secs));
190        }
191
192        if let Some(token) = refresh_token {
193            self.refresh_token = Some(token);
194        }
195    }
196}
197
198impl AuthSecret {
199    /// Create a new secret
200    pub fn new(key: String, secret: String) -> Self {
201        AuthSecret { secret, key }
202    }
203
204    /// Get the secret value
205    pub fn get_secret(&self) -> &str {
206        &self.secret
207    }
208
209    /// Get the provider name
210    pub fn get_provider(&self) -> &str {
211        &self.key
212    }
213}
214
215/// Authentication metadata trait for tools
216pub trait AuthMetadata: Send + Sync {
217    /// Get the auth entity identifier (e.g., "google", "twitter", "api_key_service")
218    fn get_auth_entity(&self) -> String;
219
220    /// Get the authentication type and configuration
221    fn get_auth_type(&self) -> AuthType;
222
223    /// Check if authentication is required for this tool
224    fn requires_auth(&self) -> bool {
225        !matches!(self.get_auth_type(), AuthType::None)
226    }
227
228    /// Get additional authentication configuration
229    fn get_auth_config(&self) -> HashMap<String, serde_json::Value> {
230        HashMap::new()
231    }
232}
233
234/// OAuth2 authentication error types
235#[derive(Debug, thiserror::Error)]
236pub enum AuthError {
237    #[error("OAuth2 flow error: {0}")]
238    OAuth2Flow(String),
239
240    #[error("Token expired and refresh failed: {0}")]
241    TokenRefreshFailed(String),
242
243    #[error("Invalid authentication configuration: {0}")]
244    InvalidConfig(String),
245
246    #[error("Authentication required but not configured for entity: {0}")]
247    AuthRequired(String),
248
249    #[error("API key not found for entity: {0}")]
250    ApiKeyNotFound(String),
251
252    #[error("Storage error: {0}")]
253    Storage(#[from] anyhow::Error),
254
255    #[error("Store error: {0}")]
256    StoreError(String),
257
258    #[error("Provider not found: {0}")]
259    ProviderNotFound(String),
260
261    #[error("Server error: {0}")]
262    ServerError(String),
263}
264
265/// Authentication requirement specification for tools
266#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
267#[serde(tag = "type")]
268pub enum AuthRequirement {
269    #[serde(rename = "oauth2")]
270    OAuth2 {
271        provider: String,
272        #[serde(default)]
273        scopes: Vec<String>,
274        #[serde(
275            rename = "authorizationUrl",
276            default,
277            skip_serializing_if = "Option::is_none"
278        )]
279        authorization_url: Option<String>,
280        #[serde(rename = "tokenUrl", default, skip_serializing_if = "Option::is_none")]
281        token_url: Option<String>,
282        #[serde(
283            rename = "refreshUrl",
284            default,
285            skip_serializing_if = "Option::is_none"
286        )]
287        refresh_url: Option<String>,
288        #[serde(
289            rename = "sendRedirectUri",
290            default,
291            skip_serializing_if = "Option::is_none"
292        )]
293        send_redirect_uri: Option<bool>,
294    },
295    #[serde(rename = "secret")]
296    Secret {
297        provider: String,
298        #[serde(default)]
299        fields: Vec<SecretFieldSpec>,
300    },
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
304pub struct SecretFieldSpec {
305    pub key: String,
306    #[serde(default)]
307    pub label: Option<String>,
308    #[serde(default)]
309    pub description: Option<String>,
310    #[serde(default)]
311    pub optional: bool,
312}
313
314/// OAuth2 flow state for managing authorization flows
315#[derive(Debug, Clone, Serialize, Deserialize)]
316pub struct OAuth2State {
317    /// Random state parameter for security  
318    pub state: String,
319    /// Provider name for this OAuth flow
320    pub provider_name: String,
321    /// Redirect URI for the OAuth flow (if the provider requires it)
322    #[serde(default, skip_serializing_if = "Option::is_none")]
323    pub redirect_uri: Option<String>,
324    /// User ID if available
325    pub user_id: String,
326    /// Requested scopes
327    pub scopes: Vec<String>,
328    /// Additional metadata
329    pub metadata: HashMap<String, serde_json::Value>,
330    /// State creation time
331    pub created_at: DateTime<Utc>,
332}
333
334pub const PKCE_CODE_VERIFIER_KEY: &str = "pkce_code_verifier";
335pub const PKCE_CODE_CHALLENGE_METHOD: &str = "S256";
336const PKCE_RANDOM_BYTES: usize = 32;
337
338pub fn generate_pkce_pair() -> (String, String) {
339    let mut random = vec![0u8; PKCE_RANDOM_BYTES];
340    rand::thread_rng().fill_bytes(&mut random);
341    let verifier = URL_SAFE_NO_PAD.encode(&random);
342    let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
343    (verifier, challenge)
344}
345
346pub fn append_pkce_challenge(auth_url: &str, challenge: &str) -> Result<String, AuthError> {
347    let mut url = Url::parse(auth_url)
348        .map_err(|e| AuthError::InvalidConfig(format!("Invalid authorization URL: {}", e)))?;
349    {
350        let mut pairs = url.query_pairs_mut();
351        pairs.append_pair("code_challenge", challenge);
352        pairs.append_pair("code_challenge_method", PKCE_CODE_CHALLENGE_METHOD);
353    }
354    Ok(url.to_string())
355}
356
357impl OAuth2State {
358    /// Create a new OAuth2 state with provided state parameter
359    pub fn new_with_state(
360        state: String,
361        provider_name: String,
362        redirect_uri: Option<String>,
363        user_id: String,
364        scopes: Vec<String>,
365    ) -> Self {
366        Self {
367            state,
368            provider_name,
369            redirect_uri,
370            user_id,
371            scopes,
372            metadata: HashMap::new(),
373            created_at: Utc::now(),
374        }
375    }
376
377    /// Create a new OAuth2 state with auto-generated state parameter (deprecated)
378    pub fn new(
379        provider_name: String,
380        redirect_uri: Option<String>,
381        user_id: String,
382        scopes: Vec<String>,
383    ) -> Self {
384        Self::new_with_state(
385            uuid::Uuid::new_v4().to_string(),
386            provider_name,
387            redirect_uri,
388            user_id,
389            scopes,
390        )
391    }
392
393    /// Check if the state has expired (default 10 minutes)
394    pub fn is_expired(&self, max_age_seconds: i64) -> bool {
395        let max_age = chrono::Duration::seconds(max_age_seconds);
396        Utc::now() - self.created_at > max_age
397    }
398}
399
400/// Storage-only trait for authentication stores
401/// Implementations only need to handle storage operations
402#[async_trait]
403pub trait ToolAuthStore: Send + Sync {
404    /// Session Management
405
406    /// Get current authentication session for an entity
407    async fn get_session(
408        &self,
409        auth_entity: &str,
410        user_id: &str,
411    ) -> Result<Option<AuthSession>, AuthError>;
412
413    /// Store authentication session
414    async fn store_session(
415        &self,
416        auth_entity: &str,
417        user_id: &str,
418        session: AuthSession,
419    ) -> Result<(), AuthError>;
420
421    /// Remove authentication session
422    async fn remove_session(&self, auth_entity: &str, user_id: &str) -> Result<bool, AuthError>;
423
424    /// Secret Management
425
426    /// Store secret for a user (optionally scoped to auth_entity)
427    async fn store_secret(
428        &self,
429        user_id: &str,
430        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
431        secret: AuthSecret,
432    ) -> Result<(), AuthError>;
433
434    /// Get stored secret by key (optionally scoped to auth_entity)  
435    async fn get_secret(
436        &self,
437        user_id: &str,
438        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
439        key: &str,
440    ) -> Result<Option<AuthSecret>, AuthError>;
441
442    /// Remove stored secret by key (optionally scoped to auth_entity)
443    async fn remove_secret(
444        &self,
445        user_id: &str,
446        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
447        key: &str,
448    ) -> Result<bool, AuthError>;
449
450    /// State Management (for OAuth2 flows)
451
452    /// Store OAuth2 state for security
453    async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError>;
454
455    /// Get OAuth2 state by state parameter
456    async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError>;
457
458    /// Remove OAuth2 state (after successful callback)
459    async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError>;
460
461    async fn list_secrets(&self, user_id: &str) -> Result<HashMap<String, AuthSecret>, AuthError>;
462
463    async fn list_sessions(
464        &self,
465        _user_id: &str,
466    ) -> Result<HashMap<String, AuthSession>, AuthError>;
467}
468
469/// OAuth handler that works with any AuthStore implementation
470#[derive(Clone)]
471pub struct OAuthHandler {
472    store: Arc<dyn ToolAuthStore>,
473    provider_registry: Option<Arc<dyn ProviderRegistry>>,
474    redirect_uri: String,
475}
476
477/// Provider registry trait for getting auth providers
478#[async_trait]
479pub trait ProviderRegistry: Send + Sync {
480    async fn get_provider(&self, provider_name: &str) -> Option<Arc<dyn AuthProvider>>;
481    async fn get_auth_type(&self, provider_name: &str) -> Option<AuthType>;
482    async fn is_provider_available(&self, provider_name: &str) -> bool;
483    async fn list_providers(&self) -> Vec<String>;
484    async fn requires_pkce(&self, _provider_name: &str) -> bool {
485        false
486    }
487}
488
489impl OAuthHandler {
490    pub fn new(store: Arc<dyn ToolAuthStore>, redirect_uri: String) -> Self {
491        Self {
492            store,
493            provider_registry: None,
494            redirect_uri,
495        }
496    }
497
498    pub fn with_provider_registry(
499        store: Arc<dyn ToolAuthStore>,
500        provider_registry: Arc<dyn ProviderRegistry>,
501        redirect_uri: String,
502    ) -> Self {
503        Self {
504            store,
505            provider_registry: Some(provider_registry),
506            redirect_uri,
507        }
508    }
509
510    /// Generate authorization URL for OAuth2 flow
511    pub async fn get_auth_url(
512        &self,
513        auth_entity: &str,
514        user_id: &str,
515        auth_config: &AuthType,
516        scopes: &[String],
517    ) -> Result<String, AuthError> {
518        tracing::debug!(
519            "Getting auth URL for entity: {} user: {:?}",
520            auth_entity,
521            user_id
522        );
523
524        match auth_config {
525            AuthType::OAuth2 {
526                flow_type: OAuth2FlowType::ClientCredentials,
527                ..
528            } => Err(AuthError::InvalidConfig(
529                "Client credentials flow doesn't require authorization URL".to_string(),
530            )),
531            auth_config @ AuthType::OAuth2 {
532                send_redirect_uri, ..
533            } => {
534                // Create OAuth2 state
535                let redirect_uri = if *send_redirect_uri {
536                    Some(self.redirect_uri.clone())
537                } else {
538                    None
539                };
540                let mut state = OAuth2State::new(
541                    auth_entity.to_string(),
542                    redirect_uri.clone(),
543                    user_id.to_string(),
544                    scopes.to_vec(),
545                );
546
547                let mut pkce_challenge = None;
548                if let Some(registry) = &self.provider_registry {
549                    if registry.requires_pkce(auth_entity).await {
550                        let (verifier, challenge) = generate_pkce_pair();
551                        state.metadata.insert(
552                            PKCE_CODE_VERIFIER_KEY.to_string(),
553                            serde_json::Value::String(verifier.clone()),
554                        );
555                        pkce_challenge = Some(challenge);
556                    }
557                }
558
559                // Store the state
560                self.store.store_oauth2_state(state.clone()).await?;
561
562                // Get the appropriate provider using the auth_entity as provider name
563                let provider = self.get_provider(auth_entity).await?;
564
565                // Build the authorization URL
566                let mut auth_url = provider.build_auth_url(
567                    auth_config,
568                    &state.state,
569                    scopes,
570                    redirect_uri.as_deref(),
571                )?;
572
573                if let Some(challenge) = pkce_challenge {
574                    auth_url = append_pkce_challenge(&auth_url, &challenge)?;
575                }
576
577                tracing::debug!("Generated auth URL: {}", auth_url);
578                Ok(auth_url)
579            }
580            AuthType::Secret { .. } => Err(AuthError::InvalidConfig(
581                "Secret authentication doesn't require authorization URL".to_string(),
582            )),
583            AuthType::None => Err(AuthError::InvalidConfig(
584                "No authentication doesn't require authorization URL".to_string(),
585            )),
586        }
587    }
588
589    /// Handle OAuth2 callback and exchange code for tokens
590    pub async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthSession, AuthError> {
591        tracing::debug!("Handling OAuth2 callback with state: {}", state);
592
593        // Get and remove the state
594        let oauth2_state = self.store.get_oauth2_state(state).await?.ok_or_else(|| {
595            AuthError::OAuth2Flow("Invalid or expired state parameter".to_string())
596        })?;
597
598        // Remove the used state
599        self.store.remove_oauth2_state(state).await?;
600
601        // Check if state is expired (10 minutes max)
602        if oauth2_state.is_expired(600) {
603            return Err(AuthError::OAuth2Flow(
604                "OAuth2 state has expired".to_string(),
605            ));
606        }
607
608        // Get auth config from provider registry
609        let auth_config = if let Some(registry) = &self.provider_registry {
610            registry
611                .get_auth_type(&oauth2_state.provider_name)
612                .await
613                .ok_or_else(|| {
614                    AuthError::InvalidConfig(format!(
615                        "No configuration found for provider: {}",
616                        oauth2_state.provider_name
617                    ))
618                })?
619        } else {
620            return Err(AuthError::InvalidConfig(
621                "No provider registry configured".to_string(),
622            ));
623        };
624
625        // Get the appropriate provider
626        let provider = self.get_provider(&oauth2_state.provider_name).await?;
627
628        // Exchange the authorization code for tokens
629        let redirect_uri = match &auth_config {
630            AuthType::OAuth2 {
631                send_redirect_uri, ..
632            } if *send_redirect_uri => oauth2_state
633                .redirect_uri
634                .clone()
635                .or_else(|| Some(self.redirect_uri.clone())),
636            AuthType::OAuth2 { .. } => None,
637            _ => None,
638        };
639        let pkce_code_verifier = oauth2_state
640            .metadata
641            .get(PKCE_CODE_VERIFIER_KEY)
642            .and_then(|v| v.as_str());
643
644        let session = provider
645            .exchange_code(
646                code,
647                redirect_uri.as_deref(),
648                &auth_config,
649                pkce_code_verifier,
650            )
651            .await?;
652
653        // Store the session
654        self.store
655            .store_session(
656                &oauth2_state.provider_name,
657                &oauth2_state.user_id,
658                session.clone(),
659            )
660            .await?;
661
662        tracing::debug!(
663            "Successfully stored auth session for entity: {}",
664            oauth2_state.provider_name
665        );
666        Ok(session)
667    }
668
669    /// Refresh an expired session
670    pub async fn refresh_session(
671        &self,
672        auth_entity: &str,
673        user_id: &str,
674        auth_config: &AuthType,
675    ) -> Result<AuthSession, AuthError> {
676        tracing::debug!(
677            "Refreshing session for entity: {} user: {:?}",
678            auth_entity,
679            user_id
680        );
681
682        // Get current session
683        let current_session = self
684            .store
685            .get_session(auth_entity, &user_id)
686            .await?
687            .ok_or_else(|| {
688                AuthError::TokenRefreshFailed("No session found to refresh".to_string())
689            })?;
690
691        let refresh_token = current_session.refresh_token.ok_or_else(|| {
692            AuthError::TokenRefreshFailed("No refresh token available".to_string())
693        })?;
694
695        match auth_config {
696            AuthType::OAuth2 {
697                flow_type: OAuth2FlowType::ClientCredentials,
698                ..
699            } => {
700                // For client credentials, get a new token instead of refreshing
701                let provider = self.get_provider(auth_entity).await?;
702                let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
703
704                // Store the new session
705                self.store
706                    .store_session(auth_entity, &user_id, new_session.clone())
707                    .await?;
708                Ok(new_session)
709            }
710            auth_config @ AuthType::OAuth2 { .. } => {
711                // Get the appropriate provider
712                let provider = self.get_provider(auth_entity).await?;
713
714                // Refresh the token
715                let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
716
717                // Store the new session
718                self.store
719                    .store_session(auth_entity, &user_id, new_session.clone())
720                    .await?;
721                Ok(new_session)
722            }
723            _ => Err(AuthError::InvalidConfig(
724                "Cannot refresh non-OAuth2 session".to_string(),
725            )),
726        }
727    }
728
729    /// Get session, automatically refreshing if expired
730    pub async fn refresh_get_session(
731        &self,
732        auth_entity: &str,
733        user_id: &str,
734        auth_config: &AuthType,
735    ) -> Result<Option<AuthSession>, AuthError> {
736        match self.store.get_session(auth_entity, user_id).await? {
737            Some(session) => {
738                if session.needs_refresh() {
739                    tracing::debug!(
740                        "Session expired for {}:{:?}, attempting refresh",
741                        auth_entity,
742                        user_id
743                    );
744                    match self
745                        .refresh_session(auth_entity, user_id, auth_config)
746                        .await
747                    {
748                        Ok(refreshed_session) => {
749                            tracing::info!(
750                                "Successfully refreshed session for {}:{:?}",
751                                auth_entity,
752                                user_id
753                            );
754                            Ok(Some(refreshed_session))
755                        }
756                        Err(e) => {
757                            tracing::warn!(
758                                "Failed to refresh session for {}:{:?}: {}",
759                                auth_entity,
760                                user_id,
761                                e
762                            );
763                            Err(e)
764                        }
765                    }
766                } else {
767                    Ok(Some(session))
768                }
769            }
770            None => Ok(None),
771        }
772    }
773
774    async fn get_provider(&self, provider_name: &str) -> Result<Arc<dyn AuthProvider>, AuthError> {
775        if let Some(registry) = &self.provider_registry {
776            registry
777                .get_provider(provider_name)
778                .await
779                .ok_or_else(|| AuthError::ProviderNotFound(provider_name.to_string()))
780        } else {
781            Err(AuthError::InvalidConfig(
782                "No provider registry configured".to_string(),
783            ))
784        }
785    }
786
787    // Storage delegation methods
788    pub async fn get_session(
789        &self,
790        auth_entity: &str,
791        user_id: &str,
792    ) -> Result<Option<AuthSession>, AuthError> {
793        self.store.get_session(auth_entity, user_id).await
794    }
795
796    pub async fn store_session(
797        &self,
798        auth_entity: &str,
799        user_id: &str,
800        session: AuthSession,
801    ) -> Result<(), AuthError> {
802        self.store
803            .store_session(auth_entity, user_id, session)
804            .await
805    }
806
807    pub async fn remove_session(
808        &self,
809        auth_entity: &str,
810        user_id: &str,
811    ) -> Result<bool, AuthError> {
812        self.store.remove_session(auth_entity, user_id).await
813    }
814
815    pub async fn store_secret(
816        &self,
817        user_id: &str,
818        auth_entity: Option<&str>,
819        secret: AuthSecret,
820    ) -> Result<(), AuthError> {
821        self.store.store_secret(user_id, auth_entity, secret).await
822    }
823
824    pub async fn get_secret(
825        &self,
826        user_id: &str,
827        auth_entity: Option<&str>,
828        key: &str,
829    ) -> Result<Option<AuthSecret>, AuthError> {
830        self.store.get_secret(user_id, auth_entity, key).await
831    }
832
833    pub async fn remove_secret(
834        &self,
835        user_id: &str,
836        auth_entity: Option<&str>,
837        key: &str,
838    ) -> Result<bool, AuthError> {
839        self.store.remove_secret(user_id, auth_entity, key).await
840    }
841
842    pub async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError> {
843        self.store.store_oauth2_state(state).await
844    }
845
846    pub async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError> {
847        self.store.get_oauth2_state(state).await
848    }
849
850    pub async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError> {
851        self.store.remove_oauth2_state(state).await
852    }
853
854    pub async fn list_secrets(
855        &self,
856        user_id: &str,
857    ) -> Result<HashMap<String, AuthSecret>, AuthError> {
858        self.store.list_secrets(user_id).await
859    }
860
861    pub async fn list_sessions(
862        &self,
863        user_id: &str,
864    ) -> Result<HashMap<String, AuthSession>, AuthError> {
865        self.store.list_sessions(user_id).await
866    }
867}
868
869/// Authentication provider trait for different OAuth2 providers
870#[async_trait]
871pub trait AuthProvider: Send + Sync {
872    /// Provider name (e.g., "google", "github", "twitter")
873    fn provider_name(&self) -> &str;
874
875    /// Exchange authorization code for access token
876    async fn exchange_code(
877        &self,
878        code: &str,
879        redirect_uri: Option<&str>,
880        auth_config: &AuthType,
881        pkce_code_verifier: Option<&str>,
882    ) -> Result<AuthSession, AuthError>;
883
884    /// Refresh an access token
885    async fn refresh_token(
886        &self,
887        refresh_token: &str,
888        auth_config: &AuthType,
889    ) -> Result<AuthSession, AuthError>;
890
891    /// Build authorization URL
892    fn build_auth_url(
893        &self,
894        auth_config: &AuthType,
895        state: &str,
896        scopes: &[String],
897        redirect_uri: Option<&str>,
898    ) -> Result<String, AuthError>;
899}