use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
use astrid_core::PrincipalId;
use astrid_core::dirs::AstridHome;
use base64::Engine;
use rand::RngCore;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
pub const TOKEN_RAW_LEN: usize = 24;
pub const MAX_EXPIRY_SECS: u64 = 60 * 60;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PairToken {
pub token_hash: String,
pub principal: PrincipalId,
pub expires_at_epoch: u64,
pub issued_at_epoch: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub label: Option<String>,
}
#[derive(Debug)]
pub struct PairTokenStore {
path: PathBuf,
}
impl PairTokenStore {
#[must_use]
pub const fn new(path: PathBuf) -> Self {
Self { path }
}
#[must_use]
pub fn path_for(home: &AstridHome) -> PathBuf {
home.etc_dir().join("pair-tokens.toml")
}
pub fn load(&self) -> Result<Vec<PairToken>, PairTokenStoreError> {
let bytes = match std::fs::read(&self.path) {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(Vec::new()),
Err(e) => return Err(PairTokenStoreError::Io(e)),
};
let text = std::str::from_utf8(&bytes).map_err(|e| {
PairTokenStoreError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e))
})?;
if text.trim().is_empty() {
return Ok(Vec::new());
}
let parsed: PersistedFile = toml::from_str(text).map_err(PairTokenStoreError::Toml)?;
Ok(parsed.pair_token)
}
pub fn save(&self, tokens: &[PairToken]) -> Result<(), PairTokenStoreError> {
if let Some(parent) = self.path.parent() {
std::fs::create_dir_all(parent).map_err(PairTokenStoreError::Io)?;
}
let body = PersistedFile {
pair_token: tokens.to_vec(),
};
let text = toml::to_string_pretty(&body).map_err(PairTokenStoreError::TomlSer)?;
#[cfg(unix)]
{
use std::io::Write;
use std::os::unix::fs::OpenOptionsExt;
let tmp_path = self
.path
.with_extension(format!("{}.tmp", std::process::id()));
let mut f = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(&tmp_path)
.map_err(PairTokenStoreError::Io)?;
f.write_all(text.as_bytes())
.map_err(PairTokenStoreError::Io)?;
f.sync_all().map_err(PairTokenStoreError::Io)?;
drop(f);
if let Err(e) = std::fs::rename(&tmp_path, &self.path) {
let _ = std::fs::remove_file(&tmp_path);
return Err(PairTokenStoreError::Io(e));
}
}
#[cfg(not(unix))]
{
std::fs::write(&self.path, text.as_bytes()).map_err(PairTokenStoreError::Io)?;
}
Ok(())
}
}
#[derive(Debug)]
pub enum PairTokenStoreError {
Io(std::io::Error),
Toml(toml::de::Error),
TomlSer(toml::ser::Error),
}
impl std::fmt::Display for PairTokenStoreError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "pair-token store io: {e}"),
Self::Toml(e) => write!(f, "pair-token store parse: {e}"),
Self::TomlSer(e) => write!(f, "pair-token store serialise: {e}"),
}
}
}
impl std::error::Error for PairTokenStoreError {}
#[derive(Debug, Default, Serialize, Deserialize)]
struct PersistedFile {
#[serde(default)]
pair_token: Vec<PairToken>,
}
#[must_use]
pub fn generate_token() -> String {
let mut bytes = [0u8; TOKEN_RAW_LEN];
rand::rngs::OsRng.fill_bytes(&mut bytes);
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
#[must_use]
pub fn hash_token(token: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
hex::encode(hasher.finalize())
}
#[must_use]
pub fn ct_hash_eq(a: &str, b: &str) -> bool {
if a.len() != b.len() {
return false;
}
a.as_bytes().ct_eq(b.as_bytes()).into()
}
#[must_use]
pub fn now_epoch() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.as_secs())
}
pub fn prune_expired(tokens: &mut Vec<PairToken>) -> usize {
let now = now_epoch();
let before = tokens.len();
tokens.retain(|t| t.expires_at_epoch > now);
before.saturating_sub(tokens.len())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_is_random_and_short() {
let a = generate_token();
let b = generate_token();
assert_ne!(a, b);
assert_eq!(a.len(), 32);
}
#[test]
fn hash_is_deterministic_hex() {
let h = hash_token("hello");
assert_eq!(h.len(), 64);
assert_eq!(h, hash_token("hello"));
assert_ne!(h, hash_token("world"));
}
#[test]
fn round_trip_save_load() {
let dir = tempfile::tempdir().unwrap();
let store = PairTokenStore::new(dir.path().join("pair-tokens.toml"));
let token = PairToken {
token_hash: "abc".into(),
principal: PrincipalId::new("alice").unwrap(),
expires_at_epoch: 9_999_999_999,
issued_at_epoch: 1,
label: Some("phone".into()),
};
store.save(&[token.clone()]).unwrap();
let loaded = store.load().unwrap();
assert_eq!(loaded, vec![token]);
}
#[test]
fn prune_drops_expired() {
let now = now_epoch();
let mut v = vec![
PairToken {
token_hash: "a".into(),
principal: PrincipalId::default(),
expires_at_epoch: now.saturating_add(60),
issued_at_epoch: now,
label: None,
},
PairToken {
token_hash: "b".into(),
principal: PrincipalId::default(),
expires_at_epoch: now.saturating_sub(60),
issued_at_epoch: now.saturating_sub(120),
label: None,
},
];
assert_eq!(prune_expired(&mut v), 1);
assert_eq!(v.len(), 1);
}
}