1use std::sync::Arc;
2
3use chrono::Duration;
4use sqlx::SqlitePool;
5
6use crate::db::Db;
7use crate::error::AuthError;
8use crate::sessions::{self, SessionConfig};
9use crate::types::SessionToken;
10
11#[derive(Debug, thiserror::Error)]
13pub enum BuildError {
14 #[error("database error: {0}")]
16 Database(#[from] AuthError),
17
18 #[error("invalid configuration: {0}")]
21 InvalidConfig(&'static str),
22}
23
24enum PoolSource {
25 Url(String),
26 Pool(SqlitePool),
27}
28
29pub struct AllowThemBuilder {
31 pool_source: PoolSource,
32 session_ttl: Option<Duration>,
33 cookie_name: Option<&'static str>,
34 cookie_secure: Option<bool>,
35 cookie_domain: String,
36 mfa_key: Option<[u8; 32]>,
37}
38
39impl AllowThemBuilder {
40 pub fn new(url: impl Into<String>) -> Self {
45 Self {
46 pool_source: PoolSource::Url(url.into()),
47 session_ttl: None,
48 cookie_name: None,
49 cookie_secure: None,
50 cookie_domain: String::new(),
51 mfa_key: None,
52 }
53 }
54
55 pub fn with_pool(pool: SqlitePool) -> Self {
60 Self {
61 pool_source: PoolSource::Pool(pool),
62 session_ttl: None,
63 cookie_name: None,
64 cookie_secure: None,
65 cookie_domain: String::new(),
66 mfa_key: None,
67 }
68 }
69
70 pub fn session_ttl(mut self, ttl: Duration) -> Self {
72 self.session_ttl = Some(ttl);
73 self
74 }
75
76 pub fn cookie_name(mut self, name: &'static str) -> Self {
78 self.cookie_name = Some(name);
79 self
80 }
81
82 pub fn cookie_secure(mut self, secure: bool) -> Self {
86 self.cookie_secure = Some(secure);
87 self
88 }
89
90 pub fn cookie_domain(mut self, domain: impl Into<String>) -> Self {
95 self.cookie_domain = domain.into();
96 self
97 }
98
99 pub fn mfa_key(mut self, key: [u8; 32]) -> Self {
104 self.mfa_key = Some(key);
105 self
106 }
107
108 pub async fn build(self) -> Result<AllowThem, BuildError> {
113 let db = match self.pool_source {
114 PoolSource::Url(url) => Db::connect(&url).await?,
115 PoolSource::Pool(pool) => Db::new(pool).await?,
116 };
117
118 let defaults = SessionConfig::default();
119 let session_config = SessionConfig {
120 ttl: self.session_ttl.unwrap_or(defaults.ttl),
121 cookie_name: self.cookie_name.unwrap_or(defaults.cookie_name),
122 secure: self.cookie_secure.unwrap_or(defaults.secure),
123 };
124
125 Ok(AllowThem {
126 inner: Arc::new(Inner {
127 db,
128 session_config,
129 cookie_domain: self.cookie_domain,
130 mfa_key: self.mfa_key,
131 }),
132 })
133 }
134}
135
136struct Inner {
137 db: Db,
138 session_config: SessionConfig,
139 cookie_domain: String,
140 mfa_key: Option<[u8; 32]>,
141}
142
143#[derive(Clone)]
149pub struct AllowThem {
150 inner: Arc<Inner>,
151}
152
153impl AllowThem {
154 pub fn db(&self) -> &Db {
159 &self.inner.db
160 }
161
162 pub fn session_config(&self) -> &SessionConfig {
164 &self.inner.session_config
165 }
166
167 pub fn session_cookie(&self, token: &SessionToken) -> String {
172 sessions::session_cookie(token, &self.inner.session_config, &self.inner.cookie_domain)
173 }
174
175 pub(crate) fn mfa_key(&self) -> Result<&[u8; 32], AuthError> {
177 self.inner
178 .mfa_key
179 .as_ref()
180 .ok_or(AuthError::MfaNotConfigured)
181 }
182
183 pub fn parse_session_cookie(&self, cookie_header: &str) -> Option<SessionToken> {
187 sessions::parse_session_cookie(cookie_header, self.inner.session_config.cookie_name)
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use sqlx::sqlite::SqliteConnectOptions;
194 use std::str::FromStr;
195
196 use super::*;
197 use crate::sessions::generate_token;
198 use crate::types::Email;
199
200 #[tokio::test]
201 async fn build_with_url_defaults() {
202 let ath = AllowThemBuilder::new("sqlite::memory:")
203 .build()
204 .await
205 .unwrap();
206
207 let config = ath.session_config();
208 assert_eq!(config.ttl, Duration::hours(24));
209 assert_eq!(config.cookie_name, "allowthem_session");
210 assert!(config.secure);
211
212 let token = generate_token();
213 let cookie = ath.session_cookie(&token);
214 assert!(!cookie.contains("; Domain="));
215 }
216
217 #[tokio::test]
218 async fn build_with_pool() {
219 let opts = SqliteConnectOptions::from_str("sqlite::memory:")
220 .unwrap()
221 .pragma("foreign_keys", "ON");
222 let pool = sqlx::SqlitePool::connect_with(opts).await.unwrap();
223
224 let ath = AllowThemBuilder::with_pool(pool).build().await.unwrap();
225
226 let email = Email::new("test@example.com".into()).unwrap();
227 let user = ath.db().create_user(email, "password123", None).await;
228 assert!(user.is_ok());
229 }
230
231 #[tokio::test]
232 async fn build_with_overrides() {
233 let ath = AllowThemBuilder::new("sqlite::memory:")
234 .session_ttl(Duration::hours(48))
235 .cookie_name("my_session")
236 .cookie_secure(false)
237 .cookie_domain("example.com")
238 .build()
239 .await
240 .unwrap();
241
242 let config = ath.session_config();
243 assert_eq!(config.ttl, Duration::hours(48));
244 assert_eq!(config.cookie_name, "my_session");
245 assert!(!config.secure);
246 }
247
248 #[tokio::test]
249 async fn session_cookie_uses_config() {
250 let ath = AllowThemBuilder::new("sqlite::memory:")
251 .cookie_name("custom")
252 .cookie_secure(false)
253 .cookie_domain("example.com")
254 .build()
255 .await
256 .unwrap();
257
258 let token = generate_token();
259 let cookie = ath.session_cookie(&token);
260
261 assert!(cookie.contains("custom="));
262 assert!(cookie.contains("; Domain=example.com"));
263 assert!(!cookie.contains("; Secure"));
264 }
265
266 #[tokio::test]
267 async fn parse_session_cookie_uses_config() {
268 let ath = AllowThemBuilder::new("sqlite::memory:")
269 .cookie_name("custom")
270 .build()
271 .await
272 .unwrap();
273
274 let header = "custom=abc123; other=xyz";
275 let result = ath.parse_session_cookie(header);
276
277 assert!(result.is_some());
278 assert_eq!(result.unwrap().as_str(), "abc123");
279 }
280
281 #[tokio::test]
282 async fn build_with_bad_url_fails() {
283 let result = AllowThemBuilder::new("not-a-url").build().await;
284
285 assert!(result.is_err());
286 assert!(matches!(result.err().unwrap(), BuildError::Database(_)));
287 }
288
289 #[tokio::test]
290 async fn clone_shares_state() {
291 let ath = AllowThemBuilder::new("sqlite::memory:")
292 .build()
293 .await
294 .unwrap();
295 let ath2 = ath.clone();
296
297 let email = Email::new("shared@example.com".into()).unwrap();
298 let user = ath
299 .db()
300 .create_user(email, "password123", None)
301 .await
302 .unwrap();
303
304 let found = ath2.db().get_user(user.id).await;
305 assert!(found.is_ok());
306 assert_eq!(found.unwrap().id, user.id);
307 }
308}