use crate::migrate;
use anyhow::Result;
use chrono::{DateTime, Duration, Utc};
use fs2::FileExt;
use rand::distr::Alphanumeric;
use rand::RngExt;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::{self, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::PathBuf;
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct AuthEntry {
pub authorized_at: DateTime<Utc>,
#[serde(default)]
pub mention_only: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug, Default)]
pub struct Registry {
#[serde(default)]
pub users: HashMap<String, AuthEntry>, #[serde(default)]
pub channels: HashMap<String, AuthEntry>, }
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct PendingToken {
pub token: String,
pub type_: String, pub id: String,
pub expires_at: DateTime<Utc>,
}
#[derive(Serialize, Deserialize, Clone, Debug, Default)]
pub struct PendingStore {
pub tokens: HashMap<String, PendingToken>, }
pub struct AuthManager {
auth_path: PathBuf,
pending_path: PathBuf,
}
impl AuthManager {
pub fn new() -> Self {
let base_dir = migrate::get_base_dir();
fs::create_dir_all(&base_dir).unwrap();
Self::with_paths(
base_dir.join("auth.json"),
base_dir.join("pending_tokens.json"),
)
}
pub fn with_paths(auth_path: PathBuf, pending_path: PathBuf) -> Self {
Self {
auth_path,
pending_path,
}
}
fn with_lock<T, F>(&self, path: PathBuf, default: T, f: F) -> Result<T>
where
T: serde::de::DeserializeOwned + serde::Serialize + Default,
F: FnOnce(&mut T) -> Result<()>,
{
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&path)?;
file.lock_exclusive()?;
let mut content = String::new();
let mut reader = std::io::BufReader::new(&file);
reader.read_to_string(&mut content)?;
let mut data: T = if content.trim().is_empty() {
default
} else {
serde_json::from_str(&content).unwrap_or(default)
};
f(&mut data)?;
let json = serde_json::to_string_pretty(&data)?;
let mut file = file; file.set_len(0)?;
file.seek(SeekFrom::Start(0))?;
file.write_all(json.as_bytes())?;
file.unlock()?;
Ok(data)
}
pub fn is_authorized(&self, user_id: &str, channel_id: &str) -> (bool, bool) {
if let Ok(content) = fs::read_to_string(&self.auth_path) {
if let Ok(reg) = serde_json::from_str::<Registry>(&content) {
if reg.users.contains_key(user_id) {
return (true, false); }
if let Some(entry) = reg.channels.get(channel_id) {
return (true, entry.mention_only);
}
}
}
(false, false)
}
pub fn get_channel_mention_only(&self, channel_id: &str) -> Option<bool> {
if let Ok(content) = fs::read_to_string(&self.auth_path) {
if let Ok(reg) = serde_json::from_str::<Registry>(&content) {
return reg.channels.get(channel_id).map(|entry| entry.mention_only);
}
}
None
}
pub async fn is_authorized_with_thread(
&self,
ctx: &serenity::all::Context,
user_id: &str,
channel_id: serenity::model::id::ChannelId,
) -> (bool, bool) {
let id_str = channel_id.to_string();
let (auth, mention) = self.is_authorized(user_id, &id_str);
if auth {
return (auth, mention);
}
if let Ok(channel) = channel_id.to_channel(&ctx.http).await {
if let Some(guild_channel) = channel.guild() {
if let Some(parent_id) = guild_channel.parent_id {
return self.is_authorized(user_id, &parent_id.to_string());
}
}
}
(false, false)
}
pub fn create_token(&self, type_: &str, id: &str) -> Result<String> {
let token: String = rand::rng()
.sample_iter(&Alphanumeric)
.take(6)
.map(char::from)
.collect();
let entry = PendingToken {
token: token.clone(),
type_: type_.to_string(),
id: id.to_string(),
expires_at: Utc::now() + Duration::minutes(5),
};
self.with_lock(
self.pending_path.clone(),
PendingStore::default(),
|store| {
let now = Utc::now();
store.tokens.retain(|_, v| v.expires_at > now);
store.tokens.insert(token.clone(), entry);
Ok(())
},
)?;
Ok(token)
}
pub fn redeem_token(&self, token: &str) -> Result<(String, String)> {
let mut found_entry: Option<PendingToken> = None;
self.with_lock(
self.pending_path.clone(),
PendingStore::default(),
|store| {
let now = Utc::now();
store.tokens.retain(|_, v| v.expires_at > now);
if let Some(entry) = store.tokens.remove(token) {
found_entry = Some(entry);
}
Ok(())
},
)?;
let entry = found_entry.ok_or_else(|| anyhow::anyhow!("Invalid or expired token"))?;
self.with_lock(self.auth_path.clone(), Registry::default(), |reg| {
let auth_entry = AuthEntry {
authorized_at: Utc::now(),
mention_only: entry.type_ == "channel", };
match entry.type_.as_str() {
"user" => {
reg.users.insert(entry.id.clone(), auth_entry);
}
"channel" => {
reg.channels.insert(entry.id.clone(), auth_entry);
}
_ => {}
}
Ok(())
})?;
Ok((entry.type_, entry.id))
}
pub fn set_mention_only(&self, channel_id: &str, enable: bool) -> Result<()> {
self.with_lock(self.auth_path.clone(), Registry::default(), |reg| {
if let Some(entry) = reg.channels.get_mut(channel_id) {
entry.mention_only = enable;
} else {
anyhow::bail!("Channel not authorized yet.");
}
Ok(())
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::{tempdir, TempDir};
fn create_test_manager() -> anyhow::Result<(TempDir, AuthManager)> {
let dir = tempdir()?;
let manager = AuthManager::with_paths(
dir.path().join("auth.json"),
dir.path().join("pending_tokens.json"),
);
Ok((dir, manager))
}
#[test]
fn test_auth_token_flow() -> anyhow::Result<()> {
let (_dir, manager) = create_test_manager()?;
let token = manager.create_token("channel", "12345")?;
assert_eq!(token.len(), 6);
let (type_, id) = manager.redeem_token(&token)?;
assert_eq!(type_, "channel");
assert_eq!(id, "12345");
let (auth, mention) = manager.is_authorized("user_0", "12345");
assert!(auth);
assert!(mention);
Ok(())
}
#[test]
fn test_auth_user_override() -> anyhow::Result<()> {
let (_dir, manager) = create_test_manager()?;
let token = manager.create_token("channel", "chan_1")?;
let _ = manager.redeem_token(&token)?;
let u_token = manager.create_token("user", "user_god")?;
let _ = manager.redeem_token(&u_token)?;
let (auth, mention) = manager.is_authorized("user_god", "chan_1");
assert!(auth);
assert!(!mention);
Ok(())
}
}