Skip to main content

aster/agents/
prompt_manager.rs

1//! 提示词管理器
2//!
3//! 管理 Agent 的系统提示词,支持分层组合:
4//! 1. Identity(身份层)- 应用层可完全控制
5//! 2. Capabilities(能力层)- 框架提供的 Extensions 等能力描述
6//! 3. Context(上下文层)- 运行时注入的 hints 和额外指令
7
8#[cfg(test)]
9use chrono::DateTime;
10use chrono::Utc;
11use serde::Serialize;
12use serde_json::Value;
13use std::collections::HashMap;
14
15use super::identity::AgentIdentity;
16use crate::agents::extension::ExtensionInfo;
17use crate::hints::load_hints::{load_hint_files, AGENTS_MD_FILENAME, ASTER_HINTS_FILENAME};
18use crate::{
19    config::{AsterMode, Config},
20    prompt_template,
21    utils::sanitize_unicode_tags,
22};
23use std::path::Path;
24
25const MAX_EXTENSIONS: usize = 5;
26const MAX_TOOLS: usize = 50;
27
28pub struct PromptManager {
29    /// 完全覆盖系统提示词(向后兼容)
30    system_prompt_override: Option<String>,
31    /// 额外指令(追加到末尾)
32    system_prompt_extras: Vec<String>,
33    /// 当前时间戳
34    current_date_timestamp: String,
35    /// Agent 身份配置(新增)
36    identity: AgentIdentity,
37    /// Session 级别的系统提示词
38    session_prompt: Option<String>,
39}
40
41impl Default for PromptManager {
42    fn default() -> Self {
43        PromptManager::new()
44    }
45}
46
47/// 身份提示词上下文
48#[derive(Serialize)]
49struct IdentityContext {
50    agent_name: String,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    agent_creator: Option<String>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    agent_description: Option<String>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    language_preference: Option<String>,
57}
58
59/// 能力提示词上下文
60#[derive(Serialize)]
61struct SystemPromptContext {
62    extensions: Vec<ExtensionInfo>,
63    current_date_time: String,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    extension_tool_limits: Option<(usize, usize)>,
66    aster_mode: AsterMode,
67    is_autonomous: bool,
68    enable_subagents: bool,
69    max_extensions: usize,
70    max_tools: usize,
71    code_execution_mode: bool,
72}
73
74pub struct SystemPromptBuilder<'a, M> {
75    manager: &'a M,
76
77    extensions_info: Vec<ExtensionInfo>,
78    frontend_instructions: Option<String>,
79    extension_tool_count: Option<(usize, usize)>,
80    subagents_enabled: bool,
81    hints: Option<String>,
82    code_execution_mode: bool,
83    session_prompt: Option<String>,
84}
85
86impl<'a> SystemPromptBuilder<'a, PromptManager> {
87    pub fn with_extension(mut self, extension: ExtensionInfo) -> Self {
88        self.extensions_info.push(extension);
89        self
90    }
91
92    pub fn with_extensions(mut self, extensions: impl Iterator<Item = ExtensionInfo>) -> Self {
93        for extension in extensions {
94            self.extensions_info.push(extension);
95        }
96        self
97    }
98
99    pub fn with_frontend_instructions(mut self, frontend_instructions: Option<String>) -> Self {
100        self.frontend_instructions = frontend_instructions;
101        self
102    }
103
104    pub fn with_extension_and_tool_counts(
105        mut self,
106        extension_count: usize,
107        tool_count: usize,
108    ) -> Self {
109        self.extension_tool_count = Some((extension_count, tool_count));
110        self
111    }
112
113    pub fn with_code_execution_mode(mut self, enabled: bool) -> Self {
114        self.code_execution_mode = enabled;
115        self
116    }
117
118    pub fn with_hints(mut self, working_dir: &Path) -> Self {
119        let config = Config::global();
120        let hints_filenames = config
121            .get_param::<Vec<String>>("CONTEXT_FILE_NAMES")
122            .unwrap_or_else(|_| {
123                vec![
124                    ASTER_HINTS_FILENAME.to_string(),
125                    AGENTS_MD_FILENAME.to_string(),
126                ]
127            });
128        let ignore_patterns = {
129            let builder = ignore::gitignore::GitignoreBuilder::new(working_dir);
130            builder.build().unwrap_or_else(|_| {
131                ignore::gitignore::GitignoreBuilder::new(working_dir)
132                    .build()
133                    .expect("Failed to build default gitignore")
134            })
135        };
136
137        let hints = load_hint_files(working_dir, &hints_filenames, &ignore_patterns);
138
139        if !hints.is_empty() {
140            self.hints = Some(hints);
141        }
142        self
143    }
144
145    pub fn with_enable_subagents(mut self, subagents_enabled: bool) -> Self {
146        self.subagents_enabled = subagents_enabled;
147        self
148    }
149
150    /// 设置 session 级别的系统提示词
151    pub fn with_session_prompt(mut self, prompt: Option<String>) -> Self {
152        self.session_prompt = prompt;
153        self
154    }
155
156    pub fn build(self) -> String {
157        let mut extensions_info = self.extensions_info;
158
159        // Add frontend instructions to extensions_info to simplify json rendering
160        if let Some(frontend_instructions) = self.frontend_instructions {
161            extensions_info.push(ExtensionInfo::new(
162                "frontend",
163                &frontend_instructions,
164                false,
165            ));
166        }
167        // Stable tool ordering is important for multi session prompt caching.
168        extensions_info.sort_by(|a, b| a.name.cmp(&b.name));
169
170        let sanitized_extensions_info: Vec<ExtensionInfo> = extensions_info
171            .into_iter()
172            .map(|mut ext_info| {
173                ext_info.instructions = sanitize_unicode_tags(&ext_info.instructions);
174                ext_info
175            })
176            .collect();
177
178        let config = Config::global();
179        let aster_mode = config.get_aster_mode().unwrap_or(AsterMode::Auto);
180
181        let extension_tool_limits = self
182            .extension_tool_count
183            .filter(|(extensions, tools)| *extensions > MAX_EXTENSIONS || *tools > MAX_TOOLS);
184
185        let capabilities_context = SystemPromptContext {
186            extensions: sanitized_extensions_info,
187            current_date_time: self.manager.current_date_timestamp.clone(),
188            extension_tool_limits,
189            aster_mode,
190            is_autonomous: aster_mode == AsterMode::Auto,
191            enable_subagents: self.subagents_enabled,
192            max_extensions: MAX_EXTENSIONS,
193            max_tools: MAX_TOOLS,
194            code_execution_mode: self.code_execution_mode,
195        };
196
197        // 构建提示词:优先使用 override,否则使用分层结构
198        let base_prompt = if let Some(override_prompt) = &self.manager.system_prompt_override {
199            // 向后兼容:完全覆盖模式
200            let sanitized_override_prompt = sanitize_unicode_tags(override_prompt);
201            prompt_template::render_inline_once(&sanitized_override_prompt, &capabilities_context)
202                .unwrap_or_else(|_| override_prompt.clone())
203        } else {
204            // 新的分层模式:Identity + Session Context + Capabilities
205            Self::build_layered_prompt_with_session(
206                &self.manager.identity,
207                &self.session_prompt,
208                &capabilities_context,
209            )
210        };
211
212        let mut system_prompt_extras = self.manager.system_prompt_extras.clone();
213
214        // Add hints if provided
215        if let Some(hints) = self.hints {
216            system_prompt_extras.push(hints);
217        }
218
219        if aster_mode == AsterMode::Chat {
220            system_prompt_extras.push(
221                "Right now you are in the chat only mode, no access to any tool use and system."
222                    .to_string(),
223            );
224        }
225
226        let sanitized_system_prompt_extras: Vec<String> = system_prompt_extras
227            .into_iter()
228            .map(|extra| sanitize_unicode_tags(&extra))
229            .collect();
230
231        if sanitized_system_prompt_extras.is_empty() {
232            base_prompt
233        } else {
234            format!(
235                "{}\n\n# Additional Instructions:\n\n{}",
236                base_prompt,
237                sanitized_system_prompt_extras.join("\n\n")
238            )
239        }
240    }
241
242    /// 构建分层提示词:Identity + Capabilities(静态方法)
243    fn build_layered_prompt_static(
244        identity: &AgentIdentity,
245        capabilities_context: &SystemPromptContext,
246    ) -> String {
247        // 1. 构建身份层
248        let identity_prompt = if let Some(custom) = &identity.custom_prompt {
249            // 使用完全自定义的身份提示词
250            sanitize_unicode_tags(custom)
251        } else {
252            // 使用模板渲染身份
253            let identity_context = IdentityContext {
254                agent_name: identity.name.clone(),
255                agent_creator: identity.creator.clone(),
256                agent_description: identity.description.clone(),
257                language_preference: identity.language.clone(),
258            };
259            prompt_template::render_global_file("identity.md", &identity_context)
260                .unwrap_or_else(|_| format!("You are an AI agent called {}.", identity.name))
261        };
262
263        // 2. 构建能力层
264        let capabilities_prompt =
265            prompt_template::render_global_file("capabilities.md", capabilities_context)
266                .unwrap_or_default();
267
268        // 3. 组合
269        if capabilities_prompt.is_empty() {
270            identity_prompt
271        } else {
272            format!("{}\n\n{}", identity_prompt, capabilities_prompt)
273        }
274    }
275
276    /// 构建分层提示词(包含 session_prompt):Identity + Session Context + Capabilities
277    fn build_layered_prompt_with_session(
278        identity: &AgentIdentity,
279        session_prompt: &Option<String>,
280        capabilities_context: &SystemPromptContext,
281    ) -> String {
282        // 1. 构建身份层
283        let identity_prompt = if let Some(custom) = &identity.custom_prompt {
284            sanitize_unicode_tags(custom)
285        } else {
286            let identity_context = IdentityContext {
287                agent_name: identity.name.clone(),
288                agent_creator: identity.creator.clone(),
289                agent_description: identity.description.clone(),
290                language_preference: identity.language.clone(),
291            };
292            prompt_template::render_global_file("identity.md", &identity_context)
293                .unwrap_or_else(|_| format!("You are an AI agent called {}.", identity.name))
294        };
295
296        // 2. Session Context 层(如果有)
297        let session_section = if let Some(prompt) = session_prompt {
298            let sanitized = sanitize_unicode_tags(prompt);
299            format!("\n\n## Session Context\n\n{}", sanitized)
300        } else {
301            String::new()
302        };
303
304        // 3. 构建能力层
305        let capabilities_prompt =
306            prompt_template::render_global_file("capabilities.md", capabilities_context)
307                .unwrap_or_default();
308
309        // 4. 组合:Identity + Session Context + Capabilities
310        if capabilities_prompt.is_empty() {
311            format!("{}{}", identity_prompt, session_section)
312        } else {
313            format!(
314                "{}{}\n\n{}",
315                identity_prompt, session_section, capabilities_prompt
316            )
317        }
318    }
319}
320
321impl PromptManager {
322    pub fn new() -> Self {
323        PromptManager {
324            system_prompt_override: None,
325            system_prompt_extras: Vec::new(),
326            current_date_timestamp: Utc::now().format("%Y-%m-%d %H:00").to_string(),
327            identity: AgentIdentity::default(),
328            session_prompt: None,
329        }
330    }
331
332    /// 创建带自定义身份的 PromptManager
333    pub fn with_identity(identity: AgentIdentity) -> Self {
334        PromptManager {
335            system_prompt_override: None,
336            system_prompt_extras: Vec::new(),
337            current_date_timestamp: Utc::now().format("%Y-%m-%d %H:00").to_string(),
338            identity,
339            session_prompt: None,
340        }
341    }
342
343    #[cfg(test)]
344    pub fn with_timestamp(dt: DateTime<Utc>) -> Self {
345        PromptManager {
346            system_prompt_override: None,
347            system_prompt_extras: Vec::new(),
348            current_date_timestamp: dt.format("%Y-%m-%d %H:%M:%S").to_string(),
349            identity: AgentIdentity::default(),
350            session_prompt: None,
351        }
352    }
353
354    /// 设置 Agent 身份
355    pub fn set_identity(&mut self, identity: AgentIdentity) {
356        self.identity = identity;
357    }
358
359    /// 获取当前身份配置
360    pub fn identity(&self) -> &AgentIdentity {
361        &self.identity
362    }
363
364    /// 设置 session 级别的系统提示词
365    pub fn set_session_prompt(&mut self, prompt: Option<String>) {
366        self.session_prompt = prompt;
367    }
368
369    /// 获取当前 session 提示词
370    pub fn session_prompt(&self) -> Option<&String> {
371        self.session_prompt.as_ref()
372    }
373
374    /// 清除 session 提示词
375    pub fn clear_session_prompt(&mut self) {
376        self.session_prompt = None;
377    }
378
379    /// Add an additional instruction to the system prompt
380    pub fn add_system_prompt_extra(&mut self, instruction: String) {
381        self.system_prompt_extras.push(instruction);
382    }
383
384    /// Override the system prompt with custom text (向后兼容)
385    pub fn set_system_prompt_override(&mut self, template: String) {
386        self.system_prompt_override = Some(template);
387    }
388
389    pub fn builder<'a>(&'a self) -> SystemPromptBuilder<'a, Self> {
390        SystemPromptBuilder {
391            manager: self,
392
393            extensions_info: vec![],
394            frontend_instructions: None,
395            extension_tool_count: None,
396            subagents_enabled: false,
397            hints: None,
398            code_execution_mode: false,
399            session_prompt: None,
400        }
401    }
402
403    pub async fn get_recipe_prompt(&self) -> String {
404        let context: HashMap<&str, Value> = HashMap::new();
405        prompt_template::render_global_file("recipe.md", &context)
406            .unwrap_or_else(|_| "The recipe prompt is busted. Tell the user.".to_string())
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use insta::assert_snapshot;
413
414    use super::*;
415
416    #[test]
417    fn test_build_system_prompt_sanitizes_override() {
418        let mut manager = PromptManager::new();
419        let malicious_override = "System prompt\u{E0041}\u{E0042}\u{E0043}with hidden text";
420        manager.set_system_prompt_override(malicious_override.to_string());
421
422        let result = manager.builder().build();
423
424        assert!(!result.contains('\u{E0041}'));
425        assert!(!result.contains('\u{E0042}'));
426        assert!(!result.contains('\u{E0043}'));
427        assert!(result.contains("System prompt"));
428        assert!(result.contains("with hidden text"));
429    }
430
431    #[test]
432    fn test_build_system_prompt_sanitizes_extras() {
433        let mut manager = PromptManager::new();
434        let malicious_extra = "Extra instruction\u{E0041}\u{E0042}\u{E0043}hidden";
435        manager.add_system_prompt_extra(malicious_extra.to_string());
436
437        let result = manager.builder().build();
438
439        assert!(!result.contains('\u{E0041}'));
440        assert!(!result.contains('\u{E0042}'));
441        assert!(!result.contains('\u{E0043}'));
442        assert!(result.contains("Extra instruction"));
443        assert!(result.contains("hidden"));
444    }
445
446    #[test]
447    fn test_build_system_prompt_sanitizes_multiple_extras() {
448        let mut manager = PromptManager::new();
449        manager.add_system_prompt_extra("First\u{E0041}instruction".to_string());
450        manager.add_system_prompt_extra("Second\u{E0042}instruction".to_string());
451        manager.add_system_prompt_extra("Third\u{E0043}instruction".to_string());
452
453        let result = manager.builder().build();
454
455        assert!(!result.contains('\u{E0041}'));
456        assert!(!result.contains('\u{E0042}'));
457        assert!(!result.contains('\u{E0043}'));
458        assert!(result.contains("Firstinstruction"));
459        assert!(result.contains("Secondinstruction"));
460        assert!(result.contains("Thirdinstruction"));
461    }
462
463    #[test]
464    fn test_build_system_prompt_preserves_legitimate_unicode_in_extras() {
465        let mut manager = PromptManager::new();
466        let legitimate_unicode = "Instruction with 世界 and 🌍 emojis";
467        manager.add_system_prompt_extra(legitimate_unicode.to_string());
468
469        let result = manager.builder().build();
470
471        assert!(result.contains("世界"));
472        assert!(result.contains("🌍"));
473        assert!(result.contains("Instruction with"));
474        assert!(result.contains("emojis"));
475    }
476
477    #[test]
478    fn test_build_system_prompt_sanitizes_extension_instructions() {
479        let manager = PromptManager::new();
480        let malicious_extension_info = ExtensionInfo::new(
481            "test_extension",
482            "Extension help\u{E0041}\u{E0042}\u{E0043}hidden instructions",
483            false,
484        );
485
486        let result = manager
487            .builder()
488            .with_extension(malicious_extension_info)
489            .build();
490
491        assert!(!result.contains('\u{E0041}'));
492        assert!(!result.contains('\u{E0042}'));
493        assert!(!result.contains('\u{E0043}'));
494        assert!(result.contains("Extension help"));
495        assert!(result.contains("hidden instructions"));
496    }
497
498    #[test]
499    fn test_basic() {
500        let manager = PromptManager::with_timestamp(DateTime::<Utc>::from_timestamp(0, 0).unwrap());
501
502        let system_prompt = manager.builder().build();
503
504        assert_snapshot!(system_prompt)
505    }
506
507    #[test]
508    fn test_one_extension() {
509        let manager = PromptManager::with_timestamp(DateTime::<Utc>::from_timestamp(0, 0).unwrap());
510
511        let system_prompt = manager
512            .builder()
513            .with_extension(ExtensionInfo::new(
514                "test",
515                "how to use this extension",
516                true,
517            ))
518            .build();
519
520        assert_snapshot!(system_prompt)
521    }
522
523    #[test]
524    fn test_typical_setup() {
525        let manager = PromptManager::with_timestamp(DateTime::<Utc>::from_timestamp(0, 0).unwrap());
526
527        let system_prompt = manager
528            .builder()
529            .with_extension(ExtensionInfo::new(
530                "extension_A",
531                "<instructions on how to use extension A>",
532                true,
533            ))
534            .with_extension(ExtensionInfo::new(
535                "extension_B",
536                "<instructions on how to use extension B (no resources)>",
537                false,
538            ))
539            .with_extension_and_tool_counts(MAX_EXTENSIONS + 1, MAX_TOOLS + 1)
540            .build();
541
542        assert_snapshot!(system_prompt)
543    }
544}