Skip to main content

distri_types/
auth.rs

1use async_trait::async_trait;
2use base64::Engine;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use chrono::{DateTime, Utc};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8use std::collections::HashMap;
9use std::sync::Arc;
10use url::Url;
11
12use crate::McpSession;
13
14fn default_send_redirect_uri() -> bool {
15    true
16}
17
18/// Authentication types supported by the tool auth system
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
20#[serde(tag = "type", content = "config")]
21pub enum AuthType {
22    /// No authentication required
23    #[serde(rename = "none")]
24    None,
25    /// OAuth2 authentication flows
26    #[serde(rename = "oauth2")]
27    OAuth2 {
28        /// OAuth2 flow type
29        flow_type: OAuth2FlowType,
30        /// Authorization URL
31        authorization_url: String,
32        /// Token URL
33        token_url: String,
34        /// Optional refresh URL
35        refresh_url: Option<String>,
36        /// Required scopes
37        scopes: Vec<String>,
38        /// Whether the provider should include redirect_uri in requests
39        #[serde(default = "default_send_redirect_uri")]
40        send_redirect_uri: bool,
41    },
42    /// Secret-based authentication (API keys etc.)
43    #[serde(rename = "secret")]
44    Secret {
45        provider: String,
46        #[serde(default)]
47        fields: Vec<SecretFieldSpec>,
48    },
49}
50
51/// OAuth2 flow types
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
53#[serde(rename_all = "snake_case")]
54pub enum OAuth2FlowType {
55    AuthorizationCode,
56    ClientCredentials,
57    Implicit,
58    Password,
59}
60
61/// OAuth2 authentication session - only contains OAuth tokens
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct AuthSession {
64    /// Access token
65    pub access_token: String,
66    /// Optional refresh token
67    pub refresh_token: Option<String>,
68    /// Token expiry time
69    pub expires_at: Option<DateTime<Utc>>,
70    /// Token type (usually "Bearer")
71    pub token_type: String,
72    /// Granted scopes
73    pub scopes: Vec<String>,
74}
75
76/// Usage limits that can be embedded in tokens
77#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
78pub struct TokenLimits {
79    /// Maximum tokens per day (None = unlimited)
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub daily_tokens: Option<u64>,
82
83    /// Maximum tokens per month (None = unlimited)
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub monthly_tokens: Option<u64>,
86
87    /// Maximum API calls per day (None = unlimited)
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub daily_calls: Option<u64>,
90
91    /// Maximum API calls per week (None = unlimited)
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub weekly_calls: Option<u64>,
94}
95
96/// Response for issuing access + refresh tokens.
97///
98/// distri-cloud extends this with an opaque `extensions: serde_json::Value`
99/// field via its own response wrapper if it needs to surface authz metadata
100/// (granted permissions, scope, role) to the caller. This base type stays
101/// auth-agnostic so OSS consumers don't pull in cloud's authz model.
102#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
103pub struct TokenResponse {
104    pub access_token: String,
105    pub refresh_token: String,
106    pub expires_at: i64,
107    /// Identifier for usage tracking (e.g., "blinksheets", "my-app")
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub identifier_id: Option<String>,
110    /// Effective limits applied to this token
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub limits: Option<TokenLimits>,
113}
114
115/// Secret storage for API keys and other non-OAuth authentication
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct AuthSecret {
118    /// The secret value (API key, token, etc.)
119    pub secret: String,
120    /// Key name for this secret
121    pub key: String,
122}
123impl From<AuthSession> for McpSession {
124    fn from(val: AuthSession) -> Self {
125        McpSession {
126            token: val.access_token,
127            expiry: val.expires_at.map(|dt| dt.into()),
128        }
129    }
130}
131
132impl From<AuthSecret> for McpSession {
133    fn from(val: AuthSecret) -> Self {
134        McpSession {
135            token: val.secret,
136            expiry: None, // Secrets don't expire
137        }
138    }
139}
140
141impl AuthSession {
142    /// Create a new OAuth auth session
143    pub fn new(
144        access_token: String,
145        token_type: Option<String>,
146        expires_in: Option<i64>,
147        refresh_token: Option<String>,
148        scopes: Vec<String>,
149    ) -> Self {
150        let now = Utc::now();
151        let expires_at = expires_in.map(|secs| now + chrono::Duration::seconds(secs));
152
153        AuthSession {
154            access_token,
155            refresh_token,
156            expires_at,
157            token_type: token_type.unwrap_or_else(|| "Bearer".to_string()),
158            scopes,
159        }
160    }
161
162    /// Check if the OAuth token is expired or will expire within the given buffer
163    pub fn is_expired(&self, buffer_seconds: i64) -> bool {
164        match &self.expires_at {
165            Some(expires_at) => {
166                let buffer = chrono::Duration::seconds(buffer_seconds);
167                Utc::now() + buffer >= *expires_at
168            }
169            None => false, // No expiry means it doesn't expire
170        }
171    }
172
173    /// Check if the OAuth token needs refreshing (expired with 5 minute buffer)
174    pub fn needs_refresh(&self) -> bool {
175        self.is_expired(300) // 5 minutes buffer
176    }
177
178    /// Get access token for OAuth sessions
179    pub fn get_access_token(&self) -> &str {
180        &self.access_token
181    }
182
183    /// Update OAuth session with new token data
184    pub fn update_tokens(
185        &mut self,
186        access_token: String,
187        expires_in: Option<i64>,
188        refresh_token: Option<String>,
189    ) {
190        self.access_token = access_token;
191
192        if let Some(secs) = expires_in {
193            self.expires_at = Some(Utc::now() + chrono::Duration::seconds(secs));
194        }
195
196        if let Some(token) = refresh_token {
197            self.refresh_token = Some(token);
198        }
199    }
200}
201
202impl AuthSecret {
203    /// Create a new secret
204    pub fn new(key: String, secret: String) -> Self {
205        AuthSecret { secret, key }
206    }
207
208    /// Get the secret value
209    pub fn get_secret(&self) -> &str {
210        &self.secret
211    }
212
213    /// Get the provider name
214    pub fn get_provider(&self) -> &str {
215        &self.key
216    }
217}
218
219/// Authentication metadata trait for tools
220pub trait AuthMetadata: Send + Sync {
221    /// Get the auth entity identifier (e.g., "google", "twitter", "api_key_service")
222    fn get_auth_entity(&self) -> String;
223
224    /// Get the authentication type and configuration
225    fn get_auth_type(&self) -> AuthType;
226
227    /// Check if authentication is required for this tool
228    fn requires_auth(&self) -> bool {
229        !matches!(self.get_auth_type(), AuthType::None)
230    }
231
232    /// Get additional authentication configuration
233    fn get_auth_config(&self) -> HashMap<String, serde_json::Value> {
234        HashMap::new()
235    }
236}
237
238/// OAuth2 authentication error types
239#[derive(Debug, thiserror::Error)]
240pub enum AuthError {
241    #[error("OAuth2 flow error: {0}")]
242    OAuth2Flow(String),
243
244    #[error("Token expired and refresh failed: {0}")]
245    TokenRefreshFailed(String),
246
247    #[error("Invalid authentication configuration: {0}")]
248    InvalidConfig(String),
249
250    #[error("Authentication required but not configured for entity: {0}")]
251    AuthRequired(String),
252
253    #[error("API key not found for entity: {0}")]
254    ApiKeyNotFound(String),
255
256    #[error("Storage error: {0}")]
257    Storage(#[from] anyhow::Error),
258
259    #[error("Store error: {0}")]
260    StoreError(String),
261
262    #[error("Provider not found: {0}")]
263    ProviderNotFound(String),
264
265    #[error("Server error: {0}")]
266    ServerError(String),
267}
268
269/// Authentication requirement specification for tools
270#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
271#[serde(tag = "type")]
272pub enum AuthRequirement {
273    #[serde(rename = "oauth2")]
274    OAuth2 {
275        provider: String,
276        #[serde(default)]
277        scopes: Vec<String>,
278        #[serde(
279            rename = "authorizationUrl",
280            default,
281            skip_serializing_if = "Option::is_none"
282        )]
283        authorization_url: Option<String>,
284        #[serde(rename = "tokenUrl", default, skip_serializing_if = "Option::is_none")]
285        token_url: Option<String>,
286        #[serde(
287            rename = "refreshUrl",
288            default,
289            skip_serializing_if = "Option::is_none"
290        )]
291        refresh_url: Option<String>,
292        #[serde(
293            rename = "sendRedirectUri",
294            default,
295            skip_serializing_if = "Option::is_none"
296        )]
297        send_redirect_uri: Option<bool>,
298    },
299    #[serde(rename = "secret")]
300    Secret {
301        provider: String,
302        #[serde(default)]
303        fields: Vec<SecretFieldSpec>,
304    },
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
308pub struct SecretFieldSpec {
309    pub key: String,
310    #[serde(default)]
311    pub label: Option<String>,
312    #[serde(default)]
313    pub description: Option<String>,
314    #[serde(default)]
315    pub optional: bool,
316}
317
318/// OAuth2 flow state for managing authorization flows
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct OAuth2State {
321    /// Random state parameter for security  
322    pub state: String,
323    /// Provider name for this OAuth flow
324    pub provider_name: String,
325    /// Redirect URI for the OAuth flow (if the provider requires it)
326    #[serde(default, skip_serializing_if = "Option::is_none")]
327    pub redirect_uri: Option<String>,
328    /// User ID if available
329    pub user_id: String,
330    /// Requested scopes
331    pub scopes: Vec<String>,
332    /// Additional metadata
333    pub metadata: HashMap<String, serde_json::Value>,
334    /// State creation time
335    pub created_at: DateTime<Utc>,
336}
337
338pub const PKCE_CODE_VERIFIER_KEY: &str = "pkce_code_verifier";
339pub const PKCE_CODE_CHALLENGE_METHOD: &str = "S256";
340const PKCE_RANDOM_BYTES: usize = 32;
341
342pub fn generate_pkce_pair() -> (String, String) {
343    let mut random = vec![0u8; PKCE_RANDOM_BYTES];
344    rand::fill(&mut random);
345    let verifier = URL_SAFE_NO_PAD.encode(&random);
346    let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
347    (verifier, challenge)
348}
349
350pub fn append_pkce_challenge(auth_url: &str, challenge: &str) -> Result<String, AuthError> {
351    let mut url = Url::parse(auth_url)
352        .map_err(|e| AuthError::InvalidConfig(format!("Invalid authorization URL: {}", e)))?;
353    {
354        let mut pairs = url.query_pairs_mut();
355        pairs.append_pair("code_challenge", challenge);
356        pairs.append_pair("code_challenge_method", PKCE_CODE_CHALLENGE_METHOD);
357    }
358    Ok(url.to_string())
359}
360
361impl OAuth2State {
362    /// Create a new OAuth2 state with provided state parameter
363    pub fn new_with_state(
364        state: String,
365        provider_name: String,
366        redirect_uri: Option<String>,
367        user_id: String,
368        scopes: Vec<String>,
369    ) -> Self {
370        Self {
371            state,
372            provider_name,
373            redirect_uri,
374            user_id,
375            scopes,
376            metadata: HashMap::new(),
377            created_at: Utc::now(),
378        }
379    }
380
381    /// Create a new OAuth2 state with auto-generated state parameter (deprecated)
382    pub fn new(
383        provider_name: String,
384        redirect_uri: Option<String>,
385        user_id: String,
386        scopes: Vec<String>,
387    ) -> Self {
388        Self::new_with_state(
389            uuid::Uuid::new_v4().to_string(),
390            provider_name,
391            redirect_uri,
392            user_id,
393            scopes,
394        )
395    }
396
397    /// Check if the state has expired (default 10 minutes)
398    pub fn is_expired(&self, max_age_seconds: i64) -> bool {
399        let max_age = chrono::Duration::seconds(max_age_seconds);
400        Utc::now() - self.created_at > max_age
401    }
402}
403
404/// Storage-only trait for authentication stores
405/// Implementations only need to handle storage operations
406#[async_trait]
407pub trait ToolAuthStore: Send + Sync {
408    /// Session Management
409    /// Get current authentication session for an entity
410    async fn get_session(
411        &self,
412        auth_entity: &str,
413        user_id: &str,
414    ) -> Result<Option<AuthSession>, AuthError>;
415
416    /// Store authentication session
417    async fn store_session(
418        &self,
419        auth_entity: &str,
420        user_id: &str,
421        session: AuthSession,
422    ) -> Result<(), AuthError>;
423
424    /// Remove authentication session
425    async fn remove_session(&self, auth_entity: &str, user_id: &str) -> Result<bool, AuthError>;
426
427    /// Secret Management
428    /// Store secret for a user (optionally scoped to auth_entity)
429    async fn store_secret(
430        &self,
431        user_id: &str,
432        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
433        secret: AuthSecret,
434    ) -> Result<(), AuthError>;
435
436    /// Get stored secret by key (optionally scoped to auth_entity)  
437    async fn get_secret(
438        &self,
439        user_id: &str,
440        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
441        key: &str,
442    ) -> Result<Option<AuthSecret>, AuthError>;
443
444    /// Remove stored secret by key (optionally scoped to auth_entity)
445    async fn remove_secret(
446        &self,
447        user_id: &str,
448        auth_entity: Option<&str>, // None for global secrets, Some() for auth_entity-specific
449        key: &str,
450    ) -> Result<bool, AuthError>;
451
452    /// State Management (for OAuth2 flows)
453    /// Store OAuth2 state for security
454    async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError>;
455
456    /// Get OAuth2 state by state parameter
457    async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError>;
458
459    /// Remove OAuth2 state (after successful callback)
460    async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError>;
461
462    async fn list_secrets(&self, user_id: &str) -> Result<HashMap<String, AuthSecret>, AuthError>;
463
464    async fn list_sessions(
465        &self,
466        _user_id: &str,
467    ) -> Result<HashMap<String, AuthSession>, AuthError>;
468}
469
470/// OAuth handler that works with any AuthStore implementation
471#[derive(Clone)]
472pub struct OAuthHandler {
473    store: Arc<dyn ToolAuthStore>,
474    provider_registry: Option<Arc<dyn ProviderRegistry>>,
475    redirect_uri: String,
476}
477
478/// Provider registry trait for getting auth providers
479#[async_trait]
480pub trait ProviderRegistry: Send + Sync {
481    async fn get_provider(&self, provider_name: &str) -> Option<Arc<dyn AuthProvider>>;
482    async fn get_auth_type(&self, provider_name: &str) -> Option<AuthType>;
483    async fn is_provider_available(&self, provider_name: &str) -> bool;
484    async fn list_providers(&self) -> Vec<String>;
485    async fn requires_pkce(&self, _provider_name: &str) -> bool {
486        false
487    }
488}
489
490impl OAuthHandler {
491    pub fn new(store: Arc<dyn ToolAuthStore>, redirect_uri: String) -> Self {
492        Self {
493            store,
494            provider_registry: None,
495            redirect_uri,
496        }
497    }
498
499    pub fn with_provider_registry(
500        store: Arc<dyn ToolAuthStore>,
501        provider_registry: Arc<dyn ProviderRegistry>,
502        redirect_uri: String,
503    ) -> Self {
504        Self {
505            store,
506            provider_registry: Some(provider_registry),
507            redirect_uri,
508        }
509    }
510
511    /// Generate authorization URL for OAuth2 flow. `extra_params` is
512    /// appended verbatim to the auth URL — provider-agnostic passthrough
513    /// for caller-supplied knobs (e.g. Slack `team=`, Microsoft `tenant=`).
514    /// `provider_override` lets the caller swap in a per-connection
515    /// `OAuth2Provider` built from BYOK creds or discovered metadata
516    /// instead of the registry's catalog provider.
517    pub async fn get_auth_url(
518        &self,
519        auth_entity: &str,
520        user_id: &str,
521        auth_config: &AuthType,
522        scopes: &[String],
523        extra_params: &HashMap<String, String>,
524        provider_override: Option<Arc<dyn AuthProvider>>,
525    ) -> Result<String, AuthError> {
526        tracing::debug!(
527            "Getting auth URL for entity: {} user: {:?}",
528            auth_entity,
529            user_id
530        );
531
532        match auth_config {
533            AuthType::OAuth2 {
534                flow_type: OAuth2FlowType::ClientCredentials,
535                ..
536            } => Err(AuthError::InvalidConfig(
537                "Client credentials flow doesn't require authorization URL".to_string(),
538            )),
539            auth_config @ AuthType::OAuth2 {
540                send_redirect_uri, ..
541            } => {
542                // Create OAuth2 state
543                let redirect_uri = if *send_redirect_uri {
544                    Some(self.redirect_uri.clone())
545                } else {
546                    None
547                };
548                let mut state = OAuth2State::new(
549                    auth_entity.to_string(),
550                    redirect_uri.clone(),
551                    user_id.to_string(),
552                    scopes.to_vec(),
553                );
554
555                let mut pkce_challenge = None;
556                if let Some(registry) = &self.provider_registry
557                    && registry.requires_pkce(auth_entity).await
558                {
559                    let (verifier, challenge) = generate_pkce_pair();
560                    state.metadata.insert(
561                        PKCE_CODE_VERIFIER_KEY.to_string(),
562                        serde_json::Value::String(verifier.clone()),
563                    );
564                    pkce_challenge = Some(challenge);
565                }
566
567                // Store the state
568                self.store.store_oauth2_state(state.clone()).await?;
569
570                // Get the appropriate provider — caller-supplied override if any
571                // (per-connection BYOK / DCR), else look up by entity name.
572                let provider = self
573                    .effective_provider(auth_entity, provider_override)
574                    .await?;
575
576                // Build the authorization URL
577                let mut auth_url = provider.build_auth_url(
578                    auth_config,
579                    &state.state,
580                    scopes,
581                    redirect_uri.as_deref(),
582                    extra_params,
583                )?;
584
585                if let Some(challenge) = pkce_challenge {
586                    auth_url = append_pkce_challenge(&auth_url, &challenge)?;
587                }
588
589                tracing::debug!("Generated auth URL: {}", auth_url);
590                Ok(auth_url)
591            }
592            AuthType::Secret { .. } => Err(AuthError::InvalidConfig(
593                "Secret authentication doesn't require authorization URL".to_string(),
594            )),
595            AuthType::None => Err(AuthError::InvalidConfig(
596                "No authentication doesn't require authorization URL".to_string(),
597            )),
598        }
599    }
600
601    /// Handle OAuth2 callback and exchange code for tokens. `provider_override`
602    /// lets the caller (e.g. cloud's ConnectionService) plug in a per-connection
603    /// provider built from BYOK creds or discovered metadata.
604    pub async fn handle_callback(
605        &self,
606        code: &str,
607        state: &str,
608        provider_override: Option<Arc<dyn AuthProvider>>,
609    ) -> Result<AuthSession, AuthError> {
610        tracing::debug!("Handling OAuth2 callback with state: {}", state);
611
612        // Get and remove the state
613        let oauth2_state = self.store.get_oauth2_state(state).await?.ok_or_else(|| {
614            AuthError::OAuth2Flow("Invalid or expired state parameter".to_string())
615        })?;
616
617        // Remove the used state
618        self.store.remove_oauth2_state(state).await?;
619
620        // Check if state is expired (10 minutes max)
621        if oauth2_state.is_expired(600) {
622            return Err(AuthError::OAuth2Flow(
623                "OAuth2 state has expired".to_string(),
624            ));
625        }
626
627        // Get auth config from provider registry
628        let auth_config = if let Some(registry) = &self.provider_registry {
629            registry
630                .get_auth_type(&oauth2_state.provider_name)
631                .await
632                .ok_or_else(|| {
633                    AuthError::InvalidConfig(format!(
634                        "No configuration found for provider: {}",
635                        oauth2_state.provider_name
636                    ))
637                })?
638        } else {
639            return Err(AuthError::InvalidConfig(
640                "No provider registry configured".to_string(),
641            ));
642        };
643
644        // Get the appropriate provider — caller-supplied override (BYOK / DCR)
645        // takes precedence over the registry lookup.
646        let provider = self
647            .effective_provider(&oauth2_state.provider_name, provider_override)
648            .await?;
649
650        // Exchange the authorization code for tokens
651        let redirect_uri = match &auth_config {
652            AuthType::OAuth2 {
653                send_redirect_uri, ..
654            } if *send_redirect_uri => oauth2_state
655                .redirect_uri
656                .clone()
657                .or_else(|| Some(self.redirect_uri.clone())),
658            AuthType::OAuth2 { .. } => None,
659            _ => None,
660        };
661        let pkce_code_verifier = oauth2_state
662            .metadata
663            .get(PKCE_CODE_VERIFIER_KEY)
664            .and_then(|v| v.as_str());
665
666        let session = provider
667            .exchange_code(
668                code,
669                redirect_uri.as_deref(),
670                &auth_config,
671                pkce_code_verifier,
672            )
673            .await?;
674
675        // Store the session
676        self.store
677            .store_session(
678                &oauth2_state.provider_name,
679                &oauth2_state.user_id,
680                session.clone(),
681            )
682            .await?;
683
684        tracing::debug!(
685            "Successfully stored auth session for entity: {}",
686            oauth2_state.provider_name
687        );
688        Ok(session)
689    }
690
691    /// Refresh an expired session. `provider_override` (when supplied) is used
692    /// instead of the registry-resolved provider — needed for BYOK and DCR
693    /// flows where the per-connection client creds are workspace secrets,
694    /// not in the catalog.
695    pub async fn refresh_session(
696        &self,
697        auth_entity: &str,
698        user_id: &str,
699        auth_config: &AuthType,
700        provider_override: Option<Arc<dyn AuthProvider>>,
701    ) -> Result<AuthSession, AuthError> {
702        tracing::debug!(
703            "Refreshing session for entity: {} user: {:?}",
704            auth_entity,
705            user_id
706        );
707
708        // Get current session
709        let current_session = self
710            .store
711            .get_session(auth_entity, user_id)
712            .await?
713            .ok_or_else(|| {
714                AuthError::TokenRefreshFailed("No session found to refresh".to_string())
715            })?;
716
717        let refresh_token = current_session.refresh_token.ok_or_else(|| {
718            AuthError::TokenRefreshFailed("No refresh token available".to_string())
719        })?;
720
721        match auth_config {
722            AuthType::OAuth2 {
723                flow_type: OAuth2FlowType::ClientCredentials,
724                ..
725            } => {
726                // For client credentials, get a new token instead of refreshing.
727                let provider = self
728                    .effective_provider(auth_entity, provider_override)
729                    .await?;
730                let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
731
732                // Store the new session
733                self.store
734                    .store_session(auth_entity, user_id, new_session.clone())
735                    .await?;
736                Ok(new_session)
737            }
738            auth_config @ AuthType::OAuth2 { .. } => {
739                // Get the appropriate provider — override wins if supplied.
740                let provider = self
741                    .effective_provider(auth_entity, provider_override)
742                    .await?;
743
744                // Refresh the token
745                let new_session = provider.refresh_token(&refresh_token, auth_config).await?;
746
747                // Store the new session
748                self.store
749                    .store_session(auth_entity, user_id, new_session.clone())
750                    .await?;
751                Ok(new_session)
752            }
753            _ => Err(AuthError::InvalidConfig(
754                "Cannot refresh non-OAuth2 session".to_string(),
755            )),
756        }
757    }
758
759    /// Get session, automatically refreshing if expired. `provider_override`
760    /// is forwarded to `refresh_session` for BYOK / DCR client creds.
761    pub async fn refresh_get_session(
762        &self,
763        auth_entity: &str,
764        user_id: &str,
765        auth_config: &AuthType,
766        provider_override: Option<Arc<dyn AuthProvider>>,
767    ) -> Result<Option<AuthSession>, AuthError> {
768        match self.store.get_session(auth_entity, user_id).await? {
769            Some(session) => {
770                if session.needs_refresh() {
771                    tracing::debug!(
772                        "Session expired for {}:{:?}, attempting refresh",
773                        auth_entity,
774                        user_id
775                    );
776                    match self
777                        .refresh_session(auth_entity, user_id, auth_config, provider_override)
778                        .await
779                    {
780                        Ok(refreshed_session) => {
781                            tracing::info!(
782                                "Successfully refreshed session for {}:{:?}",
783                                auth_entity,
784                                user_id
785                            );
786                            Ok(Some(refreshed_session))
787                        }
788                        Err(e) => {
789                            tracing::warn!(
790                                "Failed to refresh session for {}:{:?}: {}",
791                                auth_entity,
792                                user_id,
793                                e
794                            );
795                            Err(e)
796                        }
797                    }
798                } else {
799                    Ok(Some(session))
800                }
801            }
802            None => Ok(None),
803        }
804    }
805
806    async fn get_provider(&self, provider_name: &str) -> Result<Arc<dyn AuthProvider>, AuthError> {
807        if let Some(registry) = &self.provider_registry {
808            registry
809                .get_provider(provider_name)
810                .await
811                .ok_or_else(|| AuthError::ProviderNotFound(provider_name.to_string()))
812        } else {
813            Err(AuthError::InvalidConfig(
814                "No provider registry configured".to_string(),
815            ))
816        }
817    }
818
819    /// Return the caller-supplied provider if any, else fall back to the
820    /// registry-resolved provider. Used by the three flow methods so the
821    /// caller (cloud's ConnectionService) can swap in a per-connection
822    /// `OAuth2Provider` built from BYOK creds or discovered metadata.
823    async fn effective_provider(
824        &self,
825        provider_name: &str,
826        override_provider: Option<Arc<dyn AuthProvider>>,
827    ) -> Result<Arc<dyn AuthProvider>, AuthError> {
828        match override_provider {
829            Some(p) => Ok(p),
830            None => self.get_provider(provider_name).await,
831        }
832    }
833
834    // Storage delegation methods
835    pub async fn get_session(
836        &self,
837        auth_entity: &str,
838        user_id: &str,
839    ) -> Result<Option<AuthSession>, AuthError> {
840        self.store.get_session(auth_entity, user_id).await
841    }
842
843    pub async fn store_session(
844        &self,
845        auth_entity: &str,
846        user_id: &str,
847        session: AuthSession,
848    ) -> Result<(), AuthError> {
849        self.store
850            .store_session(auth_entity, user_id, session)
851            .await
852    }
853
854    pub async fn remove_session(
855        &self,
856        auth_entity: &str,
857        user_id: &str,
858    ) -> Result<bool, AuthError> {
859        self.store.remove_session(auth_entity, user_id).await
860    }
861
862    pub async fn store_secret(
863        &self,
864        user_id: &str,
865        auth_entity: Option<&str>,
866        secret: AuthSecret,
867    ) -> Result<(), AuthError> {
868        self.store.store_secret(user_id, auth_entity, secret).await
869    }
870
871    pub async fn get_secret(
872        &self,
873        user_id: &str,
874        auth_entity: Option<&str>,
875        key: &str,
876    ) -> Result<Option<AuthSecret>, AuthError> {
877        self.store.get_secret(user_id, auth_entity, key).await
878    }
879
880    pub async fn remove_secret(
881        &self,
882        user_id: &str,
883        auth_entity: Option<&str>,
884        key: &str,
885    ) -> Result<bool, AuthError> {
886        self.store.remove_secret(user_id, auth_entity, key).await
887    }
888
889    pub async fn store_oauth2_state(&self, state: OAuth2State) -> Result<(), AuthError> {
890        self.store.store_oauth2_state(state).await
891    }
892
893    pub async fn get_oauth2_state(&self, state: &str) -> Result<Option<OAuth2State>, AuthError> {
894        self.store.get_oauth2_state(state).await
895    }
896
897    pub async fn remove_oauth2_state(&self, state: &str) -> Result<(), AuthError> {
898        self.store.remove_oauth2_state(state).await
899    }
900
901    pub async fn list_secrets(
902        &self,
903        user_id: &str,
904    ) -> Result<HashMap<String, AuthSecret>, AuthError> {
905        self.store.list_secrets(user_id).await
906    }
907
908    pub async fn list_sessions(
909        &self,
910        user_id: &str,
911    ) -> Result<HashMap<String, AuthSession>, AuthError> {
912        self.store.list_sessions(user_id).await
913    }
914}
915
916/// Authentication provider trait for different OAuth2 providers
917#[async_trait]
918pub trait AuthProvider: Send + Sync {
919    /// Provider name (e.g., "google", "github", "twitter")
920    fn provider_name(&self) -> &str;
921
922    /// Exchange authorization code for access token
923    async fn exchange_code(
924        &self,
925        code: &str,
926        redirect_uri: Option<&str>,
927        auth_config: &AuthType,
928        pkce_code_verifier: Option<&str>,
929    ) -> Result<AuthSession, AuthError>;
930
931    /// Refresh an access token
932    async fn refresh_token(
933        &self,
934        refresh_token: &str,
935        auth_config: &AuthType,
936    ) -> Result<AuthSession, AuthError>;
937
938    /// Build authorization URL. `extra_params` is appended verbatim to the
939    /// query string — no provider-name awareness in this layer; callers
940    /// pass through provider-specific knobs (e.g. Slack `team=`).
941    fn build_auth_url(
942        &self,
943        auth_config: &AuthType,
944        state: &str,
945        scopes: &[String],
946        redirect_uri: Option<&str>,
947        extra_params: &HashMap<String, String>,
948    ) -> Result<String, AuthError>;
949}