oauth2_passkey/oauth2/types.rs
1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4use sqlx::FromRow;
5
6use super::errors::OAuth2Error;
7use super::main::IdInfo as GoogleIdInfo;
8
9use crate::session::UserId;
10use crate::storage::CacheData;
11
12/// Represents an OAuth2 account linked to a user
13///
14/// This struct contains information about an OAuth2 account that has been
15/// authenticated and linked to a user in the system. It stores both
16/// the provider-specific information and internal tracking data.
17#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
18pub struct OAuth2Account {
19 /// Unique identifier for this OAuth2 account in our system
20 pub id: String,
21 /// Internal user ID this OAuth2 account is linked to
22 pub user_id: String,
23 /// OAuth2 provider name (e.g., "google")
24 pub provider: String,
25 /// User identifier from the OAuth2 provider
26 pub provider_user_id: String,
27 /// User's display name from the OAuth2 provider
28 pub name: String,
29 /// User's email address from the OAuth2 provider
30 pub email: String,
31 /// Optional URL to user's profile picture
32 pub picture: Option<String>,
33 /// Additional provider-specific metadata as JSON
34 pub metadata: Value,
35 /// When this OAuth2 account was first linked
36 pub created_at: DateTime<Utc>,
37 /// When this OAuth2 account was last updated
38 pub updated_at: DateTime<Utc>,
39}
40
41impl Default for OAuth2Account {
42 fn default() -> Self {
43 Self {
44 id: String::new(),
45 user_id: String::new(),
46 provider: String::new(),
47 provider_user_id: String::new(),
48 name: String::new(),
49 email: String::new(),
50 picture: None,
51 metadata: Value::Null,
52 created_at: Utc::now(),
53 updated_at: Utc::now(),
54 }
55 }
56}
57
58// The user data we'll get back from Google
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub(crate) struct GoogleUserInfo {
61 pub(crate) sub: String,
62 pub(crate) family_name: String,
63 pub name: String,
64 pub picture: Option<String>,
65 pub(crate) email: String,
66 pub(crate) given_name: String,
67 pub(crate) hd: Option<String>,
68 pub(crate) email_verified: bool,
69}
70
71// Add these implementations
72impl From<GoogleUserInfo> for OAuth2Account {
73 fn from(google_user: GoogleUserInfo) -> Self {
74 Self {
75 id: String::new(), // Will be set during storage
76 user_id: String::new(), // Will be set during upsert process
77 name: google_user.name,
78 email: google_user.email,
79 picture: google_user.picture,
80 provider: "google".to_string(),
81 provider_user_id: format!("google_{}", google_user.sub),
82 metadata: json!({
83 "family_name": google_user.family_name,
84 "given_name": google_user.given_name,
85 "hd": google_user.hd,
86 "email_verified": google_user.email_verified,
87 }),
88 created_at: Utc::now(),
89 updated_at: Utc::now(),
90 }
91 }
92}
93
94impl From<GoogleIdInfo> for OAuth2Account {
95 fn from(idinfo: GoogleIdInfo) -> Self {
96 Self {
97 id: String::new(), // Will be set during storage
98 user_id: String::new(), // Will be set during upsert process
99 name: idinfo.name,
100 email: idinfo.email,
101 picture: idinfo.picture,
102 provider: "google".to_string(),
103 provider_user_id: format!("google_{}", idinfo.sub),
104 metadata: json!({
105 "family_name": idinfo.family_name,
106 "given_name": idinfo.given_name,
107 "hd": idinfo.hd,
108 "verified_email": idinfo.email_verified,
109 }),
110 created_at: Utc::now(),
111 updated_at: Utc::now(),
112 }
113 }
114}
115
116#[derive(Serialize, Deserialize, Debug, Clone)]
117pub(crate) struct StateParams {
118 pub(crate) csrf_id: String,
119 pub(crate) nonce_id: String,
120 pub(crate) pkce_id: String,
121 pub(crate) misc_id: Option<String>,
122 pub(crate) mode_id: Option<String>,
123}
124
125#[derive(Serialize, Clone, Deserialize, Debug)]
126pub(crate) struct StoredToken {
127 pub(crate) token: String,
128 pub(crate) expires_at: DateTime<Utc>,
129 pub(crate) user_agent: Option<String>,
130 pub(crate) ttl: u64,
131}
132
133/// Response from an OAuth2 authorization request
134///
135/// This struct represents the data received from an OAuth2 provider's
136/// authorization endpoint. It contains the authorization code and state
137/// parameter needed to complete the OAuth2 flow.
138#[derive(Debug, Deserialize)]
139pub struct AuthResponse {
140 /// Authorization code from the OAuth2 provider
141 pub(crate) code: String,
142 /// State parameter that was included in the original request
143 pub state: String,
144 /// Optional ID token if provided directly by the authorization endpoint
145 _id_token: Option<String>,
146}
147
148#[derive(Debug, Deserialize, Serialize)]
149pub(super) struct OidcTokenResponse {
150 pub(super) access_token: String,
151 token_type: String,
152 expires_in: u64,
153 refresh_token: Option<String>,
154 scope: String,
155 pub(super) id_token: Option<String>,
156}
157
158impl From<StoredToken> for CacheData {
159 fn from(data: StoredToken) -> Self {
160 Self {
161 value: serde_json::to_string(&data).expect("Failed to serialize StoredToken"),
162 }
163 }
164}
165
166impl TryFrom<CacheData> for StoredToken {
167 type Error = OAuth2Error;
168
169 fn try_from(data: CacheData) -> Result<Self, Self::Error> {
170 serde_json::from_str(&data.value).map_err(|e| OAuth2Error::Storage(e.to_string()))
171 }
172}
173
174/// Search field options for credential lookup
175#[allow(dead_code)]
176#[derive(Debug, PartialEq)]
177pub(crate) enum AccountSearchField {
178 /// Search by ID (type-safe)
179 Id(AccountId),
180 /// Search by user ID (database ID, type-safe)
181 UserId(UserId),
182 /// Search by provider (type-safe)
183 Provider(Provider),
184 /// Search by provider user ID (type-safe)
185 ProviderUserId(ProviderUserId),
186 /// Search by name (type-safe)
187 Name(DisplayName),
188 /// Search by email (type-safe)
189 Email(Email),
190}
191
192/// Mode of OAuth2 operation to explicitly indicate user intent.
193///
194/// This enum defines the available modes for OAuth2 authentication, determining
195/// the behavior when a user authenticates with an OAuth2 provider.
196#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
197#[serde(rename_all = "snake_case")]
198pub enum OAuth2Mode {
199 /// Add an OAuth2 account to an existing user.
200 ///
201 /// This mode is used when an authenticated user wants to link an additional
202 /// OAuth2 provider account to their existing account.
203 AddToUser,
204
205 /// Create a new user account from the OAuth2 provider data.
206 ///
207 /// This mode is used specifically for new user registration using OAuth2.
208 CreateUser,
209
210 /// Login with an existing OAuth2 account.
211 ///
212 /// This mode is used when a user wants to authenticate using a previously
213 /// linked OAuth2 provider account.
214 Login,
215
216 /// Create a new user if no matching account exists, otherwise login.
217 ///
218 /// This flexible mode attempts to login with an existing account if one matches
219 /// the OAuth2 provider data, or creates a new user account if none is found.
220 CreateUserOrLogin,
221}
222
223impl OAuth2Mode {
224 /// Converts the OAuth2Mode enum variant to its string representation.
225 ///
226 /// This method returns a static string representing the mode, which can be
227 /// used in URLs, API responses, or for logging purposes.
228 ///
229 /// # Returns
230 ///
231 /// * A string representation of the OAuth2Mode
232 pub fn as_str(&self) -> &'static str {
233 match self {
234 Self::AddToUser => "add_to_user",
235 Self::CreateUser => "create_user",
236 Self::Login => "login",
237 Self::CreateUserOrLogin => "create_user_or_login",
238 }
239 }
240}
241
242impl std::str::FromStr for OAuth2Mode {
243 type Err = OAuth2Error;
244
245 fn from_str(s: &str) -> Result<Self, Self::Err> {
246 match s {
247 "add_to_user" => Ok(Self::AddToUser),
248 "create_user" => Ok(Self::CreateUser),
249 "login" => Ok(Self::Login),
250 "create_user_or_login" => Ok(Self::CreateUserOrLogin),
251 _ => Err(OAuth2Error::InvalidMode(s.to_string())),
252 }
253 }
254}
255
256/// Type-safe wrapper for OAuth2 account identifiers.
257///
258/// This provides compile-time safety to prevent mixing up account IDs with other string types.
259/// Account IDs are database-specific identifiers for OAuth2 accounts.
260#[derive(Debug, Clone, PartialEq)]
261pub struct AccountId(String);
262
263impl AccountId {
264 /// Creates a new AccountId from a string with validation.
265 ///
266 /// # Arguments
267 /// * `id` - The account ID string
268 ///
269 /// # Returns
270 /// * `Ok(AccountId)` - If the ID is valid
271 /// * `Err(OAuth2Error)` - If the ID is invalid
272 ///
273 /// # Validation Rules
274 /// * Must not be empty
275 /// * Must contain only safe characters (alphanumeric + basic symbols)
276 /// * Must not contain control characters or dangerous sequences
277 pub fn new(id: String) -> Result<Self, crate::oauth2::OAuth2Error> {
278 use crate::oauth2::OAuth2Error;
279
280 // Validate ID is not empty
281 if id.is_empty() {
282 return Err(OAuth2Error::Validation(
283 "Account ID cannot be empty".to_string(),
284 ));
285 }
286
287 // Validate ID length (reasonable bounds)
288 if id.len() > 255 {
289 return Err(OAuth2Error::Validation("Account ID too long".to_string()));
290 }
291
292 // Validate ID contains only safe characters
293 if !id
294 .chars()
295 .all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.' | '@' | '+'))
296 {
297 return Err(OAuth2Error::Validation(
298 "Account ID contains invalid characters".to_string(),
299 ));
300 }
301
302 // Check for dangerous sequences
303 if id.contains("..") || id.contains("--") || id.contains("__") {
304 return Err(OAuth2Error::Validation(
305 "Account ID contains dangerous character sequences".to_string(),
306 ));
307 }
308
309 Ok(AccountId(id))
310 }
311
312 /// Returns the account ID as a string slice.
313 ///
314 /// # Returns
315 /// * A string slice containing the account ID
316 pub fn as_str(&self) -> &str {
317 &self.0
318 }
319}
320
321/// Type-safe wrapper for OAuth2 provider names.
322///
323/// This provides compile-time safety to prevent mixing up provider names with other string types.
324/// Provider names identify the OAuth2 service (e.g., "google", "github").
325#[derive(Debug, Clone, PartialEq)]
326pub struct Provider(String);
327
328impl Provider {
329 /// Creates a new Provider from a string with validation.
330 ///
331 /// # Arguments
332 /// * `provider` - The provider name string
333 ///
334 /// # Returns
335 /// * `Ok(Provider)` - If the provider name is valid
336 /// * `Err(OAuth2Error)` - If the provider name is invalid
337 ///
338 /// # Validation Rules
339 /// * Must not be empty
340 /// * Must contain only safe characters (alphanumeric, hyphens, underscores, periods)
341 /// * Must not start with special characters
342 pub fn new(provider: String) -> Result<Self, crate::oauth2::OAuth2Error> {
343 use crate::oauth2::OAuth2Error;
344
345 // Validate provider is not empty
346 if provider.is_empty() {
347 return Err(OAuth2Error::Validation(
348 "Provider name cannot be empty".to_string(),
349 ));
350 }
351
352 // Validate provider length (reasonable bounds for provider names)
353 if provider.len() > 50 {
354 return Err(OAuth2Error::Validation(
355 "Provider name too long".to_string(),
356 ));
357 }
358
359 // Validate provider contains only safe characters (alphanumeric, hyphens, underscores, periods)
360 // Must not start with special characters
361 if !provider
362 .chars()
363 .all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.'))
364 {
365 return Err(OAuth2Error::Validation(
366 "Provider name contains invalid characters".to_string(),
367 ));
368 }
369
370 if provider.starts_with('-') || provider.starts_with('_') || provider.starts_with('.') {
371 return Err(OAuth2Error::Validation(
372 "Provider name cannot start with special characters".to_string(),
373 ));
374 }
375
376 Ok(Provider(provider))
377 }
378
379 /// Returns the provider name as a string slice.
380 ///
381 /// # Returns
382 /// * A string slice containing the provider name
383 pub fn as_str(&self) -> &str {
384 &self.0
385 }
386}
387
388/// Type-safe wrapper for provider-specific user identifiers.
389///
390/// This provides compile-time safety to prevent mixing up provider user IDs with database user IDs.
391/// Provider user IDs are external identifiers from OAuth2 providers (e.g., Google user ID).
392#[derive(Debug, Clone, PartialEq)]
393pub struct ProviderUserId(String);
394
395impl ProviderUserId {
396 /// Creates a new ProviderUserId from a string with validation.
397 ///
398 /// # Arguments
399 /// * `id` - The provider user ID string
400 ///
401 /// # Returns
402 /// * `Ok(ProviderUserId)` - If the ID is valid
403 /// * `Err(OAuth2Error)` - If the ID is invalid
404 ///
405 /// # Validation Rules
406 /// * Must not be empty
407 /// * Must contain only safe characters (alphanumeric + basic symbols)
408 /// * Must not contain control characters or dangerous sequences
409 pub fn new(id: String) -> Result<Self, crate::oauth2::OAuth2Error> {
410 use crate::oauth2::OAuth2Error;
411
412 // Validate ID is not empty
413 if id.is_empty() {
414 return Err(OAuth2Error::Validation(
415 "Provider user ID cannot be empty".to_string(),
416 ));
417 }
418
419 // Validate ID length (provider IDs can be long but reasonable bounds)
420 if id.len() > 512 {
421 return Err(OAuth2Error::Validation(
422 "Provider user ID too long".to_string(),
423 ));
424 }
425
426 // Validate ID contains only safe characters
427 if !id.chars().all(|c| {
428 c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.' | '@' | '+' | '=' | '(' | ')')
429 }) {
430 return Err(OAuth2Error::Validation(
431 "Provider user ID contains invalid characters".to_string(),
432 ));
433 }
434
435 // Check for dangerous sequences
436 if id.contains("..") || id.contains("--") || id.contains("__") {
437 return Err(OAuth2Error::Validation(
438 "Provider user ID contains dangerous character sequences".to_string(),
439 ));
440 }
441
442 Ok(ProviderUserId(id))
443 }
444
445 /// Returns the provider user ID as a string slice.
446 ///
447 /// # Returns
448 /// * A string slice containing the provider user ID
449 pub fn as_str(&self) -> &str {
450 &self.0
451 }
452}
453
454/// Type-safe wrapper for user display names.
455///
456/// This provides compile-time safety to prevent mixing up display names with other string types.
457/// Display names are user-facing names from OAuth2 providers.
458#[derive(Debug, Clone, PartialEq)]
459pub struct DisplayName(String);
460
461impl DisplayName {
462 /// Creates a new DisplayName from a string with validation.
463 ///
464 /// This constructor is part of the public type-safe search API and is used
465 /// internally by the AccountSearchField enum for database queries.
466 ///
467 /// # Arguments
468 /// * `name` - The display name string
469 ///
470 /// # Returns
471 /// * `Ok(DisplayName)` - If the name is valid
472 /// * `Err(OAuth2Error)` - If the name is invalid
473 ///
474 /// # Validation Rules
475 /// * Must not be empty
476 /// * Must not consist only of whitespace
477 /// * Must not contain dangerous sequences
478 #[allow(dead_code)] // Part of type-safe search API, used in tests but not by library's public interface
479 pub fn new(name: String) -> Result<Self, crate::oauth2::OAuth2Error> {
480 use crate::oauth2::OAuth2Error;
481
482 // Validate name is not empty
483 if name.is_empty() {
484 return Err(OAuth2Error::Validation(
485 "Display name cannot be empty".to_string(),
486 ));
487 }
488
489 // Validate name length (reasonable bounds for display names)
490 if name.len() > 100 {
491 return Err(OAuth2Error::Validation("Display name too long".to_string()));
492 }
493
494 // Validate name doesn't consist only of whitespace
495 if name.trim().is_empty() {
496 return Err(OAuth2Error::Validation(
497 "Display name cannot consist only of whitespace".to_string(),
498 ));
499 }
500
501 // Check for dangerous sequences
502 if name.contains("..") || name.contains("--") || name.contains("__") {
503 return Err(OAuth2Error::Validation(
504 "Display name contains dangerous character sequences".to_string(),
505 ));
506 }
507
508 Ok(DisplayName(name))
509 }
510
511 /// Returns the display name as a string slice.
512 ///
513 /// # Returns
514 /// * A string slice containing the display name
515 pub fn as_str(&self) -> &str {
516 &self.0
517 }
518}
519
520/// Type-safe wrapper for email addresses.
521///
522/// This provides compile-time safety to prevent mixing up email addresses with other string types.
523/// Email addresses are provided by OAuth2 providers for user identification.
524#[derive(Debug, Clone, PartialEq)]
525pub struct Email(String);
526
527impl Email {
528 /// Creates a new Email from a string with validation.
529 ///
530 /// This constructor is part of the public type-safe search API and is used
531 /// internally by the AccountSearchField enum for database queries.
532 ///
533 /// # Arguments
534 /// * `email` - The email address string
535 ///
536 /// # Returns
537 /// * `Ok(Email)` - If the email is valid
538 /// * `Err(OAuth2Error)` - If the email is invalid
539 ///
540 /// # Validation Rules
541 /// * Must not be empty
542 /// * Must contain @ symbol
543 /// * Must have reasonable length
544 #[allow(dead_code)] // Part of type-safe search API, used in tests but not by library's public interface
545 pub fn new(email: String) -> Result<Self, crate::oauth2::OAuth2Error> {
546 use crate::oauth2::OAuth2Error;
547
548 // Validate email is not empty
549 if email.is_empty() {
550 return Err(OAuth2Error::Validation("Email cannot be empty".to_string()));
551 }
552
553 // Validate email length (RFC 5321 limits: maximum 254 characters)
554 if email.len() < 3 {
555 return Err(OAuth2Error::Validation("Email too short".to_string()));
556 }
557
558 if email.len() > 254 {
559 return Err(OAuth2Error::Validation("Email too long".to_string()));
560 }
561
562 // Basic email format validation (must contain @ and reasonable structure)
563 if !email.contains('@') {
564 return Err(OAuth2Error::Validation(
565 "Email must contain @ symbol".to_string(),
566 ));
567 }
568
569 let parts: Vec<&str> = email.split('@').collect();
570 if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
571 return Err(OAuth2Error::Validation(
572 "Email format is invalid".to_string(),
573 ));
574 }
575
576 Ok(Email(email))
577 }
578
579 /// Returns the email address as a string slice.
580 ///
581 /// # Returns
582 /// * A string slice containing the email address
583 pub fn as_str(&self) -> &str {
584 &self.0
585 }
586}
587
588/// Type-safe wrapper for OAuth2 state parameters.
589///
590/// This provides compile-time safety to prevent mixing up OAuth2 state strings with other string types.
591/// OAuth2 state parameters are base64url-encoded JSON that carries CSRF protection and flow parameters
592/// between authorization requests and callbacks. Proper validation is critical for security.
593#[derive(Debug, Clone, PartialEq)]
594pub struct OAuth2State(String);
595
596impl std::fmt::Display for OAuth2State {
597 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
598 write!(f, "{}", self.0)
599 }
600}
601
602impl OAuth2State {
603 /// Creates a new OAuth2State from a string with validation.
604 ///
605 /// This constructor validates the OAuth2 state format to ensure it meets
606 /// security requirements for CSRF protection and parameter integrity.
607 ///
608 /// # Arguments
609 /// * `state` - The OAuth2 state string (should be base64url-encoded)
610 ///
611 /// # Returns
612 /// * `Ok(OAuth2State)` - If the state is valid
613 /// * `Err(OAuth2Error)` - If the state is invalid
614 ///
615 /// # Validation Rules
616 /// * Must not be empty
617 /// * Must be valid base64url encoding
618 /// * Must contain valid JSON when decoded
619 /// * Must be reasonable length
620 pub fn new(state: String) -> Result<Self, super::errors::OAuth2Error> {
621 use super::errors::OAuth2Error;
622 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
623
624 // Validate state is not empty
625 if state.is_empty() {
626 return Err(OAuth2Error::DecodeState(
627 "OAuth2 state cannot be empty".to_string(),
628 ));
629 }
630
631 // Validate state length (reasonable bounds)
632 if state.len() < 10 {
633 return Err(OAuth2Error::DecodeState(
634 "OAuth2 state too short".to_string(),
635 ));
636 }
637
638 if state.len() > 8192 {
639 return Err(OAuth2Error::DecodeState(
640 "OAuth2 state too long".to_string(),
641 ));
642 }
643
644 // Validate state is valid base64url
645 let decoded_bytes = URL_SAFE_NO_PAD
646 .decode(&state)
647 .map_err(|e| OAuth2Error::DecodeState(format!("Invalid base64url encoding: {e}")))?;
648
649 // Validate decoded content is valid UTF-8
650 let decoded_string = String::from_utf8(decoded_bytes).map_err(|e| {
651 OAuth2Error::DecodeState(format!("Invalid UTF-8 in decoded state: {e}"))
652 })?;
653
654 // Validate decoded content is valid JSON
655 let _: StateParams = serde_json::from_str(&decoded_string)
656 .map_err(|e| OAuth2Error::DecodeState(format!("Invalid JSON in decoded state: {e}")))?;
657
658 Ok(OAuth2State(state))
659 }
660
661 /// Returns the OAuth2 state as a string slice.
662 ///
663 /// # Returns
664 /// * A string slice containing the OAuth2 state
665 pub fn as_str(&self) -> &str {
666 &self.0
667 }
668
669 /// Checks if the state contains a substring.
670 ///
671 /// # Arguments
672 /// * `needle` - The substring to search for
673 ///
674 /// # Returns
675 /// * `true` if the substring is found, `false` otherwise
676 pub fn contains(&self, needle: char) -> bool {
677 self.0.contains(needle)
678 }
679}
680
681/// Type-safe wrapper for OAuth2 token types.
682///
683/// This enum provides compile-time safety to prevent mixing up different types of OAuth2 tokens.
684/// It ensures that token types are clearly defined and prevents typos in token type strings.
685#[derive(Debug, Clone, Copy, PartialEq)]
686pub enum TokenType {
687 /// CSRF protection token for OAuth2 authorization flow
688 Csrf,
689 /// Nonce token for OpenID Connect
690 Nonce,
691 /// PKCE (Proof Key for Code Exchange) verifier token
692 Pkce,
693}
694
695impl std::fmt::Display for TokenType {
696 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
697 write!(f, "{}", self.as_str())
698 }
699}
700
701impl TokenType {
702 /// Returns the token type as a string slice.
703 ///
704 /// # Returns
705 /// * A string slice containing the token type name
706 pub fn as_str(&self) -> &str {
707 match self {
708 TokenType::Csrf => "csrf",
709 TokenType::Nonce => "nonce",
710 TokenType::Pkce => "pkce",
711 }
712 }
713}
714
715#[cfg(test)]
716mod tests;