use crate::forwarder::ForwardRule;
use serde_json::{from_reader, to_writer_pretty};
use std::collections::HashMap;
use std::fs::{self, File, OpenOptions};
use std::io;
use std::path::Path;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::RwLock;
#[derive(Error, Debug)]
pub enum StorageError {
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("{0}")]
Other(String),
}
impl From<String> for StorageError {
fn from(error: String) -> Self {
StorageError::Other(error)
}
}
pub struct FileStorage {
file_path: String,
rules: Arc<RwLock<HashMap<String, ForwardRule>>>,
}
impl FileStorage {
pub fn new(file_path: &str) -> Result<Self, StorageError> {
if let Some(parent) = Path::new(file_path).parent() {
fs::create_dir_all(parent).map_err(StorageError::Io)?;
}
let rules = if Path::new(file_path).exists() {
let file = File::open(file_path).map_err(StorageError::Io)?;
from_reader(file).unwrap_or_else(|_| HashMap::new())
} else {
HashMap::new()
};
Ok(Self {
file_path: file_path.to_string(),
rules: Arc::new(RwLock::new(rules)),
})
}
pub async fn save_rule(&self, rule: ForwardRule) -> Result<(), StorageError> {
let mut rules = self.rules.write().await;
rules.insert(rule.id.clone(), rule);
self.save_to_file(&rules).await
}
pub async fn remove_rule(&self, id: &str) -> Result<(), StorageError> {
let mut rules = self.rules.write().await;
rules.remove(id);
self.save_to_file(&rules).await
}
async fn save_to_file(&self, rules: &HashMap<String, ForwardRule>) -> Result<(), StorageError> {
let temp_path = format!("{}.tmp", self.file_path);
let file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&temp_path)
.map_err(StorageError::Io)?;
to_writer_pretty(file, &rules).map_err(StorageError::Json)?;
fs::rename(&temp_path, &self.file_path).map_err(StorageError::Io)?;
Ok(())
}
pub async fn get_rules(&self) -> Vec<ForwardRule> {
let rules = self.rules.read().await;
rules.values().cloned().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
use crate::forwarder::Protocol;
#[tokio::test]
async fn test_save_and_get_rules() -> Result<(), StorageError> {
let dir = tempdir().map_err(StorageError::Io)?;
let file_path = dir.path().join("test_rules.json");
let storage = FileStorage::new(file_path.to_str().unwrap())?;
let rule = ForwardRule {
id: "test1".to_string(),
local_port: 8080,
remote_host: "example.com".to_string(),
remote_port: 80,
protocol: Protocol::TCP,
enabled: true,
};
storage.save_rule(rule.clone()).await?;
let rules = storage.get_rules().await;
assert_eq!(rules.len(), 1);
assert_eq!(rules[0].id, "test1");
assert_eq!(rules[0].local_port, 8080);
Ok(())
}
#[tokio::test]
async fn test_remove_rule() -> Result<(), StorageError> {
let dir = tempdir().map_err(StorageError::Io)?;
let file_path = dir.path().join("test_rules.json");
let storage = FileStorage::new(file_path.to_str().unwrap())?;
let rule = ForwardRule {
id: "test1".to_string(),
local_port: 8080,
remote_host: "example.com".to_string(),
remote_port: 80,
protocol: Protocol::TCP,
enabled: true,
};
storage.save_rule(rule).await?;
storage.remove_rule("test1").await?;
let rules = storage.get_rules().await;
assert_eq!(rules.len(), 0);
Ok(())
}
#[tokio::test]
async fn test_file_persistence() -> Result<(), StorageError> {
let dir = tempdir().map_err(StorageError::Io)?;
let file_path = dir.path().join("test_rules.json");
let file_path_str = file_path.to_str().unwrap();
{
let storage = FileStorage::new(file_path_str)?;
let rule = ForwardRule {
id: "test1".to_string(),
local_port: 8080,
remote_host: "example.com".to_string(),
remote_port: 80,
protocol: Protocol::TCP,
enabled: true,
};
storage.save_rule(rule).await?;
}
{
let storage = FileStorage::new(file_path_str)?;
let rules = storage.get_rules().await;
assert_eq!(rules.len(), 1);
assert_eq!(rules[0].id, "test1");
}
Ok(())
}
}