infrarust 1.2.0

A Rust universal Minecraft proxy
Documentation
//! In-memory storage backend for ban system

use std::net::IpAddr;
use std::sync::Arc;

use async_trait::async_trait;
use tokio::sync::RwLock;
use tracing::{debug, info};

use super::{
    BanAuditLogEntry, BanEntry, BanError, BanStatistics, BanStorageBackend, index::BanIndex,
};

#[derive(Debug)]
pub struct MemoryBanStorage {
    index: BanIndex,
    audit_logs: Arc<RwLock<Vec<BanAuditLogEntry>>>,
}

impl Default for MemoryBanStorage {
    fn default() -> Self {
        Self::new()
    }
}

impl MemoryBanStorage {
    pub fn new() -> Self {
        info!("Initializing in-memory ban storage");
        Self {
            index: BanIndex::new(),
            audit_logs: Arc::new(RwLock::new(Vec::new())),
        }
    }
}

#[async_trait]
impl BanStorageBackend for MemoryBanStorage {
    async fn add_ban(&self, ban: BanEntry) -> Result<(), BanError> {
        debug!("Adding ban: {:?}", ban);
        self.index.add(Arc::new(ban)).await;
        Ok(())
    }

    async fn add_bans_batch(&self, bans: Vec<BanEntry>) -> Result<(), BanError> {
        debug!("Adding {} bans in batch", bans.len());
        for ban in bans {
            self.index.add(Arc::new(ban)).await;
        }
        Ok(())
    }

    async fn remove_ban(&self, ban_id: &str) -> Result<BanEntry, BanError> {
        debug!("Removing ban with ID: {}", ban_id);
        match self.index.remove(ban_id).await {
            Some(ban) => {
                debug!("Ban removed: {:?}", ban);
                Ok((*ban).clone())
            }
            None => {
                debug!("Ban not found with ID: {}", ban_id);
                Err(BanError::NotFound)
            }
        }
    }

    async fn get_ban_by_id(&self, ban_id: &str) -> Result<BanEntry, BanError> {
        match self.index.get_by_id(ban_id) {
            Some(ban) => Ok((*ban).clone()),
            None => Err(BanError::NotFound),
        }
    }

    async fn get_bans_by_ip(&self, ip: &IpAddr) -> Result<Vec<BanEntry>, BanError> {
        let bans = self.index.get_by_ip(ip);
        Ok(bans.into_iter().map(|b| (*b).clone()).collect())
    }

    async fn get_bans_by_uuid(&self, uuid: &str) -> Result<Vec<BanEntry>, BanError> {
        let bans = self.index.get_by_uuid(uuid);
        Ok(bans.into_iter().map(|b| (*b).clone()).collect())
    }

    async fn get_bans_by_username(&self, username: &str) -> Result<Vec<BanEntry>, BanError> {
        let bans = self.index.get_by_username(username);
        Ok(bans.into_iter().map(|b| (*b).clone()).collect())
    }

    async fn is_ip_banned(&self, ip: &IpAddr) -> Result<bool, BanError> {
        Ok(self.index.is_ip_banned(ip).await)
    }

    async fn is_uuid_banned(&self, uuid: &str) -> Result<bool, BanError> {
        Ok(self.index.is_uuid_banned(uuid).await)
    }

    async fn is_username_banned(&self, username: &str) -> Result<bool, BanError> {
        Ok(self.index.is_username_banned(username).await)
    }

    async fn get_ban_reason_for_ip(&self, ip: &IpAddr) -> Result<Option<String>, BanError> {
        let bans = self.index.get_by_ip(ip);

        if bans.is_empty() {
            return Ok(None);
        }

        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap_or_default()
            .as_secs();

        for ban in bans {
            if let Some(expires_at) = ban.expires_at {
                if expires_at > now {
                    return Ok(Some(ban.reason.clone()));
                }
            } else {
                // Permanent ban
                return Ok(Some(ban.reason.clone()));
            }
        }

        Ok(None)
    }

    async fn get_ban_reason_for_uuid(&self, uuid: &str) -> Result<Option<String>, BanError> {
        let bans = self.index.get_by_uuid(uuid);

        if bans.is_empty() {
            return Ok(None);
        }

        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap_or_default()
            .as_secs();

        for ban in bans {
            if let Some(expires_at) = ban.expires_at {
                if expires_at > now {
                    return Ok(Some(ban.reason.clone()));
                }
            } else {
                // Permanent ban
                return Ok(Some(ban.reason.clone()));
            }
        }

        Ok(None)
    }

