oauth2_passkey/oauth2/
types.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4use sqlx::FromRow;
5
6use super::errors::OAuth2Error;
7use super::main::IdInfo as GoogleIdInfo;
8
9use crate::session::UserId;
10use crate::storage::CacheData;
11
12/// Represents an OAuth2 account linked to a user
13///
14/// This struct contains information about an OAuth2 account that has been
15/// authenticated and linked to a user in the system. It stores both
16/// the provider-specific information and internal tracking data.
17#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
18pub struct OAuth2Account {
19    /// Unique identifier for this OAuth2 account in our system
20    pub id: String,
21    /// Internal user ID this OAuth2 account is linked to
22    pub user_id: String,
23    /// OAuth2 provider name (e.g., "google")
24    pub provider: String,
25    /// User identifier from the OAuth2 provider
26    pub provider_user_id: String,
27    /// User's display name from the OAuth2 provider
28    pub name: String,
29    /// User's email address from the OAuth2 provider
30    pub email: String,
31    /// Optional URL to user's profile picture
32    pub picture: Option<String>,
33    /// Additional provider-specific metadata as JSON
34    pub metadata: Value,
35    /// When this OAuth2 account was first linked
36    pub created_at: DateTime<Utc>,
37    /// When this OAuth2 account was last updated
38    pub updated_at: DateTime<Utc>,
39}
40
41impl Default for OAuth2Account {
42    fn default() -> Self {
43        Self {
44            id: String::new(),
45            user_id: String::new(),
46            provider: String::new(),
47            provider_user_id: String::new(),
48            name: String::new(),
49            email: String::new(),
50            picture: None,
51            metadata: Value::Null,
52            created_at: Utc::now(),
53            updated_at: Utc::now(),
54        }
55    }
56}
57
58// The user data we'll get back from Google
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub(crate) struct GoogleUserInfo {
61    pub(crate) sub: String,
62    pub(crate) family_name: String,
63    pub name: String,
64    pub picture: Option<String>,
65    pub(crate) email: String,
66    pub(crate) given_name: String,
67    pub(crate) hd: Option<String>,
68    pub(crate) email_verified: bool,
69}
70
71// Add these implementations
72impl From<GoogleUserInfo> for OAuth2Account {
73    fn from(google_user: GoogleUserInfo) -> Self {
74        Self {
75            id: String::new(),      // Will be set during storage
76            user_id: String::new(), // Will be set during upsert process
77            name: google_user.name,
78            email: google_user.email,
79            picture: google_user.picture,
80            provider: "google".to_string(),
81            provider_user_id: format!("google_{}", google_user.sub),
82            metadata: json!({
83                "family_name": google_user.family_name,
84                "given_name": google_user.given_name,
85                "hd": google_user.hd,
86                "email_verified": google_user.email_verified,
87            }),
88            created_at: Utc::now(),
89            updated_at: Utc::now(),
90        }
91    }
92}
93
94impl From<GoogleIdInfo> for OAuth2Account {
95    fn from(idinfo: GoogleIdInfo) -> Self {
96        Self {
97            id: String::new(),      // Will be set during storage
98            user_id: String::new(), // Will be set during upsert process
99            name: idinfo.name,
100            email: idinfo.email,
101            picture: idinfo.picture,
102            provider: "google".to_string(),
103            provider_user_id: format!("google_{}", idinfo.sub),
104            metadata: json!({
105                "family_name": idinfo.family_name,
106                "given_name": idinfo.given_name,
107                "hd": idinfo.hd,
108                "verified_email": idinfo.email_verified,
109            }),
110            created_at: Utc::now(),
111            updated_at: Utc::now(),
112        }
113    }
114}
115
116#[derive(Serialize, Deserialize, Debug, Clone)]
117pub(crate) struct StateParams {
118    pub(crate) csrf_id: String,
119    pub(crate) nonce_id: String,
120    pub(crate) pkce_id: String,
121    pub(crate) misc_id: Option<String>,
122    pub(crate) mode_id: Option<String>,
123}
124
125#[derive(Serialize, Clone, Deserialize, Debug)]
126pub(crate) struct StoredToken {
127    pub(crate) token: String,
128    pub(crate) expires_at: DateTime<Utc>,
129    pub(crate) user_agent: Option<String>,
130    pub(crate) ttl: u64,
131}
132
133/// Response from an OAuth2 authorization request
134///
135/// This struct represents the data received from an OAuth2 provider's
136/// authorization endpoint. It contains the authorization code and state
137/// parameter needed to complete the OAuth2 flow.
138#[derive(Debug, Deserialize)]
139pub struct AuthResponse {
140    /// Authorization code from the OAuth2 provider
141    pub(crate) code: String,
142    /// State parameter that was included in the original request
143    pub state: String,
144    /// Optional ID token if provided directly by the authorization endpoint
145    _id_token: Option<String>,
146}
147
148#[derive(Debug, Deserialize, Serialize)]
149pub(super) struct OidcTokenResponse {
150    pub(super) access_token: String,
151    token_type: String,
152    expires_in: u64,
153    refresh_token: Option<String>,
154    scope: String,
155    pub(super) id_token: Option<String>,
156}
157
158impl From<StoredToken> for CacheData {
159    fn from(data: StoredToken) -> Self {
160        Self {
161            value: serde_json::to_string(&data).expect("Failed to serialize StoredToken"),
162        }
163    }
164}
165
166impl TryFrom<CacheData> for StoredToken {
167    type Error = OAuth2Error;
168
169    fn try_from(data: CacheData) -> Result<Self, Self::Error> {
170        serde_json::from_str(&data.value).map_err(|e| OAuth2Error::Storage(e.to_string()))
171    }
172}
173
174/// Search field options for credential lookup
175#[allow(dead_code)]
176#[derive(Debug, PartialEq)]
177pub(crate) enum AccountSearchField {
178    /// Search by ID (type-safe)
179    Id(AccountId),
180    /// Search by user ID (database ID, type-safe)
181    UserId(UserId),
182    /// Search by provider (type-safe)
183    Provider(Provider),
184    /// Search by provider user ID (type-safe)
185    ProviderUserId(ProviderUserId),
186    /// Search by name (type-safe)
187    Name(DisplayName),
188    /// Search by email (type-safe)
189    Email(Email),
190}
191
192/// Mode of OAuth2 operation to explicitly indicate user intent.
193///
194/// This enum defines the available modes for OAuth2 authentication, determining
195/// the behavior when a user authenticates with an OAuth2 provider.
196#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
197#[serde(rename_all = "snake_case")]
198pub enum OAuth2Mode {
199    /// Add an OAuth2 account to an existing user.
200    ///
201    /// This mode is used when an authenticated user wants to link an additional
202    /// OAuth2 provider account to their existing account.
203    AddToUser,
204
205    /// Create a new user account from the OAuth2 provider data.
206    ///
207    /// This mode is used specifically for new user registration using OAuth2.
208    CreateUser,
209
210    /// Login with an existing OAuth2 account.
211    ///
212    /// This mode is used when a user wants to authenticate using a previously
213    /// linked OAuth2 provider account.
214    Login,
215
216    /// Create a new user if no matching account exists, otherwise login.
217    ///
218    /// This flexible mode attempts to login with an existing account if one matches
219    /// the OAuth2 provider data, or creates a new user account if none is found.
220    CreateUserOrLogin,
221}
222
223impl OAuth2Mode {
224    /// Converts the OAuth2Mode enum variant to its string representation.
225    ///
226    /// This method returns a static string representing the mode, which can be
227    /// used in URLs, API responses, or for logging purposes.
228    ///
229    /// # Returns
230    ///
231    /// * A string representation of the OAuth2Mode
232    pub fn as_str(&self) -> &'static str {
233        match self {
234            Self::AddToUser => "add_to_user",
235            Self::CreateUser => "create_user",
236            Self::Login => "login",
237            Self::CreateUserOrLogin => "create_user_or_login",
238        }
239    }
240}
241
242impl std::str::FromStr for OAuth2Mode {
243    type Err = OAuth2Error;
244
245    fn from_str(s: &str) -> Result<Self, Self::Err> {
246        match s {
247            "add_to_user" => Ok(Self::AddToUser),
248            "create_user" => Ok(Self::CreateUser),
249            "login" => Ok(Self::Login),
250            "create_user_or_login" => Ok(Self::CreateUserOrLogin),
251            _ => Err(OAuth2Error::InvalidMode(s.to_string())),
252        }
253    }
254}
255
256/// Type-safe wrapper for OAuth2 account identifiers.
257///
258/// This provides compile-time safety to prevent mixing up account IDs with other string types.
259/// Account IDs are database-specific identifiers for OAuth2 accounts.
260#[derive(Debug, Clone, PartialEq)]
261pub struct AccountId(String);
262
263impl AccountId {
264    /// Creates a new AccountId from a string with validation.
265    ///
266    /// # Arguments
267    /// * `id` - The account ID string
268    ///
269    /// # Returns
270    /// * `Ok(AccountId)` - If the ID is valid
271    /// * `Err(OAuth2Error)` - If the ID is invalid
272    ///
273    /// # Validation Rules
274    /// * Must not be empty
275    /// * Must contain only safe characters (alphanumeric + basic symbols)
276    /// * Must not contain control characters or dangerous sequences
277    pub fn new(id: String) -> Result<Self, crate::oauth2::OAuth2Error> {
278        use crate::oauth2::OAuth2Error;
279
280        // Validate ID is not empty
281        if id.is_empty() {
282            return Err(OAuth2Error::Validation(
283                "Account ID cannot be empty".to_string(),
284            ));
285        }
286
287        // Validate ID length (reasonable bounds)
288        if id.len() > 255 {
289            return Err(OAuth2Error::Validation("Account ID too long".to_string()));
290        }
291
292        // Validate ID contains only safe characters
293        if !id
294            .chars()
295            .all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.' | '@' | '+'))
296        {
297            return Err(OAuth2Error::Validation(
298                "Account ID contains invalid characters".to_string(),
299            ));
300        }
301
302        // Check for dangerous sequences
303        if id.contains("..") || id.contains("--") || id.contains("__") {
304            return Err(OAuth2Error::Validation(
305                "Account ID contains dangerous character sequences".to_string(),
306            ));
307        }
308
309        Ok(AccountId(id))
310    }
311
312    /// Returns the account ID as a string slice.
313    ///
314    /// # Returns
315    /// * A string slice containing the account ID
316    pub fn as_str(&self) -> &str {
317        &self.0
318    }
319}
320
321/// Type-safe wrapper for OAuth2 provider names.
322///
323/// This provides compile-time safety to prevent mixing up provider names with other string types.
324/// Provider names identify the OAuth2 service (e.g., "google", "github").
325#[derive(Debug, Clone, PartialEq)]
326pub struct Provider(String);
327
328impl Provider {
329    /// Creates a new Provider from a string with validation.
330    ///
331    /// # Arguments
332    /// * `provider` - The provider name string
333    ///
334    /// # Returns
335    /// * `Ok(Provider)` - If the provider name is valid
336    /// * `Err(OAuth2Error)` - If the provider name is invalid
337    ///
338    /// # Validation Rules
339    /// * Must not be empty
340    /// * Must contain only safe characters (alphanumeric, hyphens, underscores, periods)
341    /// * Must not start with special characters
342    pub fn new(provider: String) -> Result<Self, crate::oauth2::OAuth2Error> {
343        use crate::oauth2::OAuth2Error;
344
345        // Validate provider is not empty
346        if provider.is_empty() {
347            return Err(OAuth2Error::Validation(
348                "Provider name cannot be empty".to_string(),
349            ));
350        }
351
352        // Validate provider length (reasonable bounds for provider names)
353        if provider.len() > 50 {
354            return Err(OAuth2Error::Validation(
355                "Provider name too long".to_string(),
356            ));
357        }
358
359        // Validate provider contains only safe characters (alphanumeric, hyphens, underscores, periods)
360        // Must not start with special characters
361        if !provider
362            .chars()
363            .all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.'))
364        {
365            return Err(OAuth2Error::Validation(
366                "Provider name contains invalid characters".to_string(),
367            ));
368        }
369
370        if provider.starts_with('-') || provider.starts_with('_') || provider.starts_with('.') {
371            return Err(OAuth2Error::Validation(
372                "Provider name cannot start with special characters".to_string(),
373            ));
374        }
375
376        Ok(Provider(provider))
377    }
378
379    /// Returns the provider name as a string slice.
380    ///
381    /// # Returns
382    /// * A string slice containing the provider name
383    pub fn as_str(&self) -> &str {
384        &self.0
385    }
386}
387
388/// Type-safe wrapper for provider-specific user identifiers.
389///
390/// This provides compile-time safety to prevent mixing up provider user IDs with database user IDs.
391/// Provider user IDs are external identifiers from OAuth2 providers (e.g., Google user ID).
392#[derive(Debug, Clone, PartialEq)]
393pub struct ProviderUserId(String);
394
395impl ProviderUserId {
396    /// Creates a new ProviderUserId from a string with validation.
397    ///
398    /// # Arguments
399    /// * `id` - The provider user ID string
400    ///
401    /// # Returns
402    /// * `Ok(ProviderUserId)` - If the ID is valid
403    /// * `Err(OAuth2Error)` - If the ID is invalid
404    ///
405    /// # Validation Rules
406    /// * Must not be empty
407    /// * Must contain only safe characters (alphanumeric + basic symbols)
408    /// * Must not contain control characters or dangerous sequences
409    pub fn new(id: String) -> Result<Self, crate::oauth2::OAuth2Error> {
410        use crate::oauth2::OAuth2Error;
411
412        // Validate ID is not empty
413        if id.is_empty() {
414            return Err(OAuth2Error::Validation(
415                "Provider user ID cannot be empty".to_string(),
416            ));
417        }
418
419        // Validate ID length (provider IDs can be long but reasonable bounds)
420        if id.len() > 512 {
421            return Err(OAuth2Error::Validation(
422                "Provider user ID too long".to_string(),
423            ));
424        }
425
426        // Validate ID contains only safe characters
427        if !id.chars().all(|c| {
428            c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.' | '@' | '+' | '=' | '(' | ')')
429        }) {
430            return Err(OAuth2Error::Validation(
431                "Provider user ID contains invalid characters".to_string(),
432            ));
433        }
434
435        // Check for dangerous sequences
436        if id.contains("..") || id.contains("--") || id.contains("__") {
437            return Err(OAuth2Error::Validation(
438                "Provider user ID contains dangerous character sequences".to_string(),
439            ));
440        }
441
442        Ok(ProviderUserId(id))
443    }
444
445    /// Returns the provider user ID as a string slice.
446    ///
447    /// # Returns
448    /// * A string slice containing the provider user ID
449    pub fn as_str(&self) -> &str {
450        &self.0
451    }
452}
453
454/// Type-safe wrapper for user display names.
455///
456/// This provides compile-time safety to prevent mixing up display names with other string types.
457/// Display names are user-facing names from OAuth2 providers.
458#[derive(Debug, Clone, PartialEq)]
459pub struct DisplayName(String);
460
461impl DisplayName {
462    /// Creates a new DisplayName from a string with validation.
463    ///
464    /// This constructor is part of the public type-safe search API and is used
465    /// internally by the AccountSearchField enum for database queries.
466    ///
467    /// # Arguments
468    /// * `name` - The display name string
469    ///
470    /// # Returns
471    /// * `Ok(DisplayName)` - If the name is valid
472    /// * `Err(OAuth2Error)` - If the name is invalid
473    ///
474    /// # Validation Rules
475    /// * Must not be empty
476    /// * Must not consist only of whitespace
477    /// * Must not contain dangerous sequences
478    #[allow(dead_code)] // Part of type-safe search API, used in tests but not by library's public interface
479    pub fn new(name: String) -> Result<Self, crate::oauth2::OAuth2Error> {
480        use crate::oauth2::OAuth2Error;
481
482        // Validate name is not empty
483        if name.is_empty() {
484            return Err(OAuth2Error::Validation(
485                "Display name cannot be empty".to_string(),
486            ));
487        }
488
489        // Validate name length (reasonable bounds for display names)
490        if name.len() > 100 {
491            return Err(OAuth2Error::Validation("Display name too long".to_string()));
492        }
493
494        // Validate name doesn't consist only of whitespace
495        if name.trim().is_empty() {
496            return Err(OAuth2Error::Validation(
497                "Display name cannot consist only of whitespace".to_string(),
498            ));
499        }
500
501        // Check for dangerous sequences
502        if name.contains("..") || name.contains("--") || name.contains("__") {
503            return Err(OAuth2Error::Validation(
504                "Display name contains dangerous character sequences".to_string(),
505            ));
506        }
507
508        Ok(DisplayName(name))
509    }
510
511    /// Returns the display name as a string slice.
512    ///
513    /// # Returns
514    /// * A string slice containing the display name
515    pub fn as_str(&self) -> &str {
516        &self.0
517    }
518}
519
520/// Type-safe wrapper for email addresses.
521///
522/// This provides compile-time safety to prevent mixing up email addresses with other string types.
523/// Email addresses are provided by OAuth2 providers for user identification.
524#[derive(Debug, Clone, PartialEq)]
525pub struct Email(String);
526
527impl Email {
528    /// Creates a new Email from a string with validation.
529    ///
530    /// This constructor is part of the public type-safe search API and is used
531    /// internally by the AccountSearchField enum for database queries.
532    ///
533    /// # Arguments
534    /// * `email` - The email address string
535    ///
536    /// # Returns
537    /// * `Ok(Email)` - If the email is valid
538    /// * `Err(OAuth2Error)` - If the email is invalid
539    ///
540    /// # Validation Rules
541    /// * Must not be empty
542    /// * Must contain @ symbol
543    /// * Must have reasonable length
544    #[allow(dead_code)] // Part of type-safe search API, used in tests but not by library's public interface
545    pub fn new(email: String) -> Result<Self, crate::oauth2::OAuth2Error> {
546        use crate::oauth2::OAuth2Error;
547
548        // Validate email is not empty
549        if email.is_empty() {
550            return Err(OAuth2Error::Validation("Email cannot be empty".to_string()));
551        }
552
553        // Validate email length (RFC 5321 limits: maximum 254 characters)
554        if email.len() < 3 {
555            return Err(OAuth2Error::Validation("Email too short".to_string()));
556        }
557
558        if email.len() > 254 {
559            return Err(OAuth2Error::Validation("Email too long".to_string()));
560        }
561
562        // Basic email format validation (must contain @ and reasonable structure)
563        if !email.contains('@') {
564            return Err(OAuth2Error::Validation(
565                "Email must contain @ symbol".to_string(),
566            ));
567        }
568
569        let parts: Vec<&str> = email.split('@').collect();
570        if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
571            return Err(OAuth2Error::Validation(
572                "Email format is invalid".to_string(),
573            ));
574        }
575
576        Ok(Email(email))
577    }
578
579    /// Returns the email address as a string slice.
580    ///
581    /// # Returns
582    /// * A string slice containing the email address
583    pub fn as_str(&self) -> &str {
584        &self.0
585    }
586}
587
588/// Type-safe wrapper for OAuth2 state parameters.
589///
590/// This provides compile-time safety to prevent mixing up OAuth2 state strings with other string types.
591/// OAuth2 state parameters are base64url-encoded JSON that carries CSRF protection and flow parameters
592/// between authorization requests and callbacks. Proper validation is critical for security.
593#[derive(Debug, Clone, PartialEq)]
594pub struct OAuth2State(String);
595
596impl std::fmt::Display for OAuth2State {
597    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
598        write!(f, "{}", self.0)
599    }
600}
601
602impl OAuth2State {
603    /// Creates a new OAuth2State from a string with validation.
604    ///
605    /// This constructor validates the OAuth2 state format to ensure it meets
606    /// security requirements for CSRF protection and parameter integrity.
607    ///
608    /// # Arguments
609    /// * `state` - The OAuth2 state string (should be base64url-encoded)
610    ///
611    /// # Returns
612    /// * `Ok(OAuth2State)` - If the state is valid
613    /// * `Err(OAuth2Error)` - If the state is invalid
614    ///
615    /// # Validation Rules
616    /// * Must not be empty
617    /// * Must be valid base64url encoding
618    /// * Must contain valid JSON when decoded
619    /// * Must be reasonable length
620    pub fn new(state: String) -> Result<Self, super::errors::OAuth2Error> {
621        use super::errors::OAuth2Error;
622        use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
623
624        // Validate state is not empty
625        if state.is_empty() {
626            return Err(OAuth2Error::DecodeState(
627                "OAuth2 state cannot be empty".to_string(),
628            ));
629        }
630
631        // Validate state length (reasonable bounds)
632        if state.len() < 10 {
633            return Err(OAuth2Error::DecodeState(
634                "OAuth2 state too short".to_string(),
635            ));
636        }
637
638        if state.len() > 8192 {
639            return Err(OAuth2Error::DecodeState(
640                "OAuth2 state too long".to_string(),
641            ));
642        }
643
644        // Validate state is valid base64url
645        let decoded_bytes = URL_SAFE_NO_PAD
646            .decode(&state)
647            .map_err(|e| OAuth2Error::DecodeState(format!("Invalid base64url encoding: {e}")))?;
648
649        // Validate decoded content is valid UTF-8
650        let decoded_string = String::from_utf8(decoded_bytes).map_err(|e| {
651            OAuth2Error::DecodeState(format!("Invalid UTF-8 in decoded state: {e}"))
652        })?;
653
654        // Validate decoded content is valid JSON
655        let _: StateParams = serde_json::from_str(&decoded_string)
656            .map_err(|e| OAuth2Error::DecodeState(format!("Invalid JSON in decoded state: {e}")))?;
657
658        Ok(OAuth2State(state))
659    }
660
661    /// Returns the OAuth2 state as a string slice.
662    ///
663    /// # Returns
664    /// * A string slice containing the OAuth2 state
665    pub fn as_str(&self) -> &str {
666        &self.0
667    }
668
669    /// Checks if the state contains a substring.
670    ///
671    /// # Arguments
672    /// * `needle` - The substring to search for
673    ///
674    /// # Returns
675    /// * `true` if the substring is found, `false` otherwise
676    pub fn contains(&self, needle: char) -> bool {
677        self.0.contains(needle)
678    }
679}
680
681/// Type-safe wrapper for OAuth2 token types.
682///
683/// This enum provides compile-time safety to prevent mixing up different types of OAuth2 tokens.
684/// It ensures that token types are clearly defined and prevents typos in token type strings.
685#[derive(Debug, Clone, Copy, PartialEq)]
686pub enum TokenType {
687    /// CSRF protection token for OAuth2 authorization flow
688    Csrf,
689    /// Nonce token for OpenID Connect
690    Nonce,
691    /// PKCE (Proof Key for Code Exchange) verifier token
692    Pkce,
693}
694
695impl std::fmt::Display for TokenType {
696    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
697        write!(f, "{}", self.as_str())
698    }
699}
700
701impl TokenType {
702    /// Returns the token type as a string slice.
703    ///
704    /// # Returns
705    /// * A string slice containing the token type name
706    pub fn as_str(&self) -> &str {
707        match self {
708            TokenType::Csrf => "csrf",
709            TokenType::Nonce => "nonce",
710            TokenType::Pkce => "pkce",
711        }
712    }
713}
714
715#[cfg(test)]
716mod tests;