use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use rand::RngCore;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SessionError {
#[error("unknown session")]
Unknown,
#[error("session expired")]
Expired,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SessionId(String);
impl SessionId {
pub fn generate() -> Self {
let mut buf = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut buf);
Self(hex::encode(buf))
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn from_raw(s: impl Into<String>) -> Self {
Self(s.into())
}
}
#[derive(Debug, Clone)]
pub struct AuthCodeRecord {
pub code: String,
pub client_id: String,
pub account_id: String,
pub redirect_uri: String,
pub code_challenge: Option<String>,
pub issued_at: Instant,
pub requested_scope: Option<String>,
}
#[derive(Debug, Clone)]
pub struct SessionRecord {
pub account_id: String,
pub created_at: Instant,
pub last_access: Instant,
}
impl SessionRecord {
fn new(account_id: String) -> Self {
Self {
account_id,
created_at: Instant::now(),
last_access: Instant::now(),
}
}
}
#[derive(Clone, Default)]
pub struct SessionStore {
inner: Arc<RwLock<Inner>>,
session_ttl: Duration,
code_ttl: Duration,
}
#[derive(Default)]
struct Inner {
sessions: HashMap<String, SessionRecord>,
codes: HashMap<String, AuthCodeRecord>,
}
impl SessionStore {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(Inner::default())),
session_ttl: Duration::from_secs(14 * 24 * 3600),
code_ttl: Duration::from_secs(10 * 60),
}
}
pub fn with_ttls(mut self, session_ttl: Duration, code_ttl: Duration) -> Self {
self.session_ttl = session_ttl;
self.code_ttl = code_ttl;
self
}
pub fn create_session(&self, account_id: impl Into<String>) -> SessionId {
let id = SessionId::generate();
self.inner
.write()
.sessions
.insert(id.as_str().to_string(), SessionRecord::new(account_id.into()));
id
}
pub fn lookup(&self, id: &SessionId) -> Result<SessionRecord, SessionError> {
let mut inner = self.inner.write();
let entry = inner
.sessions
.get_mut(id.as_str())
.ok_or(SessionError::Unknown)?;
if entry.last_access.elapsed() > self.session_ttl {
inner.sessions.remove(id.as_str());
return Err(SessionError::Expired);
}
entry.last_access = Instant::now();
Ok(entry.clone())
}
pub fn revoke(&self, id: &SessionId) {
self.inner.write().sessions.remove(id.as_str());
}
pub fn issue_code(
&self,
client_id: impl Into<String>,
account_id: impl Into<String>,
redirect_uri: impl Into<String>,
code_challenge: Option<String>,
requested_scope: Option<String>,
) -> AuthCodeRecord {
let mut buf = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut buf);
let code = hex::encode(buf);
let rec = AuthCodeRecord {
code: code.clone(),
client_id: client_id.into(),
account_id: account_id.into(),
redirect_uri: redirect_uri.into(),
code_challenge,
issued_at: Instant::now(),
requested_scope,
};
self.inner.write().codes.insert(code, rec.clone());
rec
}
pub fn take_code(&self, code: &str) -> Option<AuthCodeRecord> {
let mut inner = self.inner.write();
let rec = inner.codes.remove(code)?;
if rec.issued_at.elapsed() > self.code_ttl {
return None;
}
Some(rec)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn session_ids_are_unique() {
let a = SessionId::generate();
let b = SessionId::generate();
assert_ne!(a.as_str(), b.as_str());
assert_eq!(a.as_str().len(), 64); }
#[test]
fn session_create_lookup_revoke_roundtrip() {
let s = SessionStore::new();
let id = s.create_session("acct-1");
let rec = s.lookup(&id).unwrap();
assert_eq!(rec.account_id, "acct-1");
s.revoke(&id);
assert!(matches!(s.lookup(&id), Err(SessionError::Unknown)));
}
#[test]
fn session_expiry_is_enforced() {
let s = SessionStore::new().with_ttls(Duration::from_millis(1), Duration::from_secs(60));
let id = s.create_session("acct-2");
std::thread::sleep(Duration::from_millis(10));
let err = s.lookup(&id).unwrap_err();
assert!(matches!(err, SessionError::Expired));
}
#[test]
fn auth_code_is_single_use() {
let s = SessionStore::new();
let rec = s.issue_code("c-1", "acct-3", "https://app/cb", None, None);
let a = s.take_code(&rec.code).unwrap();
assert_eq!(a.account_id, "acct-3");
assert!(s.take_code(&rec.code).is_none());
}
#[test]
fn auth_code_expires() {
let s = SessionStore::new()
.with_ttls(Duration::from_secs(60), Duration::from_millis(1));
let rec = s.issue_code("c-1", "acct-4", "https://app/cb", None, None);
std::thread::sleep(Duration::from_millis(10));
assert!(s.take_code(&rec.code).is_none());
}
}