1pub mod pkce;
7
8use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
9use chrono::{DateTime, Utc};
10use uuid::Uuid;
11
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct OAuthClient {
15 pub client_id: String,
16 pub client_name: Option<String>,
17 pub redirect_uris: Vec<String>,
18 pub token_endpoint_auth_method: String,
19 pub created_at: DateTime<Utc>,
20}
21
22#[derive(Debug, Clone)]
24pub struct AuthorizationCode {
25 pub code: String,
26 pub client_id: String,
27 pub user_id: Uuid,
28 pub redirect_uri: String,
29 pub code_challenge: String,
30 pub code_challenge_method: String,
31 pub scopes: Vec<String>,
32 pub expires_at: DateTime<Utc>,
33}
34
35pub fn generate_random_token() -> String {
37 let a = Uuid::new_v4();
38 let b = Uuid::new_v4();
39 let mut bytes = [0u8; 32];
40 bytes[..16].copy_from_slice(a.as_bytes());
41 bytes[16..].copy_from_slice(b.as_bytes());
42 URL_SAFE_NO_PAD.encode(bytes)
43}
44
45pub fn validate_redirect_uri(requested: &str, registered: &[String]) -> bool {
50 for uri in registered {
51 if requested == uri {
52 return true;
53 }
54 if is_localhost_uri(uri) && is_localhost_uri(requested) {
56 let (scheme_host_a, path_a) = split_localhost_uri(uri);
57 let (scheme_host_b, path_b) = split_localhost_uri(requested);
58 if scheme_host_a == scheme_host_b && path_a == path_b {
59 return true;
60 }
61 }
62 }
63 false
64}
65
66fn is_localhost_uri(uri: &str) -> bool {
67 uri.starts_with("http://localhost") || uri.starts_with("http://127.0.0.1")
68}
69
70fn split_localhost_uri(uri: &str) -> (&str, &str) {
73 let scheme_end = if uri.starts_with("http://localhost") {
74 "http://localhost".len()
75 } else {
76 "http://127.0.0.1".len()
77 };
78 let rest = &uri[scheme_end..];
79 let path_start = rest.find('/').unwrap_or(rest.len());
81 let path = &rest[path_start..];
82 let scheme_host = &uri[..scheme_end];
83 (scheme_host, path)
84}
85
86#[cfg(test)]
87#[allow(clippy::unwrap_used)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn test_generate_random_token_length() {
93 let code = generate_random_token();
94 assert_eq!(code.len(), 43); }
96
97 #[test]
98 fn test_generate_random_token_uniqueness() {
99 let a = generate_random_token();
100 let b = generate_random_token();
101 assert_ne!(a, b);
102 }
103
104 #[test]
105 fn test_validate_redirect_uri_exact_match() {
106 let registered = vec!["https://example.com/callback".to_string()];
107 assert!(validate_redirect_uri(
108 "https://example.com/callback",
109 ®istered
110 ));
111 assert!(!validate_redirect_uri(
112 "https://example.com/other",
113 ®istered
114 ));
115 }
116
117 #[test]
118 fn test_validate_redirect_uri_localhost_port_exception() {
119 let registered = vec!["http://localhost:3000/callback".to_string()];
120 assert!(validate_redirect_uri(
121 "http://localhost:9999/callback",
122 ®istered
123 ));
124 assert!(validate_redirect_uri(
125 "http://localhost:3000/callback",
126 ®istered
127 ));
128 }
129
130 #[test]
131 fn test_validate_redirect_uri_localhost_different_path() {
132 let registered = vec!["http://localhost:3000/callback".to_string()];
133 assert!(!validate_redirect_uri(
134 "http://localhost:9999/other",
135 ®istered
136 ));
137 }
138
139 #[test]
140 fn test_validate_redirect_uri_no_localhost_exception_for_https() {
141 let registered = vec!["https://localhost:3000/callback".to_string()];
142 assert!(!validate_redirect_uri(
143 "https://localhost:9999/callback",
144 ®istered
145 ));
146 }
147
148 #[test]
149 fn test_validate_redirect_uri_127_0_0_1() {
150 let registered = vec!["http://127.0.0.1:3000/callback".to_string()];
151 assert!(validate_redirect_uri(
152 "http://127.0.0.1:9999/callback",
153 ®istered
154 ));
155 }
156}