Skip to main content

aster/prompt/
builder.rs

1//! 系统提示词构建器
2//!
3//! 组装完整的模块化系统提示词
4
5use std::time::Instant;
6
7use super::attachments::AttachmentManager;
8use super::cache::{estimate_tokens, generate_cache_key, PromptCache};
9use super::templates::{
10    get_environment_info, get_permission_mode_description, EnvironmentInfo, CODING_GUIDELINES,
11    CORE_IDENTITY, GIT_GUIDELINES, OUTPUT_STYLE, SUBAGENT_SYSTEM, TASK_MANAGEMENT, TOOL_GUIDELINES,
12};
13use super::types::{
14    Attachment, BuildResult, PermissionMode, PromptContext, PromptTooLongError, SystemPromptOptions,
15};
16
17/// 系统提示词构建器
18pub struct SystemPromptBuilder {
19    attachment_manager: AttachmentManager,
20    cache: PromptCache,
21    debug: bool,
22}
23
24impl SystemPromptBuilder {
25    /// 创建新的构建器
26    pub fn new(debug: bool) -> Self {
27        Self {
28            attachment_manager: AttachmentManager::default(),
29            cache: PromptCache::default(),
30            debug,
31        }
32    }
33
34    /// 使用自定义组件创建构建器
35    pub fn with_components(
36        attachment_manager: AttachmentManager,
37        cache: PromptCache,
38        debug: bool,
39    ) -> Self {
40        Self {
41            attachment_manager,
42            cache,
43            debug,
44        }
45    }
46
47    /// 构建完整的系统提示词
48    pub fn build(
49        &mut self,
50        context: &PromptContext,
51        options: Option<SystemPromptOptions>,
52    ) -> Result<BuildResult, PromptTooLongError> {
53        let start_time = Instant::now();
54        let opts = options.unwrap_or_default();
55
56        // 检查缓存
57        if opts.enable_cache {
58            let cache_key = generate_cache_key(
59                &context.working_dir.display().to_string(),
60                context.model.as_deref(),
61                context
62                    .permission_mode
63                    .map(|m| format!("{:?}", m))
64                    .as_deref(),
65                context.plan_mode,
66            );
67
68            if let Some((content, hash_info)) = self.cache.get(&cache_key) {
69                if self.debug {
70                    eprintln!("[SystemPromptBuilder] Cache hit");
71                }
72                return Ok(BuildResult {
73                    content,
74                    hash_info,
75                    attachments: vec![],
76                    truncated: false,
77                    build_time_ms: start_time.elapsed().as_millis() as u64,
78                });
79            }
80        }
81
82        // 生成附件
83        let attachments = self.attachment_manager.generate_attachments(context);
84
85        // 构建各个部分
86        let mut parts: Vec<String> = Vec::new();
87
88        // 1. 核心身份
89        if opts.include_identity {
90            parts.push(CORE_IDENTITY.to_string());
91        }
92
93        // 2. 帮助信息
94        parts.push(
95            "If the user asks for help or wants to give feedback inform them of the following:\n\
96             - /help: Get help with using the agent\n\
97             - To give feedback, users should report the issue at the project repository"
98                .to_string(),
99        );
100
101        // 3. 输出风格
102        parts.push(OUTPUT_STYLE.to_string());
103
104        // 4. 任务管理
105        parts.push(TASK_MANAGEMENT.to_string());
106
107        // 5. 代码编写指南
108        parts.push(CODING_GUIDELINES.to_string());
109
110        // 6. 工具使用指南
111        if opts.include_tool_guidelines {
112            parts.push(TOOL_GUIDELINES.to_string());
113        }
114
115        // 7. Git 操作指南
116        parts.push(GIT_GUIDELINES.to_string());
117
118        // 8. 子代理系统
119        parts.push(SUBAGENT_SYSTEM.to_string());
120
121        // 9. 权限模式
122        if opts.include_permission_mode {
123            if let Some(mode) = context.permission_mode {
124                let mode_str = match mode {
125                    PermissionMode::Default => "default",
126                    PermissionMode::AcceptEdits => "accept_edits",
127                    PermissionMode::BypassPermissions => "bypass",
128                    PermissionMode::Plan => "plan",
129                    PermissionMode::Delegate => "delegate",
130                    PermissionMode::DontAsk => "dont_ask",
131                };
132                parts.push(get_permission_mode_description(mode_str).to_string());
133            }
134        }
135
136        // 10. 环境信息
137        let env_info = EnvironmentInfo {
138            working_dir: &context.working_dir.display().to_string(),
139            is_git_repo: context.is_git_repo,
140            platform: context.platform.as_deref().unwrap_or("unknown"),
141            today_date: context.today_date.as_deref().unwrap_or("unknown"),
142            model: context.model.as_deref(),
143        };
144        parts.push(get_environment_info(&env_info));
145
146        // 11. 附件内容
147        for attachment in &attachments {
148            if !attachment.content.is_empty() {
149                parts.push(attachment.content.clone());
150            }
151        }
152
153        // 组装完整提示词
154        let mut content = parts.join("\n\n");
155
156        // 检查长度限制
157        let mut truncated = false;
158        let estimated_tokens = estimate_tokens(&content);
159
160        if estimated_tokens > opts.max_tokens {
161            // 尝试截断附件
162            content = self.truncate_to_limit(&parts, &attachments, opts.max_tokens);
163            truncated = true;
164
165            // 再次检查
166            let final_tokens = estimate_tokens(&content);
167            if final_tokens > opts.max_tokens {
168                return Err(PromptTooLongError::new(final_tokens, opts.max_tokens));
169            }
170        }
171
172        // 计算哈希
173        let hash_info = self.cache.compute_hash(&content);
174
175        // 缓存结果
176        if opts.enable_cache {
177            let cache_key = generate_cache_key(
178                &context.working_dir.display().to_string(),
179                context.model.as_deref(),
180                context
181                    .permission_mode
182                    .map(|m| format!("{:?}", m))
183                    .as_deref(),
184                context.plan_mode,
185            );
186            self.cache
187                .set(cache_key, content.clone(), Some(hash_info.clone()));
188        }
189
190        let build_time_ms = start_time.elapsed().as_millis() as u64;
191
192        if self.debug {
193            eprintln!(
194                "[SystemPromptBuilder] Built in {}ms, {} tokens",
195                build_time_ms, hash_info.estimated_tokens
196            );
197        }
198
199        Ok(BuildResult {
200            content,
201            hash_info,
202            attachments,
203            truncated,
204            build_time_ms,
205        })
206    }
207
208    /// 截断到限制
209    fn truncate_to_limit(
210        &self,
211        parts: &[String],
212        _attachments: &[Attachment],
213        max_tokens: usize,
214    ) -> String {
215        // 优先保留核心部分
216        let core_parts: Vec<&String> = parts.iter().take(7).collect();
217        let remaining_parts: Vec<&String> = parts.iter().skip(7).collect();
218
219        // 计算核心部分的 tokens
220        let mut content = core_parts
221            .iter()
222            .map(|s| s.as_str())
223            .collect::<Vec<_>>()
224            .join("\n\n");
225        let mut current_tokens = estimate_tokens(&content);
226
227        // 添加剩余部分直到接近限制
228        let reserve_tokens = max_tokens / 10; // 保留 10% 空间
229        let target_tokens = max_tokens - reserve_tokens;
230
231        for part in remaining_parts {
232            let part_tokens = estimate_tokens(part);
233            if current_tokens + part_tokens < target_tokens {
234                content.push_str("\n\n");
235                content.push_str(part);
236                current_tokens += part_tokens;
237            }
238        }
239
240        // 添加截断提示
241        content.push_str("\n\n<system-reminder>\nSome context was truncated due to length limits. Use tools to gather additional information as needed.\n</system-reminder>");
242
243        content
244    }
245
246    /// 获取提示词预览
247    pub fn preview(&self, content: &str, max_length: usize) -> String {
248        if content.len() <= max_length {
249            return content.to_string();
250        }
251        format!(
252            "{}\n... [truncated, total {} chars]",
253            content.get(..max_length).unwrap_or(content),
254            content.len()
255        )
256    }
257
258    /// 获取调试信息
259    pub fn get_debug_info(&self, result: &BuildResult) -> String {
260        let mut lines = vec![
261            "=== System Prompt Debug Info ===".to_string(),
262            format!("Hash: {}", result.hash_info.hash),
263            format!("Length: {} chars", result.hash_info.length),
264            format!("Estimated Tokens: {}", result.hash_info.estimated_tokens),
265            format!("Build Time: {}ms", result.build_time_ms),
266            format!("Truncated: {}", result.truncated),
267            format!("Attachments: {}", result.attachments.len()),
268        ];
269
270        if !result.attachments.is_empty() {
271            lines.push("Attachment Details:".to_string());
272            for att in &result.attachments {
273                lines.push(format!(
274                    "  - {:?}: {} ({} chars)",
275                    att.attachment_type,
276                    att.label.as_deref().unwrap_or("no label"),
277                    att.content.len()
278                ));
279            }
280        }
281
282        lines.push("=================================".to_string());
283        lines.join("\n")
284    }
285
286    /// 清除缓存
287    pub fn clear_cache(&mut self) {
288        self.cache.clear();
289    }
290}
291
292impl Default for SystemPromptBuilder {
293    fn default() -> Self {
294        Self::new(false)
295    }
296}