Skip to main content

aiway_protocol/gateway/
firewall.rs

1use crate::common::constants::ENCRYPT_KEY;
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3use std::collections::HashSet;
4use std::fmt::{Debug, Display, Formatter};
5
6/// 防火墙配置
7///
8/// 防火墙中不配置插件,因为插件需要获取请求上下文,而上下文是在安全校验后才提取的,在防火墙执行阶段无法获取。
9/// 但是可以使用全局插件的方式在获取请求上下文后再校验。
10#[derive(Clone, Serialize, Deserialize)]
11pub struct Firewall {
12    /// IP策略模式,allow或deny
13    pub ip_policy_mode: AllowDenyPolicy,
14    /// IP策略值,支持单个IP和CIDR网段,例如:192.168.1.1, 192.168.1.0/24
15    /// 注意:IP匹配逻辑由调用方实现
16    pub ip_policy: HashSet<String>,
17    /// 受信IP
18    ///
19    /// 受信IP将直接放行,不受访问策略的影响,支持单个IP和CIDR网段
20    /// 注意:IP匹配逻辑由调用方实现
21    pub trust_ips: HashSet<String>,
22    /// Referer策略模式,allow或deny
23    pub referer_policy_mode: AllowDenyPolicy,
24    /// Referer策略值,例如:https://aaa.com
25    pub referer_policy: HashSet<String>,
26    /// 是否允许空Referer
27    pub allow_empty_referer: bool,
28    /// 单个网关节点的最大连接数限制
29    // /// 例如:127.0.0.1:8080/1000,
30    // /// 对所有节点限制:*/2000,
31    // /// 如果配置了具体的节点限制,则优先使用具体配置。
32    pub max_connections: Option<usize>,
33    /// API密钥的加密密钥,长度固定为32位,由控制台验证长度。
34    /// 可能为空字符串,为空时使用默认密钥
35    #[serde(
36        default = "default_api_secret_encrypt_key",
37        serialize_with = "serialize_encrypt_key",
38        deserialize_with = "deserialize_encrypt_key"
39    )]
40    pub api_secret_encrypt_key: [u8; 32],
41    /// TLS证书
42    pub tls_cert: Option<Vec<u8>>,
43    /// TLS密钥
44    pub tls_key: Option<Vec<u8>>,
45}
46
47impl Default for Firewall {
48    fn default() -> Self {
49        Firewall {
50            ip_policy_mode: AllowDenyPolicy::Disable,
51            ip_policy: Default::default(),
52            trust_ips: Default::default(),
53            referer_policy_mode: Default::default(),
54            referer_policy: Default::default(),
55            allow_empty_referer: false,
56            max_connections: Default::default(),
57            api_secret_encrypt_key: *ENCRYPT_KEY,
58            tls_cert: Default::default(),
59            tls_key: Default::default(),
60        }
61    }
62}
63
64fn serialize_encrypt_key<S>(key: &[u8; 32], serializer: S) -> Result<S::Ok, S::Error>
65where
66    S: Serializer,
67{
68    let key_str = std::str::from_utf8(key).unwrap_or("");
69    serializer.serialize_str(key_str)
70}
71
72fn deserialize_encrypt_key<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error>
73where
74    D: Deserializer<'de>,
75{
76    let s = String::deserialize(deserializer)?;
77    let mut key = [0u8; 32];
78
79    if s.is_empty() {
80        key = default_api_secret_encrypt_key();
81    } else {
82        let bytes = s.as_bytes();
83        let len = std::cmp::min(32, bytes.len());
84        key[..len].copy_from_slice(&bytes[..len]);
85    }
86
87    Ok(key)
88}
89
90fn default_api_secret_encrypt_key() -> [u8; 32] {
91    *ENCRYPT_KEY
92}
93
94#[derive(Debug, Clone, Default, Eq, Ord, PartialOrd, PartialEq, Serialize, Deserialize)]
95pub enum AllowDenyPolicy {
96    /// 不启用该功能
97    #[default]
98    Disable,
99    /// 允许
100    Allow,
101    /// 拒绝
102    Deny,
103}
104
105impl From<&str> for AllowDenyPolicy {
106    fn from(value: &str) -> Self {
107        match value {
108            "allow" => AllowDenyPolicy::Allow,
109            "deny" => AllowDenyPolicy::Deny,
110            _ => panic!("invalid allow deny policy"),
111        }
112    }
113}
114
115impl Debug for Firewall {
116    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
117        f.debug_struct("Firewall")
118            .field("ip_policy_mode", &self.ip_policy_mode)
119            .field("ip_policy", &self.ip_policy)
120            .field("trust_ips", &self.trust_ips)
121            .field("referer_policy_mode", &self.referer_policy_mode)
122            .field("referer_policy", &self.referer_policy)
123            .field("allow_empty_referer", &self.allow_empty_referer)
124            .field("max_connections", &self.max_connections)
125            .field(
126                "api_secret_encrypt_key",
127                &format!(
128                    "{}***",
129                    String::from_utf8(self.api_secret_encrypt_key[0..5].to_vec()).unwrap()
130                ),
131            )
132            .finish()
133    }
134}