    async fn get_ban_reason_for_username(
        &self,
        username: &str,
    ) -> Result<Option<String>, BanError> {
        let bans = self.index.get_by_username(username);

        if bans.is_empty() {
            return Ok(None);
        }

        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap_or_default()
            .as_secs();

        for ban in bans {
            if let Some(expires_at) = ban.expires_at {
                if expires_at > now {
                    return Ok(Some(ban.reason.clone()));
                }
            } else {
                // Permanent ban
                return Ok(Some(ban.reason.clone()));
            }
        }

        Ok(None)
    }

    async fn get_all_bans(&self) -> Result<Vec<BanEntry>, BanError> {
        let bans = self.index.get_all();
        Ok(bans.into_iter().map(|b| (*b).clone()).collect())
    }

    async fn get_active_bans(&self) -> Result<Vec<BanEntry>, BanError> {
        let bans = self.index.get_active().await;
        Ok(bans.into_iter().map(|b| (*b).clone()).collect())
    }

    async fn get_active_bans_paged(
        &self,
        page: usize,
        page_size: usize,
    ) -> Result<(Vec<BanEntry>, usize), BanError> {
        let all_bans = self.index.get_active().await;
        let total = all_bans.len();

        let start = page * page_size;
        let end = (start + page_size).min(total);

        if start >= total {
            return Ok((Vec::new(), total));
        }

        let paged_bans = all_bans[start..end].iter().map(|b| (**b).clone()).collect();

        Ok((paged_bans, total))
    }

    async fn get_expired_bans(&self) -> Result<Vec<BanEntry>, BanError> {
        let bans = self.index.get_expired().await;
        Ok(bans.into_iter().map(|b| (*b).clone()).collect())
    }

    async fn clear_expired_bans(&self) -> Result<usize, BanError> {
        let expired = self.index.get_expired().await;
        let count = expired.len();

        for ban in expired {
            let _ = self.index.remove(&ban.id).await;
        }

        debug!("Cleared {} expired bans", count);
        Ok(count)
    }

    async fn add_audit_log(&self, entry: BanAuditLogEntry) -> Result<(), BanError> {
        let mut logs = self.audit_logs.write().await;
        logs.push(entry);
        Ok(())
    }

    async fn add_audit_logs_batch(&self, entries: Vec<BanAuditLogEntry>) -> Result<(), BanError> {
        let mut logs = self.audit_logs.write().await;
        logs.extend(entries);
        Ok(())
    }

    async fn get_audit_logs_paged(
        &self,
        page: usize,
        page_size: usize,
    ) -> Result<(Vec<BanAuditLogEntry>, usize), BanError> {
        let logs = self.audit_logs.read().await;
        let total = logs.len();

        let start = page * page_size;
        let end = (start + page_size).min(total);

        if start >= total {
            return Ok((Vec::new(), total));
        }

        let paged_logs = logs[start..end].to_vec();
        Ok((paged_logs, total))
    }

    async fn search_bans(
        &self,
        ip: Option<IpAddr>,
        uuid: Option<&str>,
        username: Option<&str>,
        reason_contains: Option<&str>,
        banned_by: Option<&str>,
        page: usize,
        page_size: usize,
    ) -> Result<(Vec<BanEntry>, usize), BanError> {
        let results = self
            .index
            .search(ip, uuid, username, reason_contains, banned_by)
            .await;

        let total = results.len();

        let start = page * page_size;
        let end = (start + page_size).min(total);

        if start >= total {
            return Ok((Vec::new(), total));
        }

        let paged_results = results[start..end].iter().map(|b| (**b).clone()).collect();

        Ok((paged_results, total))
    }

    async fn get_statistics(&self) -> Result<BanStatistics, BanError> {
        let all_bans = self.index.get_all();
        let active_bans = self.index.get_active().await;
        let expired_bans = self.index.get_expired().await;

        let mut ip_bans = 0;
        let mut uuid_bans = 0;
        let mut username_bans = 0;
        let mut permanent_bans = 0;
        let mut temporary_bans = 0;

        for ban in &all_bans {
            if ban.ip.is_some() {
                ip_bans += 1;
            }

            if ban.uuid.is_some() {
                uuid_bans += 1;
            }

            if ban.username.is_some() {
                username_bans += 1;
            }

            if ban.expires_at.is_none() {
                permanent_bans += 1;
            } else {
                temporary_bans += 1;
            }
        }

        Ok(BanStatistics {
            total_bans: all_bans.len(),
            active_bans: active_bans.len(),
            expired_bans: expired_bans.len(),
            permanent_bans,
            temporary_bans,
            ip_bans,
            uuid_bans,
            username_bans,
        })
    }
}