1use std::sync::Arc;
20use std::time::{Duration, SystemTime, UNIX_EPOCH};
21
22use cookie::{Cookie, SameSite, time::Duration as CookieDuration};
23use rand::RngCore;
24use url::Url;
25
26use crate::error::Result;
27use crate::store::{Session, SessionStore};
28
29pub const SESSION_COOKIE: &str = "assay_session";
32
33pub const CSRF_COOKIE: &str = "assay_csrf";
36
37pub const DEFAULT_SESSION_DURATION: Duration = Duration::from_secs(60 * 60 * 24 * 30);
41
42#[derive(Clone)]
46pub struct SessionManager {
47 store: Arc<dyn SessionStore>,
48 default_duration: Duration,
49}
50
51impl SessionManager {
52 pub fn new(store: Arc<dyn SessionStore>, default_duration: Duration) -> Self {
56 Self {
57 store,
58 default_duration,
59 }
60 }
61
62 pub fn with_default_duration(store: Arc<dyn SessionStore>) -> Self {
64 Self::new(store, DEFAULT_SESSION_DURATION)
65 }
66
67 pub async fn create(&self, user_id: &str) -> Result<Session> {
72 let id = format!("sess_{}", random_token());
73 let csrf_token = format!("csrf_{}", random_token());
74 let created_at = now_secs();
75 let expires_at = created_at + self.default_duration.as_secs_f64();
76 let session = Session {
77 id,
78 user_id: user_id.to_string(),
79 csrf_token,
80 created_at,
81 expires_at,
82 ip_hash: None,
83 user_agent_hash: None,
84 };
85 self.store.create(&session).await?;
86 Ok(session)
87 }
88
89 pub async fn resolve(&self, id: &str) -> Result<Option<Session>> {
94 let Some(session) = self.store.get(id).await? else {
95 return Ok(None);
96 };
97 if session.expires_at <= now_secs() {
98 return Ok(None);
99 }
100 Ok(Some(session))
101 }
102
103 pub async fn rotate(&self, old_id: &str) -> Result<Option<Session>> {
109 let Some(old) = self.store.get(old_id).await? else {
110 return Ok(None);
111 };
112 self.store.delete(old_id).await?;
116 let new_id = format!("sess_{}", random_token());
117 let csrf_token = format!("csrf_{}", random_token());
118 let session = Session {
119 id: new_id,
120 user_id: old.user_id,
121 csrf_token,
122 created_at: now_secs(),
123 expires_at: old.expires_at,
124 ip_hash: old.ip_hash,
125 user_agent_hash: old.user_agent_hash,
126 };
127 self.store.create(&session).await?;
128 Ok(Some(session))
129 }
130
131 pub async fn revoke(&self, id: &str) -> Result<bool> {
133 Ok(self.store.delete(id).await?)
134 }
135
136 pub async fn revoke_for_user(&self, user_id: &str) -> Result<u64> {
139 Ok(self.store.delete_for_user(user_id).await?)
140 }
141
142 pub fn store(&self) -> &Arc<dyn SessionStore> {
146 &self.store
147 }
148}
149
150pub fn cookie_for(session: &Session, public_url: &Url) -> Cookie<'static> {
158 let max_age = max_age_for(session);
159 let secure = is_secure(public_url);
160 Cookie::build((SESSION_COOKIE, session.id.clone()))
161 .path("/")
162 .secure(secure)
163 .http_only(true)
164 .same_site(SameSite::Lax)
165 .max_age(max_age)
166 .build()
167 .into_owned()
168}
169
170pub fn csrf_cookie_for(session: &Session) -> Cookie<'static> {
174 let max_age = max_age_for(session);
175 Cookie::build((CSRF_COOKIE, session.csrf_token.clone()))
176 .path("/")
177 .secure(true)
178 .http_only(false)
179 .same_site(SameSite::Lax)
180 .max_age(max_age)
181 .build()
182 .into_owned()
183}
184
185fn max_age_for(session: &Session) -> CookieDuration {
190 let secs = (session.expires_at - now_secs()).max(0.0) as i64;
191 CookieDuration::seconds(secs)
192}
193
194fn is_secure(public_url: &Url) -> bool {
200 public_url.scheme().eq_ignore_ascii_case("https")
201}
202
203fn now_secs() -> f64 {
204 SystemTime::now()
205 .duration_since(UNIX_EPOCH)
206 .unwrap_or_default()
207 .as_secs_f64()
208}
209
210fn random_token() -> String {
211 let mut buf = [0u8; 32];
212 rand::rng().fill_bytes(&mut buf);
213 data_encoding::BASE64URL_NOPAD.encode(&buf)
214}
215
216use axum::extract::{FromRef, State};
219use axum::http::{HeaderMap, StatusCode, header};
220use axum::response::{IntoResponse, Json, Response};
221use axum::routing::{delete, get, post};
222use axum::Router;
223use serde::Deserialize;
224use serde_json::json;
225
226use crate::ctx::AuthCtx;
227
228pub fn router<S>() -> Router<S>
231where
232 S: Clone + Send + Sync + 'static,
233 AuthCtx: FromRef<S>,
234{
235 Router::new()
236 .route("/login", post(login_post))
237 .route("/session", delete(logout_delete))
238 .route("/whoami", get(whoami_get))
239 .route("/passkey/register/start", post(passkey_register_start))
240 .route("/passkey/register/finish", post(passkey_register_finish))
241 .route("/passkey/auth/start", post(passkey_auth_start))
242 .route("/passkey/auth/finish", post(passkey_auth_finish))
243}
244
245#[derive(Deserialize)]
246struct LoginBody {
247 email: String,
248 password: String,
249}
250
251async fn login_post(State(ctx): State<AuthCtx>, Json(body): Json<LoginBody>) -> Response {
252 let user = match ctx.users.get_user_by_email(&body.email).await {
253 Ok(Some(u)) => u,
254 _ => return unauthorized("invalid credentials"),
255 };
256 let stored = match ctx.users.get_password_hash(&user.id).await {
257 Ok(Some(h)) => h,
258 _ => return unauthorized("invalid credentials"),
259 };
260 let hasher = crate::password::PasswordHasher::default();
261 let ok = match hasher.verify(&body.password, &stored) {
262 Ok(b) => b,
263 Err(_) => return unauthorized("invalid credentials"),
264 };
265 if !ok {
266 return unauthorized("invalid credentials");
267 }
268 let mgr = SessionManager::with_default_duration(ctx.sessions.clone());
269 let session = match mgr.create(&user.id).await {
270 Ok(s) => s,
271 Err(e) => return server_error(&format!("create session: {e}")),
272 };
273 let public_url = ctx
274 .oidc_provider
275 .as_ref()
276 .map(|p| p.public_url.clone())
277 .unwrap_or_else(|| url::Url::parse("http://localhost").unwrap());
278 let cookie = cookie_for(&session, &public_url);
279 let csrf = csrf_cookie_for(&session);
280 let mut response = (
281 StatusCode::OK,
282 Json(json!({
283 "user_id": user.id,
284 "email": user.email,
285 "csrf_token": session.csrf_token,
286 })),
287 )
288 .into_response();
289 if let Ok(value) = cookie.to_string().parse() {
290 response.headers_mut().append(header::SET_COOKIE, value);
291 }
292 if let Ok(value) = csrf.to_string().parse() {
293 response.headers_mut().append(header::SET_COOKIE, value);
294 }
295 response
296}
297
298async fn logout_delete(State(ctx): State<AuthCtx>, headers: HeaderMap) -> Response {
299 if let Some(sid) = parse_cookie(&headers, SESSION_COOKIE) {
300 let _ = ctx.sessions.delete(&sid).await;
301 }
302 let mut response = (StatusCode::NO_CONTENT, "").into_response();
303 let clear = format!(
304 "{}=; Path=/; HttpOnly; SameSite=Lax; Max-Age=0",
305 SESSION_COOKIE
306 );
307 if let Ok(v) = clear.parse() {
308 response.headers_mut().append(header::SET_COOKIE, v);
309 }
310 response
311}
312
313async fn whoami_get(State(ctx): State<AuthCtx>, headers: HeaderMap) -> Response {
314 let sid = match parse_cookie(&headers, SESSION_COOKIE) {
315 Some(s) => s,
316 None => return unauthorized("no session"),
317 };
318 let mgr = SessionManager::with_default_duration(ctx.sessions.clone());
319 let session = match mgr.resolve(&sid).await {
320 Ok(Some(s)) => s,
321 _ => return unauthorized("session unknown"),
322 };
323 let user = match ctx.users.get_user_by_id(&session.user_id).await {
324 Ok(Some(u)) => u,
325 _ => return unauthorized("user unknown"),
326 };
327 (
328 StatusCode::OK,
329 Json(json!({
330 "user_id": user.id,
331 "email": user.email,
332 "email_verified": user.email_verified,
333 "display_name": user.display_name,
334 })),
335 )
336 .into_response()
337}
338
339#[derive(Deserialize)]
340struct PasskeyRegisterStartBody {
341 user_id: String,
342 user_name: String,
343 display_name: String,
344}
345
346async fn passkey_register_start(
347 State(ctx): State<AuthCtx>,
348 Json(body): Json<PasskeyRegisterStartBody>,
349) -> Response {
350 let Some(mgr) = ctx.passkeys.as_ref() else {
351 return svc_unavailable("passkey manager not configured");
352 };
353 let uuid = uuid::Uuid::new_v5(&uuid::Uuid::NAMESPACE_OID, body.user_id.as_bytes());
354 match mgr
355 .start_registration(uuid, &body.user_name, &body.display_name, Some(&body.user_id))
356 .await
357 {
358 Ok((challenge, state)) => {
359 let state_blob = serde_json::to_string(&state).unwrap_or_default();
360 (
361 StatusCode::OK,
362 Json(json!({
363 "challenge": challenge,
364 "state": state_blob,
365 })),
366 )
367 .into_response()
368 }
369 Err(e) => bad_request(&format!("start_registration: {e}")),
370 }
371}
372
373#[derive(Deserialize)]
374struct PasskeyRegisterFinishBody {
375 user_id: String,
376 state: String,
377 response: serde_json::Value,
378}
379
380async fn passkey_register_finish(
381 State(ctx): State<AuthCtx>,
382 Json(body): Json<PasskeyRegisterFinishBody>,
383) -> Response {
384 let Some(mgr) = ctx.passkeys.as_ref() else {
385 return svc_unavailable("passkey manager not configured");
386 };
387 let state: webauthn_rs::prelude::PasskeyRegistration =
388 match serde_json::from_str(&body.state) {
389 Ok(s) => s,
390 Err(e) => return bad_request(&format!("decode state: {e}")),
391 };
392 let response: webauthn_rs::prelude::RegisterPublicKeyCredential =
393 match serde_json::from_value(body.response) {
394 Ok(r) => r,
395 Err(e) => return bad_request(&format!("decode response: {e}")),
396 };
397 let passkey = match mgr.finish_registration(&state, &response) {
398 Ok(p) => p,
399 Err(e) => return bad_request(&format!("finish_registration: {e}")),
400 };
401 let cred = crate::passkey::passkey_to_cred(&passkey, now_secs());
402 if let Err(e) = ctx.users.add_passkey(&body.user_id, &cred).await {
403 return server_error(&format!("persist passkey: {e}"));
404 }
405 (
406 StatusCode::OK,
407 Json(json!({"credential_id": data_encoding::BASE64URL_NOPAD.encode(&cred.credential_id)})),
408 )
409 .into_response()
410}
411
412#[derive(Deserialize)]
413struct PasskeyAuthStartBody {
414 user_id: String,
415 #[serde(default)]
419 passkeys: Vec<serde_json::Value>,
420}
421
422async fn passkey_auth_start(
423 State(ctx): State<AuthCtx>,
424 Json(body): Json<PasskeyAuthStartBody>,
425) -> Response {
426 let Some(mgr) = ctx.passkeys.as_ref() else {
427 return svc_unavailable("passkey manager not configured");
428 };
429 let _ = body.user_id;
430 let mut creds: Vec<webauthn_rs::prelude::Passkey> = Vec::with_capacity(body.passkeys.len());
431 for v in body.passkeys {
432 match serde_json::from_value(v) {
433 Ok(p) => creds.push(p),
434 Err(e) => return bad_request(&format!("decode passkey: {e}")),
435 }
436 }
437 if creds.is_empty() {
438 return bad_request("passkeys list is empty");
439 }
440 match mgr.start_authentication_with(&creds) {
441 Ok((challenge, state)) => (
442 StatusCode::OK,
443 Json(json!({
444 "challenge": challenge,
445 "state": serde_json::to_string(&state).unwrap_or_default(),
446 })),
447 )
448 .into_response(),
449 Err(e) => bad_request(&format!("start_authentication: {e}")),
450 }
451}
452
453#[derive(Deserialize)]
454struct PasskeyAuthFinishBody {
455 state: String,
456 response: serde_json::Value,
457}
458
459async fn passkey_auth_finish(
460 State(ctx): State<AuthCtx>,
461 Json(body): Json<PasskeyAuthFinishBody>,
462) -> Response {
463 let Some(mgr) = ctx.passkeys.as_ref() else {
464 return svc_unavailable("passkey manager not configured");
465 };
466 let state: webauthn_rs::prelude::PasskeyAuthentication =
467 match serde_json::from_str(&body.state) {
468 Ok(s) => s,
469 Err(e) => return bad_request(&format!("decode state: {e}")),
470 };
471 let response: webauthn_rs::prelude::PublicKeyCredential =
472 match serde_json::from_value(body.response) {
473 Ok(r) => r,
474 Err(e) => return bad_request(&format!("decode response: {e}")),
475 };
476 match mgr.finish_authentication(&state, &response) {
477 Ok(result) => (
478 StatusCode::OK,
479 Json(json!({
480 "credential_id": data_encoding::BASE64URL_NOPAD.encode(&result.credential_id),
481 "sign_count": result.sign_count,
482 "user_verified": result.user_verified,
483 })),
484 )
485 .into_response(),
486 Err(e) => unauthorized(&format!("finish_authentication: {e}")),
487 }
488}
489
490fn parse_cookie(headers: &HeaderMap, name: &str) -> Option<String> {
491 let raw = headers.get(header::COOKIE)?.to_str().ok()?;
492 for kv in raw.split(';') {
493 let kv = kv.trim();
494 if let Some((k, v)) = kv.split_once('=')
495 && k == name {
496 return Some(v.to_string());
497 }
498 }
499 None
500}
501
502fn unauthorized(msg: &str) -> Response {
503 (StatusCode::UNAUTHORIZED, Json(json!({"error": msg}))).into_response()
504}
505fn bad_request(msg: &str) -> Response {
506 (StatusCode::BAD_REQUEST, Json(json!({"error": msg}))).into_response()
507}
508fn server_error(msg: &str) -> Response {
509 (
510 StatusCode::INTERNAL_SERVER_ERROR,
511 Json(json!({"error": msg})),
512 )
513 .into_response()
514}
515fn svc_unavailable(msg: &str) -> Response {
516 (
517 StatusCode::SERVICE_UNAVAILABLE,
518 Json(json!({"error": msg})),
519 )
520 .into_response()
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526
527 use std::collections::HashMap;
528 use std::sync::Mutex;
529
530 struct MemSessionStore(Mutex<HashMap<String, Session>>);
533
534 impl MemSessionStore {
535 fn new() -> Self {
536 Self(Mutex::new(HashMap::new()))
537 }
538 }
539
540 #[async_trait::async_trait]
541 impl SessionStore for MemSessionStore {
542 async fn create(&self, session: &Session) -> anyhow::Result<()> {
543 self.0
544 .lock()
545 .unwrap()
546 .insert(session.id.clone(), session.clone());
547 Ok(())
548 }
549
550 async fn get(&self, id: &str) -> anyhow::Result<Option<Session>> {
551 Ok(self.0.lock().unwrap().get(id).cloned())
552 }
553
554 async fn delete(&self, id: &str) -> anyhow::Result<bool> {
555 Ok(self.0.lock().unwrap().remove(id).is_some())
556 }
557
558 async fn list_for_user(&self, user_id: &str) -> anyhow::Result<Vec<Session>> {
559 Ok(self
560 .0
561 .lock()
562 .unwrap()
563 .values()
564 .filter(|s| s.user_id == user_id)
565 .cloned()
566 .collect())
567 }
568
569 async fn delete_for_user(&self, user_id: &str) -> anyhow::Result<u64> {
570 let mut guard = self.0.lock().unwrap();
571 let before = guard.len();
572 guard.retain(|_, s| s.user_id != user_id);
573 Ok((before - guard.len()) as u64)
574 }
575
576 async fn purge_expired(&self, now: f64) -> anyhow::Result<u64> {
577 let mut guard = self.0.lock().unwrap();
578 let before = guard.len();
579 guard.retain(|_, s| s.expires_at > now);
580 Ok((before - guard.len()) as u64)
581 }
582
583 async fn list_all(
584 &self,
585 limit: i64,
586 offset: i64,
587 user_filter: Option<&str>,
588 ) -> anyhow::Result<Vec<Session>> {
589 let guard = self.0.lock().unwrap();
590 let mut all: Vec<Session> = guard
591 .values()
592 .filter(|s| user_filter.is_none_or(|u| s.user_id == u))
593 .cloned()
594 .collect();
595 all.sort_by(|a, b| {
596 b.created_at
597 .partial_cmp(&a.created_at)
598 .unwrap_or(std::cmp::Ordering::Equal)
599 });
600 let off = offset.max(0) as usize;
601 let lim = limit.clamp(1, 500) as usize;
602 Ok(all.into_iter().skip(off).take(lim).collect())
603 }
604
605 async fn count_all(&self, user_filter: Option<&str>) -> anyhow::Result<i64> {
606 let guard = self.0.lock().unwrap();
607 Ok(guard
608 .values()
609 .filter(|s| user_filter.is_none_or(|u| s.user_id == u))
610 .count() as i64)
611 }
612 }
613
614 fn manager() -> SessionManager {
615 SessionManager::with_default_duration(Arc::new(MemSessionStore::new()))
616 }
617
618 #[tokio::test]
619 async fn create_then_resolve_returns_same_session() {
620 let mgr = manager();
621 let created = mgr.create("user_alice").await.unwrap();
622 let resolved = mgr.resolve(&created.id).await.unwrap().unwrap();
623 assert_eq!(resolved.id, created.id);
624 assert_eq!(resolved.user_id, "user_alice");
625 assert!(created.id.starts_with("sess_"));
626 assert!(created.csrf_token.starts_with("csrf_"));
627 }
628
629 #[tokio::test]
630 async fn resolve_returns_none_for_unknown_id() {
631 let mgr = manager();
632 assert!(mgr.resolve("sess_nope").await.unwrap().is_none());
633 }
634
635 #[tokio::test]
636 async fn resolve_returns_none_for_expired_session() {
637 let store = Arc::new(MemSessionStore::new()) as Arc<dyn SessionStore>;
640 let expired = Session {
641 id: "sess_expired".to_string(),
642 user_id: "user_x".to_string(),
643 csrf_token: "csrf_expired".to_string(),
644 created_at: now_secs() - 1000.0,
645 expires_at: now_secs() - 1.0,
646 ip_hash: None,
647 user_agent_hash: None,
648 };
649 store.create(&expired).await.unwrap();
650 let mgr = SessionManager::with_default_duration(store);
651 assert!(mgr.resolve("sess_expired").await.unwrap().is_none());
652 }
653
654 #[tokio::test]
655 async fn rotate_returns_new_id_with_same_user_and_expiry() {
656 let mgr = manager();
657 let original = mgr.create("user_bob").await.unwrap();
658 let rotated = mgr.rotate(&original.id).await.unwrap().unwrap();
659 assert_ne!(rotated.id, original.id);
660 assert_eq!(rotated.user_id, original.user_id);
661 assert!((rotated.expires_at - original.expires_at).abs() < f64::EPSILON);
662 assert!(mgr.resolve(&original.id).await.unwrap().is_none());
664 assert!(mgr.resolve(&rotated.id).await.unwrap().is_some());
666 }
667
668 #[tokio::test]
669 async fn revoke_drops_the_session() {
670 let mgr = manager();
671 let s = mgr.create("user_eve").await.unwrap();
672 assert!(mgr.revoke(&s.id).await.unwrap());
673 assert!(mgr.resolve(&s.id).await.unwrap().is_none());
674 assert!(!mgr.revoke(&s.id).await.unwrap());
676 }
677
678 #[tokio::test]
679 async fn revoke_for_user_drops_every_session_for_that_user() {
680 let mgr = manager();
681 let _s1 = mgr.create("user_multi").await.unwrap();
682 let _s2 = mgr.create("user_multi").await.unwrap();
683 let _other = mgr.create("user_other").await.unwrap();
684 let dropped = mgr.revoke_for_user("user_multi").await.unwrap();
685 assert_eq!(dropped, 2);
686 }
687
688 #[test]
689 fn cookie_for_https_url_is_secure_httponly_lax() {
690 let session = Session {
691 id: "sess_abc".to_string(),
692 user_id: "u".to_string(),
693 csrf_token: "csrf_abc".to_string(),
694 created_at: now_secs(),
695 expires_at: now_secs() + 3600.0,
696 ip_hash: None,
697 user_agent_hash: None,
698 };
699 let url = Url::parse("https://app.example.com").unwrap();
700 let cookie = cookie_for(&session, &url);
701 assert_eq!(cookie.name(), SESSION_COOKIE);
702 assert_eq!(cookie.value(), "sess_abc");
703 assert_eq!(cookie.http_only(), Some(true));
704 assert_eq!(cookie.secure(), Some(true));
705 assert_eq!(cookie.same_site(), Some(SameSite::Lax));
706 assert_eq!(cookie.path(), Some("/"));
707 }
708
709 #[test]
710 fn csrf_cookie_is_not_http_only() {
711 let session = Session {
712 id: "sess_abc".to_string(),
713 user_id: "u".to_string(),
714 csrf_token: "csrf_abc".to_string(),
715 created_at: now_secs(),
716 expires_at: now_secs() + 3600.0,
717 ip_hash: None,
718 user_agent_hash: None,
719 };
720 let cookie = csrf_cookie_for(&session);
721 assert_eq!(cookie.name(), CSRF_COOKIE);
722 assert_eq!(cookie.value(), "csrf_abc");
723 assert_eq!(cookie.http_only(), Some(false));
725 }
726
727 #[test]
728 fn cookie_for_http_url_is_not_secure() {
729 let session = Session {
730 id: "sess_abc".to_string(),
731 user_id: "u".to_string(),
732 csrf_token: "csrf_abc".to_string(),
733 created_at: now_secs(),
734 expires_at: now_secs() + 3600.0,
735 ip_hash: None,
736 user_agent_hash: None,
737 };
738 let url = Url::parse("http://localhost:3000").unwrap();
739 let cookie = cookie_for(&session, &url);
740 assert_eq!(cookie.secure(), Some(false));
742 }
743}