1use std::{
2 path::{Path, PathBuf},
3 sync::LazyLock,
4};
5
6use parking_lot::Mutex;
7use rust_embed::RustEmbed;
8use tera::{Context, Tera};
9
10use crate::error::{CommitGenError, Result};
11
12pub struct PromptParts {
14 pub system: String,
15 pub user: String,
16}
17
18const USER_SEPARATOR_MARKER: &str = "======USER=======";
19
20fn find_user_separator(content: &str) -> Option<(usize, usize)> {
29 let marker_pos = content.find(USER_SEPARATOR_MARKER)?;
30 let system_end = if marker_pos >= 2 && &content[marker_pos - 2..marker_pos] == "\r\n" {
33 marker_pos - 2
34 } else if marker_pos >= 1 && &content[marker_pos - 1..marker_pos] == "\n" {
35 marker_pos - 1
36 } else {
37 marker_pos
39 };
40 let after_marker = marker_pos + USER_SEPARATOR_MARKER.len();
41 let user_start = if content.get(after_marker..after_marker + 2) == Some("\r\n") {
42 after_marker + 2
43 } else if content.get(after_marker..after_marker + 1) == Some("\n") {
44 after_marker + 1
45 } else {
46 after_marker
47 };
48 Some((system_end, user_start))
49}
50
51fn split_prompt_template(template_content: &str) -> (Option<&str>, &str) {
53 if let Some((system_end, user_start)) = find_user_separator(template_content) {
54 (Some(&template_content[..system_end]), &template_content[user_start..])
55 } else {
56 (None, template_content)
57 }
58}
59
60fn ensure_static_system_prompt(system_template: &str, template_name: &str) -> Result<()> {
62 let has_template_tags = system_template.contains("{{")
63 || system_template.contains("{%")
64 || system_template.contains("{#");
65
66 if has_template_tags {
67 return Err(CommitGenError::Other(format!(
68 "Template '{template_name}' contains dynamic tags in system section. Move interpolated \
69 content below ======USER=======."
70 )));
71 }
72
73 Ok(())
74}
75
76fn render_prompt_parts(
78 template_name: &str,
79 template_content: &str,
80 context: &Context,
81) -> Result<PromptParts> {
82 let (system_template, user_template) = split_prompt_template(template_content);
83
84 let system = if let Some(system_template) = system_template {
85 ensure_static_system_prompt(system_template, template_name)?;
86 system_template.trim().to_string()
87 } else {
88 String::new()
89 };
90
91 let mut tera = TERA.lock();
92 let rendered_user = tera.render_str(user_template, context).map_err(|e| {
93 CommitGenError::Other(format!("Failed to render {template_name} prompt template: {e}"))
94 })?;
95
96 Ok(PromptParts { system, user: rendered_user.trim().to_string() })
97}
98
99#[derive(Default)]
101pub struct AnalysisParams<'a> {
102 pub variant: &'a str,
103 pub stat: &'a str,
104 pub diff: &'a str,
105 pub scope_candidates: &'a str,
106 pub recent_commits: Option<&'a str>,
107 pub common_scopes: Option<&'a str>,
108 pub types_description: Option<&'a str>,
109 pub project_context: Option<&'a str>,
110}
111
112#[derive(RustEmbed)]
114#[folder = "prompts/"]
115#[include = "**/*.md"]
116struct Prompts;
117
118static TERA: LazyLock<Mutex<Tera>> = LazyLock::new(|| {
121 if let Err(e) = ensure_prompts_dir() {
123 eprintln!("Warning: Failed to initialize prompts directory: {e}");
124 }
125
126 let mut tera = Tera::default();
127
128 if let Some(prompts_dir) = get_user_prompts_dir() {
130 if let Err(e) =
131 register_directory_templates(&mut tera, &prompts_dir.join("analysis"), "analysis")
132 {
133 eprintln!("Warning: {e}");
134 }
135 if let Err(e) =
136 register_directory_templates(&mut tera, &prompts_dir.join("summary"), "summary")
137 {
138 eprintln!("Warning: {e}");
139 }
140 if let Err(e) =
141 register_directory_templates(&mut tera, &prompts_dir.join("changelog"), "changelog")
142 {
143 eprintln!("Warning: {e}");
144 }
145 if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("map"), "map") {
146 eprintln!("Warning: {e}");
147 }
148 if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("reduce"), "reduce")
149 {
150 eprintln!("Warning: {e}");
151 }
152 if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("fast"), "fast") {
153 eprintln!("Warning: {e}");
154 }
155 if let Err(e) = register_directory_templates(
156 &mut tera,
157 &prompts_dir.join("compose-intent"),
158 "compose-intent",
159 ) {
160 eprintln!("Warning: {e}");
161 }
162 if let Err(e) =
163 register_directory_templates(&mut tera, &prompts_dir.join("compose-bind"), "compose-bind")
164 {
165 eprintln!("Warning: {e}");
166 }
167 }
168
169 for file in Prompts::iter() {
171 if tera.get_template_names().any(|name| name == file.as_ref()) {
172 continue;
173 }
174
175 if let Some(embedded_file) = Prompts::get(file.as_ref()) {
176 match std::str::from_utf8(embedded_file.data.as_ref()) {
177 Ok(content) => {
178 if let Err(e) = tera.add_raw_template(file.as_ref(), content) {
179 eprintln!(
180 "Warning: Failed to register embedded template {}: {}",
181 file.as_ref(),
182 e
183 );
184 }
185 },
186 Err(e) => {
187 eprintln!("Warning: Embedded template {} is not valid UTF-8: {}", file.as_ref(), e);
188 },
189 }
190 }
191 }
192
193 tera.autoescape_on(vec![]);
195
196 Mutex::new(tera)
197});
198
199fn get_user_prompts_dir() -> Option<PathBuf> {
201 std::env::var("HOME")
202 .or_else(|_| std::env::var("USERPROFILE"))
203 .ok()
204 .map(|home| PathBuf::from(home).join(".llm-git").join("prompts"))
205}
206
207pub fn ensure_prompts_dir() -> Result<()> {
209 let Some(user_prompts_dir) = get_user_prompts_dir() else {
210 return Ok(());
213 };
214
215 let user_llm_git_dir = user_prompts_dir
217 .parent()
218 .ok_or_else(|| CommitGenError::Other("Invalid prompts directory path".to_string()))?;
219
220 if !user_llm_git_dir.exists() {
222 std::fs::create_dir_all(user_llm_git_dir).map_err(|e| {
223 CommitGenError::Other(format!(
224 "Failed to create directory {}: {}",
225 user_llm_git_dir.display(),
226 e
227 ))
228 })?;
229 }
230
231 if !user_prompts_dir.exists() {
233 std::fs::create_dir_all(&user_prompts_dir).map_err(|e| {
234 CommitGenError::Other(format!(
235 "Failed to create directory {}: {}",
236 user_prompts_dir.display(),
237 e
238 ))
239 })?;
240 }
241
242 for file in Prompts::iter() {
244 let file_path = user_prompts_dir.join(file.as_ref());
245
246 if let Some(parent) = file_path.parent() {
248 std::fs::create_dir_all(parent).map_err(|e| {
249 CommitGenError::Other(format!("Failed to create directory {}: {}", parent.display(), e))
250 })?;
251 }
252
253 if let Some(embedded_file) = Prompts::get(file.as_ref()) {
254 let embedded_content = embedded_file.data;
255
256 let should_write = if file_path.exists() {
258 match std::fs::read(&file_path) {
259 Ok(existing_content) => existing_content != embedded_content.as_ref(),
260 Err(_) => true, }
262 } else {
263 true };
265
266 if should_write {
267 std::fs::write(&file_path, embedded_content.as_ref()).map_err(|e| {
268 CommitGenError::Other(format!("Failed to write file {}: {}", file_path.display(), e))
269 })?;
270 }
271 }
272 }
273
274 Ok(())
275}
276
277fn register_directory_templates(tera: &mut Tera, directory: &Path, category: &str) -> Result<()> {
278 if !directory.exists() {
279 return Ok(());
280 }
281
282 for entry in std::fs::read_dir(directory).map_err(|e| {
283 CommitGenError::Other(format!(
284 "Failed to read {} templates directory {}: {}",
285 category,
286 directory.display(),
287 e
288 ))
289 })? {
290 let entry = match entry {
291 Ok(entry) => entry,
292 Err(e) => {
293 eprintln!(
294 "Warning: Failed to iterate template entry in {}: {}",
295 directory.display(),
296 e
297 );
298 continue;
299 },
300 };
301
302 let path = entry.path();
303 if path.extension().and_then(|s| s.to_str()) != Some("md") {
304 continue;
305 }
306
307 let template_name = format!(
308 "{}/{}",
309 category,
310 path
311 .file_name()
312 .and_then(|s| s.to_str())
313 .unwrap_or_default()
314 );
315
316 if let Err(e) = tera.add_template_file(&path, Some(&template_name)) {
319 eprintln!("Warning: Failed to load template file {}: {}", path.display(), e);
320 }
321 }
322
323 Ok(())
324}
325
326fn load_template_file(category: &str, variant: &str) -> Result<String> {
328 if let Some(prompts_dir) = get_user_prompts_dir() {
330 let template_path = prompts_dir.join(category).join(format!("{variant}.md"));
331 if template_path.exists() {
332 return std::fs::read_to_string(&template_path).map_err(|e| {
333 CommitGenError::Other(format!(
334 "Failed to read template file {}: {}",
335 template_path.display(),
336 e
337 ))
338 });
339 }
340 }
341
342 let embedded_key = format!("{category}/{variant}.md");
344 if let Some(bytes) = Prompts::get(&embedded_key) {
345 return std::str::from_utf8(bytes.data.as_ref())
346 .map(|s| s.to_string())
347 .map_err(|e| {
348 CommitGenError::Other(format!(
349 "Embedded template {embedded_key} is not valid UTF-8: {e}"
350 ))
351 });
352 }
353
354 Err(CommitGenError::Other(format!(
355 "Template variant '{variant}' in category '{category}' not found as user override or \
356 embedded default"
357 )))
358}
359
360pub fn render_analysis_prompt(p: &AnalysisParams<'_>) -> Result<PromptParts> {
362 let template_content = load_template_file("analysis", p.variant)?;
364
365 let mut context = Context::new();
367 context.insert("stat", p.stat);
368 context.insert("diff", p.diff);
369 context.insert("scope_candidates", p.scope_candidates);
370 if let Some(commits) = p.recent_commits {
371 context.insert("recent_commits", commits);
372 }
373 if let Some(scopes) = p.common_scopes {
374 context.insert("common_scopes", scopes);
375 }
376 if let Some(types) = p.types_description {
377 context.insert("types_description", types);
378 }
379 if let Some(ctx) = p.project_context {
380 context.insert("project_context", ctx);
381 }
382
383 render_prompt_parts(&format!("analysis/{}.md", p.variant), &template_content, &context)
384}
385
386pub fn render_summary_prompt(
388 variant: &str,
389 commit_type: &str,
390 scope: &str,
391 chars: &str,
392 details: &str,
393 stat: &str,
394 user_context: Option<&str>,
395) -> Result<PromptParts> {
396 let template_content = load_template_file("summary", variant)?;
398
399 let mut context = Context::new();
401 context.insert("commit_type", commit_type);
402 context.insert("scope", scope);
403 context.insert("chars", chars);
404 context.insert("details", details);
405 context.insert("stat", stat);
406 if let Some(ctx) = user_context {
407 context.insert("user_context", ctx);
408 }
409
410 render_prompt_parts(&format!("summary/{variant}.md"), &template_content, &context)
411}
412
413pub fn render_changelog_prompt(
415 variant: &str,
416 changelog_path: &str,
417 is_package_changelog: bool,
418 stat: &str,
419 diff: &str,
420 existing_entries: Option<&str>,
421) -> Result<PromptParts> {
422 let template_content = load_template_file("changelog", variant)?;
424
425 let mut context = Context::new();
427 context.insert("changelog_path", changelog_path);
428 context.insert("is_package_changelog", &is_package_changelog);
429 context.insert("stat", stat);
430 context.insert("diff", diff);
431 if let Some(entries) = existing_entries {
432 context.insert("existing_entries", entries);
433 }
434
435 render_prompt_parts(&format!("changelog/{variant}.md"), &template_content, &context)
436}
437
438pub fn render_map_prompt(
440 variant: &str,
441 filename: &str,
442 diff: &str,
443 context_header: &str,
444) -> Result<PromptParts> {
445 let template_content = load_template_file("map", variant)?;
446
447 let mut context = Context::new();
448 context.insert("filename", filename);
449 context.insert("diff", diff);
450 if !context_header.is_empty() {
451 context.insert("context_header", context_header);
452 }
453
454 render_prompt_parts(&format!("map/{variant}.md"), &template_content, &context)
455}
456
457pub fn render_reduce_prompt(
459 variant: &str,
460 observations: &str,
461 stat: &str,
462 scope_candidates: &str,
463 types_description: Option<&str>,
464) -> Result<PromptParts> {
465 let template_content = load_template_file("reduce", variant)?;
466
467 let mut context = Context::new();
468 context.insert("observations", observations);
469 context.insert("stat", stat);
470 context.insert("scope_candidates", scope_candidates);
471 if let Some(types_desc) = types_description {
472 context.insert("types_description", types_desc);
473 }
474
475 render_prompt_parts(&format!("reduce/{variant}.md"), &template_content, &context)
476}
477
478pub struct ComposeIntentPromptParams<'a> {
480 pub variant: &'a str,
481 pub max_commits: usize,
482 pub stat: &'a str,
483 pub snapshot_summary: &'a str,
484 pub planning_targets: &'a str,
485 pub planning_notes: &'a str,
486 pub split_bias: &'a str,
487}
488
489pub fn render_compose_intent_prompt(p: &ComposeIntentPromptParams<'_>) -> Result<PromptParts> {
491 let template_content = load_template_file("compose-intent", p.variant)?;
492
493 let mut context = Context::new();
494 context.insert("max_commits", &p.max_commits);
495 context.insert("stat", p.stat);
496 context.insert("snapshot_summary", p.snapshot_summary);
497 context.insert("planning_targets", p.planning_targets);
498 context.insert("planning_notes", p.planning_notes);
499 context.insert("split_bias", p.split_bias);
500
501 render_prompt_parts(&format!("compose-intent/{}.md", p.variant), &template_content, &context)
502}
503
504pub struct ComposeBindPromptParams<'a> {
506 pub variant: &'a str,
507 pub groups: &'a str,
508 pub ambiguous_files: &'a str,
509}
510
511pub fn render_compose_bind_prompt(p: &ComposeBindPromptParams<'_>) -> Result<PromptParts> {
513 let template_content = load_template_file("compose-bind", p.variant)?;
514
515 let mut context = Context::new();
516 context.insert("groups", p.groups);
517 context.insert("ambiguous_files", p.ambiguous_files);
518
519 render_prompt_parts(&format!("compose-bind/{}.md", p.variant), &template_content, &context)
520}
521
522pub struct FastPromptParams<'a> {
524 pub variant: &'a str,
525 pub stat: &'a str,
526 pub diff: &'a str,
527 pub scope_candidates: &'a str,
528 pub user_context: Option<&'a str>,
529}
530
531pub fn render_fast_prompt(p: &FastPromptParams<'_>) -> Result<PromptParts> {
533 let template_content = load_template_file("fast", p.variant)?;
534
535 let mut context = Context::new();
536 context.insert("stat", p.stat);
537 context.insert("diff", p.diff);
538 context.insert("scope_candidates", p.scope_candidates);
539 if let Some(ctx) = p.user_context {
540 context.insert("user_context", ctx);
541 }
542
543 render_prompt_parts(&format!("fast/{}.md", p.variant), &template_content, &context)
544}
545
546#[cfg(test)]
547mod tests {
548 use super::{
549 ComposeBindPromptParams, ComposeIntentPromptParams, render_compose_bind_prompt,
550 render_compose_intent_prompt, split_prompt_template,
551 };
552
553 #[test]
554 fn test_split_prompt_template_lf() {
555 let content = "system text\nmore system\n======USER=======\nuser body\n";
556 let (system, user) = split_prompt_template(content);
557 assert_eq!(system, Some("system text\nmore system"));
558 assert_eq!(user, "user body\n");
559 }
560
561 #[test]
562 fn test_split_prompt_template_crlf() {
563 let content = "system text\r\nmore system\r\n======USER=======\r\nuser body\r\n";
566 let (system, user) = split_prompt_template(content);
567 assert_eq!(system, Some("system text\r\nmore system"));
568 assert_eq!(user, "user body\r\n");
569 }
570
571 #[test]
572 fn test_split_prompt_template_no_separator() {
573 let content = "no separator here";
574 let (system, user) = split_prompt_template(content);
575 assert_eq!(system, None);
576 assert_eq!(user, content);
577 }
578
579 #[test]
580 fn test_render_compose_intent_prompt() {
581 let parts = render_compose_intent_prompt(&ComposeIntentPromptParams {
582 variant: "default",
583 max_commits: 3,
584 stat: "src/foo.rs | 10 +++++-----",
585 snapshot_summary: "- F1 src/foo.rs",
586 planning_targets: "file IDs",
587 planning_notes: "Prefer conservative grouping over speculative splitting.",
588 split_bias: "Prefer fewer groups when the split is uncertain.",
589 })
590 .unwrap();
591
592 assert!(parts.system.contains("create_compose_intent_plan"));
593 assert!(parts.user.contains("max_commits: 3"));
594 assert!(parts.user.contains("src/foo.rs"));
595 }
596
597 #[test]
598 fn test_render_compose_bind_prompt() {
599 let parts = render_compose_bind_prompt(&ComposeBindPromptParams {
600 variant: "default",
601 groups: "- G1 [feat(api)] Added endpoint",
602 ambiguous_files: "- F2 src/api.rs candidates: G1",
603 })
604 .unwrap();
605
606 assert!(parts.system.contains("bind_compose_hunks"));
607 assert!(parts.user.contains("G1"));
608 assert!(parts.user.contains("src/api.rs"));
609 }
610}