use crate::AuthUser;
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use chrono::{DateTime, Utc};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use umbral::orm::ForeignKey;
pub const TOKEN_PREFIX: &str = "umbral_";
#[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize, umbral::orm::Model)]
pub struct AuthToken {
pub id: i64,
#[umbral(on_delete = "cascade")]
pub user_id: ForeignKey<AuthUser>,
#[umbral(max_length = 64, unique)]
pub key_hash: String,
#[umbral(max_length = 80)]
pub name: String,
pub created_at: DateTime<Utc>,
pub last_used_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone)]
pub struct PlaintextToken(pub String);
impl std::fmt::Display for PlaintextToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
fn generate_plaintext() -> String {
let mut buf = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut buf);
format!("{TOKEN_PREFIX}{}", URL_SAFE_NO_PAD.encode(buf))
}
pub fn digest_token(plaintext: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(plaintext.as_bytes());
URL_SAFE_NO_PAD.encode(hasher.finalize())
}
impl AuthToken {
pub async fn create_for(
user: &AuthUser,
name: &str,
) -> Result<(Self, PlaintextToken), crate::AuthError> {
let plaintext = generate_plaintext();
let key_hash = digest_token(&plaintext);
let label = if name.is_empty() { "default" } else { name };
let row = AuthToken::objects()
.create(AuthToken {
id: 0, user_id: ForeignKey::new(user.id),
key_hash,
name: label.to_string(),
created_at: Utc::now(),
last_used_at: None,
})
.await?;
Ok((row, PlaintextToken(plaintext)))
}
pub async fn lookup(plaintext: &str) -> Result<Option<Self>, crate::AuthError> {
let key_hash = digest_token(plaintext);
let row = AuthToken::objects()
.filter(auth_token::KEY_HASH.eq(key_hash))
.first()
.await?;
Ok(row)
}
pub async fn revoke(&self) -> Result<(), crate::AuthError> {
AuthToken::objects()
.filter(auth_token::ID.eq(self.id))
.delete()
.await?;
Ok(())
}
pub(crate) async fn touch_last_used(&self) {
let now = Utc::now();
let mut delta = serde_json::Map::new();
delta.insert("last_used_at".to_string(), serde_json::json!(now));
let _ = AuthToken::objects()
.filter(auth_token::ID.eq(self.id))
.update_values(delta)
.await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generated_plaintext_has_prefix_and_decent_length() {
let t = generate_plaintext();
assert!(t.starts_with(TOKEN_PREFIX), "missing prefix: {t}");
assert_eq!(t.len(), TOKEN_PREFIX.len() + 43, "unexpected length: {t}");
}
#[test]
fn generated_plaintext_is_unique_per_call() {
let a = generate_plaintext();
let b = generate_plaintext();
assert_ne!(
a, b,
"two consecutive tokens collided (statistically impossible)"
);
}
#[test]
fn digest_is_deterministic_and_unique() {
let a = digest_token("umbral_AAAAA");
let b = digest_token("umbral_AAAAA");
let c = digest_token("umbral_BBBBB");
assert_eq!(a, b, "digest is supposed to be deterministic");
assert_ne!(a, c, "different inputs must produce different digests");
assert_eq!(a.len(), 43, "unexpected digest length: {a}");
}
}