use anyhow::Result;
use parking_lot::{Mutex, RwLock};
use rusqlite::{Connection, params};
use rusqlite_migration::{M, Migrations};
use ustr::{Ustr, UstrMap};
use crate::{error::BlacklistError, utils};
pub trait Blacklist {
fn add_to_blacklist(&mut self, unit_id: Ustr) -> Result<(), BlacklistError>;
fn remove_from_blacklist(&mut self, unit_id: Ustr) -> Result<(), BlacklistError>;
fn remove_prefix_from_blacklist(&mut self, prefix: &str) -> Result<(), BlacklistError>;
fn blacklisted(&self, unit_id: Ustr) -> Result<bool, BlacklistError>;
fn get_blacklist_entries(&self) -> Result<Vec<Ustr>, BlacklistError>;
}
pub struct LocalBlacklist {
cache: RwLock<UstrMap<bool>>,
connection: Mutex<Connection>,
}
impl LocalBlacklist {
fn migrations() -> Migrations<'static> {
Migrations::new(vec![
M::up("CREATE TABLE blacklist(unit_id TEXT NOT NULL UNIQUE);")
.down("DROP TABLE blacklist"),
M::up("CREATE INDEX unit_id_index ON blacklist (unit_id);")
.down("DROP INDEX unit_id_index"),
])
}
fn init(&mut self) -> Result<()> {
let migrations = Self::migrations();
let mut connection = self.connection.lock();
migrations.to_latest(&mut connection)?;
Ok(())
}
fn new(connection: Connection) -> Result<LocalBlacklist> {
let mut blacklist = LocalBlacklist {
cache: RwLock::new(UstrMap::default()),
connection: Mutex::new(connection),
};
blacklist.init()?;
for unit_id in blacklist.get_blacklist_entries()? {
blacklist.cache.write().insert(unit_id, true);
}
Ok(blacklist)
}
pub fn new_from_disk(db_path: &str) -> Result<LocalBlacklist> {
let connection = utils::new_connection(db_path)?;
Self::new(connection)
}
#[inline]
fn has_entry(&self, unit_id: Ustr) -> bool {
if let Some(has_entry) = self.cache.read().get(&unit_id) {
return *has_entry;
}
self.cache.write().insert(unit_id, false);
false
}
fn add_to_blacklist_helper(&mut self, unit_id: Ustr) -> Result<()> {
let has_entry = self.has_entry(unit_id);
if has_entry {
return Ok(());
}
let connection = self.connection.lock();
let mut stmt = connection.prepare_cached("INSERT INTO blacklist (unit_id) VALUES (?1)")?;
stmt.execute(params![unit_id.as_str()])?;
self.cache.write().insert(unit_id, true);
Ok(())
}
fn remove_from_blacklist_helper(&mut self, unit_id: Ustr) -> Result<()> {
let connection = self.connection.lock();
let mut stmt = connection.prepare_cached("DELETE FROM blacklist WHERE unit_id = $1")?;
stmt.execute(params![unit_id.as_str()])?;
self.cache.write().insert(unit_id, false);
Ok(())
}
fn remove_prefix_from_blacklist_helper(&mut self, prefix: &str) -> Result<()> {
let connection = self.connection.lock();
let mut stmt =
connection.prepare_cached("SELECT unit_id from blacklist WHERE unit_id LIKE $1;")?;
let mut rows = stmt.query(params![format!("{}%", prefix)])?;
let mut stmt = connection.prepare_cached("DELETE FROM blacklist WHERE unit_id = $1")?;
let mut cache = self.cache.write();
while let Some(row) = rows.next()? {
let unit_id: String = row.get(0)?;
stmt.execute(params![unit_id])?;
cache.insert(unit_id.into(), false);
}
connection.execute_batch("VACUUM;")?;
Ok(())
}
fn all_blacklist_entries_helper(&self) -> Result<Vec<Ustr>> {
let connection = self.connection.lock();
let mut stmt = connection.prepare_cached("SELECT unit_id from blacklist;")?;
let mut rows = stmt.query(params![])?;
let mut entries = Vec::new();
while let Some(row) = rows.next()? {
let unit_id: String = row.get(0)?;
entries.push(Ustr::from(&unit_id));
}
Ok(entries)
}
}
impl Blacklist for LocalBlacklist {
fn add_to_blacklist(&mut self, unit_id: Ustr) -> Result<(), BlacklistError> {
self.add_to_blacklist_helper(unit_id)
.map_err(|e| BlacklistError::AddUnit(unit_id, e))
}
fn remove_from_blacklist(&mut self, unit_id: Ustr) -> Result<(), BlacklistError> {
self.remove_from_blacklist_helper(unit_id)
.map_err(|e| BlacklistError::RemoveUnit(unit_id, e))
}
fn remove_prefix_from_blacklist(&mut self, prefix: &str) -> Result<(), BlacklistError> {
self.remove_prefix_from_blacklist_helper(prefix)
.map_err(|e| BlacklistError::RemovePrefix(prefix.into(), e))
}
#[inline]
fn blacklisted(&self, unit_id: Ustr) -> Result<bool, BlacklistError> {
Ok(self.has_entry(unit_id))
}
fn get_blacklist_entries(&self) -> Result<Vec<Ustr>, BlacklistError> {
self.all_blacklist_entries_helper()
.map_err(BlacklistError::GetEntries)
}
}
#[cfg(test)]
#[cfg_attr(coverage, coverage(off))]
mod test {
use anyhow::Result;
use rusqlite::Connection;
use tempfile::tempdir;
use ustr::Ustr;
use crate::blacklist::{Blacklist, LocalBlacklist};
fn new_test_blacklist() -> Result<Box<dyn Blacklist>> {
let connection = Connection::open_in_memory()?;
let blacklist = LocalBlacklist::new(connection)?;
Ok(Box::new(blacklist))
}
#[test]
fn not_in_blacklist() -> Result<()> {
let blacklist = new_test_blacklist()?;
assert!(!blacklist.blacklisted(Ustr::from("unit_id"))?);
Ok(())
}
#[test]
fn add_and_remove_from_blacklist() -> Result<()> {
let mut blacklist = new_test_blacklist()?;
let unit_id = Ustr::from("unit_id");
blacklist.add_to_blacklist(unit_id)?;
assert!(blacklist.blacklisted(unit_id)?);
blacklist.remove_from_blacklist(unit_id)?;
assert!(!blacklist.blacklisted(unit_id)?);
Ok(())
}
#[test]
fn remove_prefix_from_blacklist() -> Result<()> {
let mut blacklist = new_test_blacklist()?;
let units = vec![
Ustr::from("a"),
Ustr::from("a::a"),
Ustr::from("b"),
Ustr::from("b::a"),
Ustr::from("c"),
Ustr::from("c::a"),
];
for unit in &units {
blacklist.add_to_blacklist(*unit).unwrap();
}
blacklist.remove_prefix_from_blacklist("a").unwrap();
for unit in units {
if unit.as_str().starts_with('a') {
assert!(!blacklist.blacklisted(unit).unwrap());
} else {
assert!(blacklist.blacklisted(unit).unwrap());
}
}
Ok(())
}
#[test]
fn blacklist_cache() -> Result<()> {
let mut blacklist = new_test_blacklist()?;
let unit_id = Ustr::from("unit_id");
blacklist.add_to_blacklist(unit_id)?;
assert!(blacklist.blacklisted(unit_id)?);
assert!(blacklist.blacklisted(unit_id)?);
blacklist.add_to_blacklist(unit_id)?;
Ok(())
}
#[test]
fn readd_to_blacklist() -> Result<()> {
let mut blacklist = new_test_blacklist()?;
let unit_id = Ustr::from("unit_id");
blacklist.add_to_blacklist(unit_id)?;
assert!(blacklist.blacklisted(unit_id)?);
blacklist.remove_from_blacklist(unit_id)?;
assert!(!blacklist.blacklisted(unit_id)?);
blacklist.add_to_blacklist(unit_id)?;
assert!(blacklist.blacklisted(unit_id)?);
Ok(())
}
#[test]
fn all_entries() -> Result<()> {
let mut blacklist = new_test_blacklist()?;
let unit_id = Ustr::from("unit_id");
let unit_id2 = Ustr::from("unit_id2");
blacklist.add_to_blacklist(unit_id)?;
assert!(blacklist.blacklisted(unit_id)?);
blacklist.add_to_blacklist(unit_id2)?;
assert!(blacklist.blacklisted(unit_id2)?);
assert_eq!(blacklist.get_blacklist_entries()?, vec![unit_id, unit_id2]);
Ok(())
}
#[test]
fn reopen_blacklist() -> Result<()> {
let dir = tempdir()?;
let mut blacklist =
LocalBlacklist::new_from_disk(dir.path().join("blacklist.db").to_str().unwrap())?;
let unit_id = Ustr::from("unit_id");
blacklist.add_to_blacklist(unit_id)?;
assert!(blacklist.blacklisted(unit_id)?);
let new_blacklist =
LocalBlacklist::new_from_disk(dir.path().join("blacklist.db").to_str().unwrap())?;
assert!(new_blacklist.blacklisted(unit_id)?);
Ok(())
}
}