1use std::{
2 path::{Path, PathBuf},
3 sync::LazyLock,
4};
5
6use parking_lot::Mutex;
7use rust_embed::RustEmbed;
8use serde::Serialize;
9use tera::{Context, Tera};
10
11use crate::error::{CommitGenError, Result};
12
13pub struct PromptParts {
15 pub system: String,
16 pub user: String,
17}
18
19const USER_SEPARATOR_MARKER: &str = "======USER=======";
20
21fn find_user_separator(content: &str) -> Option<(usize, usize)> {
30 let marker_pos = content.find(USER_SEPARATOR_MARKER)?;
31 let system_end = if marker_pos >= 2 && &content[marker_pos - 2..marker_pos] == "\r\n" {
34 marker_pos - 2
35 } else if marker_pos >= 1 && &content[marker_pos - 1..marker_pos] == "\n" {
36 marker_pos - 1
37 } else {
38 marker_pos
40 };
41 let after_marker = marker_pos + USER_SEPARATOR_MARKER.len();
42 let user_start = if content.get(after_marker..after_marker + 2) == Some("\r\n") {
43 after_marker + 2
44 } else if content.get(after_marker..after_marker + 1) == Some("\n") {
45 after_marker + 1
46 } else {
47 after_marker
48 };
49 Some((system_end, user_start))
50}
51
52fn split_prompt_template(template_content: &str) -> (Option<&str>, &str) {
54 if let Some((system_end, user_start)) = find_user_separator(template_content) {
55 (Some(&template_content[..system_end]), &template_content[user_start..])
56 } else {
57 (None, template_content)
58 }
59}
60
61fn ensure_static_system_prompt(system_template: &str, template_name: &str) -> Result<()> {
63 let has_template_tags = system_template.contains("{{")
64 || system_template.contains("{%")
65 || system_template.contains("{#");
66
67 if has_template_tags {
68 return Err(CommitGenError::Other(format!(
69 "Template '{template_name}' contains dynamic tags in system section. Move interpolated \
70 content below ======USER=======."
71 )));
72 }
73
74 Ok(())
75}
76
77fn render_prompt_parts(
79 template_name: &str,
80 template_content: &str,
81 context: &Context,
82) -> Result<PromptParts> {
83 let (system_template, user_template) = split_prompt_template(template_content);
84
85 let system = if let Some(system_template) = system_template {
86 ensure_static_system_prompt(system_template, template_name)?;
87 system_template.trim().to_string()
88 } else {
89 String::new()
90 };
91
92 let mut tera = TERA.lock();
93 let rendered_user = tera.render_str(user_template, context).map_err(|e| {
94 CommitGenError::Other(format!("Failed to render {template_name} prompt template: {e}"))
95 })?;
96
97 Ok(PromptParts { system, user: rendered_user.trim().to_string() })
98}
99
100#[derive(Default)]
102pub struct AnalysisParams<'a> {
103 pub variant: &'a str,
104 pub stat: &'a str,
105 pub diff: &'a str,
106 pub scope_candidates: &'a str,
107 pub recent_commits: Option<&'a str>,
108 pub common_scopes: Option<&'a str>,
109 pub types_description: Option<&'a str>,
110 pub project_context: Option<&'a str>,
111}
112
113#[derive(RustEmbed)]
115#[folder = "prompts/"]
116#[include = "**/*.md"]
117struct Prompts;
118
119static TERA: LazyLock<Mutex<Tera>> = LazyLock::new(|| {
122 if let Err(e) = ensure_prompts_dir() {
124 eprintln!("Warning: Failed to initialize prompts directory: {e}");
125 }
126
127 let mut tera = Tera::default();
128
129 if let Some(prompts_dir) = get_user_prompts_dir() {
131 if let Err(e) =
132 register_directory_templates(&mut tera, &prompts_dir.join("analysis"), "analysis")
133 {
134 eprintln!("Warning: {e}");
135 }
136 if let Err(e) =
137 register_directory_templates(&mut tera, &prompts_dir.join("summary"), "summary")
138 {
139 eprintln!("Warning: {e}");
140 }
141 if let Err(e) =
142 register_directory_templates(&mut tera, &prompts_dir.join("changelog"), "changelog")
143 {
144 eprintln!("Warning: {e}");
145 }
146 if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("map"), "map") {
147 eprintln!("Warning: {e}");
148 }
149 if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("reduce"), "reduce")
150 {
151 eprintln!("Warning: {e}");
152 }
153 if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("fast"), "fast") {
154 eprintln!("Warning: {e}");
155 }
156 if let Err(e) = register_directory_templates(
157 &mut tera,
158 &prompts_dir.join("compose-intent"),
159 "compose-intent",
160 ) {
161 eprintln!("Warning: {e}");
162 }
163 if let Err(e) =
164 register_directory_templates(&mut tera, &prompts_dir.join("compose-bind"), "compose-bind")
165 {
166 eprintln!("Warning: {e}");
167 }
168 }
169
170 for file in Prompts::iter() {
172 if tera.get_template_names().any(|name| name == file.as_ref()) {
173 continue;
174 }
175
176 if let Some(embedded_file) = Prompts::get(file.as_ref()) {
177 match std::str::from_utf8(embedded_file.data.as_ref()) {
178 Ok(content) => {
179 if let Err(e) = tera.add_raw_template(file.as_ref(), content) {
180 eprintln!(
181 "Warning: Failed to register embedded template {}: {}",
182 file.as_ref(),
183 e
184 );
185 }
186 },
187 Err(e) => {
188 eprintln!("Warning: Embedded template {} is not valid UTF-8: {}", file.as_ref(), e);
189 },
190 }
191 }
192 }
193
194 tera.autoescape_on(vec![]);
196
197 Mutex::new(tera)
198});
199
200fn get_user_prompts_dir() -> Option<PathBuf> {
202 std::env::var("HOME")
203 .or_else(|_| std::env::var("USERPROFILE"))
204 .ok()
205 .map(|home| PathBuf::from(home).join(".llm-git").join("prompts"))
206}
207
208pub fn ensure_prompts_dir() -> Result<()> {
210 let Some(user_prompts_dir) = get_user_prompts_dir() else {
211 return Ok(());
214 };
215
216 let user_llm_git_dir = user_prompts_dir
218 .parent()
219 .ok_or_else(|| CommitGenError::Other("Invalid prompts directory path".to_string()))?;
220
221 if !user_llm_git_dir.exists() {
223 std::fs::create_dir_all(user_llm_git_dir).map_err(|e| {
224 CommitGenError::Other(format!(
225 "Failed to create directory {}: {}",
226 user_llm_git_dir.display(),
227 e
228 ))
229 })?;
230 }
231
232 if !user_prompts_dir.exists() {
234 std::fs::create_dir_all(&user_prompts_dir).map_err(|e| {
235 CommitGenError::Other(format!(
236 "Failed to create directory {}: {}",
237 user_prompts_dir.display(),
238 e
239 ))
240 })?;
241 }
242
243 for file in Prompts::iter() {
245 let file_path = user_prompts_dir.join(file.as_ref());
246
247 if let Some(parent) = file_path.parent() {
249 std::fs::create_dir_all(parent).map_err(|e| {
250 CommitGenError::Other(format!("Failed to create directory {}: {}", parent.display(), e))
251 })?;
252 }
253
254 if let Some(embedded_file) = Prompts::get(file.as_ref()) {
255 let embedded_content = embedded_file.data;
256
257 let should_write = if file_path.exists() {
259 match std::fs::read(&file_path) {
260 Ok(existing_content) => existing_content != embedded_content.as_ref(),
261 Err(_) => true, }
263 } else {
264 true };
266
267 if should_write {
268 std::fs::write(&file_path, embedded_content.as_ref()).map_err(|e| {
269 CommitGenError::Other(format!("Failed to write file {}: {}", file_path.display(), e))
270 })?;
271 }
272 }
273 }
274
275 Ok(())
276}
277
278fn register_directory_templates(tera: &mut Tera, directory: &Path, category: &str) -> Result<()> {
279 if !directory.exists() {
280 return Ok(());
281 }
282
283 for entry in std::fs::read_dir(directory).map_err(|e| {
284 CommitGenError::Other(format!(
285 "Failed to read {} templates directory {}: {}",
286 category,
287 directory.display(),
288 e
289 ))
290 })? {
291 let entry = match entry {
292 Ok(entry) => entry,
293 Err(e) => {
294 eprintln!(
295 "Warning: Failed to iterate template entry in {}: {}",
296 directory.display(),
297 e
298 );
299 continue;
300 },
301 };
302
303 let path = entry.path();
304 if path.extension().and_then(|s| s.to_str()) != Some("md") {
305 continue;
306 }
307
308 let template_name = format!(
309 "{}/{}",
310 category,
311 path
312 .file_name()
313 .and_then(|s| s.to_str())
314 .unwrap_or_default()
315 );
316
317 if let Err(e) = tera.add_template_file(&path, Some(&template_name)) {
320 eprintln!("Warning: Failed to load template file {}: {}", path.display(), e);
321 }
322 }
323
324 Ok(())
325}
326
327fn load_template_file(category: &str, variant: &str) -> Result<String> {
329 if let Some(prompts_dir) = get_user_prompts_dir() {
331 let template_path = prompts_dir.join(category).join(format!("{variant}.md"));
332 if template_path.exists() {
333 return std::fs::read_to_string(&template_path).map_err(|e| {
334 CommitGenError::Other(format!(
335 "Failed to read template file {}: {}",
336 template_path.display(),
337 e
338 ))
339 });
340 }
341 }
342
343 let embedded_key = format!("{category}/{variant}.md");
345 if let Some(bytes) = Prompts::get(&embedded_key) {
346 return std::str::from_utf8(bytes.data.as_ref())
347 .map(|s| s.to_string())
348 .map_err(|e| {
349 CommitGenError::Other(format!(
350 "Embedded template {embedded_key} is not valid UTF-8: {e}"
351 ))
352 });
353 }
354
355 Err(CommitGenError::Other(format!(
356 "Template variant '{variant}' in category '{category}' not found as user override or \
357 embedded default"
358 )))
359}
360
361pub fn render_analysis_prompt(p: &AnalysisParams<'_>) -> Result<PromptParts> {
363 let template_content = load_template_file("analysis", p.variant)?;
365
366 let mut context = Context::new();
368 context.insert("stat", p.stat);
369 context.insert("diff", p.diff);
370 context.insert("scope_candidates", p.scope_candidates);
371 if let Some(commits) = p.recent_commits {
372 context.insert("recent_commits", commits);
373 }
374 if let Some(scopes) = p.common_scopes {
375 context.insert("common_scopes", scopes);
376 }
377 if let Some(types) = p.types_description {
378 context.insert("types_description", types);
379 }
380 if let Some(ctx) = p.project_context {
381 context.insert("project_context", ctx);
382 }
383
384 render_prompt_parts(&format!("analysis/{}.md", p.variant), &template_content, &context)
385}
386
387pub fn render_summary_prompt(
389 variant: &str,
390 commit_type: &str,
391 scope: &str,
392 chars: &str,
393 details: &str,
394 stat: &str,
395 user_context: Option<&str>,
396) -> Result<PromptParts> {
397 let template_content = load_template_file("summary", variant)?;
399
400 let mut context = Context::new();
402 context.insert("commit_type", commit_type);
403 context.insert("scope", scope);
404 context.insert("chars", chars);
405 context.insert("details", details);
406 context.insert("stat", stat);
407 if let Some(ctx) = user_context {
408 context.insert("user_context", ctx);
409 }
410
411 render_prompt_parts(&format!("summary/{variant}.md"), &template_content, &context)
412}
413
414pub fn render_changelog_prompt(
416 variant: &str,
417 changelog_path: &str,
418 is_package_changelog: bool,
419 stat: &str,
420 diff: &str,
421 existing_entries: Option<&str>,
422) -> Result<PromptParts> {
423 let template_content = load_template_file("changelog", variant)?;
425
426 let mut context = Context::new();
428 context.insert("changelog_path", changelog_path);
429 context.insert("is_package_changelog", &is_package_changelog);
430 context.insert("stat", stat);
431 context.insert("diff", diff);
432 if let Some(entries) = existing_entries {
433 context.insert("existing_entries", entries);
434 }
435
436 render_prompt_parts(&format!("changelog/{variant}.md"), &template_content, &context)
437}
438
439#[derive(Serialize)]
440pub struct MapFile<'a> {
441 pub path: &'a str,
442 pub diff: &'a str,
443}
444
445pub fn render_map_prompt(
447 variant: &str,
448 files: &[MapFile<'_>],
449 context_header: &str,
450) -> Result<PromptParts> {
451 let template_content = load_template_file("map", variant)?;
452
453 let mut context = Context::new();
454 context.insert("files", files);
455 if !context_header.is_empty() {
456 context.insert("context_header", context_header);
457 }
458
459 render_prompt_parts(&format!("map/{variant}.md"), &template_content, &context)
460}
461
462pub fn render_reduce_prompt(
464 variant: &str,
465 observations: &str,
466 stat: &str,
467 scope_candidates: &str,
468 types_description: Option<&str>,
469) -> Result<PromptParts> {
470 let template_content = load_template_file("reduce", variant)?;
471
472 let mut context = Context::new();
473 context.insert("observations", observations);
474 context.insert("stat", stat);
475 context.insert("scope_candidates", scope_candidates);
476 if let Some(types_desc) = types_description {
477 context.insert("types_description", types_desc);
478 }
479
480 render_prompt_parts(&format!("reduce/{variant}.md"), &template_content, &context)
481}
482
483pub struct ComposeIntentPromptParams<'a> {
485 pub variant: &'a str,
486 pub max_commits: usize,
487 pub stat: &'a str,
488 pub snapshot_summary: &'a str,
489 pub planning_targets: &'a str,
490 pub planning_notes: &'a str,
491 pub split_bias: &'a str,
492}
493
494pub fn render_compose_intent_prompt(p: &ComposeIntentPromptParams<'_>) -> Result<PromptParts> {
496 let template_content = load_template_file("compose-intent", p.variant)?;
497
498 let mut context = Context::new();
499 context.insert("max_commits", &p.max_commits);
500 context.insert("stat", p.stat);
501 context.insert("snapshot_summary", p.snapshot_summary);
502 context.insert("planning_targets", p.planning_targets);
503 context.insert("planning_notes", p.planning_notes);
504 context.insert("split_bias", p.split_bias);
505
506 render_prompt_parts(&format!("compose-intent/{}.md", p.variant), &template_content, &context)
507}
508
509pub struct ComposeBindPromptParams<'a> {
511 pub variant: &'a str,
512 pub groups: &'a str,
513 pub ambiguous_files: &'a str,
514}
515
516pub fn render_compose_bind_prompt(p: &ComposeBindPromptParams<'_>) -> Result<PromptParts> {
518 let template_content = load_template_file("compose-bind", p.variant)?;
519
520 let mut context = Context::new();
521 context.insert("groups", p.groups);
522 context.insert("ambiguous_files", p.ambiguous_files);
523
524 render_prompt_parts(&format!("compose-bind/{}.md", p.variant), &template_content, &context)
525}
526
527pub struct FastPromptParams<'a> {
529 pub variant: &'a str,
530 pub stat: &'a str,
531 pub diff: &'a str,
532 pub scope_candidates: &'a str,
533 pub user_context: Option<&'a str>,
534 pub types_description: Option<&'a str>,
535}
536
537pub fn render_fast_prompt(p: &FastPromptParams<'_>) -> Result<PromptParts> {
539 let template_content = load_template_file("fast", p.variant)?;
540
541 let mut context = Context::new();
542 context.insert("stat", p.stat);
543 context.insert("diff", p.diff);
544 context.insert("scope_candidates", p.scope_candidates);
545 if let Some(ctx) = p.user_context {
546 context.insert("user_context", ctx);
547 }
548 if let Some(types_desc) = p.types_description {
549 context.insert("types_description", types_desc);
550 }
551
552 render_prompt_parts(&format!("fast/{}.md", p.variant), &template_content, &context)
553}
554
555#[cfg(test)]
556mod tests {
557 use super::{
558 AnalysisParams, ComposeBindPromptParams, ComposeIntentPromptParams, FastPromptParams,
559 ensure_prompts_dir, render_analysis_prompt, render_changelog_prompt, render_compose_bind_prompt,
560 render_compose_intent_prompt, render_fast_prompt, render_reduce_prompt,
561 render_summary_prompt, split_prompt_template,
562 };
563
564 #[test]
565 fn test_split_prompt_template_lf() {
566 let content = "system text\nmore system\n======USER=======\nuser body\n";
567 let (system, user) = split_prompt_template(content);
568 assert_eq!(system, Some("system text\nmore system"));
569 assert_eq!(user, "user body\n");
570 }
571
572 #[test]
573 fn test_split_prompt_template_crlf() {
574 let content = "system text\r\nmore system\r\n======USER=======\r\nuser body\r\n";
577 let (system, user) = split_prompt_template(content);
578 assert_eq!(system, Some("system text\r\nmore system"));
579 assert_eq!(user, "user body\r\n");
580 }
581
582 #[test]
583 fn test_split_prompt_template_no_separator() {
584 let content = "no separator here";
585 let (system, user) = split_prompt_template(content);
586 assert_eq!(system, None);
587 assert_eq!(user, content);
588 }
589
590 #[test]
591 fn test_render_analysis_prompt_requests_holistic_summary() {
592 ensure_prompts_dir().unwrap();
593
594 let parts = render_analysis_prompt(&AnalysisParams {
595 variant: "default",
596 stat: "src/api/client.rs | 24 +++++++++++++++---------",
597 diff: "diff --git a/src/api/client.rs b/src/api/client.rs\n",
598 scope_candidates: "api",
599 recent_commits: None,
600 common_scopes: None,
601 types_description: None,
602 project_context: None,
603 })
604 .unwrap();
605
606 assert!(parts.system.contains("Generate Summary"));
607 assert!(parts.system.contains("\"summary\""));
608 assert!(
609 parts
610 .system
611 .contains("umbrella headline for the whole changeset")
612 );
613 assert!(parts.system.contains("Does not copy detail #1"));
614 }
615
616 #[test]
617 fn test_render_changelog_prompt_variants_render() {
618 ensure_prompts_dir().unwrap();
622
623 for variant in ["default", "markdown"] {
624 let parts = render_changelog_prompt(
625 variant,
626 "CHANGELOG.md",
627 false,
628 "src/api.rs | 4 ++--",
629 "diff --git a/src/api.rs b/src/api.rs\n",
630 Some("- Added existing entry"),
631 )
632 .unwrap_or_else(|e| panic!("{variant} changelog prompt failed to render: {e}"));
633
634 assert!(parts.user.contains("src/api.rs"), "{variant}: diff missing");
635 assert!(
636 parts.user.contains("Added existing entry"),
637 "{variant}: existing entries missing"
638 );
639
640 match variant {
645 "markdown" => {
646 assert!(
647 parts.system.contains("# Added"),
648 "markdown variant must advertise markdown sections"
649 );
650 assert!(
651 !parts.system.contains("{\"entries\""),
652 "markdown variant must not advertise JSON output"
653 );
654 },
655 "default" => assert!(
656 parts.system.contains("{\"entries\""),
657 "default variant must advertise JSON output"
658 ),
659 _ => unreachable!(),
660 }
661 }
662 }
663
664 #[test]
665 fn test_render_fast_prompt_surfaces_type_guidance() {
666 ensure_prompts_dir().unwrap();
667
668 let parts = render_fast_prompt(&FastPromptParams {
669 variant: "default",
670 stat: "prompts/analysis/default.md | 5 +++++",
671 diff: "diff --git a/prompts/analysis/default.md \
672 b/prompts/analysis/default.md\n",
673 scope_candidates: "prompts",
674 user_context: None,
675 types_description: Some(
676 "**docs**: Documentation only changes\n Note: Excludes prompt template files.",
677 ),
678 })
679 .unwrap();
680
681 assert!(parts.user.contains("<commit_types>"));
683 assert!(parts.user.contains("Excludes prompt template files."));
684 assert!(parts.system.contains("not `docs`"));
685 }
686
687 #[test]
688 fn test_render_fast_prompt_omits_commit_types_when_absent() {
689 ensure_prompts_dir().unwrap();
690
691 let parts = render_fast_prompt(&FastPromptParams {
692 variant: "default",
693 stat: "src/main.rs | 5 +++++",
694 diff: "diff --git a/src/main.rs b/src/main.rs\n",
695 scope_candidates: "",
696 user_context: None,
697 types_description: None,
698 })
699 .unwrap();
700
701 assert!(!parts.user.contains("<commit_types>"));
702 }
703
704 #[test]
705 fn test_render_reduce_prompt_guides_grouped_synthesis() {
706 ensure_prompts_dir().unwrap();
707
708 let parts = render_reduce_prompt(
709 "default",
710 r#"[{"file":"src/a.rs","observations":["Added retry handling."]}]"#,
711 "src/a.rs | 10 +++++-----",
712 "api",
713 None,
714 )
715 .unwrap();
716
717 assert!(parts.system.contains("3-4 strong grouped details"));
718 assert!(
719 parts
720 .system
721 .contains("Synthesize repeated file observations")
722 );
723 assert!(parts.system.contains("over enumerating files"));
724 }
725 #[test]
726 fn test_render_compose_intent_prompt() {
727 let parts = render_compose_intent_prompt(&ComposeIntentPromptParams {
728 variant: "default",
729 max_commits: 3,
730 stat: "src/foo.rs | 10 +++++-----",
731 snapshot_summary: "- F1 src/foo.rs",
732 planning_targets: "file IDs",
733 planning_notes: "Prefer conservative grouping over speculative splitting.",
734 split_bias: "Prefer fewer groups when the split is uncertain.",
735 })
736 .unwrap();
737
738 assert!(parts.system.contains("create_compose_intent_plan"));
739 assert!(parts.user.contains("max_commits: 3"));
740 assert!(parts.user.contains("src/foo.rs"));
741 }
742
743 #[test]
744 fn test_render_summary_prompt_guides_umbrella_title() {
745 ensure_prompts_dir().unwrap();
746
747 let parts = render_summary_prompt(
748 "default",
749 "feat",
750 "api",
751 "72",
752 "Added websocket reconnects.\nUpdated client retry tests.",
753 "src/api/client.rs | 24 +++++++++++++++---------",
754 None,
755 )
756 .unwrap();
757
758 assert!(
759 parts
760 .system
761 .contains("umbrella description for the whole changeset")
762 );
763 assert!(parts.system.contains("not as candidate titles to copy"));
764 assert!(
765 parts
766 .system
767 .contains("does not copy or narrowly paraphrase one detail point")
768 );
769 assert!(parts.user.contains("<detail_points>"));
770 assert!(parts.user.contains("Added websocket reconnects."));
771 assert!(parts.user.contains("Updated client retry tests."));
772 }
773
774 #[test]
775 fn test_render_compose_bind_prompt() {
776 let parts = render_compose_bind_prompt(&ComposeBindPromptParams {
777 variant: "default",
778 groups: "- G1 [feat(api)] Added endpoint",
779 ambiguous_files: "- F2 src/api.rs candidates: G1",
780 })
781 .unwrap();
782
783 assert!(parts.system.contains("bind_compose_hunks"));
784 assert!(parts.user.contains("G1"));
785 assert!(parts.user.contains("src/api.rs"));
786 }
787}