Skip to main content

aster/blueprint/
boundary_checker.rs

1//! 边界检查器
2//!
3//!
4//! 提供:
5//! 1. 模块边界验证
6//! 2. 受保护文件检测
7//! 3. 技术栈扩展检查
8//! 4. 跨模块修改检测
9
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12use std::path::Path;
13
14use super::types::*;
15
16// ============================================================================
17// 边界检查结果
18// ============================================================================
19
20/// 边界检查结果
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BoundaryCheckResult {
23    /// 是否允许
24    pub allowed: bool,
25    /// 原因
26    pub reason: Option<String>,
27    /// 违规类型
28    pub violation_type: Option<ViolationType>,
29    /// 建议
30    pub suggestion: Option<String>,
31}
32
33impl BoundaryCheckResult {
34    /// 创建允许的结果
35    pub fn allow() -> Self {
36        Self {
37            allowed: true,
38            reason: None,
39            violation_type: None,
40            suggestion: None,
41        }
42    }
43
44    /// 创建拒绝的结果
45    pub fn deny(reason: String, violation_type: ViolationType) -> Self {
46        Self {
47            allowed: false,
48            reason: Some(reason),
49            violation_type: Some(violation_type),
50            suggestion: None,
51        }
52    }
53
54    /// 添加建议
55    pub fn with_suggestion(mut self, suggestion: String) -> Self {
56        self.suggestion = Some(suggestion);
57        self
58    }
59}
60
61/// 违规类型
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63#[serde(rename_all = "snake_case")]
64pub enum ViolationType {
65    /// 跨模块修改
66    CrossModule,
67    /// 修改受保护文件
68    ProtectedFile,
69    /// 技术栈不匹配
70    TechStackMismatch,
71    /// 修改配置文件
72    ConfigFile,
73    /// 超出根路径
74    OutOfScope,
75}
76
77/// 受保护文件模式
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ProtectedPattern {
80    pub pattern: String,
81    pub reason: String,
82}
83
84// ============================================================================
85// 边界检查器配置
86// ============================================================================
87
88/// 边界检查器配置
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct BoundaryCheckerConfig {
91    /// 受保护的文件模式
92    pub protected_patterns: Vec<ProtectedPattern>,
93    /// 受保护的配置文件
94    pub protected_configs: Vec<String>,
95    /// 是否严格模式
96    pub strict_mode: bool,
97}
98
99impl Default for BoundaryCheckerConfig {
100    fn default() -> Self {
101        Self {
102            protected_patterns: vec![
103                ProtectedPattern {
104                    pattern: "package.json".to_string(),
105                    reason: "包配置文件".to_string(),
106                },
107                ProtectedPattern {
108                    pattern: "Cargo.toml".to_string(),
109                    reason: "Rust 项目配置".to_string(),
110                },
111                ProtectedPattern {
112                    pattern: "tsconfig.json".to_string(),
113                    reason: "TypeScript 配置".to_string(),
114                },
115                ProtectedPattern {
116                    pattern: ".env".to_string(),
117                    reason: "环境变量文件".to_string(),
118                },
119                ProtectedPattern {
120                    pattern: ".gitignore".to_string(),
121                    reason: "Git 忽略配置".to_string(),
122                },
123            ],
124            protected_configs: vec![
125                "package.json".to_string(),
126                "package-lock.json".to_string(),
127                "Cargo.toml".to_string(),
128                "Cargo.lock".to_string(),
129                "tsconfig.json".to_string(),
130                "vite.config.ts".to_string(),
131                "webpack.config.js".to_string(),
132            ],
133            strict_mode: true,
134        }
135    }
136}
137
138// ============================================================================
139// 边界检查器
140// ============================================================================
141
142/// 边界检查器
143pub struct BoundaryChecker {
144    config: BoundaryCheckerConfig,
145    blueprint: Blueprint,
146    /// 模块根路径映射
147    module_paths: std::collections::HashMap<String, String>,
148}
149
150impl BoundaryChecker {
151    /// 创建新的边界检查器
152    pub fn new(blueprint: Blueprint, config: Option<BoundaryCheckerConfig>) -> Self {
153        let config = config.unwrap_or_default();
154
155        // 构建模块路径映射
156        let mut module_paths = std::collections::HashMap::new();
157        for module in &blueprint.modules {
158            let root_path = module
159                .root_path
160                .clone()
161                .unwrap_or_else(|| format!("src/{}", module.name.to_lowercase()));
162            module_paths.insert(module.id.clone(), root_path);
163        }
164
165        Self {
166            config,
167            blueprint,
168            module_paths,
169        }
170    }
171
172    /// 检查任务边界
173    pub fn check_task_boundary(
174        &self,
175        task_module_id: Option<&str>,
176        file_path: &str,
177    ) -> BoundaryCheckResult {
178        // 1. 检查是否是受保护文件
179        if let Some(result) = self.check_protected_file(file_path) {
180            return result;
181        }
182
183        // 2. 检查是否是配置文件
184        if let Some(result) = self.check_config_file(file_path) {
185            return result;
186        }
187
188        // 3. 如果没有指定模块,允许
189        let module_id = match task_module_id {
190            Some(id) => id,
191            None => return BoundaryCheckResult::allow(),
192        };
193
194        // 4. 检查是否在模块范围内
195        self.check_module_scope(module_id, file_path)
196    }
197
198    /// 检查受保护文件
199    fn check_protected_file(&self, file_path: &str) -> Option<BoundaryCheckResult> {
200        let file_name = Path::new(file_path)
201            .file_name()
202            .and_then(|n| n.to_str())
203            .unwrap_or(file_path);
204
205        for pattern in &self.config.protected_patterns {
206            if file_name == pattern.pattern || file_path.ends_with(&pattern.pattern) {
207                return Some(
208                    BoundaryCheckResult::deny(
209                        format!("不能修改受保护文件: {} ({})", file_path, pattern.reason),
210                        ViolationType::ProtectedFile,
211                    )
212                    .with_suggestion("请联系蜂王(主 Agent)处理此文件".to_string()),
213                );
214            }
215        }
216
217        None
218    }
219
220    /// 检查配置文件
221    fn check_config_file(&self, file_path: &str) -> Option<BoundaryCheckResult> {
222        let file_name = Path::new(file_path)
223            .file_name()
224            .and_then(|n| n.to_str())
225            .unwrap_or(file_path);
226
227        if self
228            .config
229            .protected_configs
230            .contains(&file_name.to_string())
231        {
232            return Some(
233                BoundaryCheckResult::deny(
234                    format!("不能修改配置文件: {}", file_path),
235                    ViolationType::ConfigFile,
236                )
237                .with_suggestion("配置文件修改需要蜂王审批".to_string()),
238            );
239        }
240
241        None
242    }
243
244    /// 检查模块范围
245    fn check_module_scope(&self, module_id: &str, file_path: &str) -> BoundaryCheckResult {
246        let module_root = match self.module_paths.get(module_id) {
247            Some(root) => root,
248            None => return BoundaryCheckResult::allow(),
249        };
250
251        // 规范化路径
252        let normalized_path = file_path.replace('\\', "/");
253        let normalized_root = module_root.replace('\\', "/");
254
255        // 检查文件是否在模块根路径下
256        if normalized_path.starts_with(&normalized_root) {
257            return BoundaryCheckResult::allow();
258        }
259
260        // 检查是否在其他模块的范围内
261        for (other_id, other_root) in &self.module_paths {
262            if other_id != module_id {
263                let other_normalized = other_root.replace('\\', "/");
264                if normalized_path.starts_with(&other_normalized) {
265                    return BoundaryCheckResult::deny(
266                        format!(
267                            "跨模块修改: 文件 {} 属于模块 {},但当前任务属于模块 {}",
268                            file_path, other_id, module_id
269                        ),
270                        ViolationType::CrossModule,
271                    )
272                    .with_suggestion(format!(
273                        "请在模块 {} 的范围内工作,或请求蜂王重新分配任务",
274                        module_id
275                    ));
276                }
277            }
278        }
279
280        // 文件不在任何已知模块范围内
281        if self.config.strict_mode {
282            BoundaryCheckResult::deny(
283                format!("文件 {} 不在模块 {} 的范围内", file_path, module_id),
284                ViolationType::OutOfScope,
285            )
286            .with_suggestion(format!("请确保文件在 {} 目录下", module_root))
287        } else {
288            BoundaryCheckResult::allow()
289        }
290    }
291
292    /// 检查技术栈匹配
293    pub fn check_tech_stack(&self, module_id: &str, file_path: &str) -> BoundaryCheckResult {
294        let module = match self.blueprint.modules.iter().find(|m| m.id == module_id) {
295            Some(m) => m,
296            None => return BoundaryCheckResult::allow(),
297        };
298
299        let tech_stack = match &module.tech_stack {
300            Some(ts) => ts,
301            None => return BoundaryCheckResult::allow(),
302        };
303
304        // 获取文件扩展名
305        let extension = Path::new(file_path)
306            .extension()
307            .and_then(|e| e.to_str())
308            .unwrap_or("");
309
310        // 检查扩展名是否与技术栈匹配
311        let allowed_extensions = self.get_extensions_from_tech_stack(tech_stack);
312
313        if allowed_extensions.is_empty() {
314            return BoundaryCheckResult::allow();
315        }
316
317        if allowed_extensions.contains(&extension.to_string()) {
318            BoundaryCheckResult::allow()
319        } else {
320            BoundaryCheckResult::deny(
321                format!(
322                    "文件扩展名 .{} 与模块 {} 的技术栈不匹配",
323                    extension, module.name
324                ),
325                ViolationType::TechStackMismatch,
326            )
327            .with_suggestion(format!("允许的扩展名: {}", allowed_extensions.join(", ")))
328        }
329    }
330
331    /// 根据技术栈获取允许的文件扩展名
332    fn get_extensions_from_tech_stack(&self, tech_stack: &[String]) -> Vec<String> {
333        let mut extensions = HashSet::new();
334
335        for tech in tech_stack {
336            let tech_lower = tech.to_lowercase();
337            match tech_lower.as_str() {
338                "typescript" => {
339                    extensions.insert("ts".to_string());
340                    extensions.insert("tsx".to_string());
341                }
342                "javascript" => {
343                    extensions.insert("js".to_string());
344                    extensions.insert("jsx".to_string());
345                }
346                "react" => {
347                    extensions.insert("tsx".to_string());
348                    extensions.insert("jsx".to_string());
349                }
350                "vue" => {
351                    extensions.insert("vue".to_string());
352                }
353                "python" => {
354                    extensions.insert("py".to_string());
355                }
356                "go" | "golang" => {
357                    extensions.insert("go".to_string());
358                }
359                "rust" => {
360                    extensions.insert("rs".to_string());
361                }
362                "java" => {
363                    extensions.insert("java".to_string());
364                }
365                "kotlin" => {
366                    extensions.insert("kt".to_string());
367                }
368                "swift" => {
369                    extensions.insert("swift".to_string());
370                }
371                _ => {}
372            }
373        }
374
375        extensions.into_iter().collect()
376    }
377
378    /// 获取模块信息
379    pub fn get_module(&self, module_id: &str) -> Option<&SystemModule> {
380        self.blueprint.modules.iter().find(|m| m.id == module_id)
381    }
382
383    /// 获取模块根路径
384    pub fn get_module_root(&self, module_id: &str) -> Option<&String> {
385        self.module_paths.get(module_id)
386    }
387
388    /// 获取所有模块 ID
389    pub fn get_module_ids(&self) -> Vec<&String> {
390        self.module_paths.keys().collect()
391    }
392
393    /// 批量检查文件
394    pub fn check_files(
395        &self,
396        task_module_id: Option<&str>,
397        file_paths: &[String],
398    ) -> Vec<(String, BoundaryCheckResult)> {
399        file_paths
400            .iter()
401            .map(|path| {
402                let result = self.check_task_boundary(task_module_id, path);
403                (path.clone(), result)
404            })
405            .collect()
406    }
407
408    /// 获取违规文件
409    pub fn get_violations(
410        &self,
411        task_module_id: Option<&str>,
412        file_paths: &[String],
413    ) -> Vec<(String, BoundaryCheckResult)> {
414        self.check_files(task_module_id, file_paths)
415            .into_iter()
416            .filter(|(_, result)| !result.allowed)
417            .collect()
418    }
419}
420
421/// 创建边界检查器
422pub fn create_boundary_checker(
423    blueprint: Blueprint,
424    config: Option<BoundaryCheckerConfig>,
425) -> BoundaryChecker {
426    BoundaryChecker::new(blueprint, config)
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    fn create_test_blueprint() -> Blueprint {
434        let mut blueprint = Blueprint::new("测试项目".to_string(), "测试描述".to_string());
435
436        blueprint.modules.push(SystemModule {
437            id: "frontend".to_string(),
438            name: "前端模块".to_string(),
439            description: "前端 UI".to_string(),
440            module_type: ModuleType::Frontend,
441            responsibilities: vec!["用户界面".to_string()],
442            dependencies: vec![],
443            interfaces: vec![],
444            tech_stack: Some(vec!["TypeScript".to_string(), "React".to_string()]),
445            root_path: Some("src/frontend".to_string()),
446        });
447
448        blueprint.modules.push(SystemModule {
449            id: "backend".to_string(),
450            name: "后端模块".to_string(),
451            description: "后端服务".to_string(),
452            module_type: ModuleType::Backend,
453            responsibilities: vec!["API 服务".to_string()],
454            dependencies: vec![],
455            interfaces: vec![],
456            tech_stack: Some(vec!["Rust".to_string()]),
457            root_path: Some("src/backend".to_string()),
458        });
459
460        blueprint
461    }
462
463    #[test]
464    fn test_boundary_checker_creation() {
465        let blueprint = create_test_blueprint();
466        let checker = BoundaryChecker::new(blueprint, None);
467
468        assert_eq!(checker.get_module_ids().len(), 2);
469    }
470
471    #[test]
472    fn test_protected_file_check() {
473        let blueprint = create_test_blueprint();
474        let checker = BoundaryChecker::new(blueprint, None);
475
476        let result = checker.check_task_boundary(Some("frontend"), "package.json");
477        assert!(!result.allowed);
478        assert_eq!(result.violation_type, Some(ViolationType::ProtectedFile));
479    }
480
481    #[test]
482    fn test_module_scope_check() {
483        let blueprint = create_test_blueprint();
484        let checker = BoundaryChecker::new(blueprint, None);
485
486        // 在模块范围内
487        let result =
488            checker.check_task_boundary(Some("frontend"), "src/frontend/components/Button.tsx");
489        assert!(result.allowed);
490
491        // 跨模块
492        let result = checker.check_task_boundary(Some("frontend"), "src/backend/api/handler.rs");
493        assert!(!result.allowed);
494        assert_eq!(result.violation_type, Some(ViolationType::CrossModule));
495    }
496
497    #[test]
498    fn test_tech_stack_check() {
499        let blueprint = create_test_blueprint();
500        let checker = BoundaryChecker::new(blueprint, None);
501
502        // 匹配的技术栈
503        let result = checker.check_tech_stack("frontend", "src/frontend/App.tsx");
504        assert!(result.allowed);
505
506        // 不匹配的技术栈
507        let result = checker.check_tech_stack("frontend", "src/frontend/main.rs");
508        assert!(!result.allowed);
509        assert_eq!(
510            result.violation_type,
511            Some(ViolationType::TechStackMismatch)
512        );
513    }
514}