distri_types/
auth.rs

1use async_trait::async_trait;
2use base64::Engine;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use chrono::{DateTime, Utc};
5use rand::RngCore;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use std::collections::HashMap;
10use std::sync::Arc;
11use url::Url;
12
13use crate::McpSession;
14
15fn default_send_redirect_uri() -> bool {
16    true
17}
18
19/// Authentication types supported by the tool auth system
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
21#[serde(tag = "type", content = "config")]
22pub enum AuthType {
23    /// No authentication required
24    #[serde(rename = "none")]
25    None,
26    /// OAuth2 authentication flows
27    #[serde(rename = "oauth2")]
28    OAuth2 {
29        /// OAuth2 flow type
30        flow_type: OAuth2FlowType,
31        /// Authorization URL
32        authorization_url: String,
33        /// Token URL
34        token_url: String,
35        /// Optional refresh URL
36        refresh_url: Option<String>,
37        /// Required scopes
38        scopes: Vec<String>,
39        /// Whether the provider should include redirect_uri in requests
40        #[serde(default = "default_send_redirect_uri")]
41        send_redirect_uri: bool,
42    },
43    /// Secret-based authentication (API keys etc.)
44    #[serde(rename = "secret")]
45    Secret {
46        provider: String,
47        #[serde(default)]
48        fields: Vec<SecretFieldSpec>,
49    },
50}
51
52/// OAuth2 flow types
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
54#[serde(rename_all = "snake_case")]
55pub enum OAuth2FlowType {
56    AuthorizationCode,
57    ClientCredentials,
58    Implicit,
59    Password,
60}
61
62/// OAuth2 authentication session - only contains OAuth tokens
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct AuthSession {
65    /// Access token
66    pub access_token: String,
67    /// Optional refresh token
68    pub refresh_token: Option<String>,
69    /// Token expiry time
70    pub expires_at: Option<DateTime<Utc>>,
71    /// Token type (usually "Bearer")
72    pub token_type: String,
73    /// Granted scopes
74    pub scopes: Vec<String>,
75}
76
77/// Response for issuing access + refresh tokens.
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
79pub struct TokenResponse {
80    pub access_token: String,
81    pub refresh_token: String,
82    pub expires_at: i64,
83}
84
85/// Secret storage for API keys and other non-OAuth authentication
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct AuthSecret {
88    /// The secret value (API key, token, etc.)
89    pub secret: String,
90    /// Key name for this secret
91    pub key: String,
92}
93impl Into<McpSession> for AuthSession {
94    fn into(self) -> McpSession {
95        McpSession {
96            token: self.access_token,
97            expiry: self.expires_at.map(|dt| dt.into()),
98        }
99    }
100}
101
102impl Into<McpSession> for AuthSecret {
103    fn into(self) -> McpSession {
104        McpSession {
105            token: self.secret,
106            expiry: None, // Secrets don't expire
107        }
108    }
109}
110
111impl AuthSession {
112    /// Create a new OAuth auth session
113    pub fn new(
114        access_token: String,
115        token_type: Option<String>,
116        expires_in: Option<i64>,
117        refresh_token: Option<String>,
118        scopes: Vec<String>,
119    ) -> Self {
120        let now = Utc::now();
121        let expires_at = expires_in.map(|secs| now + chrono::Duration::seconds(secs));
122
123        AuthSession {
124            access_token,
125            refresh_token,
126            expires_at,
127            token_type: token_type.unwrap_or_else(|| "Bearer".to_string()),
128            scopes,
129        }
130    }
131
132    /// Check if the OAuth token is expired or will expire within the given buffer
133    pub fn is_expired(&self, buffer_seconds: i64) -> bool {
134        match &self.expires_at {
135            Some(expires_at) => {
136                let buffer = chrono::Duration::seconds(buffer_seconds);
137                Utc::now() + buffer >= *expires_at
138            }
139            None => false, // No expiry means it doesn't expire
140        }
141    }
142
143    /// Check if the OAuth token needs refreshing (expired with 5 minute buffer)
144    pub fn needs_refresh(&self) -> bool {
145        self.is_expired(300) // 5 minutes buffer
146    }
147
148    /// Get access token for OAuth sessions
149    pub fn get_access_token(&self) -> &str {
150        &self.access_token
151    }
152
153    /// Update OAuth session with new token data
154    pub fn update_tokens(
155        &mut self,
156        access_token: String,
157        expires_in: Option<i64>,
158        refresh_token: Option<String>,
159    ) {
160        self.access_token = access_token;
161
162        if let Some(secs) = expires_in {
163            self.expires_at = Some(Utc::now() + chrono::Duration::seconds(secs));
164        }
165
166        if let Some(token) = refresh_token {
167            self.refresh_token = Some(token);
168        }
169    }
170}
171
172impl AuthSecret {
173    /// Create a new secret
174    pub fn new(key: String, secret: String) -> Self {
175        AuthSecret { secret, key }
176    }
177
178    /// Get the secret value
179    pub fn get_secret(&self) -> &str {
180        &self.secret
181    }
182
183    /// Get the provider name
184    pub fn get_provider(&self) -> &str {
185        &self.key
186    }
187}
188
189/// Authentication metadata trait for tools
190pub trait AuthMetadata: Send + Sync {
191    /// Get the auth entity identifier (e.g., "google", "twitter", "api_key_service")
192    fn get_auth_entity(&self) -> String;
193
194    /// Get the authentication type and configuration
195    fn get_auth_type(&self) -> AuthType;
196
197    /// Check if authentication is required for this tool
198    fn requires_auth(&self) -> bool {
199        !matches!(self.get_auth_type(), AuthType::None)
200    }
201
202    /// Get additional authentication configuration
203    fn get_auth_config(&self) -> HashMap<String, serde_json::Value> {
204        HashMap::new()
205    }
206}
207
208/// OAuth2 authentication error types
209#[derive(Debug, thiserror::Error)]
210pub enum AuthError {
211    #[error("OAuth2 flow error: {0}")]
212    OAuth2Flow(String),
213
214    #[error("Token expired and refresh failed: {0}")]
215    TokenRefreshFailed(String),
216
217    #[error("Invalid authentication configuration: {0}")]
218    InvalidConfig(String),
219
220    #[error("Authentication required but not configured for entity: {0}")]
221    AuthRequired(String),
222
223    #[error("API key not found for entity: {0}")]
224    ApiKeyNotFound(String),
225
226    #[error("Storage error: {0}")]
227    Storage(#[from] anyhow::Error),
228
229    #[error("Store error: {0}")]
230    StoreError(String),
231
232    #[error("Provider not found: {0}")]
233    ProviderNotFound(String),
234
235    #[error("Server error: {0}")]
236    ServerError(String),
237}
238
239/// Authentication requirement specification for tools
240#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
241#[serde(tag = "type")]
242pub enum AuthRequirement {
243    #[serde(rename = "oauth2")]
244    OAuth2 {
245        provider: String,
246        #[serde(default)]
247        scopes: Vec<String>,
248        #[serde(
249            rename = "authorizationUrl",
250            default,
251            skip_serializing_if = "Option::is_none"
252        )]
253        authorization_url: Option<String>,
254        #[serde(rename = "tokenUrl", default, skip_serializing_if = "Option::is_none")]
255        token_url: Option<String>,
256        #[serde(
257            rename = "refreshUrl",
258            default,
259            skip_serializing_if = "Option::is_none"
260        )]
261        refresh_url: Option<String>,
262        #[serde(
263            rename = "sendRedirectUri",
264            default,
265            skip_serializing_if = "Option::is_none"
266        )]
267        send_redirect_uri: Option<bool>,
268    },
269    #[serde(rename = "secret")]
270    Secret {
271        provider: String,
272        #[serde(default)]
273        fields: Vec<SecretFieldSpec>,
274    },
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
278pub struct SecretFieldSpec {
279    pub key: String,
280    #[serde(default)]
281    pub label: Option<String>,
282    #[serde(default)]
283    pub description: Option<String>,
284    #[serde(default)]
285    pub optional: bool,
286}
287
288/// OAuth2 flow state for managing authorization flows
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct OAuth2State {
291    /// Random state parameter for security  
292    pub state: String,
293    /// Provider name for this OAuth flow
294    pub provider_name: String,
295    /// Redirect URI for the OAuth flow (if the provider requires it)
296    #[serde(default, skip_serializing_if = "Option::is_none")]
297    pub redirect_uri: Option<String>,
298    /// User ID if available
299    pub user_id: String,
300    /// Requested scopes
301    pub scopes: Vec<String>,
302    /// Additional metadata
303    pub metadata: HashMap<String, serde_json::Value>,
304    /// State creation time
305    pub created_at: DateTime<Utc>,
306}
307
308pub const PKCE_CODE_VERIFIER_KEY: &str = "pkce_code_verifier";
309pub const PKCE_CODE_CHALLENGE_METHOD: &str = "S256";
310const PKCE_RANDOM_BYTES: usize = 32;
311
312pub fn generate_pkce_pair() -> (String, String) {
313    let mut random = vec![0u8; PKCE_RANDOM_BYTES];
314    rand::thread_rng().fill_bytes(&mut random);
315    let verifier = URL_SAFE_NO_PAD.encode(&random);
316    let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
317    (verifier, challenge)
318}
319
320pub fn append_pkce_challenge(auth_url: &str, challenge: &str) -> Result<String, AuthError> {
321    let mut url = Url::parse(auth_url)
322        .map_err(|e| AuthError::InvalidConfig(format!("Invalid authorization URL: {}", e)))?;
323    {
324        let mut pairs = url.query_pairs_mut();
325        pairs.append_pair("code_challenge", challenge);
326        pairs.append_pair("code_challenge_method", PKCE_CODE_CHALLENGE_METHOD);
327    }
328    Ok(url.to_string())
329}
330
331impl OAuth2State {
332    /// Create a new OAuth2 state with provided state parameter
333    pub fn new_with_state(
334        state: String,
335        provider_name: String,
336        redirect_uri: Option<String>,
337        user_id: String,
338        scopes: Vec<String>,
339    ) -> Self {
340        Self {
341            state,
342            provider_name,
343            redirect_uri,
344            user_id,
345            scopes,
346            metadata: HashMap::new(),
347            created_at: Utc::now(),
348        }
349    }
350
351    /// Create a new OAuth2 state with auto-generated state parameter (deprecated)
352    pub fn new(
353        provider_name: String,
354        redirect_uri: Option<String>,
355        user_id: String,
356        scopes: Vec<String>,
357    ) -> Self {
358        Self::new_with_state(
359            uuid::Uuid::new_v4().to_string(),
360            provider_name,
361            redirect_uri,
362            user_id,
363            scopes,
364        )
365    }
366
367    /// Check if the state has expired (default 10 minutes)
368    pub fn is_expired(&self, max_age_seconds: i64) -> bool {
369        let max_age = chrono::Duration::seconds(max_age_seconds);
370        Utc::now() - self.created_at > max_age
371    }
372}
373
374/// Storage-only trait for authentication stores
375/// Implementations only need to handle storage operations
376#[async_trait]
377pub trait ToolAuthStore: Send + Sync {
378    /// Session Management
379
380    /// Get current authentication session for an entity
381    async fn get_session(
382        &self,
383        auth_entity: &str,
384        user_id: &str,
385    ) -> Result<Option<AuthSession>, AuthError>;
386
387    /// Store authentication session
388    async fn store_session(
389        &self,
390        auth_entity: &str,
391        user_id: &str,
392        session: AuthSession,
393    ) -> Result<(), AuthError>;
394
395    /// Remove authentication session
396    async fn remove_session(&self, auth_entity: &str, user_id: &str) -> Result<bool, AuthError>;
397
398    /// Secret Management
399
400    /// Store secret for a user (optionally scoped to auth_entity)
401    async fn store_secret(
402        &self,
403        user_id: &str,
404        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
405        secret: AuthSecret,
406    ) -> Result<(), AuthError>;
407
408    /// Get stored secret by key (optionally scoped to auth_entity)  
409    async fn get_secret(
410        &self,
411        user_id: &str,
412        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
413        key: &str,
414    ) -> Result<Option<AuthSecret>, AuthError>;
415
416    /// Remove stored secret by key (optionally scoped to auth_entity)
417    async fn remove_secret(
418        &self,
419        user_id: &str,
420        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
421        key: &str,
422    ) -> Result<bool, AuthError>;
423
424    /// State Management (for OAuth2 flows)
425
426    /// Store OAuth2 state for security
427    async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError>;
428
429    /// Get OAuth2 state by state parameter
430    async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError>;
431
432    /// Remove OAuth2 state (after successful callback)
433    async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError>;
434
435    async fn list_secrets(&self, user_id: &str) -> Result<HashMap<String, AuthSecret>, AuthError>;
436
437    async fn list_sessions(
438        &self,
439        _user_id: &str,
440    ) -> Result<HashMap<String, AuthSession>, AuthError>;
441}
442
443/// OAuth handler that works with any AuthStore implementation
444#[derive(Clone)]
445pub struct OAuthHandler {
446    store: Arc<dyn ToolAuthStore>,
447    provider_registry: Option<Arc<dyn ProviderRegistry>>,
448    redirect_uri: String,
449}
450
451/// Provider registry trait for getting auth providers
452#[async_trait]
453pub trait ProviderRegistry: Send + Sync {
454    async fn get_provider(&self, provider_name: &str) -> Option<Arc<dyn AuthProvider>>;
455    async fn get_auth_type(&self, provider_name: &str) -> Option<AuthType>;
456    async fn is_provider_available(&self, provider_name: &str) -> bool;
457    async fn list_providers(&self) -> Vec<String>;
458    async fn requires_pkce(&self, _provider_name: &str) -> bool {
459        false
460    }
461}
462
463impl OAuthHandler {
464    pub fn new(store: Arc<dyn ToolAuthStore>, redirect_uri: String) -> Self {
465        Self {
466            store,
467            provider_registry: None,
468            redirect_uri,
469        }
470    }
471
472    pub fn with_provider_registry(
473        store: Arc<dyn ToolAuthStore>,
474        provider_registry: Arc<dyn ProviderRegistry>,
475        redirect_uri: String,
476    ) -> Self {
477        Self {
478            store,
479            provider_registry: Some(provider_registry),
480            redirect_uri,
481        }
482    }
483
484    /// Generate authorization URL for OAuth2 flow
485    pub async fn get_auth_url(
486        &self,
487        auth_entity: &str,
488        user_id: &str,
489        auth_config: &AuthType,
490        scopes: &[String],
491    ) -> Result<String, AuthError> {
492        tracing::debug!(
493            "Getting auth URL for entity: {} user: {:?}",
494            auth_entity,
495            user_id
496        );
497
498        match auth_config {
499            AuthType::OAuth2 {
500                flow_type: OAuth2FlowType::ClientCredentials,
501                ..
502            } => Err(AuthError::InvalidConfig(
503                "Client credentials flow doesn't require authorization URL".to_string(),
504            )),
505            auth_config @ AuthType::OAuth2 {
506                send_redirect_uri, ..
507            } => {
508                // Create OAuth2 state
509                let redirect_uri = if *send_redirect_uri {
510                    Some(self.redirect_uri.clone())
511                } else {
512                    None
513                };
514                let mut state = OAuth2State::new(
515                    auth_entity.to_string(),
516                    redirect_uri.clone(),
517                    user_id.to_string(),
518                    scopes.to_vec(),
519                );
520
521                let mut pkce_challenge = None;
522                if let Some(registry) = &self.provider_registry {
523                    if registry.requires_pkce(auth_entity).await {
524                        let (verifier, challenge) = generate_pkce_pair();
525                        state.metadata.insert(
526                            PKCE_CODE_VERIFIER_KEY.to_string(),
527                            serde_json::Value::String(verifier.clone()),
528                        );
529                        pkce_challenge = Some(challenge);
530                    }
531                }
532
533                // Store the state
534                self.store.store_oauth2_state(state.clone()).await?;
535
536                // Get the appropriate provider using the auth_entity as provider name
537                let provider = self.get_provider(auth_entity).await?;
538
539                // Build the authorization URL
540                let mut auth_url = provider.build_auth_url(
541                    auth_config,
542                    &state.state,
543                    scopes,
544                    redirect_uri.as_deref(),
545                )?;
546
547                if let Some(challenge) = pkce_challenge {
548                    auth_url = append_pkce_challenge(&auth_url, &challenge)?;
549                }
550
551                tracing::debug!("Generated auth URL: {}", auth_url);
552                Ok(auth_url)
553            }
554            AuthType::Secret { .. } => Err(AuthError::InvalidConfig(
555                "Secret authentication doesn't require authorization URL".to_string(),
556            )),
557            AuthType::None => Err(AuthError::InvalidConfig(
558                "No authentication doesn't require authorization URL".to_string(),
559            )),
560        }
561    }
562
563    /// Handle OAuth2 callback and exchange code for tokens
564    pub async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthSession, AuthError> {
565        tracing::debug!("Handling OAuth2 callback with state: {}", state);
566
567        // Get and remove the state
568        let oauth2_state = self.store.get_oauth2_state(state).await?.ok_or_else(|| {
569            AuthError::OAuth2Flow("Invalid or expired state parameter".to_string())
570        })?;
571
572        // Remove the used state
573        self.store.remove_oauth2_state(state).await?;
574
575        // Check if state is expired (10 minutes max)
576        if oauth2_state.is_expired(600) {
577            return Err(AuthError::OAuth2Flow(
578                "OAuth2 state has expired".to_string(),
579            ));
580        }
581
582        // Get auth config from provider registry
583        let auth_config = if let Some(registry) = &self.provider_registry {
584            registry
585                .get_auth_type(&oauth2_state.provider_name)
586                .await
587                .ok_or_else(|| {
588                    AuthError::InvalidConfig(format!(
589                        "No configuration found for provider: {}",
590                        oauth2_state.provider_name
591                    ))
592                })?
593        } else {
594            return Err(AuthError::InvalidConfig(
595                "No provider registry configured".to_string(),
596            ));
597        };
598
599        // Get the appropriate provider
600        let provider = self.get_provider(&oauth2_state.provider_name).await?;
601
602        // Exchange the authorization code for tokens
603        let redirect_uri = match &auth_config {
604            AuthType::OAuth2 {
605                send_redirect_uri, ..
606            } if *send_redirect_uri => oauth2_state
607                .redirect_uri
608                .clone()
609                .or_else(|| Some(self.redirect_uri.clone())),
610            AuthType::OAuth2 { .. } => None,
611            _ => None,
612        };
613        let pkce_code_verifier = oauth2_state
614            .metadata
615            .get(PKCE_CODE_VERIFIER_KEY)
616            .and_then(|v| v.as_str());
617
618        let session = provider
619            .exchange_code(
620                code,
621                redirect_uri.as_deref(),
622                &auth_config,
623                pkce_code_verifier,
624            )
625            .await?;
626
627        // Store the session
628        self.store
629            .store_session(
630                &oauth2_state.provider_name,
631                &oauth2_state.user_id,
632                session.clone(),
633            )
634            .await?;
635
636        tracing::debug!(
637            "Successfully stored auth session for entity: {}",
638            oauth2_state.provider_name
639        );
640        Ok(session)
641    }
642
643    /// Refresh an expired session
644    pub async fn refresh_session(
645        &self,
646        auth_entity: &str,
647        user_id: &str,
648        auth_config: &AuthType,
649    ) -> Result<AuthSession, AuthError> {
650        tracing::debug!(
651            "Refreshing session for entity: {} user: {:?}",
652            auth_entity,
653            user_id
654        );
655
656        // Get current session
657        let current_session = self
658            .store
659            .get_session(auth_entity, &user_id)
660            .await?
661            .ok_or_else(|| {
662                AuthError::TokenRefreshFailed("No session found to refresh".to_string())
663            })?;
664
665        let refresh_token = current_session.refresh_token.ok_or_else(|| {
666            AuthError::TokenRefreshFailed("No refresh token available".to_string())
667        })?;
668
669        match auth_config {
670            AuthType::OAuth2 {
671                flow_type: OAuth2FlowType::ClientCredentials,
672                ..
673            } => {
674                // For client credentials, get a new token instead of refreshing
675                let provider = self.get_provider(auth_entity).await?;
676                let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
677
678                // Store the new session
679                self.store
680                    .store_session(auth_entity, &user_id, new_session.clone())
681                    .await?;
682                Ok(new_session)
683            }
684            auth_config @ AuthType::OAuth2 { .. } => {
685                // Get the appropriate provider
686                let provider = self.get_provider(auth_entity).await?;
687
688                // Refresh the token
689                let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
690
691                // Store the new session
692                self.store
693                    .store_session(auth_entity, &user_id, new_session.clone())
694                    .await?;
695                Ok(new_session)
696            }
697            _ => Err(AuthError::InvalidConfig(
698                "Cannot refresh non-OAuth2 session".to_string(),
699            )),
700        }
701    }
702
703    /// Get session, automatically refreshing if expired
704    pub async fn refresh_get_session(
705        &self,
706        auth_entity: &str,
707        user_id: &str,
708        auth_config: &AuthType,
709    ) -> Result<Option<AuthSession>, AuthError> {
710        match self.store.get_session(auth_entity, user_id).await? {
711            Some(session) => {
712                if session.needs_refresh() {
713                    tracing::debug!(
714                        "Session expired for {}:{:?}, attempting refresh",
715                        auth_entity,
716                        user_id
717                    );
718                    match self
719                        .refresh_session(auth_entity, user_id, auth_config)
720                        .await
721                    {
722                        Ok(refreshed_session) => {
723                            tracing::info!(
724                                "Successfully refreshed session for {}:{:?}",
725                                auth_entity,
726                                user_id
727                            );
728                            Ok(Some(refreshed_session))
729                        }
730                        Err(e) => {
731                            tracing::warn!(
732                                "Failed to refresh session for {}:{:?}: {}",
733                                auth_entity,
734                                user_id,
735                                e
736                            );
737                            Err(e)
738                        }
739                    }
740                } else {
741                    Ok(Some(session))
742                }
743            }
744            None => Ok(None),
745        }
746    }
747
748    async fn get_provider(&self, provider_name: &str) -> Result<Arc<dyn AuthProvider>, AuthError> {
749        if let Some(registry) = &self.provider_registry {
750            registry
751                .get_provider(provider_name)
752                .await
753                .ok_or_else(|| AuthError::ProviderNotFound(provider_name.to_string()))
754        } else {
755            Err(AuthError::InvalidConfig(
756                "No provider registry configured".to_string(),
757            ))
758        }
759    }
760
761    // Storage delegation methods
762    pub async fn get_session(
763        &self,
764        auth_entity: &str,
765        user_id: &str,
766    ) -> Result<Option<AuthSession>, AuthError> {
767        self.store.get_session(auth_entity, user_id).await
768    }
769
770    pub async fn store_session(
771        &self,
772        auth_entity: &str,
773        user_id: &str,
774        session: AuthSession,
775    ) -> Result<(), AuthError> {
776        self.store
777            .store_session(auth_entity, user_id, session)
778            .await
779    }
780
781    pub async fn remove_session(
782        &self,
783        auth_entity: &str,
784        user_id: &str,
785    ) -> Result<bool, AuthError> {
786        self.store.remove_session(auth_entity, user_id).await
787    }
788
789    pub async fn store_secret(
790        &self,
791        user_id: &str,
792        auth_entity: Option<&str>,
793        secret: AuthSecret,
794    ) -> Result<(), AuthError> {
795        self.store.store_secret(user_id, auth_entity, secret).await
796    }
797
798    pub async fn get_secret(
799        &self,
800        user_id: &str,
801        auth_entity: Option<&str>,
802        key: &str,
803    ) -> Result<Option<AuthSecret>, AuthError> {
804        self.store.get_secret(user_id, auth_entity, key).await
805    }
806
807    pub async fn remove_secret(
808        &self,
809        user_id: &str,
810        auth_entity: Option<&str>,
811        key: &str,
812    ) -> Result<bool, AuthError> {
813        self.store.remove_secret(user_id, auth_entity, key).await
814    }
815
816    pub async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError> {
817        self.store.store_oauth2_state(state).await
818    }
819
820    pub async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError> {
821        self.store.get_oauth2_state(state).await
822    }
823
824    pub async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError> {
825        self.store.remove_oauth2_state(state).await
826    }
827
828    pub async fn list_secrets(
829        &self,
830        user_id: &str,
831    ) -> Result<HashMap<String, AuthSecret>, AuthError> {
832        self.store.list_secrets(user_id).await
833    }
834
835    pub async fn list_sessions(
836        &self,
837        user_id: &str,
838    ) -> Result<HashMap<String, AuthSession>, AuthError> {
839        self.store.list_sessions(user_id).await
840    }
841}
842
843/// Authentication provider trait for different OAuth2 providers
844#[async_trait]
845pub trait AuthProvider: Send + Sync {
846    /// Provider name (e.g., "google", "github", "twitter")
847    fn provider_name(&self) -> &str;
848
849    /// Exchange authorization code for access token
850    async fn exchange_code(
851        &self,
852        code: &str,
853        redirect_uri: Option<&str>,
854        auth_config: &AuthType,
855        pkce_code_verifier: Option<&str>,
856    ) -> Result<AuthSession, AuthError>;
857
858    /// Refresh an access token
859    async fn refresh_token(
860        &self,
861        refresh_token: &str,
862        auth_config: &AuthType,
863    ) -> Result<AuthSession, AuthError>;
864
865    /// Build authorization URL
866    fn build_auth_url(
867        &self,
868        auth_config: &AuthType,
869        state: &str,
870        scopes: &[String],
871        redirect_uri: Option<&str>,
872    ) -> Result<String, AuthError>;
873}