use std::collections::HashSet;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Action {
Upload,
Download,
Delete,
List,
Mirror,
Admin,
}
pub trait AccessControl: Send + Sync {
fn is_allowed(&self, pubkey: &str, action: Action) -> bool;
}
pub struct OpenAccess;
impl AccessControl for OpenAccess {
fn is_allowed(&self, _pubkey: &str, _action: Action) -> bool {
true
}
}
pub struct Whitelist {
pubkeys: Arc<RwLock<HashSet<String>>>,
}
impl Whitelist {
pub fn new(pubkeys: HashSet<String>) -> Self {
Self {
pubkeys: Arc::new(RwLock::new(pubkeys)),
}
}
pub fn from_file(path: &Path) -> std::io::Result<Self> {
let content = std::fs::read_to_string(path)?;
let pubkeys = Self::parse_pubkeys(&content);
Ok(Self::new(pubkeys))
}
pub async fn reload(&self, path: &Path) -> std::io::Result<()> {
let content = tokio::fs::read_to_string(path).await?;
let new_keys = Self::parse_pubkeys(&content);
let mut keys = self.pubkeys.write().await;
*keys = new_keys;
tracing::info!(
access.backend = "whitelist",
access.pubkey_count = keys.len(),
"whitelist reloaded"
);
Ok(())
}
pub async fn add(&self, pubkey: String) {
self.pubkeys.write().await.insert(pubkey);
}
pub async fn remove(&self, pubkey: &str) {
self.pubkeys.write().await.remove(pubkey);
}
pub async fn contains(&self, pubkey: &str) -> bool {
self.pubkeys.read().await.contains(pubkey)
}
pub async fn len(&self) -> usize {
self.pubkeys.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.pubkeys.read().await.is_empty()
}
fn parse_pubkeys(content: &str) -> HashSet<String> {
content
.lines()
.map(|line| line.trim())
.filter(|line| !line.is_empty() && !line.starts_with('#'))
.filter(|line| line.len() == 64 && line.chars().all(|c| c.is_ascii_hexdigit()))
.map(|line| line.to_string())
.collect()
}
}
impl AccessControl for Whitelist {
fn is_allowed(&self, pubkey: &str, _action: Action) -> bool {
match self.pubkeys.try_read() {
Ok(keys) => keys.contains(pubkey),
Err(_) => false,
}
}
}
impl AccessControl for Arc<Whitelist> {
fn is_allowed(&self, pubkey: &str, action: Action) -> bool {
(**self).is_allowed(pubkey, action)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_open_access_allows_all() {
let ac = OpenAccess;
assert!(ac.is_allowed("anything", Action::Upload));
assert!(ac.is_allowed("anything", Action::Delete));
assert!(ac.is_allowed("anything", Action::Admin));
}
#[test]
fn test_whitelist_allows_listed() {
let pubkey = "a".repeat(64);
let mut keys = HashSet::new();
keys.insert(pubkey.clone());
let wl = Whitelist::new(keys);
assert!(wl.is_allowed(&pubkey, Action::Upload));
assert!(wl.is_allowed(&pubkey, Action::Download));
}
#[test]
fn test_whitelist_denies_unlisted() {
let wl = Whitelist::new(HashSet::new());
let pubkey = "b".repeat(64);
assert!(!wl.is_allowed(&pubkey, Action::Upload));
}
#[test]
fn test_parse_pubkeys_from_content() {
let content = format!(
"# This is a comment\n\n{}\n{}\ninvalid-short\n \n{}",
"a".repeat(64),
"b".repeat(64),
"c".repeat(64),
);
let keys = Whitelist::parse_pubkeys(&content);
assert_eq!(keys.len(), 3);
assert!(keys.contains(&"a".repeat(64)));
assert!(!keys.contains("invalid-short"));
}
#[tokio::test]
async fn test_whitelist_add_remove() {
let wl = Whitelist::new(HashSet::new());
let pk = "d".repeat(64);
assert!(!wl.contains(&pk).await);
wl.add(pk.clone()).await;
assert!(wl.contains(&pk).await);
assert_eq!(wl.len().await, 1);
wl.remove(&pk).await;
assert!(!wl.contains(&pk).await);
assert!(wl.is_empty().await);
}
#[test]
fn test_whitelist_from_file() {
let dir = std::env::temp_dir().join(format!("blossom_wl_{}", rand::random::<u32>()));
std::fs::create_dir_all(&dir).unwrap();
let file = dir.join("whitelist.txt");
let content = format!("# allowed users\n{}\n{}\n", "e".repeat(64), "f".repeat(64),);
std::fs::write(&file, &content).unwrap();
let wl = Whitelist::from_file(&file).unwrap();
assert!(wl.is_allowed(&"e".repeat(64), Action::Upload));
assert!(wl.is_allowed(&"f".repeat(64), Action::Download));
assert!(!wl.is_allowed(&"0".repeat(64), Action::Upload));
let _ = std::fs::remove_dir_all(&dir);
}
}