1use std::collections::HashSet;
2
3use time::Duration;
4
5use crate::error::{AuthError, OAuthError};
6
7#[derive(Debug, Clone)]
9pub struct AuthConfig {
10 pub secret: String,
12 pub session_ttl: Duration,
14 pub verification_ttl: Duration,
16 pub reset_ttl: Duration,
18 pub token_length: usize,
20 pub email: EmailConfig,
22 pub cookie: CookieConfig,
24 pub oauth: OAuthConfig,
26}
27
28#[derive(Debug, Clone)]
30pub struct EmailConfig {
31 pub send_verification_on_signup: bool,
33 pub require_verification_to_login: bool,
35 pub auto_sign_in_after_signup: bool,
37 pub auto_sign_in_after_verification: bool,
39}
40
41#[derive(Debug, Clone)]
43pub struct CookieConfig {
44 pub name: String,
46 pub http_only: bool,
48 pub secure: bool,
50 pub same_site: SameSite,
52 pub path: String,
54 pub domain: Option<String>,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum SameSite {
61 Strict,
63 Lax,
65 None,
67}
68
69#[derive(Debug, Clone)]
71pub struct OAuthConfig {
72 pub providers: Vec<OAuthProviderEntry>,
74 pub allow_implicit_account_linking: bool,
76 pub success_redirect: Option<String>,
78 pub error_redirect: Option<String>,
80}
81
82#[derive(Debug, Clone)]
84pub struct OAuthProviderEntry {
85 pub provider_id: String,
87 pub client_id: String,
89 pub client_secret: String,
91 pub redirect_url: String,
93 pub auth_url: Option<String>,
95 pub token_url: Option<String>,
97 pub userinfo_url: Option<String>,
99}
100
101impl Default for OAuthConfig {
102 fn default() -> Self {
103 Self {
104 providers: vec![],
105 allow_implicit_account_linking: true,
106 success_redirect: None,
107 error_redirect: None,
108 }
109 }
110}
111
112impl OAuthConfig {
113 pub fn validate(&self) -> Result<(), AuthError> {
114 let mut seen_provider_ids = HashSet::new();
115
116 for provider in &self.providers {
117 if !seen_provider_ids.insert(provider.provider_id.as_str()) {
118 return Err(AuthError::OAuth(OAuthError::Misconfigured {
119 message: format!("duplicate provider_id: {}", provider.provider_id),
120 }));
121 }
122
123 match provider.provider_id.as_str() {
124 "google" | "github" => {}
125 _ => {
126 return Err(AuthError::OAuth(OAuthError::UnsupportedProvider {
127 provider: provider.provider_id.clone(),
128 }));
129 }
130 }
131
132 if provider.client_id.trim().is_empty() {
133 return Err(AuthError::OAuth(OAuthError::Misconfigured {
134 message: format!("provider {} has empty client_id", provider.provider_id),
135 }));
136 }
137
138 if provider.client_secret.trim().is_empty() {
139 return Err(AuthError::OAuth(OAuthError::Misconfigured {
140 message: format!("provider {} has empty client_secret", provider.provider_id),
141 }));
142 }
143
144 if provider.redirect_url.trim().is_empty() {
145 return Err(AuthError::OAuth(OAuthError::Misconfigured {
146 message: format!("provider {} has empty redirect_url", provider.provider_id),
147 }));
148 }
149
150 validate_url(
151 "redirect_url",
152 &provider.provider_id,
153 &provider.redirect_url,
154 )?;
155
156 if let Some(auth_url) = provider.auth_url.as_deref() {
157 validate_url("auth_url", &provider.provider_id, auth_url)?;
158 }
159
160 if let Some(token_url) = provider.token_url.as_deref() {
161 validate_url("token_url", &provider.provider_id, token_url)?;
162 }
163
164 if let Some(userinfo_url) = provider.userinfo_url.as_deref() {
165 validate_url("userinfo_url", &provider.provider_id, userinfo_url)?;
166 }
167 }
168
169 Ok(())
170 }
171}
172
173fn validate_url(field: &str, provider_id: &str, value: &str) -> Result<(), AuthError> {
174 if value.trim().is_empty() {
175 return Err(AuthError::OAuth(OAuthError::Misconfigured {
176 message: format!("provider {provider_id} has empty {field}"),
177 }));
178 }
179
180 reqwest::Url::parse(value).map_err(|e| {
181 AuthError::OAuth(OAuthError::Misconfigured {
182 message: format!("provider {provider_id} has invalid {field}: {e}"),
183 })
184 })?;
185
186 Ok(())
187}
188
189impl Default for AuthConfig {
190 fn default() -> Self {
191 Self {
192 secret: String::new(),
193 session_ttl: Duration::days(30),
194 verification_ttl: Duration::hours(1),
195 reset_ttl: Duration::hours(1),
196 token_length: 32,
197 email: EmailConfig::default(),
198 cookie: CookieConfig::default(),
199 oauth: OAuthConfig::default(),
200 }
201 }
202}
203
204impl Default for EmailConfig {
205 fn default() -> Self {
206 Self {
207 send_verification_on_signup: true,
208 require_verification_to_login: false,
209 auto_sign_in_after_signup: true,
210 auto_sign_in_after_verification: false,
211 }
212 }
213}
214
215impl Default for CookieConfig {
216 fn default() -> Self {
217 Self {
218 name: "rs_auth_session".to_string(),
219 http_only: true,
220 secure: true,
221 same_site: SameSite::Lax,
222 path: "/".to_string(),
223 domain: None,
224 }
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn default_config_has_sane_values() {
234 let config = AuthConfig::default();
235
236 assert_eq!(
237 config.session_ttl,
238 Duration::days(30),
239 "session_ttl should be 30 days"
240 );
241 assert_eq!(config.token_length, 32, "token_length should be 32");
242 assert_eq!(
243 config.cookie.name, "rs_auth_session",
244 "cookie name should be 'rs_auth_session'"
245 );
246 assert_eq!(
247 config.verification_ttl,
248 Duration::hours(1),
249 "verification_ttl should be 1 hour"
250 );
251 assert_eq!(
252 config.reset_ttl,
253 Duration::hours(1),
254 "reset_ttl should be 1 hour"
255 );
256 assert!(config.cookie.http_only, "cookie should be http_only");
257 assert!(config.cookie.secure, "cookie should be secure");
258 assert_eq!(
259 config.cookie.same_site,
260 SameSite::Lax,
261 "cookie same_site should be Lax"
262 );
263 assert_eq!(config.cookie.path, "/", "cookie path should be '/'");
264 assert_eq!(config.cookie.domain, None, "cookie domain should be None");
265 }
266
267 #[test]
268 fn oauth_config_rejects_duplicate_provider_ids() {
269 let config = OAuthConfig {
270 providers: vec![
271 OAuthProviderEntry {
272 provider_id: "google".to_string(),
273 client_id: "a".to_string(),
274 client_secret: "b".to_string(),
275 redirect_url: "https://example.com/callback/google".to_string(),
276 auth_url: None,
277 token_url: None,
278 userinfo_url: None,
279 },
280 OAuthProviderEntry {
281 provider_id: "google".to_string(),
282 client_id: "c".to_string(),
283 client_secret: "d".to_string(),
284 redirect_url: "https://example.com/callback/google-2".to_string(),
285 auth_url: None,
286 token_url: None,
287 userinfo_url: None,
288 },
289 ],
290 ..Default::default()
291 };
292
293 assert!(config.validate().is_err());
294 }
295
296 #[test]
297 fn oauth_config_rejects_unsupported_provider_ids() {
298 let config = OAuthConfig {
299 providers: vec![OAuthProviderEntry {
300 provider_id: "gitlab".to_string(),
301 client_id: "a".to_string(),
302 client_secret: "b".to_string(),
303 redirect_url: "https://example.com/callback/gitlab".to_string(),
304 auth_url: None,
305 token_url: None,
306 userinfo_url: None,
307 }],
308 ..Default::default()
309 };
310
311 assert!(config.validate().is_err());
312 }
313
314 #[test]
315 fn oauth_config_rejects_invalid_urls() {
316 let config = OAuthConfig {
317 providers: vec![OAuthProviderEntry {
318 provider_id: "google".to_string(),
319 client_id: "a".to_string(),
320 client_secret: "b".to_string(),
321 redirect_url: "not-a-url".to_string(),
322 auth_url: None,
323 token_url: None,
324 userinfo_url: None,
325 }],
326 ..Default::default()
327 };
328
329 assert!(config.validate().is_err());
330 }
331}