Skip to main content

sh_layer3/
guard_rails.rs

1//! # Guard Rails
2//!
3//! 防护栏:输入输出安全检查。
4
5use crate::types::Layer3Result;
6use async_trait::async_trait;
7
8/// 防护栏 trait
9///
10/// 定义输入输出检查接口。
11#[async_trait]
12pub trait GuardRail: Send + Sync {
13    /// 防护栏名称
14    fn name(&self) -> &str;
15
16    /// 检查输入
17    async fn check_input(&self, input: &str) -> Layer3Result<GuardResult>;
18
19    /// 检查输出
20    async fn check_output(&self, output: &str) -> Layer3Result<GuardResult>;
21
22    /// 修正输入(如果可能)
23    async fn fix_input(&self, input: &str) -> Layer3Result<String>;
24
25    /// 修正输出(如果可能)
26    async fn fix_output(&self, output: &str) -> Layer3Result<String>;
27}
28
29/// 防护检查结果
30#[derive(Debug, Clone)]
31pub struct GuardResult {
32    /// 是否通过
33    pub passed: bool,
34    /// 问题类型
35    pub issue: Option<GuardIssue>,
36    /// 建议修正
37    pub suggestion: Option<String>,
38}
39
40/// 防护问题类型
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum GuardIssue {
43    /// 包含敏感信息
44    SensitiveData,
45    /// 格式错误
46    FormatError,
47    /// 内容过长
48    TooLong,
49    /// 内容过短
50    TooShort,
51    /// 包含危险指令
52    DangerousInstruction,
53    /// 偏离主题
54    OffTopic,
55    /// 自定义问题
56    Custom(String),
57}
58
59/// 防护栏组合器
60pub struct GuardRailsComposite {
61    rails: Vec<Box<dyn GuardRail>>,
62}
63
64impl GuardRailsComposite {
65    pub fn new() -> Self {
66        Self { rails: Vec::new() }
67    }
68
69    pub fn add(&mut self, rail: Box<dyn GuardRail>) {
70        self.rails.push(rail);
71    }
72
73    pub async fn check_input_all(&self, input: &str) -> Layer3Result<Vec<GuardResult>> {
74        let mut results = Vec::new();
75        for rail in &self.rails {
76            results.push(rail.check_input(input).await?);
77        }
78        Ok(results)
79    }
80
81    pub async fn check_output_all(&self, output: &str) -> Layer3Result<Vec<GuardResult>> {
82        let mut results = Vec::new();
83        for rail in &self.rails {
84            results.push(rail.check_output(output).await?);
85        }
86        Ok(results)
87    }
88}
89
90impl Default for GuardRailsComposite {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96/// 长度防护栏
97pub struct LengthGuard {
98    min_length: usize,
99    max_length: usize,
100}
101
102impl LengthGuard {
103    pub fn new(min_length: usize, max_length: usize) -> Self {
104        Self {
105            min_length,
106            max_length,
107        }
108    }
109}
110
111impl Default for LengthGuard {
112    fn default() -> Self {
113        Self::new(1, 10000)
114    }
115}
116
117#[async_trait]
118impl GuardRail for LengthGuard {
119    fn name(&self) -> &str {
120        "length"
121    }
122
123    async fn check_input(&self, input: &str) -> Layer3Result<GuardResult> {
124        let len = input.len();
125        if len < self.min_length {
126            return Ok(GuardResult {
127                passed: false,
128                issue: Some(GuardIssue::TooShort),
129                suggestion: Some(format!("Minimum length: {}", self.min_length)),
130            });
131        }
132        if len > self.max_length {
133            return Ok(GuardResult {
134                passed: false,
135                issue: Some(GuardIssue::TooLong),
136                suggestion: Some(format!("Maximum length: {}", self.max_length)),
137            });
138        }
139        Ok(GuardResult {
140            passed: true,
141            issue: None,
142            suggestion: None,
143        })
144    }
145
146    async fn check_output(&self, output: &str) -> Layer3Result<GuardResult> {
147        self.check_input(output).await
148    }
149
150    async fn fix_input(&self, input: &str) -> Layer3Result<String> {
151        Ok(input.to_string())
152    }
153
154    async fn fix_output(&self, output: &str) -> Layer3Result<String> {
155        if output.len() > self.max_length {
156            Ok(output[..self.max_length].to_string())
157        } else {
158            Ok(output.to_string())
159        }
160    }
161}
162
163/// 正则防护栏
164pub struct RegexGuard {
165    pattern: regex::Regex,
166    block_matches: bool,
167    name: String,
168}
169
170impl RegexGuard {
171    pub fn new(pattern: regex::Regex, block_matches: bool, name: impl Into<String>) -> Self {
172        Self {
173            pattern,
174            block_matches,
175            name: name.into(),
176        }
177    }
178}
179
180#[async_trait]
181impl GuardRail for RegexGuard {
182    fn name(&self) -> &str {
183        &self.name
184    }
185
186    async fn check_input(&self, input: &str) -> Layer3Result<GuardResult> {
187        let matches = self.pattern.is_match(input);
188        let passed = if self.block_matches {
189            !matches
190        } else {
191            matches
192        };
193        Ok(GuardResult {
194            passed,
195            issue: if passed {
196                None
197            } else {
198                Some(GuardIssue::FormatError)
199            },
200            suggestion: None,
201        })
202    }
203
204    async fn check_output(&self, output: &str) -> Layer3Result<GuardResult> {
205        self.check_input(output).await
206    }
207
208    async fn fix_input(&self, input: &str) -> Layer3Result<String> {
209        Ok(self.pattern.replace_all(input, "").to_string())
210    }
211
212    async fn fix_output(&self, output: &str) -> Layer3Result<String> {
213        self.fix_input(output).await
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[tokio::test]
222    async fn test_length_guard() {
223        let guard = LengthGuard::new(5, 100);
224        let result = guard.check_input("hello").await.unwrap();
225        assert!(result.passed);
226    }
227
228    #[tokio::test]
229    async fn test_length_guard_too_short() {
230        let guard = LengthGuard::new(10, 100);
231        let result = guard.check_input("hi").await.unwrap();
232        assert!(!result.passed);
233    }
234}