use std::collections::HashMap;
use std::time::Duration;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use parking_lot::RwLock;
use rand::rngs::OsRng;
use rand::RngCore;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Invite {
pub token: String,
pub max_uses: Option<u32>,
pub expires_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Error)]
pub enum InviteStoreError {
#[error("backend: {0}")]
Backend(String),
}
#[async_trait]
pub trait InviteStore: Send + Sync + 'static {
async fn insert(&self, invite: Invite) -> Result<(), InviteStoreError>;
async fn get(&self, token: &str) -> Result<Option<Invite>, InviteStoreError>;
}
#[derive(Default)]
pub struct InMemoryInviteStore {
inner: RwLock<HashMap<String, Invite>>,
}
impl InMemoryInviteStore {
pub fn new() -> Self {
Self::default()
}
pub fn snapshot(&self) -> Vec<Invite> {
self.inner.read().values().cloned().collect()
}
}
#[async_trait]
impl InviteStore for InMemoryInviteStore {
async fn insert(&self, invite: Invite) -> Result<(), InviteStoreError> {
self.inner
.write()
.entry(invite.token.clone())
.or_insert(invite);
Ok(())
}
async fn get(&self, token: &str) -> Result<Option<Invite>, InviteStoreError> {
Ok(self.inner.read().get(token).cloned())
}
}
pub fn mint_token() -> String {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
let mut buf = [0u8; 32];
OsRng.fill_bytes(&mut buf);
URL_SAFE_NO_PAD.encode(buf)
}
pub fn parse_duration(input: &str) -> Result<Duration, String> {
let trimmed = input.trim();
if trimmed.is_empty() {
return Err("empty duration".to_string());
}
if let Ok(n) = trimmed.parse::<u64>() {
return Ok(Duration::from_secs(n));
}
let (num_part, unit) = trimmed.split_at(
trimmed
.find(|c: char| !c.is_ascii_digit())
.ok_or_else(|| format!("no unit suffix in {trimmed:?}"))?,
);
let n: u64 = num_part
.parse()
.map_err(|e| format!("invalid number {num_part:?}: {e}"))?;
let secs = match unit {
"s" => n,
"m" => n.saturating_mul(60),
"h" => n.saturating_mul(3_600),
"d" => n.saturating_mul(86_400),
"w" => n.saturating_mul(604_800),
other => return Err(format!("unknown duration unit {other:?}")),
};
Ok(Duration::from_secs(secs))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn inmemory_store_round_trips() {
let s = InMemoryInviteStore::new();
let inv = Invite {
token: "tok-1".into(),
max_uses: Some(3),
expires_at: None,
};
s.insert(inv.clone()).await.unwrap();
let got = s.get("tok-1").await.unwrap().unwrap();
assert_eq!(got, inv);
assert!(s.get("missing").await.unwrap().is_none());
}
#[test]
fn mint_token_is_base64url_and_uniqueish() {
let a = mint_token();
let b = mint_token();
assert_ne!(a, b);
assert!(a.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
assert_eq!(a.len(), 43);
}
#[test]
fn parse_duration_accepts_common_units() {
assert_eq!(parse_duration("30s").unwrap(), Duration::from_secs(30));
assert_eq!(parse_duration("5m").unwrap(), Duration::from_secs(300));
assert_eq!(parse_duration("2h").unwrap(), Duration::from_secs(7_200));
assert_eq!(parse_duration("7d").unwrap(), Duration::from_secs(604_800));
assert_eq!(parse_duration("1w").unwrap(), Duration::from_secs(604_800));
assert_eq!(parse_duration("60").unwrap(), Duration::from_secs(60));
}
#[test]
fn parse_duration_rejects_bad_input() {
assert!(parse_duration("").is_err());
assert!(parse_duration("1y").is_err());
assert!(parse_duration("abc").is_err());
}
}