Skip to main content

aiway_protocol/gateway/
firewall.rs

1use crate::common::constants::ENCRYPT_KEY;
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3use std::collections::HashSet;
4use std::fmt::{Debug, Display, Formatter};
5
6/// 防火墙配置
7///
8/// 防火墙中不配置插件,因为插件需要获取请求上下文,而上下文是在安全校验后才提取的,在防火墙执行阶段无法获取。
9/// 但是可以使用全局插件的方式在获取请求上下文后再校验。
10#[derive(Clone, Serialize, Deserialize)]
11pub struct Firewall {
12    /// IP策略模式,allow或deny
13    pub ip_policy_mode: AllowDenyPolicy,
14    /// IP策略值,例如:192.168.1.1
15    /// TODO 暂不支持网段,后面再支持
16    pub ip_policy: HashSet<String>,
17    /// 受信IP
18    ///
19    /// 受信IP将直接放行,不受访问策略的影响
20    pub trust_ips: HashSet<String>,
21    /// Referer策略模式,allow或deny
22    pub referer_policy_mode: AllowDenyPolicy,
23    /// Referer策略值,例如:https://aaa.com
24    pub referer_policy: HashSet<String>,
25    /// 是否允许空Referer
26    pub allow_empty_referer: bool,
27    /// 单个网关节点的最大连接数限制
28    // /// 例如:127.0.0.1:8080/1000,
29    // /// 对所有节点限制:*/2000,
30    // /// 如果配置了具体的节点限制,则优先使用具体配置。
31    pub max_connections: Option<usize>,
32    /// API密钥的加密密钥,长度固定为32位,由控制台验证长度。
33    /// 可能为空字符串,为空时使用默认密钥
34    #[serde(
35        default = "default_api_secret_encrypt_key",
36        serialize_with = "serialize_encrypt_key",
37        deserialize_with = "deserialize_encrypt_key"
38    )]
39    pub api_secret_encrypt_key: [u8; 32],
40    /// TLS证书
41    pub tls_cert: Option<Vec<u8>>,
42    /// TLS密钥
43    pub tls_key: Option<Vec<u8>>,
44}
45
46impl Default for Firewall {
47    fn default() -> Self {
48        Firewall {
49            ip_policy_mode: AllowDenyPolicy::Disable,
50            ip_policy: Default::default(),
51            trust_ips: Default::default(),
52            referer_policy_mode: Default::default(),
53            referer_policy: Default::default(),
54            allow_empty_referer: false,
55            max_connections: Default::default(),
56            api_secret_encrypt_key: *ENCRYPT_KEY,
57            tls_cert: Default::default(),
58            tls_key: Default::default(),
59        }
60    }
61}
62
63fn serialize_encrypt_key<S>(key: &[u8; 32], serializer: S) -> Result<S::Ok, S::Error>
64where
65    S: Serializer,
66{
67    let key_str = std::str::from_utf8(key).unwrap_or("");
68    serializer.serialize_str(key_str)
69}
70
71fn deserialize_encrypt_key<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error>
72where
73    D: Deserializer<'de>,
74{
75    let s = String::deserialize(deserializer)?;
76    let mut key = [0u8; 32];
77
78    if s.is_empty() {
79        key = default_api_secret_encrypt_key();
80    } else {
81        let bytes = s.as_bytes();
82        let len = std::cmp::min(32, bytes.len());
83        key[..len].copy_from_slice(&bytes[..len]);
84    }
85
86    Ok(key)
87}
88
89fn default_api_secret_encrypt_key() -> [u8; 32] {
90    *ENCRYPT_KEY
91}
92
93#[derive(Debug, Clone, Default, Eq, Ord, PartialOrd, PartialEq, Serialize, Deserialize)]
94pub enum AllowDenyPolicy {
95    /// 不启用该功能
96    #[default]
97    Disable,
98    /// 允许
99    Allow,
100    /// 拒绝
101    Deny,
102}
103
104impl From<&str> for AllowDenyPolicy {
105    fn from(value: &str) -> Self {
106        match value {
107            "allow" => AllowDenyPolicy::Allow,
108            "deny" => AllowDenyPolicy::Deny,
109            _ => panic!("invalid allow deny policy"),
110        }
111    }
112}
113
114impl Debug for Firewall {
115    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
116        f.debug_struct("Firewall")
117            .field("ip_policy_mode", &self.ip_policy_mode)
118            .field("ip_policy", &self.ip_policy)
119            .field("trust_ips", &self.trust_ips)
120            .field("referer_policy_mode", &self.referer_policy_mode)
121            .field("referer_policy", &self.referer_policy)
122            .field("allow_empty_referer", &self.allow_empty_referer)
123            .field("max_connections", &self.max_connections)
124            .field(
125                "api_secret_encrypt_key",
126                &format!(
127                    "{}***",
128                    String::from_utf8(self.api_secret_encrypt_key[0..5].to_vec()).unwrap()
129                ),
130            )
131            .finish()
132    }
133}