enigmatick/
blocklist.rs

1use std::sync::Arc;
2
3use crate::{models::instances::Instance, schema::instances};
4use async_mutex::Mutex;
5use deadpool_diesel::postgres::Pool;
6use diesel::prelude::*;
7
8#[derive(Debug)]
9pub enum AccessControlError {
10    Prohibited,
11}
12
13#[derive(Clone, Eq, PartialEq)]
14pub struct Permitted(pub bool);
15
16impl Permitted {
17    pub fn is_permitted(&self) -> bool {
18        matches!(self, Permitted(true))
19    }
20}
21
22#[derive(Clone)]
23pub struct BlockList {
24    pub blocked_servers: Arc<Mutex<Vec<String>>>,
25}
26
27impl BlockList {
28    // Add this new async function specifically for Axum's pool type
29    pub async fn new_axum(pool: &Pool) -> anyhow::Result<Self> {
30        let conn = pool.get().await?;
31        // The `??` operator fails here because the error type returned by `interact`
32        // is not `Sync`, which is required by `anyhow`'s `From` trait implementation.
33        // We handle it in two steps to manually convert the error.
34        let query_result = conn
35            .interact(move |c| {
36                instances::table
37                    .filter(instances::blocked.eq(true))
38                    .get_results::<Instance>(c)
39            })
40            .await
41            .map_err(|e| anyhow::anyhow!("Database interaction failed: {:?}", e))?;
42
43        let instances = query_result?;
44
45        log::debug!("loading {:?} blocked servers for Axum", instances.len());
46        let blocked_servers: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(
47            instances.iter().map(|x| x.domain_name.clone()).collect(),
48        ));
49
50        Ok(BlockList { blocked_servers })
51    }
52
53    pub fn add(&mut self, server: String) {
54        log::debug!("adding {server} to BlockList");
55        if let Some(mut x) = self.blocked_servers.try_lock() {
56            x.push(server);
57        }
58    }
59
60    pub fn is_blocked(&self, server: String) -> bool {
61        log::debug!("checking {server} against BlockList");
62        if let Some(x) = self.blocked_servers.try_lock() {
63            x.contains(&server)
64        } else {
65            // fail open
66            false
67        }
68    }
69}