claude_agent/output_style/
generator.rs1use std::path::PathBuf;
7
8use super::{
9 ChainOutputStyleProvider, InMemoryOutputStyleProvider, OutputStyle, builtin_styles,
10 default_style, file_output_style_provider,
11};
12use crate::client::DEFAULT_MODEL;
13use crate::common::Provider;
14use crate::common::SourceType;
15use crate::prompts::{
16 base::{BASE_SYSTEM_PROMPT, TOOL_USAGE_POLICY},
17 coding,
18 environment::{current_platform, environment_block, is_git_repository, os_version},
19 identity::CLI_IDENTITY,
20};
21
22#[derive(Debug, Clone)]
48pub struct SystemPromptGenerator {
49 style: OutputStyle,
50 working_dir: Option<PathBuf>,
51 model_name: String,
52 model_id: String,
53 require_cli_identity: bool,
54}
55
56impl Default for SystemPromptGenerator {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl SystemPromptGenerator {
63 pub fn new() -> Self {
66 Self {
67 style: default_style(),
68 working_dir: None,
69 model_name: "Claude".to_string(),
70 model_id: DEFAULT_MODEL.to_string(),
71 require_cli_identity: false,
72 }
73 }
74
75 pub fn cli_identity() -> Self {
78 Self {
79 style: default_style(),
80 working_dir: None,
81 model_name: "Claude".to_string(),
82 model_id: DEFAULT_MODEL.to_string(),
83 require_cli_identity: true,
84 }
85 }
86
87 pub fn require_cli_identity(mut self, required: bool) -> Self {
90 self.require_cli_identity = required;
91 self
92 }
93
94 pub fn output_style(mut self, style: OutputStyle) -> Self {
96 self.style = style;
97 self
98 }
99
100 pub fn working_dir(mut self, dir: impl Into<PathBuf>) -> Self {
102 self.working_dir = Some(dir.into());
103 self
104 }
105
106 pub fn model(mut self, model_id: impl Into<String>) -> Self {
108 let id = model_id.into();
109 self.model_name = derive_model_name(&id);
110 self.model_id = id;
111 self
112 }
113
114 pub fn model_name(mut self, name: impl Into<String>) -> Self {
116 self.model_name = name.into();
117 self
118 }
119
120 pub async fn style_name(mut self, name: &str) -> crate::Result<Self> {
127 let builtins = InMemoryOutputStyleProvider::new()
128 .items(builtin_styles())
129 .priority(0)
130 .source_type(SourceType::Builtin);
131
132 let mut chain = ChainOutputStyleProvider::new().provider(builtins);
133
134 if let Some(ref working_dir) = self.working_dir {
135 let project = file_output_style_provider()
136 .project_path(working_dir)
137 .priority(20)
138 .source_type(SourceType::Project);
139 chain = chain.provider(project);
140 }
141
142 let user = file_output_style_provider()
143 .user_path()
144 .priority(10)
145 .source_type(SourceType::User);
146 chain = chain.provider(user);
147
148 if let Some(style) = chain.get(name).await? {
149 self.style = style;
150 Ok(self)
151 } else {
152 Err(crate::Error::Config(format!(
153 "Output style '{}' not found",
154 name
155 )))
156 }
157 }
158
159 pub fn generate(&self) -> String {
170 let mut parts = Vec::new();
171
172 if self.require_cli_identity {
174 parts.push(CLI_IDENTITY.to_string());
175 }
176
177 parts.push(BASE_SYSTEM_PROMPT.to_string());
179
180 parts.push(TOOL_USAGE_POLICY.to_string());
182
183 if self.style.keep_coding_instructions {
185 parts.push(coding::coding_instructions(&self.model_name));
186 }
187
188 if !self.style.prompt.is_empty() {
190 parts.push(self.style.prompt.clone());
191 }
192
193 let is_git = is_git_repository(self.working_dir.as_deref());
195 let platform = current_platform();
196 let os_ver = os_version();
197
198 parts.push(environment_block(
199 self.working_dir.as_deref(),
200 is_git,
201 platform,
202 &os_ver,
203 &self.model_name,
204 &self.model_id,
205 ));
206
207 parts.join("\n\n")
208 }
209
210 pub fn generate_with_context(&self, additional_context: &str) -> String {
214 let mut prompt = self.generate();
215 if !additional_context.is_empty() {
216 prompt.push_str("\n\n");
217 prompt.push_str(additional_context);
218 }
219 prompt
220 }
221
222 pub fn style(&self) -> &OutputStyle {
224 &self.style
225 }
226
227 pub fn has_coding_instructions(&self) -> bool {
229 self.style.keep_coding_instructions
230 }
231}
232
233fn derive_model_name(model_id: &str) -> String {
235 if model_id.contains("opus") {
236 "Claude Opus 4.6".to_string()
237 } else if model_id.contains("sonnet") {
238 "Claude Sonnet 4.5".to_string()
239 } else if model_id.contains("haiku") {
240 "Claude Haiku 4.5".to_string()
241 } else {
242 "Claude".to_string()
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use crate::output_style::SourceType;
250
251 #[test]
252 fn test_generator_default_no_cli_identity() {
253 let prompt = SystemPromptGenerator::new().generate();
254
255 assert!(!prompt.starts_with(CLI_IDENTITY));
257 assert!(prompt.contains("Doing tasks")); assert!(prompt.contains("<env>")); }
260
261 #[test]
262 fn test_generator_cli_identity() {
263 let prompt = SystemPromptGenerator::cli_identity().generate();
264
265 assert!(prompt.starts_with(CLI_IDENTITY));
267 assert!(prompt.contains("Doing tasks")); assert!(prompt.contains("<env>")); }
270
271 #[test]
272 fn test_generator_with_custom_style_keep_coding() {
273 let style = OutputStyle::new("test", "Test style", "Custom instructions here")
274 .source_type(SourceType::User)
275 .keep_coding_instructions(true);
276
277 let prompt = SystemPromptGenerator::cli_identity()
278 .output_style(style)
279 .generate();
280
281 assert!(prompt.starts_with(CLI_IDENTITY));
282 assert!(prompt.contains("Doing tasks")); assert!(prompt.contains("Custom instructions here")); assert!(prompt.contains("<env>")); }
286
287 #[test]
288 fn test_generator_with_custom_style_no_coding() {
289 let style = OutputStyle::new("concise", "Be concise", "Keep responses short.")
290 .source_type(SourceType::User)
291 .keep_coding_instructions(false);
292
293 let prompt = SystemPromptGenerator::cli_identity()
294 .output_style(style)
295 .generate();
296
297 assert!(prompt.starts_with(CLI_IDENTITY)); assert!(!prompt.contains("Doing tasks")); assert!(prompt.contains("Keep responses short.")); assert!(prompt.contains("<env>")); }
302
303 #[test]
304 fn test_generator_working_dir() {
305 let prompt = SystemPromptGenerator::new()
306 .working_dir("/test/project")
307 .generate();
308
309 assert!(prompt.contains("/test/project"));
310 }
311
312 #[test]
313 fn test_generator_model() {
314 let prompt = SystemPromptGenerator::new()
315 .model("claude-opus-4-6")
316 .generate();
317
318 assert!(prompt.contains("claude-opus-4-6"));
319 assert!(prompt.contains("Claude Opus 4.6"));
320 }
321
322 #[test]
323 fn test_derive_model_name() {
324 assert_eq!(derive_model_name("claude-opus-4-6"), "Claude Opus 4.6");
325 assert_eq!(
326 derive_model_name("claude-sonnet-4-5-20250929"),
327 "Claude Sonnet 4.5"
328 );
329 assert_eq!(
330 derive_model_name("claude-haiku-4-5-20251001"),
331 "Claude Haiku 4.5"
332 );
333 assert_eq!(derive_model_name("unknown-model"), "Claude");
334 }
335
336 #[test]
337 fn test_generator_with_context() {
338 let prompt = SystemPromptGenerator::new()
339 .generate_with_context("# Dynamic Rules\nSome dynamic content");
340
341 assert!(prompt.contains("# Dynamic Rules"));
342 assert!(prompt.contains("Some dynamic content"));
343 }
344
345 #[test]
346 fn test_has_coding_instructions() {
347 let generator = SystemPromptGenerator::new();
348 assert!(generator.has_coding_instructions());
349
350 let style = OutputStyle::new("no-coding", "", "").keep_coding_instructions(false);
351 let generator = SystemPromptGenerator::new().output_style(style);
352 assert!(!generator.has_coding_instructions());
353 }
354
355 #[test]
356 fn test_cli_identity_cannot_be_replaced_by_custom_prompt() {
357 let style = OutputStyle::new(
360 "custom",
361 "Custom identity",
362 "I am a different assistant.", )
364 .keep_coding_instructions(false);
365
366 let prompt = SystemPromptGenerator::cli_identity()
367 .output_style(style)
368 .generate();
369
370 assert!(prompt.starts_with(CLI_IDENTITY));
372 assert!(prompt.contains("I am a different assistant."));
373 }
374}