1use oauth2::basic::BasicClient;
7use oauth2::{EndpointNotSet, EndpointSet};
8use serde::{Deserialize, Serialize};
9use std::str::FromStr;
10use std::time::{Duration, SystemTime};
11
12pub type ConfiguredClient = BasicClient<
22 EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet, >;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
31#[serde(rename_all = "lowercase")]
32pub enum OAuthProvider {
33 Google,
35 GitHub,
37 Oidc,
39}
40
41impl OAuthProvider {
42 #[must_use]
44 pub const fn as_str(&self) -> &'static str {
45 match self {
46 Self::Google => "google",
47 Self::GitHub => "github",
48 Self::Oidc => "oidc",
49 }
50 }
51}
52
53impl FromStr for OAuthProvider {
54 type Err = OAuthError;
55
56 fn from_str(s: &str) -> Result<Self, Self::Err> {
57 match s.to_lowercase().as_str() {
58 "google" => Ok(Self::Google),
59 "github" => Ok(Self::GitHub),
60 "oidc" => Ok(Self::Oidc),
61 _ => Err(OAuthError::UnknownProvider(s.to_string())),
62 }
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ProviderConfig {
69 pub client_id: String,
71 pub client_secret: String,
73 pub redirect_uri: String,
75 pub scopes: Vec<String>,
77 #[serde(skip_serializing_if = "Option::is_none")]
79 pub auth_url: Option<String>,
80 #[serde(skip_serializing_if = "Option::is_none")]
82 pub token_url: Option<String>,
83 #[serde(skip_serializing_if = "Option::is_none")]
85 pub userinfo_url: Option<String>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct OAuthConfig {
91 #[serde(skip_serializing_if = "Option::is_none")]
93 pub google: Option<ProviderConfig>,
94 #[serde(skip_serializing_if = "Option::is_none")]
96 pub github: Option<ProviderConfig>,
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub oidc: Option<ProviderConfig>,
100}
101
102impl OAuthConfig {
103 #[must_use]
105 pub const fn new() -> Self {
106 Self {
107 google: None,
108 github: None,
109 oidc: None,
110 }
111 }
112
113 const fn provider_config(&self, provider: OAuthProvider) -> Option<&ProviderConfig> {
118 match provider {
119 OAuthProvider::Google => self.google.as_ref(),
120 OAuthProvider::GitHub => self.github.as_ref(),
121 OAuthProvider::Oidc => self.oidc.as_ref(),
122 }
123 }
124
125 pub fn get_provider(&self, provider: OAuthProvider) -> Result<&ProviderConfig, OAuthError> {
131 self.provider_config(provider)
132 .ok_or(OAuthError::ProviderNotConfigured(provider))
133 }
134
135 #[must_use]
137 pub const fn is_provider_configured(&self, provider: OAuthProvider) -> bool {
138 self.provider_config(provider).is_some()
139 }
140}
141
142impl Default for OAuthConfig {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct OAuthState {
151 pub token: String,
153 pub provider: OAuthProvider,
155 pub expires_at: SystemTime,
157}
158
159impl OAuthState {
160 #[must_use]
162 pub fn generate(provider: OAuthProvider) -> Self {
163 use rand::Rng;
164
165 let random_bytes: [u8; 32] = rand::rng().random();
167 let token = hex::encode(random_bytes);
168
169 Self {
170 token,
171 provider,
172 expires_at: SystemTime::now() + Duration::from_secs(600), }
174 }
175
176 #[must_use]
178 pub fn is_expired(&self) -> bool {
179 SystemTime::now() > self.expires_at
180 }
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct OAuthToken {
186 pub access_token: String,
188 #[serde(skip_serializing_if = "Option::is_none")]
190 pub refresh_token: Option<String>,
191 pub token_type: String,
193 #[serde(skip_serializing_if = "Option::is_none")]
195 pub expires_at: Option<SystemTime>,
196 #[serde(skip_serializing_if = "Option::is_none")]
198 pub scopes: Option<Vec<String>>,
199}
200
201impl OAuthToken {
202 #[must_use]
204 pub fn is_expired(&self) -> bool {
205 self.expires_at
206 .is_some_and(|expires| SystemTime::now() > expires)
207 }
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct OAuthUserInfo {
213 pub provider_user_id: String,
215 pub email: String,
217 #[serde(skip_serializing_if = "Option::is_none")]
219 pub name: Option<String>,
220 #[serde(skip_serializing_if = "Option::is_none")]
222 pub avatar_url: Option<String>,
223 pub email_verified: bool,
225}
226
227#[derive(Debug, thiserror::Error)]
229pub enum OAuthError {
230 #[error("Unknown OAuth2 provider: {0}")]
232 UnknownProvider(String),
233
234 #[error("OAuth2 provider not configured: {0:?}")]
236 ProviderNotConfigured(OAuthProvider),
237
238 #[error("Invalid or expired OAuth2 state token")]
240 InvalidState,
241
242 #[error("OAuth2 state token mismatch (potential CSRF attack)")]
244 StateMismatch,
245
246 #[error("Failed to exchange authorization code for token: {0}")]
248 TokenExchangeFailed(String),
249
250 #[error("Failed to fetch user information: {0}")]
252 UserInfoFailed(String),
253
254 #[error("OAuth2 token has expired")]
256 TokenExpired,
257
258 #[error("OAuth2 error: {0}")]
260 Generic(String),
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_provider_as_str() {
269 assert_eq!(OAuthProvider::Google.as_str(), "google");
270 assert_eq!(OAuthProvider::GitHub.as_str(), "github");
271 assert_eq!(OAuthProvider::Oidc.as_str(), "oidc");
272 }
273
274 #[test]
275 fn test_provider_from_str() {
276 assert_eq!(
277 "google".parse::<OAuthProvider>().unwrap(),
278 OAuthProvider::Google
279 );
280 assert_eq!(
281 "GOOGLE".parse::<OAuthProvider>().unwrap(),
282 OAuthProvider::Google
283 );
284 assert_eq!(
285 "github".parse::<OAuthProvider>().unwrap(),
286 OAuthProvider::GitHub
287 );
288 assert_eq!(
289 "oidc".parse::<OAuthProvider>().unwrap(),
290 OAuthProvider::Oidc
291 );
292 assert!("invalid".parse::<OAuthProvider>().is_err());
293 }
294
295 #[test]
296 fn test_oauth_config_default() {
297 let config = OAuthConfig::default();
298 assert!(config.google.is_none());
299 assert!(config.github.is_none());
300 assert!(config.oidc.is_none());
301 }
302
303 #[test]
304 fn test_oauth_config_is_provider_configured() {
305 let mut config = OAuthConfig::default();
306 assert!(!config.is_provider_configured(OAuthProvider::Google));
307
308 config.google = Some(ProviderConfig {
309 client_id: "test".to_string(),
310 client_secret: "test".to_string(),
311 redirect_uri: "http://localhost/callback".to_string(),
312 scopes: vec!["email".to_string()],
313 auth_url: None,
314 token_url: None,
315 userinfo_url: None,
316 });
317
318 assert!(config.is_provider_configured(OAuthProvider::Google));
319 assert!(!config.is_provider_configured(OAuthProvider::GitHub));
320 }
321
322 #[test]
323 fn test_oauth_state_generation() {
324 let state = OAuthState::generate(OAuthProvider::Google);
325 assert_eq!(state.provider, OAuthProvider::Google);
326 assert!(!state.is_expired());
327 assert_eq!(state.token.len(), 64); }
329
330 #[test]
331 fn test_oauth_token_is_expired() {
332 let token = OAuthToken {
333 access_token: "test".to_string(),
334 refresh_token: None,
335 token_type: "Bearer".to_string(),
336 expires_at: None,
337 scopes: None,
338 };
339 assert!(!token.is_expired());
340
341 let expired_token = OAuthToken {
342 access_token: "test".to_string(),
343 refresh_token: None,
344 token_type: "Bearer".to_string(),
345 expires_at: Some(SystemTime::now() - Duration::from_secs(3600)),
346 scopes: None,
347 };
348 assert!(expired_token.is_expired());
349 }
350}