acton_htmx/oauth2/
types.rs

1//! Core OAuth2 types and configuration
2//!
3//! This module defines the foundational types for OAuth2 authentication,
4//! including provider configurations, tokens, and user information.
5
6use oauth2::basic::BasicClient;
7use oauth2::{EndpointNotSet, EndpointSet};
8use serde::{Deserialize, Serialize};
9use std::str::FromStr;
10use std::time::{Duration, SystemTime};
11
12/// Type alias for a configured OAuth2 client with auth and token endpoints set
13///
14/// This is the standard client type used by all OAuth2 providers (Google, GitHub, OIDC).
15/// The type parameters indicate which endpoints are configured:
16/// - `EndpointSet` for `HasAuthUrl` - Authorization endpoint is configured
17/// - `EndpointNotSet` for `HasDeviceAuthUrl` - Device auth not used
18/// - `EndpointNotSet` for `HasIntrospectionUrl` - Token introspection not used
19/// - `EndpointNotSet` for `HasRevocationUrl` - Token revocation not used
20/// - `EndpointSet` for `HasTokenUrl` - Token exchange endpoint is configured
21pub type ConfiguredClient = BasicClient<
22    EndpointSet,    // HasAuthUrl
23    EndpointNotSet, // HasDeviceAuthUrl
24    EndpointNotSet, // HasIntrospectionUrl
25    EndpointNotSet, // HasRevocationUrl
26    EndpointSet,    // HasTokenUrl
27>;
28
29/// OAuth2 provider identifier
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
31#[serde(rename_all = "lowercase")]
32pub enum OAuthProvider {
33    /// Google OAuth2
34    Google,
35    /// GitHub OAuth2
36    GitHub,
37    /// Generic OpenID Connect provider
38    Oidc,
39}
40
41impl OAuthProvider {
42    /// Get the provider as a string (lowercase)
43    #[must_use]
44    pub const fn as_str(&self) -> &'static str {
45        match self {
46            Self::Google => "google",
47            Self::GitHub => "github",
48            Self::Oidc => "oidc",
49        }
50    }
51}
52
53impl FromStr for OAuthProvider {
54    type Err = OAuthError;
55
56    fn from_str(s: &str) -> Result<Self, Self::Err> {
57        match s.to_lowercase().as_str() {
58            "google" => Ok(Self::Google),
59            "github" => Ok(Self::GitHub),
60            "oidc" => Ok(Self::Oidc),
61            _ => Err(OAuthError::UnknownProvider(s.to_string())),
62        }
63    }
64}
65
66/// Configuration for an OAuth2 provider
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ProviderConfig {
69    /// OAuth2 client ID
70    pub client_id: String,
71    /// OAuth2 client secret
72    pub client_secret: String,
73    /// Redirect URI (callback URL)
74    pub redirect_uri: String,
75    /// OAuth2 scopes to request
76    pub scopes: Vec<String>,
77    /// Authorization endpoint (for generic OIDC)
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub auth_url: Option<String>,
80    /// Token endpoint (for generic OIDC)
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub token_url: Option<String>,
83    /// UserInfo endpoint (for generic OIDC)
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub userinfo_url: Option<String>,
86}
87
88/// Complete OAuth2 configuration for all providers
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct OAuthConfig {
91    /// Google OAuth2 configuration
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub google: Option<ProviderConfig>,
94    /// GitHub OAuth2 configuration
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub github: Option<ProviderConfig>,
97    /// Generic OIDC configuration
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub oidc: Option<ProviderConfig>,
100}
101
102impl OAuthConfig {
103    /// Create a new empty OAuth2 configuration
104    #[must_use]
105    pub const fn new() -> Self {
106        Self {
107            google: None,
108            github: None,
109            oidc: None,
110        }
111    }
112
113    /// Get a reference to the provider configuration option
114    ///
115    /// This is a helper method to reduce code duplication in provider lookups.
116    /// Returns `Option<&ProviderConfig>` following Rust idioms for optional references.
117    const fn provider_config(&self, provider: OAuthProvider) -> Option<&ProviderConfig> {
118        match provider {
119            OAuthProvider::Google => self.google.as_ref(),
120            OAuthProvider::GitHub => self.github.as_ref(),
121            OAuthProvider::Oidc => self.oidc.as_ref(),
122        }
123    }
124
125    /// Get configuration for a specific provider
126    ///
127    /// # Errors
128    ///
129    /// Returns error if the provider is not configured
130    pub fn get_provider(&self, provider: OAuthProvider) -> Result<&ProviderConfig, OAuthError> {
131        self.provider_config(provider)
132            .ok_or(OAuthError::ProviderNotConfigured(provider))
133    }
134
135    /// Check if a provider is configured
136    #[must_use]
137    pub const fn is_provider_configured(&self, provider: OAuthProvider) -> bool {
138        self.provider_config(provider).is_some()
139    }
140}
141
142impl Default for OAuthConfig {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148/// OAuth2 CSRF state token
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct OAuthState {
151    /// The state token
152    pub token: String,
153    /// Provider for this state
154    pub provider: OAuthProvider,
155    /// When the state expires
156    pub expires_at: SystemTime,
157}
158
159impl OAuthState {
160    /// Generate a new state token
161    #[must_use]
162    pub fn generate(provider: OAuthProvider) -> Self {
163        use rand::Rng;
164
165        // Generate 32 bytes of random data and encode as hex
166        let random_bytes: [u8; 32] = rand::rng().random();
167        let token = hex::encode(random_bytes);
168
169        Self {
170            token,
171            provider,
172            expires_at: SystemTime::now() + Duration::from_secs(600), // 10 minutes
173        }
174    }
175
176    /// Check if the state token has expired
177    #[must_use]
178    pub fn is_expired(&self) -> bool {
179        SystemTime::now() > self.expires_at
180    }
181}
182
183/// OAuth2 access token
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct OAuthToken {
186    /// Access token
187    pub access_token: String,
188    /// Refresh token (if provided)
189    #[serde(skip_serializing_if = "Option::is_none")]
190    pub refresh_token: Option<String>,
191    /// Token type (usually "Bearer")
192    pub token_type: String,
193    /// When the token expires
194    #[serde(skip_serializing_if = "Option::is_none")]
195    pub expires_at: Option<SystemTime>,
196    /// OAuth2 scopes granted
197    #[serde(skip_serializing_if = "Option::is_none")]
198    pub scopes: Option<Vec<String>>,
199}
200
201impl OAuthToken {
202    /// Check if the access token has expired
203    #[must_use]
204    pub fn is_expired(&self) -> bool {
205        self.expires_at
206            .is_some_and(|expires| SystemTime::now() > expires)
207    }
208}
209
210/// User information from OAuth2 provider
211#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct OAuthUserInfo {
213    /// Provider-specific user ID
214    pub provider_user_id: String,
215    /// Email address
216    pub email: String,
217    /// Display name
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub name: Option<String>,
220    /// Avatar/profile picture URL
221    #[serde(skip_serializing_if = "Option::is_none")]
222    pub avatar_url: Option<String>,
223    /// Whether email is verified
224    pub email_verified: bool,
225}
226
227/// OAuth2 errors
228#[derive(Debug, thiserror::Error)]
229pub enum OAuthError {
230    /// Unknown provider
231    #[error("Unknown OAuth2 provider: {0}")]
232    UnknownProvider(String),
233
234    /// Provider not configured
235    #[error("OAuth2 provider not configured: {0:?}")]
236    ProviderNotConfigured(OAuthProvider),
237
238    /// Invalid state token
239    #[error("Invalid or expired OAuth2 state token")]
240    InvalidState,
241
242    /// State token mismatch (potential CSRF attack)
243    #[error("OAuth2 state token mismatch (potential CSRF attack)")]
244    StateMismatch,
245
246    /// Authorization code exchange failed
247    #[error("Failed to exchange authorization code for token: {0}")]
248    TokenExchangeFailed(String),
249
250    /// Failed to fetch user info
251    #[error("Failed to fetch user information: {0}")]
252    UserInfoFailed(String),
253
254    /// Token expired
255    #[error("OAuth2 token has expired")]
256    TokenExpired,
257
258    /// Generic OAuth2 error
259    #[error("OAuth2 error: {0}")]
260    Generic(String),
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn test_provider_as_str() {
269        assert_eq!(OAuthProvider::Google.as_str(), "google");
270        assert_eq!(OAuthProvider::GitHub.as_str(), "github");
271        assert_eq!(OAuthProvider::Oidc.as_str(), "oidc");
272    }
273
274    #[test]
275    fn test_provider_from_str() {
276        assert_eq!(
277            "google".parse::<OAuthProvider>().unwrap(),
278            OAuthProvider::Google
279        );
280        assert_eq!(
281            "GOOGLE".parse::<OAuthProvider>().unwrap(),
282            OAuthProvider::Google
283        );
284        assert_eq!(
285            "github".parse::<OAuthProvider>().unwrap(),
286            OAuthProvider::GitHub
287        );
288        assert_eq!(
289            "oidc".parse::<OAuthProvider>().unwrap(),
290            OAuthProvider::Oidc
291        );
292        assert!("invalid".parse::<OAuthProvider>().is_err());
293    }
294
295    #[test]
296    fn test_oauth_config_default() {
297        let config = OAuthConfig::default();
298        assert!(config.google.is_none());
299        assert!(config.github.is_none());
300        assert!(config.oidc.is_none());
301    }
302
303    #[test]
304    fn test_oauth_config_is_provider_configured() {
305        let mut config = OAuthConfig::default();
306        assert!(!config.is_provider_configured(OAuthProvider::Google));
307
308        config.google = Some(ProviderConfig {
309            client_id: "test".to_string(),
310            client_secret: "test".to_string(),
311            redirect_uri: "http://localhost/callback".to_string(),
312            scopes: vec!["email".to_string()],
313            auth_url: None,
314            token_url: None,
315            userinfo_url: None,
316        });
317
318        assert!(config.is_provider_configured(OAuthProvider::Google));
319        assert!(!config.is_provider_configured(OAuthProvider::GitHub));
320    }
321
322    #[test]
323    fn test_oauth_state_generation() {
324        let state = OAuthState::generate(OAuthProvider::Google);
325        assert_eq!(state.provider, OAuthProvider::Google);
326        assert!(!state.is_expired());
327        assert_eq!(state.token.len(), 64); // 32 bytes encoded as hex
328    }
329
330    #[test]
331    fn test_oauth_token_is_expired() {
332        let token = OAuthToken {
333            access_token: "test".to_string(),
334            refresh_token: None,
335            token_type: "Bearer".to_string(),
336            expires_at: None,
337            scopes: None,
338        };
339        assert!(!token.is_expired());
340
341        let expired_token = OAuthToken {
342            access_token: "test".to_string(),
343            refresh_token: None,
344            token_type: "Bearer".to_string(),
345            expires_at: Some(SystemTime::now() - Duration::from_secs(3600)),
346            scopes: None,
347        };
348        assert!(expired_token.is_expired());
349    }
350}