anydoor 0.0.1

A tool for forwarding traffic to a remote server
use crate::forwarder::ForwardRule;
use serde_json::{from_reader, to_writer_pretty};
use std::collections::HashMap;
use std::fs::{self, File, OpenOptions};
use std::io;
use std::path::Path;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::RwLock;

#[derive(Error, Debug)]
pub enum StorageError {
    #[error("IO error: {0}")]
    Io(#[from] io::Error),
    #[error("JSON error: {0}")]
    Json(#[from] serde_json::Error),
    #[error("{0}")]
    Other(String),
}

impl From<String> for StorageError {
    fn from(error: String) -> Self {
        StorageError::Other(error)
    }
}

pub struct FileStorage {
    file_path: String,
    rules: Arc<RwLock<HashMap<String, ForwardRule>>>,
}

impl FileStorage {
    pub fn new(file_path: &str) -> Result<Self, StorageError> {
        // 确保目录存在
        if let Some(parent) = Path::new(file_path).parent() {
            fs::create_dir_all(parent).map_err(StorageError::Io)?;
        }

        let rules = if Path::new(file_path).exists() {
            // 如果文件存在,读取规则
            let file = File::open(file_path).map_err(StorageError::Io)?;
            from_reader(file).unwrap_or_else(|_| HashMap::new())
        } else {
            // 如果文件不存在,创建空的规则集
            HashMap::new()
        };

        Ok(Self {
            file_path: file_path.to_string(),
            rules: Arc::new(RwLock::new(rules)),
        })
    }

    pub async fn save_rule(&self, rule: ForwardRule) -> Result<(), StorageError> {
        let mut rules = self.rules.write().await;
        rules.insert(rule.id.clone(), rule);

        self.save_to_file(&rules).await
    }

    pub async fn remove_rule(&self, id: &str) -> Result<(), StorageError> {
        let mut rules = self.rules.write().await;
        rules.remove(id);

        self.save_to_file(&rules).await
    }

    async fn save_to_file(&self, rules: &HashMap<String, ForwardRule>) -> Result<(), StorageError> {
        // 将整个规则集写入临时文件
        let temp_path = format!("{}.tmp", self.file_path);
        let file = OpenOptions::new()
            .write(true)
            .create(true)
            .truncate(true)
            .open(&temp_path)
            .map_err(StorageError::Io)?;

        to_writer_pretty(file, &rules).map_err(StorageError::Json)?;

        // 原子地替换文件
        fs::rename(&temp_path, &self.file_path).map_err(StorageError::Io)?;

        Ok(())
    }

    pub async fn get_rules(&self) -> Vec<ForwardRule> {
        let rules = self.rules.read().await;
        rules.values().cloned().collect()
    }

    // pub async fn get_rule(&self, id: &str) -> Option<ForwardRule> {
    //     let rules = self.rules.read().await;
    //     rules.get(id).cloned()
    // }
}

#[cfg(test)]
mod tests {
    use super::*;
    use tempfile::tempdir;
    use crate::forwarder::Protocol;

    #[tokio::test]
    async fn test_save_and_get_rules() -> Result<(), StorageError> {
        let dir = tempdir().map_err(StorageError::Io)?;
        let file_path = dir.path().join("test_rules.json");
        let storage = FileStorage::new(file_path.to_str().unwrap())?;

        let rule = ForwardRule {
            id: "test1".to_string(),
            local_port: 8080,
            remote_host: "example.com".to_string(),
            remote_port: 80,
            protocol: Protocol::TCP,
            enabled: true,
        };

        // 保存规则
        storage.save_rule(rule.clone()).await?;

        // 获取规则
        let rules = storage.get_rules().await;
        assert_eq!(rules.len(), 1);
        assert_eq!(rules[0].id, "test1");
        assert_eq!(rules[0].local_port, 8080);

        Ok(())
    }

    #[tokio::test]
    async fn test_remove_rule() -> Result<(), StorageError> {
        let dir = tempdir().map_err(StorageError::Io)?;
        let file_path = dir.path().join("test_rules.json");
        let storage = FileStorage::new(file_path.to_str().unwrap())?;

        let rule = ForwardRule {
            id: "test1".to_string(),
            local_port: 8080,
            remote_host: "example.com".to_string(),
            remote_port: 80,
            protocol: Protocol::TCP,
            enabled: true,
        };

        // 保存规则
        storage.save_rule(rule).await?;
        
        // 删除规则
        storage.remove_rule("test1").await?;

        // 验证规则已被删除
        let rules = storage.get_rules().await;
        assert_eq!(rules.len(), 0);

        Ok(())
    }

    #[tokio::test]
    async fn test_file_persistence() -> Result<(), StorageError> {
        let dir = tempdir().map_err(StorageError::Io)?;
        let file_path = dir.path().join("test_rules.json");
        let file_path_str = file_path.to_str().unwrap();

        // 创建规则并保存
        {
            let storage = FileStorage::new(file_path_str)?;
            let rule = ForwardRule {
                id: "test1".to_string(),
                local_port: 8080,
                remote_host: "example.com".to_string(),
                remote_port: 80,
                protocol: Protocol::TCP,
                enabled: true,
            };
            storage.save_rule(rule).await?;
        }

        // 创建新的存储实例并验证规则是否存在
        {
            let storage = FileStorage::new(file_path_str)?;
            let rules = storage.get_rules().await;
            assert_eq!(rules.len(), 1);
            assert_eq!(rules[0].id, "test1");
        }

        Ok(())
    }
}