1use std::{
2 path::{Path, PathBuf},
3 sync::LazyLock,
4};
5
6use parking_lot::Mutex;
7use rust_embed::RustEmbed;
8use serde::Serialize;
9use tera::{Context, Tera};
10
11use crate::error::{CommitGenError, Result};
12
13pub struct PromptParts {
15 pub system: String,
16 pub user: String,
17}
18
19const USER_SEPARATOR_MARKER: &str = "======USER=======";
20
21fn find_user_separator(content: &str) -> Option<(usize, usize)> {
30 let marker_pos = content.find(USER_SEPARATOR_MARKER)?;
31 let system_end = if marker_pos >= 2 && &content[marker_pos - 2..marker_pos] == "\r\n" {
34 marker_pos - 2
35 } else if marker_pos >= 1 && &content[marker_pos - 1..marker_pos] == "\n" {
36 marker_pos - 1
37 } else {
38 marker_pos
40 };
41 let after_marker = marker_pos + USER_SEPARATOR_MARKER.len();
42 let user_start = if content.get(after_marker..after_marker + 2) == Some("\r\n") {
43 after_marker + 2
44 } else if content.get(after_marker..after_marker + 1) == Some("\n") {
45 after_marker + 1
46 } else {
47 after_marker
48 };
49 Some((system_end, user_start))
50}
51
52fn split_prompt_template(template_content: &str) -> (Option<&str>, &str) {
54 if let Some((system_end, user_start)) = find_user_separator(template_content) {
55 (Some(&template_content[..system_end]), &template_content[user_start..])
56 } else {
57 (None, template_content)
58 }
59}
60
61fn ensure_static_system_prompt(system_template: &str, template_name: &str) -> Result<()> {
63 let has_template_tags = system_template.contains("{{")
64 || system_template.contains("{%")
65 || system_template.contains("{#");
66
67 if has_template_tags {
68 return Err(CommitGenError::Other(format!(
69 "Template '{template_name}' contains dynamic tags in system section. Move interpolated \
70 content below ======USER=======."
71 )));
72 }
73
74 Ok(())
75}
76
77fn render_prompt_parts(
79 template_name: &str,
80 template_content: &str,
81 context: &Context,
82) -> Result<PromptParts> {
83 let (system_template, user_template) = split_prompt_template(template_content);
84
85 let system = if let Some(system_template) = system_template {
86 ensure_static_system_prompt(system_template, template_name)?;
87 system_template.trim().to_string()
88 } else {
89 String::new()
90 };
91
92 let mut tera = TERA.lock();
93 let rendered_user = tera.render_str(user_template, context).map_err(|e| {
94 CommitGenError::Other(format!("Failed to render {template_name} prompt template: {e}"))
95 })?;
96
97 Ok(PromptParts { system, user: rendered_user.trim().to_string() })
98}
99
100#[derive(Default)]
102pub struct AnalysisParams<'a> {
103 pub variant: &'a str,
104 pub stat: &'a str,
105 pub diff: &'a str,
106 pub scope_candidates: &'a str,
107 pub recent_commits: Option<&'a str>,
108 pub common_scopes: Option<&'a str>,
109 pub types_description: Option<&'a str>,
110 pub project_context: Option<&'a str>,
111}
112
113#[derive(RustEmbed)]
115#[folder = "prompts/"]
116#[include = "**/*.md"]
117struct Prompts;
118
119static TERA: LazyLock<Mutex<Tera>> = LazyLock::new(|| {
122 if let Err(e) = ensure_prompts_dir() {
124 eprintln!("Warning: Failed to initialize prompts directory: {e}");
125 }
126
127 let mut tera = Tera::default();
128
129 if let Some(prompts_dir) = get_user_prompts_dir() {
131 if let Err(e) =
132 register_directory_templates(&mut tera, &prompts_dir.join("analysis"), "analysis")
133 {
134 eprintln!("Warning: {e}");
135 }
136 if let Err(e) =
137 register_directory_templates(&mut tera, &prompts_dir.join("summary"), "summary")
138 {
139 eprintln!("Warning: {e}");
140 }
141 if let Err(e) =
142 register_directory_templates(&mut tera, &prompts_dir.join("changelog"), "changelog")
143 {
144 eprintln!("Warning: {e}");
145 }
146 if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("map"), "map") {
147 eprintln!("Warning: {e}");
148 }
149 if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("reduce"), "reduce")
150 {
151 eprintln!("Warning: {e}");
152 }
153 if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("fast"), "fast") {
154 eprintln!("Warning: {e}");
155 }
156 if let Err(e) = register_directory_templates(
157 &mut tera,
158 &prompts_dir.join("compose-intent"),
159 "compose-intent",
160 ) {
161 eprintln!("Warning: {e}");
162 }
163 if let Err(e) =
164 register_directory_templates(&mut tera, &prompts_dir.join("compose-bind"), "compose-bind")
165 {
166 eprintln!("Warning: {e}");
167 }
168 }
169
170 for file in Prompts::iter() {
172 if tera.get_template_names().any(|name| name == file.as_ref()) {
173 continue;
174 }
175
176 if let Some(embedded_file) = Prompts::get(file.as_ref()) {
177 match std::str::from_utf8(embedded_file.data.as_ref()) {
178 Ok(content) => {
179 if let Err(e) = tera.add_raw_template(file.as_ref(), content) {
180 eprintln!(
181 "Warning: Failed to register embedded template {}: {}",
182 file.as_ref(),
183 e
184 );
185 }
186 },
187 Err(e) => {
188 eprintln!("Warning: Embedded template {} is not valid UTF-8: {}", file.as_ref(), e);
189 },
190 }
191 }
192 }
193
194 tera.autoescape_on(vec![]);
196
197 Mutex::new(tera)
198});
199
200fn get_user_prompts_dir() -> Option<PathBuf> {
202 std::env::var("HOME")
203 .or_else(|_| std::env::var("USERPROFILE"))
204 .ok()
205 .map(|home| PathBuf::from(home).join(".llm-git").join("prompts"))
206}
207
208pub fn ensure_prompts_dir() -> Result<()> {
210 let Some(user_prompts_dir) = get_user_prompts_dir() else {
211 return Ok(());
214 };
215
216 let user_llm_git_dir = user_prompts_dir
218 .parent()
219 .ok_or_else(|| CommitGenError::Other("Invalid prompts directory path".to_string()))?;
220
221 if !user_llm_git_dir.exists() {
223 std::fs::create_dir_all(user_llm_git_dir).map_err(|e| {
224 CommitGenError::Other(format!(
225 "Failed to create directory {}: {}",
226 user_llm_git_dir.display(),
227 e
228 ))
229 })?;
230 }
231
232 if !user_prompts_dir.exists() {
234 std::fs::create_dir_all(&user_prompts_dir).map_err(|e| {
235 CommitGenError::Other(format!(
236 "Failed to create directory {}: {}",
237 user_prompts_dir.display(),
238 e
239 ))
240 })?;
241 }
242
243 for file in Prompts::iter() {
245 let file_path = user_prompts_dir.join(file.as_ref());
246
247 if let Some(parent) = file_path.parent() {
249 std::fs::create_dir_all(parent).map_err(|e| {
250 CommitGenError::Other(format!("Failed to create directory {}: {}", parent.display(), e))
251 })?;
252 }
253
254 if let Some(embedded_file) = Prompts::get(file.as_ref()) {
255 let embedded_content = embedded_file.data;
256
257 let should_write = if file_path.exists() {
259 match std::fs::read(&file_path) {
260 Ok(existing_content) => existing_content != embedded_content.as_ref(),
261 Err(_) => true, }
263 } else {
264 true };
266
267 if should_write {
268 std::fs::write(&file_path, embedded_content.as_ref()).map_err(|e| {
269 CommitGenError::Other(format!("Failed to write file {}: {}", file_path.display(), e))
270 })?;
271 }
272 }
273 }
274
275 Ok(())
276}
277
278fn register_directory_templates(tera: &mut Tera, directory: &Path, category: &str) -> Result<()> {
279 if !directory.exists() {
280 return Ok(());
281 }
282
283 for entry in std::fs::read_dir(directory).map_err(|e| {
284 CommitGenError::Other(format!(
285 "Failed to read {} templates directory {}: {}",
286 category,
287 directory.display(),
288 e
289 ))
290 })? {
291 let entry = match entry {
292 Ok(entry) => entry,
293 Err(e) => {
294 eprintln!(
295 "Warning: Failed to iterate template entry in {}: {}",
296 directory.display(),
297 e
298 );
299 continue;
300 },
301 };
302
303 let path = entry.path();
304 if path.extension().and_then(|s| s.to_str()) != Some("md") {
305 continue;
306 }
307
308 let template_name = format!(
309 "{}/{}",
310 category,
311 path
312 .file_name()
313 .and_then(|s| s.to_str())
314 .unwrap_or_default()
315 );
316
317 if let Err(e) = tera.add_template_file(&path, Some(&template_name)) {
320 eprintln!("Warning: Failed to load template file {}: {}", path.display(), e);
321 }
322 }
323
324 Ok(())
325}
326
327fn load_template_file(category: &str, variant: &str) -> Result<String> {
329 if let Some(prompts_dir) = get_user_prompts_dir() {
331 let template_path = prompts_dir.join(category).join(format!("{variant}.md"));
332 if template_path.exists() {
333 return std::fs::read_to_string(&template_path).map_err(|e| {
334 CommitGenError::Other(format!(
335 "Failed to read template file {}: {}",
336 template_path.display(),
337 e
338 ))
339 });
340 }
341 }
342
343 let embedded_key = format!("{category}/{variant}.md");
345 if let Some(bytes) = Prompts::get(&embedded_key) {
346 return std::str::from_utf8(bytes.data.as_ref())
347 .map(|s| s.to_string())
348 .map_err(|e| {
349 CommitGenError::Other(format!(
350 "Embedded template {embedded_key} is not valid UTF-8: {e}"
351 ))
352 });
353 }
354
355 Err(CommitGenError::Other(format!(
356 "Template variant '{variant}' in category '{category}' not found as user override or \
357 embedded default"
358 )))
359}
360
361pub fn render_analysis_prompt(p: &AnalysisParams<'_>) -> Result<PromptParts> {
363 let template_content = load_template_file("analysis", p.variant)?;
365
366 let mut context = Context::new();
368 context.insert("stat", p.stat);
369 context.insert("diff", p.diff);
370 context.insert("scope_candidates", p.scope_candidates);
371 if let Some(commits) = p.recent_commits {
372 context.insert("recent_commits", commits);
373 }
374 if let Some(scopes) = p.common_scopes {
375 context.insert("common_scopes", scopes);
376 }
377 if let Some(types) = p.types_description {
378 context.insert("types_description", types);
379 }
380 if let Some(ctx) = p.project_context {
381 context.insert("project_context", ctx);
382 }
383
384 render_prompt_parts(&format!("analysis/{}.md", p.variant), &template_content, &context)
385}
386
387pub fn render_summary_prompt(
389 variant: &str,
390 commit_type: &str,
391 scope: &str,
392 chars: &str,
393 details: &str,
394 stat: &str,
395 user_context: Option<&str>,
396) -> Result<PromptParts> {
397 let template_content = load_template_file("summary", variant)?;
399
400 let mut context = Context::new();
402 context.insert("commit_type", commit_type);
403 context.insert("scope", scope);
404 context.insert("chars", chars);
405 context.insert("details", details);
406 context.insert("stat", stat);
407 if let Some(ctx) = user_context {
408 context.insert("user_context", ctx);
409 }
410
411 render_prompt_parts(&format!("summary/{variant}.md"), &template_content, &context)
412}
413
414pub fn render_changelog_prompt(
416 variant: &str,
417 changelog_path: &str,
418 is_package_changelog: bool,
419 stat: &str,
420 diff: &str,
421 existing_entries: Option<&str>,
422) -> Result<PromptParts> {
423 let template_content = load_template_file("changelog", variant)?;
425
426 let mut context = Context::new();
428 context.insert("changelog_path", changelog_path);
429 context.insert("is_package_changelog", &is_package_changelog);
430 context.insert("stat", stat);
431 context.insert("diff", diff);
432 if let Some(entries) = existing_entries {
433 context.insert("existing_entries", entries);
434 }
435
436 render_prompt_parts(&format!("changelog/{variant}.md"), &template_content, &context)
437}
438
439#[derive(Serialize)]
440pub struct MapFile<'a> {
441 pub path: &'a str,
442 pub diff: &'a str,
443}
444
445pub fn render_map_prompt(
447 variant: &str,
448 files: &[MapFile<'_>],
449 context_header: &str,
450) -> Result<PromptParts> {
451 let template_content = load_template_file("map", variant)?;
452
453 let mut context = Context::new();
454 context.insert("files", files);
455 if !context_header.is_empty() {
456 context.insert("context_header", context_header);
457 }
458
459 render_prompt_parts(&format!("map/{variant}.md"), &template_content, &context)
460}
461
462pub fn render_reduce_prompt(
464 variant: &str,
465 observations: &str,
466 stat: &str,
467 scope_candidates: &str,
468 types_description: Option<&str>,
469) -> Result<PromptParts> {
470 let template_content = load_template_file("reduce", variant)?;
471
472 let mut context = Context::new();
473 context.insert("observations", observations);
474 context.insert("stat", stat);
475 context.insert("scope_candidates", scope_candidates);
476 if let Some(types_desc) = types_description {
477 context.insert("types_description", types_desc);
478 }
479
480 render_prompt_parts(&format!("reduce/{variant}.md"), &template_content, &context)
481}
482
483pub struct ComposeIntentPromptParams<'a> {
485 pub variant: &'a str,
486 pub max_commits: usize,
487 pub stat: &'a str,
488 pub snapshot_summary: &'a str,
489 pub planning_targets: &'a str,
490 pub planning_notes: &'a str,
491 pub split_bias: &'a str,
492}
493
494pub fn render_compose_intent_prompt(p: &ComposeIntentPromptParams<'_>) -> Result<PromptParts> {
496 let template_content = load_template_file("compose-intent", p.variant)?;
497
498 let mut context = Context::new();
499 context.insert("max_commits", &p.max_commits);
500 context.insert("stat", p.stat);
501 context.insert("snapshot_summary", p.snapshot_summary);
502 context.insert("planning_targets", p.planning_targets);
503 context.insert("planning_notes", p.planning_notes);
504 context.insert("split_bias", p.split_bias);
505
506 render_prompt_parts(&format!("compose-intent/{}.md", p.variant), &template_content, &context)
507}
508
509pub struct ComposeBindPromptParams<'a> {
511 pub variant: &'a str,
512 pub groups: &'a str,
513 pub ambiguous_files: &'a str,
514}
515
516pub fn render_compose_bind_prompt(p: &ComposeBindPromptParams<'_>) -> Result<PromptParts> {
518 let template_content = load_template_file("compose-bind", p.variant)?;
519
520 let mut context = Context::new();
521 context.insert("groups", p.groups);
522 context.insert("ambiguous_files", p.ambiguous_files);
523
524 render_prompt_parts(&format!("compose-bind/{}.md", p.variant), &template_content, &context)
525}
526
527pub struct FastPromptParams<'a> {
529 pub variant: &'a str,
530 pub stat: &'a str,
531 pub diff: &'a str,
532 pub scope_candidates: &'a str,
533 pub user_context: Option<&'a str>,
534}
535
536pub fn render_fast_prompt(p: &FastPromptParams<'_>) -> Result<PromptParts> {
538 let template_content = load_template_file("fast", p.variant)?;
539
540 let mut context = Context::new();
541 context.insert("stat", p.stat);
542 context.insert("diff", p.diff);
543 context.insert("scope_candidates", p.scope_candidates);
544 if let Some(ctx) = p.user_context {
545 context.insert("user_context", ctx);
546 }
547
548 render_prompt_parts(&format!("fast/{}.md", p.variant), &template_content, &context)
549}
550
551#[cfg(test)]
552mod tests {
553 use super::{
554 AnalysisParams, ComposeBindPromptParams, ComposeIntentPromptParams, ensure_prompts_dir,
555 render_analysis_prompt, render_compose_bind_prompt, render_compose_intent_prompt,
556 render_reduce_prompt, render_summary_prompt, split_prompt_template,
557 };
558
559 #[test]
560 fn test_split_prompt_template_lf() {
561 let content = "system text\nmore system\n======USER=======\nuser body\n";
562 let (system, user) = split_prompt_template(content);
563 assert_eq!(system, Some("system text\nmore system"));
564 assert_eq!(user, "user body\n");
565 }
566
567 #[test]
568 fn test_split_prompt_template_crlf() {
569 let content = "system text\r\nmore system\r\n======USER=======\r\nuser body\r\n";
572 let (system, user) = split_prompt_template(content);
573 assert_eq!(system, Some("system text\r\nmore system"));
574 assert_eq!(user, "user body\r\n");
575 }
576
577 #[test]
578 fn test_split_prompt_template_no_separator() {
579 let content = "no separator here";
580 let (system, user) = split_prompt_template(content);
581 assert_eq!(system, None);
582 assert_eq!(user, content);
583 }
584
585 #[test]
586 fn test_render_analysis_prompt_requests_holistic_summary() {
587 ensure_prompts_dir().unwrap();
588
589 let parts = render_analysis_prompt(&AnalysisParams {
590 variant: "default",
591 stat: "src/api/client.rs | 24 +++++++++++++++---------",
592 diff: "diff --git a/src/api/client.rs b/src/api/client.rs\n",
593 scope_candidates: "api",
594 recent_commits: None,
595 common_scopes: None,
596 types_description: None,
597 project_context: None,
598 })
599 .unwrap();
600
601 assert!(parts.system.contains("Generate Summary"));
602 assert!(parts.system.contains("\"summary\""));
603 assert!(
604 parts
605 .system
606 .contains("umbrella headline for the whole changeset")
607 );
608 assert!(parts.system.contains("Does not copy detail #1"));
609 }
610
611 #[test]
612 fn test_render_reduce_prompt_guides_grouped_synthesis() {
613 ensure_prompts_dir().unwrap();
614
615 let parts = render_reduce_prompt(
616 "default",
617 r#"[{"file":"src/a.rs","observations":["Added retry handling."]}]"#,
618 "src/a.rs | 10 +++++-----",
619 "api",
620 None,
621 )
622 .unwrap();
623
624 assert!(parts.system.contains("3-4 strong grouped details"));
625 assert!(
626 parts
627 .system
628 .contains("Synthesize repeated file observations")
629 );
630 assert!(parts.system.contains("over enumerating files"));
631 }
632 #[test]
633 fn test_render_compose_intent_prompt() {
634 let parts = render_compose_intent_prompt(&ComposeIntentPromptParams {
635 variant: "default",
636 max_commits: 3,
637 stat: "src/foo.rs | 10 +++++-----",
638 snapshot_summary: "- F1 src/foo.rs",
639 planning_targets: "file IDs",
640 planning_notes: "Prefer conservative grouping over speculative splitting.",
641 split_bias: "Prefer fewer groups when the split is uncertain.",
642 })
643 .unwrap();
644
645 assert!(parts.system.contains("create_compose_intent_plan"));
646 assert!(parts.user.contains("max_commits: 3"));
647 assert!(parts.user.contains("src/foo.rs"));
648 }
649
650 #[test]
651 fn test_render_summary_prompt_guides_umbrella_title() {
652 ensure_prompts_dir().unwrap();
653
654 let parts = render_summary_prompt(
655 "default",
656 "feat",
657 "api",
658 "72",
659 "Added websocket reconnects.\nUpdated client retry tests.",
660 "src/api/client.rs | 24 +++++++++++++++---------",
661 None,
662 )
663 .unwrap();
664
665 assert!(
666 parts
667 .system
668 .contains("umbrella description for the whole changeset")
669 );
670 assert!(parts.system.contains("not as candidate titles to copy"));
671 assert!(
672 parts
673 .system
674 .contains("does not copy or narrowly paraphrase one detail point")
675 );
676 assert!(parts.user.contains("<detail_points>"));
677 assert!(parts.user.contains("Added websocket reconnects."));
678 assert!(parts.user.contains("Updated client retry tests."));
679 }
680
681 #[test]
682 fn test_render_compose_bind_prompt() {
683 let parts = render_compose_bind_prompt(&ComposeBindPromptParams {
684 variant: "default",
685 groups: "- G1 [feat(api)] Added endpoint",
686 ambiguous_files: "- F2 src/api.rs candidates: G1",
687 })
688 .unwrap();
689
690 assert!(parts.system.contains("bind_compose_hunks"));
691 assert!(parts.user.contains("G1"));
692 assert!(parts.user.contains("src/api.rs"));
693 }
694}