Skip to main content

aster/map/
semantic_generator.rs

1//! AI 语义生成器
2//!
3//! 使用 LLM 为模块和符号生成业务语义描述
4
5use std::collections::HashMap;
6use std::fs;
7use std::path::{Path, PathBuf};
8
9use super::types::ModuleNode;
10use super::types_enhanced::{
11    ArchitectureLayer, EnhancedAnalysisPhase, EnhancedAnalysisProgress, ProjectSemantic,
12    SemanticInfo, SymbolEntry,
13};
14
15/// 默认模型
16const DEFAULT_MODEL: &str = "claude-sonnet-4-20250514";
17/// 最大代码长度
18const MAX_CODE_LENGTH: usize = 8000;
19/// 批量处理大小
20const BATCH_SIZE: usize = 5;
21/// 并发数
22const CONCURRENCY: usize = 3;
23
24/// 语义生成器选项
25#[derive(Debug, Clone)]
26pub struct SemanticGeneratorOptions {
27    /// 使用的模型
28    pub model: String,
29    /// 并发数
30    pub concurrency: usize,
31    /// 批量大小
32    pub batch_size: usize,
33    /// 进度回调
34    pub on_progress: Option<fn(&EnhancedAnalysisProgress)>,
35}
36
37impl Default for SemanticGeneratorOptions {
38    fn default() -> Self {
39        Self {
40            model: DEFAULT_MODEL.to_string(),
41            concurrency: CONCURRENCY,
42            batch_size: BATCH_SIZE,
43            on_progress: None,
44        }
45    }
46}
47
48/// 模块语义响应
49#[derive(Debug, Clone)]
50struct ModuleSemanticResponse {
51    description: String,
52    responsibility: String,
53    business_domain: Option<String>,
54    architecture_layer: ArchitectureLayer,
55    tags: Vec<String>,
56}
57
58/// 项目语义响应
59#[derive(Debug, Clone)]
60struct ProjectSemanticResponse {
61    description: String,
62    purpose: String,
63    domains: Vec<String>,
64    key_concepts: Vec<KeyConceptResponse>,
65}
66
67/// 关键概念响应
68#[derive(Debug, Clone)]
69struct KeyConceptResponse {
70    name: String,
71    description: String,
72}
73
74/// 语义生成器
75pub struct SemanticGenerator {
76    root_path: PathBuf,
77    model: String,
78    concurrency: usize,
79    batch_size: usize,
80    on_progress: Option<fn(&EnhancedAnalysisProgress)>,
81}
82
83impl SemanticGenerator {
84    /// 创建新的生成器
85    pub fn new(root_path: impl AsRef<Path>, options: SemanticGeneratorOptions) -> Self {
86        Self {
87            root_path: root_path.as_ref().to_path_buf(),
88            model: options.model,
89            concurrency: options.concurrency,
90            batch_size: options.batch_size,
91            on_progress: options.on_progress,
92        }
93    }
94
95    /// 为单个模块生成语义描述
96    pub fn generate_module_semantic(&self, module: &ModuleNode) -> SemanticInfo {
97        // 读取文件内容
98        let file_path = self.root_path.join(&module.id);
99        let content = match fs::read_to_string(&file_path) {
100            Ok(c) => c,
101            Err(_) => return self.generate_fallback_semantic(module),
102        };
103
104        // 截断过长的代码
105        let content = if content.len() > MAX_CODE_LENGTH {
106            // Find safe UTF-8 boundary for truncation
107            let truncate_at = content
108                .char_indices()
109                .take_while(|(i, _)| *i < MAX_CODE_LENGTH)
110                .last()
111                .map(|(i, c)| i + c.len_utf8())
112                .unwrap_or(0);
113            format!(
114                "{}\n// ... (code truncated)",
115                content.get(..truncate_at).unwrap_or(&content)
116            )
117        } else {
118            content
119        };
120
121        // 构建提示词
122        let _prompt = self.build_module_prompt(module, &content);
123
124        // TODO: 调用 LLM API
125        // 目前返回基于规则的语义
126        self.generate_fallback_semantic(module)
127    }
128
129    /// 批量生成模块语义
130    pub fn batch_generate_module_semantics(
131        &self,
132        modules: &[ModuleNode],
133    ) -> HashMap<String, SemanticInfo> {
134        let mut results = HashMap::new();
135        let total = modules.len();
136
137        for (i, module) in modules.iter().enumerate() {
138            let semantic = self.generate_module_semantic(module);
139            results.insert(module.id.clone(), semantic);
140
141            if let Some(callback) = self.on_progress {
142                callback(&EnhancedAnalysisProgress {
143                    phase: EnhancedAnalysisPhase::Semantics,
144                    current: i + 1,
145                    total,
146                    current_file: Some(module.id.clone()),
147                    message: Some(format!("生成语义: {}", module.id)),
148                });
149            }
150        }
151
152        results
153    }
154
155    /// 生成项目级语义描述
156    pub fn generate_project_semantic(&self, modules: &[ModuleNode]) -> ProjectSemantic {
157        // 收集项目信息
158        let _module_list: Vec<_> = modules
159            .iter()
160            .take(50)
161            .map(|m| {
162                (
163                    m.id.clone(),
164                    m.classes.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
165                    m.functions
166                        .iter()
167                        .take(10)
168                        .map(|f| f.name.clone())
169                        .collect::<Vec<_>>(),
170                )
171            })
172            .collect();
173
174        // TODO: 调用 LLM API
175        // 目前返回基于规则的语义
176        self.generate_fallback_project_semantic(modules)
177    }
178
179    /// 为符号生成语义描述
180    pub fn generate_symbol_semantic(
181        &self,
182        symbol: &SymbolEntry,
183        _context: Option<&str>,
184    ) -> SemanticInfo {
185        // TODO: 调用 LLM API
186        let kind_str = format!("{:?}", symbol.kind);
187        SemanticInfo {
188            description: format!("{} {}", kind_str, symbol.name),
189            responsibility: kind_str,
190            business_domain: None,
191            architecture_layer: ArchitectureLayer::Infrastructure,
192            tags: vec![],
193            confidence: 0.3,
194            generated_at: chrono::Utc::now().to_rfc3339(),
195        }
196    }
197
198    // ========================================================================
199    // Prompt 构建
200    // ========================================================================
201
202    /// 构建模块提示词
203    fn build_module_prompt(&self, module: &ModuleNode, content: &str) -> String {
204        let classes: Vec<_> = module.classes.iter().map(|c| c.name.as_str()).collect();
205        let functions: Vec<_> = module
206            .functions
207            .iter()
208            .take(10)
209            .map(|f| f.name.as_str())
210            .collect();
211        let imports: Vec<_> = module
212            .imports
213            .iter()
214            .take(5)
215            .map(|i| i.source.as_str())
216            .collect();
217
218        format!(
219            r#"分析以下代码模块,生成简洁的业务描述。
220
221文件路径: {}
222语言: {}
223代码行数: {}
224类: {}
225函数: {}
226导入: {}
227
228代码内容:
229```{}
230{}
231```
232
233请返回 JSON 格式(不要包含 markdown 代码块标记):
234{{
235  "description": "这个模块做什么(1-2句话,用中文)",
236  "responsibility": "核心职责(1句话)",
237  "businessDomain": "所属业务领域(如:用户管理、支付、搜索等)",
238  "architectureLayer": "presentation|business|data|infrastructure|crossCutting",
239  "tags": ["关键词1", "关键词2", "关键词3"]
240}}
241
242architectureLayer 说明:
243- presentation: UI 组件、页面、视图渲染
244- business: 核心业务逻辑、领域模型、服务
245- data: API 调用、数据库、存储
246- infrastructure: 工具函数、配置、类型定义
247- crossCutting: 认证、日志、中间件、插件"#,
248            module.id,
249            module.language,
250            module.lines,
251            classes.join(", "),
252            functions.join(", "),
253            imports.join(", "),
254            module.language,
255            content
256        )
257    }
258
259    /// 构建项目提示词
260    fn build_project_prompt(&self, module_list: &[(String, Vec<String>, Vec<String>)]) -> String {
261        let modules_summary: String = module_list
262            .iter()
263            .map(|(path, classes, functions)| {
264                format!(
265                    "- {}: 类[{}], 函数[{}]",
266                    path,
267                    classes.join(", "),
268                    functions.join(", ")
269                )
270            })
271            .collect::<Vec<_>>()
272            .join("\n");
273
274        format!(
275            r#"分析以下项目结构,生成项目级语义描述。
276
277项目模块列表(前50个):
278{}
279
280请返回 JSON 格式(不要包含 markdown 代码块标记):
281{{
282  "description": "这个项目做什么(2-3句话,用中文)",
283  "purpose": "项目的核心价值和目的(1-2句话)",
284  "domains": ["业务领域1", "业务领域2", "业务领域3"],
285  "keyConcepts": [
286    {{
287      "name": "核心概念1",
288      "description": "这个概念的含义和作用"
289    }},
290    {{
291      "name": "核心概念2",
292      "description": "这个概念的含义和作用"
293    }}
294  ]
295}}"#,
296            modules_summary
297        )
298    }
299
300    // ========================================================================
301    // 辅助方法
302    // ========================================================================
303
304    /// 验证架构层
305    fn validate_layer(&self, layer: &str) -> ArchitectureLayer {
306        match layer {
307            "presentation" => ArchitectureLayer::Presentation,
308            "business" => ArchitectureLayer::Business,
309            "data" => ArchitectureLayer::Data,
310            "infrastructure" => ArchitectureLayer::Infrastructure,
311            "crossCutting" => ArchitectureLayer::CrossCutting,
312            _ => ArchitectureLayer::Infrastructure,
313        }
314    }
315
316    /// 查找相关模块
317    fn find_related_modules(&self, concept_name: &str, modules: &[ModuleNode]) -> Vec<String> {
318        let lower_name = concept_name.to_lowercase();
319        let mut related = Vec::new();
320
321        for module in modules {
322            let module_path = module.id.to_lowercase();
323            let has_matching_class = module
324                .classes
325                .iter()
326                .any(|c| c.name.to_lowercase().contains(&lower_name));
327            let has_matching_function = module
328                .functions
329                .iter()
330                .any(|f| f.name.to_lowercase().contains(&lower_name));
331
332            if module_path.contains(&lower_name) || has_matching_class || has_matching_function {
333                related.push(module.id.clone());
334            }
335        }
336
337        related.into_iter().take(10).collect()
338    }
339
340    /// 生成回退语义
341    fn generate_fallback_semantic(&self, module: &ModuleNode) -> SemanticInfo {
342        let path_parts: Vec<&str> = module.id.split('/').collect();
343        let file_name = path_parts.last().unwrap_or(&"module");
344
345        let (layer, description) =
346            if module.id.contains("/ui/") || module.id.contains("/components/") {
347                (
348                    ArchitectureLayer::Presentation,
349                    format!("UI 组件模块 {}", file_name),
350                )
351            } else if module.id.contains("/core/") || module.id.contains("/services/") {
352                (
353                    ArchitectureLayer::Business,
354                    format!("业务逻辑模块 {}", file_name),
355                )
356            } else if module.id.contains("/api/") || module.id.contains("/data/") {
357                (
358                    ArchitectureLayer::Data,
359                    format!("数据处理模块 {}", file_name),
360                )
361            } else {
362                (
363                    ArchitectureLayer::Infrastructure,
364                    format!("{} 模块", file_name),
365                )
366            };
367
368        let tags: Vec<String> = path_parts
369            .iter()
370            .filter(|p| **p != "src" && !p.contains('.'))
371            .map(|s| s.to_string())
372            .collect();
373
374        SemanticInfo {
375            description,
376            responsibility: format!("{} 的功能实现", file_name),
377            business_domain: None,
378            architecture_layer: layer,
379            tags,
380            confidence: 0.4,
381            generated_at: chrono::Utc::now().to_rfc3339(),
382        }
383    }
384
385    /// 生成回退项目语义
386    fn generate_fallback_project_semantic(&self, modules: &[ModuleNode]) -> ProjectSemantic {
387        let paths: Vec<&str> = modules.iter().map(|m| m.id.as_str()).collect();
388        let has_ui = paths
389            .iter()
390            .any(|p| p.contains("/ui/") || p.contains("/components/"));
391        let has_tools = paths.iter().any(|p| p.contains("/tools/"));
392        let has_core = paths.iter().any(|p| p.contains("/core/"));
393
394        let mut domains = Vec::new();
395        if has_ui {
396            domains.push("用户界面".to_string());
397        }
398        if has_tools {
399            domains.push("工具系统".to_string());
400        }
401        if has_core {
402            domains.push("核心引擎".to_string());
403        }
404
405        if domains.is_empty() {
406            domains.push("软件开发".to_string());
407        }
408
409        ProjectSemantic {
410            description: "代码项目(语义描述待生成)".to_string(),
411            purpose: "项目目的待分析".to_string(),
412            domains,
413            key_concepts: vec![],
414        }
415    }
416}
417
418// ============================================================================
419// 便捷函数
420// ============================================================================
421
422/// 快速生成模块语义
423pub fn generate_module_semantic(
424    root_path: impl AsRef<Path>,
425    module: &ModuleNode,
426    options: Option<SemanticGeneratorOptions>,
427) -> SemanticInfo {
428    let generator = SemanticGenerator::new(root_path, options.unwrap_or_default());
429    generator.generate_module_semantic(module)
430}
431
432/// 批量生成模块语义
433pub fn batch_generate_semantics(
434    root_path: impl AsRef<Path>,
435    modules: &[ModuleNode],
436    options: Option<SemanticGeneratorOptions>,
437) -> HashMap<String, SemanticInfo> {
438    let generator = SemanticGenerator::new(root_path, options.unwrap_or_default());
439    generator.batch_generate_module_semantics(modules)
440}
441
442/// 生成项目语义
443pub fn generate_project_semantic(
444    root_path: impl AsRef<Path>,
445    modules: &[ModuleNode],
446    options: Option<SemanticGeneratorOptions>,
447) -> ProjectSemantic {
448    let generator = SemanticGenerator::new(root_path, options.unwrap_or_default());
449    generator.generate_project_semantic(modules)
450}