use crate::common::constants::ENCRYPT_KEY;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::HashSet;
use std::fmt::{Debug, Display, Formatter};
#[derive(Clone, Serialize, Deserialize)]
pub struct Firewall {
pub ip_policy_mode: AllowDenyPolicy,
pub ip_policy: HashSet<String>,
pub trust_ips: HashSet<String>,
pub referer_policy_mode: AllowDenyPolicy,
pub referer_policy: HashSet<String>,
pub allow_empty_referer: bool,
pub max_connections: Option<usize>,
#[serde(
default = "default_api_secret_encrypt_key",
serialize_with = "serialize_encrypt_key",
deserialize_with = "deserialize_encrypt_key"
)]
pub api_secret_encrypt_key: [u8; 32],
}
impl Default for Firewall {
fn default() -> Self {
Firewall {
ip_policy_mode: AllowDenyPolicy::Disable,
ip_policy: Default::default(),
trust_ips: Default::default(),
referer_policy_mode: Default::default(),
referer_policy: Default::default(),
allow_empty_referer: false,
max_connections: Default::default(),
api_secret_encrypt_key: *ENCRYPT_KEY,
}
}
}
fn serialize_encrypt_key<S>(key: &[u8; 32], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let key_str = std::str::from_utf8(key).unwrap_or("");
serializer.serialize_str(key_str)
}
fn deserialize_encrypt_key<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let mut key = [0u8; 32];
if s.is_empty() {
key = default_api_secret_encrypt_key();
} else {
let bytes = s.as_bytes();
let len = std::cmp::min(32, bytes.len());
key[..len].copy_from_slice(&bytes[..len]);
}
Ok(key)
}
fn default_api_secret_encrypt_key() -> [u8; 32] {
*ENCRYPT_KEY
}
#[derive(Debug, Clone, Default, Eq, Ord, PartialOrd, PartialEq, Serialize, Deserialize)]
pub enum AllowDenyPolicy {
#[default]
Disable,
Allow,
Deny,
}
impl From<&str> for AllowDenyPolicy {
fn from(value: &str) -> Self {
match value {
"allow" => AllowDenyPolicy::Allow,
"deny" => AllowDenyPolicy::Deny,
_ => panic!("invalid allow deny policy"),
}
}
}
impl Debug for Firewall {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Firewall")
.field("ip_policy_mode", &self.ip_policy_mode)
.field("ip_policy", &self.ip_policy)
.field("trust_ips", &self.trust_ips)
.field("referer_policy_mode", &self.referer_policy_mode)
.field("referer_policy", &self.referer_policy)
.field("allow_empty_referer", &self.allow_empty_referer)
.field("max_connections", &self.max_connections)
.field(
"api_secret_encrypt_key",
&format!(
"{}***",
String::from_utf8(self.api_secret_encrypt_key[0..5].to_vec()).unwrap()
),
)
.finish()
}
}