Skip to main content

llm_git/
templates.rs

1use std::{
2   path::{Path, PathBuf},
3   sync::LazyLock,
4};
5
6use parking_lot::Mutex;
7use rust_embed::RustEmbed;
8use serde::Serialize;
9use tera::{Context, Tera};
10
11use crate::error::{CommitGenError, Result};
12
13/// Rendered prompt split into system and user parts.
14pub struct PromptParts {
15   pub system: String,
16   pub user:   String,
17}
18
19const USER_SEPARATOR_MARKER: &str = "======USER=======";
20
21/// Locate the USER separator and return (`system_end`, `user_start`) byte
22/// offsets.
23///
24/// The marker may be surrounded by either LF or CRLF line endings — the latter
25/// happens on Windows checkouts where Git's default `core.autocrlf=true`
26/// converts embedded `.md` templates to CRLF. We strip whichever line
27/// terminator wraps the marker so the system section never includes a trailing
28/// blank line and the user section never starts with one.
29fn find_user_separator(content: &str) -> Option<(usize, usize)> {
30   let marker_pos = content.find(USER_SEPARATOR_MARKER)?;
31   // System section ends immediately before the line break that precedes the
32   // marker. Accept CRLF or LF.
33   let system_end = if marker_pos >= 2 && &content[marker_pos - 2..marker_pos] == "\r\n" {
34      marker_pos - 2
35   } else if marker_pos >= 1 && &content[marker_pos - 1..marker_pos] == "\n" {
36      marker_pos - 1
37   } else {
38      // Marker appears at start of file or without preceding newline.
39      marker_pos
40   };
41   let after_marker = marker_pos + USER_SEPARATOR_MARKER.len();
42   let user_start = if content.get(after_marker..after_marker + 2) == Some("\r\n") {
43      after_marker + 2
44   } else if content.get(after_marker..after_marker + 1) == Some("\n") {
45      after_marker + 1
46   } else {
47      after_marker
48   };
49   Some((system_end, user_start))
50}
51
52/// Split a prompt template into static system text and templated user content.
53fn split_prompt_template(template_content: &str) -> (Option<&str>, &str) {
54   if let Some((system_end, user_start)) = find_user_separator(template_content) {
55      (Some(&template_content[..system_end]), &template_content[user_start..])
56   } else {
57      (None, template_content)
58   }
59}
60
61/// Ensure system prompt does not include Tera interpolation tags.
62fn ensure_static_system_prompt(system_template: &str, template_name: &str) -> Result<()> {
63   let has_template_tags = system_template.contains("{{")
64      || system_template.contains("{%")
65      || system_template.contains("{#");
66
67   if has_template_tags {
68      return Err(CommitGenError::Other(format!(
69         "Template '{template_name}' contains dynamic tags in system section. Move interpolated \
70          content below ======USER=======."
71      )));
72   }
73
74   Ok(())
75}
76
77/// Render a prompt template and enforce static system/user separation.
78fn render_prompt_parts(
79   template_name: &str,
80   template_content: &str,
81   context: &Context,
82) -> Result<PromptParts> {
83   let (system_template, user_template) = split_prompt_template(template_content);
84
85   let system = if let Some(system_template) = system_template {
86      ensure_static_system_prompt(system_template, template_name)?;
87      system_template.trim().to_string()
88   } else {
89      String::new()
90   };
91
92   let mut tera = TERA.lock();
93   let rendered_user = tera.render_str(user_template, context).map_err(|e| {
94      CommitGenError::Other(format!("Failed to render {template_name} prompt template: {e}"))
95   })?;
96
97   Ok(PromptParts { system, user: rendered_user.trim().to_string() })
98}
99
100/// Parameters for rendering the analysis prompt template.
101#[derive(Default)]
102pub struct AnalysisParams<'a> {
103   pub variant:           &'a str,
104   pub stat:              &'a str,
105   pub diff:              &'a str,
106   pub scope_candidates:  &'a str,
107   pub recent_commits:    Option<&'a str>,
108   pub common_scopes:     Option<&'a str>,
109   pub types_description: Option<&'a str>,
110   pub project_context:   Option<&'a str>,
111}
112
113/// Embedded prompts folder (compiled into binary)
114#[derive(RustEmbed)]
115#[folder = "prompts/"]
116#[include = "**/*.md"]
117struct Prompts;
118
119/// Global Tera instance for template rendering (wrapped in Mutex for mutable
120/// access)
121static TERA: LazyLock<Mutex<Tera>> = LazyLock::new(|| {
122   // Ensure prompts are initialized
123   if let Err(e) = ensure_prompts_dir() {
124      eprintln!("Warning: Failed to initialize prompts directory: {e}");
125   }
126
127   let mut tera = Tera::default();
128
129   // Load templates from user prompts directory first so they take precedence.
130   if let Some(prompts_dir) = get_user_prompts_dir() {
131      if let Err(e) =
132         register_directory_templates(&mut tera, &prompts_dir.join("analysis"), "analysis")
133      {
134         eprintln!("Warning: {e}");
135      }
136      if let Err(e) =
137         register_directory_templates(&mut tera, &prompts_dir.join("summary"), "summary")
138      {
139         eprintln!("Warning: {e}");
140      }
141      if let Err(e) =
142         register_directory_templates(&mut tera, &prompts_dir.join("changelog"), "changelog")
143      {
144         eprintln!("Warning: {e}");
145      }
146      if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("map"), "map") {
147         eprintln!("Warning: {e}");
148      }
149      if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("reduce"), "reduce")
150      {
151         eprintln!("Warning: {e}");
152      }
153      if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("fast"), "fast") {
154         eprintln!("Warning: {e}");
155      }
156      if let Err(e) = register_directory_templates(
157         &mut tera,
158         &prompts_dir.join("compose-intent"),
159         "compose-intent",
160      ) {
161         eprintln!("Warning: {e}");
162      }
163      if let Err(e) =
164         register_directory_templates(&mut tera, &prompts_dir.join("compose-bind"), "compose-bind")
165      {
166         eprintln!("Warning: {e}");
167      }
168   }
169
170   // Register embedded templates that aren't overridden by user-provided files.
171   for file in Prompts::iter() {
172      if tera.get_template_names().any(|name| name == file.as_ref()) {
173         continue;
174      }
175
176      if let Some(embedded_file) = Prompts::get(file.as_ref()) {
177         match std::str::from_utf8(embedded_file.data.as_ref()) {
178            Ok(content) => {
179               if let Err(e) = tera.add_raw_template(file.as_ref(), content) {
180                  eprintln!(
181                     "Warning: Failed to register embedded template {}: {}",
182                     file.as_ref(),
183                     e
184                  );
185               }
186            },
187            Err(e) => {
188               eprintln!("Warning: Embedded template {} is not valid UTF-8: {}", file.as_ref(), e);
189            },
190         }
191      }
192   }
193
194   // Disable auto-escaping for markdown files
195   tera.autoescape_on(vec![]);
196
197   Mutex::new(tera)
198});
199
200/// Determine user prompts directory (~/.llm-git/prompts/) if a home dir exists.
201fn get_user_prompts_dir() -> Option<PathBuf> {
202   std::env::var("HOME")
203      .or_else(|_| std::env::var("USERPROFILE"))
204      .ok()
205      .map(|home| PathBuf::from(home).join(".llm-git").join("prompts"))
206}
207
208/// Initialize prompts directory by unpacking embedded prompts if needed
209pub fn ensure_prompts_dir() -> Result<()> {
210   let Some(user_prompts_dir) = get_user_prompts_dir() else {
211      // No HOME/USERPROFILE, so we can't materialize templates on disk.
212      // We'll fall back to the embedded prompts in-memory.
213      return Ok(());
214   };
215
216   // Safety: prompts dir always has a parent (…/.llm-git/prompts)
217   let user_llm_git_dir = user_prompts_dir
218      .parent()
219      .ok_or_else(|| CommitGenError::Other("Invalid prompts directory path".to_string()))?;
220
221   // Create ~/.llm-git directory if it doesn't exist
222   if !user_llm_git_dir.exists() {
223      std::fs::create_dir_all(user_llm_git_dir).map_err(|e| {
224         CommitGenError::Other(format!(
225            "Failed to create directory {}: {}",
226            user_llm_git_dir.display(),
227            e
228         ))
229      })?;
230   }
231
232   // Create prompts subdirectory if it doesn't exist
233   if !user_prompts_dir.exists() {
234      std::fs::create_dir_all(&user_prompts_dir).map_err(|e| {
235         CommitGenError::Other(format!(
236            "Failed to create directory {}: {}",
237            user_prompts_dir.display(),
238            e
239         ))
240      })?;
241   }
242
243   // Unpack embedded prompts, updating if content differs
244   for file in Prompts::iter() {
245      let file_path = user_prompts_dir.join(file.as_ref());
246
247      // Create parent directories if needed
248      if let Some(parent) = file_path.parent() {
249         std::fs::create_dir_all(parent).map_err(|e| {
250            CommitGenError::Other(format!("Failed to create directory {}: {}", parent.display(), e))
251         })?;
252      }
253
254      if let Some(embedded_file) = Prompts::get(file.as_ref()) {
255         let embedded_content = embedded_file.data;
256
257         // Check if we need to write: file doesn't exist OR content differs
258         let should_write = if file_path.exists() {
259            match std::fs::read(&file_path) {
260               Ok(existing_content) => existing_content != embedded_content.as_ref(),
261               Err(_) => true, // Can't read, assume we should write
262            }
263         } else {
264            true // File doesn't exist
265         };
266
267         if should_write {
268            std::fs::write(&file_path, embedded_content.as_ref()).map_err(|e| {
269               CommitGenError::Other(format!("Failed to write file {}: {}", file_path.display(), e))
270            })?;
271         }
272      }
273   }
274
275   Ok(())
276}
277
278fn register_directory_templates(tera: &mut Tera, directory: &Path, category: &str) -> Result<()> {
279   if !directory.exists() {
280      return Ok(());
281   }
282
283   for entry in std::fs::read_dir(directory).map_err(|e| {
284      CommitGenError::Other(format!(
285         "Failed to read {} templates directory {}: {}",
286         category,
287         directory.display(),
288         e
289      ))
290   })? {
291      let entry = match entry {
292         Ok(entry) => entry,
293         Err(e) => {
294            eprintln!(
295               "Warning: Failed to iterate template entry in {}: {}",
296               directory.display(),
297               e
298            );
299            continue;
300         },
301      };
302
303      let path = entry.path();
304      if path.extension().and_then(|s| s.to_str()) != Some("md") {
305         continue;
306      }
307
308      let template_name = format!(
309         "{}/{}",
310         category,
311         path
312            .file_name()
313            .and_then(|s| s.to_str())
314            .unwrap_or_default()
315      );
316
317      // Add template (overwrites if exists, allowing user files to override embedded
318      // defaults)
319      if let Err(e) = tera.add_template_file(&path, Some(&template_name)) {
320         eprintln!("Warning: Failed to load template file {}: {}", path.display(), e);
321      }
322   }
323
324   Ok(())
325}
326
327/// Load template content from file (for dynamic user templates)
328fn load_template_file(category: &str, variant: &str) -> Result<String> {
329   // Prefer user-provided template if available.
330   if let Some(prompts_dir) = get_user_prompts_dir() {
331      let template_path = prompts_dir.join(category).join(format!("{variant}.md"));
332      if template_path.exists() {
333         return std::fs::read_to_string(&template_path).map_err(|e| {
334            CommitGenError::Other(format!(
335               "Failed to read template file {}: {}",
336               template_path.display(),
337               e
338            ))
339         });
340      }
341   }
342
343   // Fallback to embedded template bundled with the binary.
344   let embedded_key = format!("{category}/{variant}.md");
345   if let Some(bytes) = Prompts::get(&embedded_key) {
346      return std::str::from_utf8(bytes.data.as_ref())
347         .map(|s| s.to_string())
348         .map_err(|e| {
349            CommitGenError::Other(format!(
350               "Embedded template {embedded_key} is not valid UTF-8: {e}"
351            ))
352         });
353   }
354
355   Err(CommitGenError::Other(format!(
356      "Template variant '{variant}' in category '{category}' not found as user override or \
357       embedded default"
358   )))
359}
360
361/// Render analysis prompt template
362pub fn render_analysis_prompt(p: &AnalysisParams<'_>) -> Result<PromptParts> {
363   // Try to load template dynamically (supports user-added templates)
364   let template_content = load_template_file("analysis", p.variant)?;
365
366   // Create context with all the data
367   let mut context = Context::new();
368   context.insert("stat", p.stat);
369   context.insert("diff", p.diff);
370   context.insert("scope_candidates", p.scope_candidates);
371   if let Some(commits) = p.recent_commits {
372      context.insert("recent_commits", commits);
373   }
374   if let Some(scopes) = p.common_scopes {
375      context.insert("common_scopes", scopes);
376   }
377   if let Some(types) = p.types_description {
378      context.insert("types_description", types);
379   }
380   if let Some(ctx) = p.project_context {
381      context.insert("project_context", ctx);
382   }
383
384   render_prompt_parts(&format!("analysis/{}.md", p.variant), &template_content, &context)
385}
386
387/// Render summary prompt template
388pub fn render_summary_prompt(
389   variant: &str,
390   commit_type: &str,
391   scope: &str,
392   chars: &str,
393   details: &str,
394   stat: &str,
395   user_context: Option<&str>,
396) -> Result<PromptParts> {
397   // Try to load template dynamically (supports user-added templates)
398   let template_content = load_template_file("summary", variant)?;
399
400   // Create context with all the data
401   let mut context = Context::new();
402   context.insert("commit_type", commit_type);
403   context.insert("scope", scope);
404   context.insert("chars", chars);
405   context.insert("details", details);
406   context.insert("stat", stat);
407   if let Some(ctx) = user_context {
408      context.insert("user_context", ctx);
409   }
410
411   render_prompt_parts(&format!("summary/{variant}.md"), &template_content, &context)
412}
413
414/// Render changelog prompt template
415pub fn render_changelog_prompt(
416   variant: &str,
417   changelog_path: &str,
418   is_package_changelog: bool,
419   stat: &str,
420   diff: &str,
421   existing_entries: Option<&str>,
422) -> Result<PromptParts> {
423   // Try to load template dynamically (supports user-added templates)
424   let template_content = load_template_file("changelog", variant)?;
425
426   // Create context with all the data
427   let mut context = Context::new();
428   context.insert("changelog_path", changelog_path);
429   context.insert("is_package_changelog", &is_package_changelog);
430   context.insert("stat", stat);
431   context.insert("diff", diff);
432   if let Some(entries) = existing_entries {
433      context.insert("existing_entries", entries);
434   }
435
436   render_prompt_parts(&format!("changelog/{variant}.md"), &template_content, &context)
437}
438
439#[derive(Serialize)]
440pub struct MapFile<'a> {
441   pub path: &'a str,
442   pub diff: &'a str,
443}
444
445/// Render map prompt template (batched file observation extraction)
446pub fn render_map_prompt(
447   variant: &str,
448   files: &[MapFile<'_>],
449   context_header: &str,
450) -> Result<PromptParts> {
451   let template_content = load_template_file("map", variant)?;
452
453   let mut context = Context::new();
454   context.insert("files", files);
455   if !context_header.is_empty() {
456      context.insert("context_header", context_header);
457   }
458
459   render_prompt_parts(&format!("map/{variant}.md"), &template_content, &context)
460}
461
462/// Render reduce prompt template (synthesis from observations)
463pub fn render_reduce_prompt(
464   variant: &str,
465   observations: &str,
466   stat: &str,
467   scope_candidates: &str,
468   types_description: Option<&str>,
469) -> Result<PromptParts> {
470   let template_content = load_template_file("reduce", variant)?;
471
472   let mut context = Context::new();
473   context.insert("observations", observations);
474   context.insert("stat", stat);
475   context.insert("scope_candidates", scope_candidates);
476   if let Some(types_desc) = types_description {
477      context.insert("types_description", types_desc);
478   }
479
480   render_prompt_parts(&format!("reduce/{variant}.md"), &template_content, &context)
481}
482
483/// Parameters for rendering the compose intent prompt template.
484pub struct ComposeIntentPromptParams<'a> {
485   pub variant:          &'a str,
486   pub max_commits:      usize,
487   pub stat:             &'a str,
488   pub snapshot_summary: &'a str,
489   pub planning_targets: &'a str,
490   pub planning_notes:   &'a str,
491   pub split_bias:       &'a str,
492}
493
494/// Render compose intent prompt template.
495pub fn render_compose_intent_prompt(p: &ComposeIntentPromptParams<'_>) -> Result<PromptParts> {
496   let template_content = load_template_file("compose-intent", p.variant)?;
497
498   let mut context = Context::new();
499   context.insert("max_commits", &p.max_commits);
500   context.insert("stat", p.stat);
501   context.insert("snapshot_summary", p.snapshot_summary);
502   context.insert("planning_targets", p.planning_targets);
503   context.insert("planning_notes", p.planning_notes);
504   context.insert("split_bias", p.split_bias);
505
506   render_prompt_parts(&format!("compose-intent/{}.md", p.variant), &template_content, &context)
507}
508
509/// Parameters for rendering the compose bind prompt template.
510pub struct ComposeBindPromptParams<'a> {
511   pub variant:         &'a str,
512   pub groups:          &'a str,
513   pub ambiguous_files: &'a str,
514}
515
516/// Render compose bind prompt template.
517pub fn render_compose_bind_prompt(p: &ComposeBindPromptParams<'_>) -> Result<PromptParts> {
518   let template_content = load_template_file("compose-bind", p.variant)?;
519
520   let mut context = Context::new();
521   context.insert("groups", p.groups);
522   context.insert("ambiguous_files", p.ambiguous_files);
523
524   render_prompt_parts(&format!("compose-bind/{}.md", p.variant), &template_content, &context)
525}
526
527/// Parameters for rendering the fast mode prompt template.
528pub struct FastPromptParams<'a> {
529   pub variant:          &'a str,
530   pub stat:             &'a str,
531   pub diff:             &'a str,
532   pub scope_candidates: &'a str,
533   pub user_context:     Option<&'a str>,
534}
535
536/// Render fast mode prompt template (single-call commit generation)
537pub fn render_fast_prompt(p: &FastPromptParams<'_>) -> Result<PromptParts> {
538   let template_content = load_template_file("fast", p.variant)?;
539
540   let mut context = Context::new();
541   context.insert("stat", p.stat);
542   context.insert("diff", p.diff);
543   context.insert("scope_candidates", p.scope_candidates);
544   if let Some(ctx) = p.user_context {
545      context.insert("user_context", ctx);
546   }
547
548   render_prompt_parts(&format!("fast/{}.md", p.variant), &template_content, &context)
549}
550
551#[cfg(test)]
552mod tests {
553   use super::{
554      AnalysisParams, ComposeBindPromptParams, ComposeIntentPromptParams, ensure_prompts_dir,
555      render_analysis_prompt, render_compose_bind_prompt, render_compose_intent_prompt,
556      render_reduce_prompt, render_summary_prompt, split_prompt_template,
557   };
558
559   #[test]
560   fn test_split_prompt_template_lf() {
561      let content = "system text\nmore system\n======USER=======\nuser body\n";
562      let (system, user) = split_prompt_template(content);
563      assert_eq!(system, Some("system text\nmore system"));
564      assert_eq!(user, "user body\n");
565   }
566
567   #[test]
568   fn test_split_prompt_template_crlf() {
569      // Windows checkouts under Git's default core.autocrlf=true produce CRLF
570      // separators; the splitter must locate the marker line regardless.
571      let content = "system text\r\nmore system\r\n======USER=======\r\nuser body\r\n";
572      let (system, user) = split_prompt_template(content);
573      assert_eq!(system, Some("system text\r\nmore system"));
574      assert_eq!(user, "user body\r\n");
575   }
576
577   #[test]
578   fn test_split_prompt_template_no_separator() {
579      let content = "no separator here";
580      let (system, user) = split_prompt_template(content);
581      assert_eq!(system, None);
582      assert_eq!(user, content);
583   }
584
585   #[test]
586   fn test_render_analysis_prompt_requests_holistic_summary() {
587      ensure_prompts_dir().unwrap();
588
589      let parts = render_analysis_prompt(&AnalysisParams {
590         variant:           "default",
591         stat:              "src/api/client.rs | 24 +++++++++++++++---------",
592         diff:              "diff --git a/src/api/client.rs b/src/api/client.rs\n",
593         scope_candidates:  "api",
594         recent_commits:    None,
595         common_scopes:     None,
596         types_description: None,
597         project_context:   None,
598      })
599      .unwrap();
600
601      assert!(parts.system.contains("Generate Summary"));
602      assert!(parts.system.contains("\"summary\""));
603      assert!(
604         parts
605            .system
606            .contains("umbrella headline for the whole changeset")
607      );
608      assert!(parts.system.contains("Does not copy detail #1"));
609   }
610
611   #[test]
612   fn test_render_reduce_prompt_guides_grouped_synthesis() {
613      ensure_prompts_dir().unwrap();
614
615      let parts = render_reduce_prompt(
616         "default",
617         r#"[{"file":"src/a.rs","observations":["Added retry handling."]}]"#,
618         "src/a.rs | 10 +++++-----",
619         "api",
620         None,
621      )
622      .unwrap();
623
624      assert!(parts.system.contains("3-4 strong grouped details"));
625      assert!(
626         parts
627            .system
628            .contains("Synthesize repeated file observations")
629      );
630      assert!(parts.system.contains("over enumerating files"));
631   }
632   #[test]
633   fn test_render_compose_intent_prompt() {
634      let parts = render_compose_intent_prompt(&ComposeIntentPromptParams {
635         variant:          "default",
636         max_commits:      3,
637         stat:             "src/foo.rs | 10 +++++-----",
638         snapshot_summary: "- F1 src/foo.rs",
639         planning_targets: "file IDs",
640         planning_notes:   "Prefer conservative grouping over speculative splitting.",
641         split_bias:       "Prefer fewer groups when the split is uncertain.",
642      })
643      .unwrap();
644
645      assert!(parts.system.contains("create_compose_intent_plan"));
646      assert!(parts.user.contains("max_commits: 3"));
647      assert!(parts.user.contains("src/foo.rs"));
648   }
649
650   #[test]
651   fn test_render_summary_prompt_guides_umbrella_title() {
652      ensure_prompts_dir().unwrap();
653
654      let parts = render_summary_prompt(
655         "default",
656         "feat",
657         "api",
658         "72",
659         "Added websocket reconnects.\nUpdated client retry tests.",
660         "src/api/client.rs | 24 +++++++++++++++---------",
661         None,
662      )
663      .unwrap();
664
665      assert!(
666         parts
667            .system
668            .contains("umbrella description for the whole changeset")
669      );
670      assert!(parts.system.contains("not as candidate titles to copy"));
671      assert!(
672         parts
673            .system
674            .contains("does not copy or narrowly paraphrase one detail point")
675      );
676      assert!(parts.user.contains("<detail_points>"));
677      assert!(parts.user.contains("Added websocket reconnects."));
678      assert!(parts.user.contains("Updated client retry tests."));
679   }
680
681   #[test]
682   fn test_render_compose_bind_prompt() {
683      let parts = render_compose_bind_prompt(&ComposeBindPromptParams {
684         variant:         "default",
685         groups:          "- G1 [feat(api)] Added endpoint",
686         ambiguous_files: "- F2 src/api.rs candidates: G1",
687      })
688      .unwrap();
689
690      assert!(parts.system.contains("bind_compose_hunks"));
691      assert!(parts.user.contains("G1"));
692      assert!(parts.user.contains("src/api.rs"));
693   }
694}