Skip to main content

allowthem_core/
handle.rs

1use std::sync::Arc;
2
3use chrono::Duration;
4use sqlx::SqlitePool;
5
6use crate::db::Db;
7use crate::error::AuthError;
8use crate::sessions::{self, SessionConfig};
9use crate::types::SessionToken;
10
11/// Error type for builder construction and validation failures.
12#[derive(Debug, thiserror::Error)]
13pub enum BuildError {
14    /// Database connection or migration failure.
15    #[error("database error: {0}")]
16    Database(#[from] AuthError),
17
18    /// Invalid builder configuration.
19    /// Reserved for future validation; not currently produced.
20    #[error("invalid configuration: {0}")]
21    InvalidConfig(&'static str),
22}
23
24enum PoolSource {
25    Url(String),
26    Pool(SqlitePool),
27}
28
29/// Builder for constructing a configured [`AllowThem`] handle.
30pub struct AllowThemBuilder {
31    pool_source: PoolSource,
32    session_ttl: Option<Duration>,
33    cookie_name: Option<&'static str>,
34    cookie_secure: Option<bool>,
35    cookie_domain: String,
36    mfa_key: Option<[u8; 32]>,
37}
38
39impl AllowThemBuilder {
40    /// Start building from a database URL.
41    ///
42    /// At build time, calls `Db::connect(url)` which creates the pool,
43    /// sets pragmas (foreign_keys, WAL, busy_timeout), and runs migrations.
44    pub fn new(url: impl Into<String>) -> Self {
45        Self {
46            pool_source: PoolSource::Url(url.into()),
47            session_ttl: None,
48            cookie_name: None,
49            cookie_secure: None,
50            cookie_domain: String::new(),
51            mfa_key: None,
52        }
53    }
54
55    /// Start building from an existing pool.
56    ///
57    /// At build time, calls `Db::new(pool)` which runs migrations.
58    /// The caller is responsible for pragma configuration on their pool.
59    pub fn with_pool(pool: SqlitePool) -> Self {
60        Self {
61            pool_source: PoolSource::Pool(pool),
62            session_ttl: None,
63            cookie_name: None,
64            cookie_secure: None,
65            cookie_domain: String::new(),
66            mfa_key: None,
67        }
68    }
69
70    /// Override session TTL. Default: 24 hours.
71    pub fn session_ttl(mut self, ttl: Duration) -> Self {
72        self.session_ttl = Some(ttl);
73        self
74    }
75
76    /// Override session cookie name. Default: `"allowthem_session"`.
77    pub fn cookie_name(mut self, name: &'static str) -> Self {
78        self.cookie_name = Some(name);
79        self
80    }
81
82    /// Set the Secure attribute on session cookies.
83    ///
84    /// Default: `true`. Set to `false` for local development over HTTP.
85    pub fn cookie_secure(mut self, secure: bool) -> Self {
86        self.cookie_secure = Some(secure);
87        self
88    }
89
90    /// Set the Domain attribute on session cookies.
91    ///
92    /// Default: empty (omitted). When set, the cookie is sent to the domain
93    /// and all its subdomains.
94    pub fn cookie_domain(mut self, domain: impl Into<String>) -> Self {
95        self.cookie_domain = domain.into();
96        self
97    }
98
99    /// Set the AES-256-GCM encryption key for MFA secrets.
100    ///
101    /// When not set, all MFA operations return `AuthError::MfaNotConfigured`.
102    /// This keeps MFA opt-in for embedded integrators who don't need it.
103    pub fn mfa_key(mut self, key: [u8; 32]) -> Self {
104        self.mfa_key = Some(key);
105        self
106    }
107
108    /// Construct the [`AllowThem`] handle.
109    ///
110    /// Connects to (or wraps) the database, runs migrations, and assembles
111    /// the session configuration from overrides plus defaults.
112    pub async fn build(self) -> Result<AllowThem, BuildError> {
113        let db = match self.pool_source {
114            PoolSource::Url(url) => Db::connect(&url).await?,
115            PoolSource::Pool(pool) => Db::new(pool).await?,
116        };
117
118        let defaults = SessionConfig::default();
119        let session_config = SessionConfig {
120            ttl: self.session_ttl.unwrap_or(defaults.ttl),
121            cookie_name: self.cookie_name.unwrap_or(defaults.cookie_name),
122            secure: self.cookie_secure.unwrap_or(defaults.secure),
123        };
124
125        Ok(AllowThem {
126            inner: Arc::new(Inner {
127                db,
128                session_config,
129                cookie_domain: self.cookie_domain,
130                mfa_key: self.mfa_key,
131            }),
132        })
133    }
134}
135
136struct Inner {
137    db: Db,
138    session_config: SessionConfig,
139    cookie_domain: String,
140    mfa_key: Option<[u8; 32]>,
141}
142
143/// Configured allowthem handle.
144///
145/// Bundles a `Db`, `SessionConfig`, and cookie domain into a single value
146/// that is cheaply cloneable and safe to share across Axum handlers via
147/// `State<AllowThem>` or `Extension<AllowThem>`.
148#[derive(Clone)]
149pub struct AllowThem {
150    inner: Arc<Inner>,
151}
152
153impl AllowThem {
154    /// Access the underlying database handle.
155    ///
156    /// Escape hatch for callers who need direct `Db` access for operations
157    /// not yet wrapped by `AllowThem` methods (e.g., user CRUD, role management).
158    pub fn db(&self) -> &Db {
159        &self.inner.db
160    }
161
162    /// Access the session configuration.
163    pub fn session_config(&self) -> &SessionConfig {
164        &self.inner.session_config
165    }
166
167    /// Build a `Set-Cookie` header value for the given session token.
168    ///
169    /// Uses the stored `SessionConfig` and cookie domain. Delegates to
170    /// `sessions::session_cookie()`.
171    pub fn session_cookie(&self, token: &SessionToken) -> String {
172        sessions::session_cookie(token, &self.inner.session_config, &self.inner.cookie_domain)
173    }
174
175    /// Returns the MFA encryption key, or `Err(MfaNotConfigured)` if not set.
176    pub(crate) fn mfa_key(&self) -> Result<&[u8; 32], AuthError> {
177        self.inner
178            .mfa_key
179            .as_ref()
180            .ok_or(AuthError::MfaNotConfigured)
181    }
182
183    /// Extract the session token from a `Cookie` header value.
184    ///
185    /// Uses the stored cookie name. Delegates to `sessions::parse_session_cookie()`.
186    pub fn parse_session_cookie(&self, cookie_header: &str) -> Option<SessionToken> {
187        sessions::parse_session_cookie(cookie_header, self.inner.session_config.cookie_name)
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use sqlx::sqlite::SqliteConnectOptions;
194    use std::str::FromStr;
195
196    use super::*;
197    use crate::sessions::generate_token;
198    use crate::types::Email;
199
200    #[tokio::test]
201    async fn build_with_url_defaults() {
202        let ath = AllowThemBuilder::new("sqlite::memory:")
203            .build()
204            .await
205            .unwrap();
206
207        let config = ath.session_config();
208        assert_eq!(config.ttl, Duration::hours(24));
209        assert_eq!(config.cookie_name, "allowthem_session");
210        assert!(config.secure);
211
212        let token = generate_token();
213        let cookie = ath.session_cookie(&token);
214        assert!(!cookie.contains("; Domain="));
215    }
216
217    #[tokio::test]
218    async fn build_with_pool() {
219        let opts = SqliteConnectOptions::from_str("sqlite::memory:")
220            .unwrap()
221            .pragma("foreign_keys", "ON");
222        let pool = sqlx::SqlitePool::connect_with(opts).await.unwrap();
223
224        let ath = AllowThemBuilder::with_pool(pool).build().await.unwrap();
225
226        let email = Email::new("test@example.com".into()).unwrap();
227        let user = ath.db().create_user(email, "password123", None).await;
228        assert!(user.is_ok());
229    }
230
231    #[tokio::test]
232    async fn build_with_overrides() {
233        let ath = AllowThemBuilder::new("sqlite::memory:")
234            .session_ttl(Duration::hours(48))
235            .cookie_name("my_session")
236            .cookie_secure(false)
237            .cookie_domain("example.com")
238            .build()
239            .await
240            .unwrap();
241
242        let config = ath.session_config();
243        assert_eq!(config.ttl, Duration::hours(48));
244        assert_eq!(config.cookie_name, "my_session");
245        assert!(!config.secure);
246    }
247
248    #[tokio::test]
249    async fn session_cookie_uses_config() {
250        let ath = AllowThemBuilder::new("sqlite::memory:")
251            .cookie_name("custom")
252            .cookie_secure(false)
253            .cookie_domain("example.com")
254            .build()
255            .await
256            .unwrap();
257
258        let token = generate_token();
259        let cookie = ath.session_cookie(&token);
260
261        assert!(cookie.contains("custom="));
262        assert!(cookie.contains("; Domain=example.com"));
263        assert!(!cookie.contains("; Secure"));
264    }
265
266    #[tokio::test]
267    async fn parse_session_cookie_uses_config() {
268        let ath = AllowThemBuilder::new("sqlite::memory:")
269            .cookie_name("custom")
270            .build()
271            .await
272            .unwrap();
273
274        let header = "custom=abc123; other=xyz";
275        let result = ath.parse_session_cookie(header);
276
277        assert!(result.is_some());
278        assert_eq!(result.unwrap().as_str(), "abc123");
279    }
280
281    #[tokio::test]
282    async fn build_with_bad_url_fails() {
283        let result = AllowThemBuilder::new("not-a-url").build().await;
284
285        assert!(result.is_err());
286        assert!(matches!(result.err().unwrap(), BuildError::Database(_)));
287    }
288
289    #[tokio::test]
290    async fn clone_shares_state() {
291        let ath = AllowThemBuilder::new("sqlite::memory:")
292            .build()
293            .await
294            .unwrap();
295        let ath2 = ath.clone();
296
297        let email = Email::new("shared@example.com".into()).unwrap();
298        let user = ath
299            .db()
300            .create_user(email, "password123", None)
301            .await
302            .unwrap();
303
304        let found = ath2.db().get_user(user.id).await;
305        assert!(found.is_ok());
306        assert_eq!(found.unwrap().id, user.id);
307    }
308}