1use std::{
2 path::{Path, PathBuf},
3 sync::LazyLock,
4};
5
6use parking_lot::Mutex;
7use rust_embed::RustEmbed;
8use tera::{Context, Tera};
9
10use crate::error::{CommitGenError, Result};
11
12pub struct PromptParts {
14 pub system: String,
15 pub user: String,
16}
17
18const USER_SEPARATOR: &str = "\n======USER=======\n";
19
20fn split_prompt_template(template_content: &str) -> (Option<&str>, &str) {
22 if let Some(pos) = template_content.find(USER_SEPARATOR) {
23 let system = &template_content[..pos];
24 let user = &template_content[pos + USER_SEPARATOR.len()..];
25 (Some(system), user)
26 } else {
27 (None, template_content)
28 }
29}
30
31fn ensure_static_system_prompt(system_template: &str, template_name: &str) -> Result<()> {
33 let has_template_tags = system_template.contains("{{")
34 || system_template.contains("{%")
35 || system_template.contains("{#");
36
37 if has_template_tags {
38 return Err(CommitGenError::Other(format!(
39 "Template '{template_name}' contains dynamic tags in system section. Move interpolated \
40 content below ======USER=======."
41 )));
42 }
43
44 Ok(())
45}
46
47fn render_prompt_parts(
49 template_name: &str,
50 template_content: &str,
51 context: &Context,
52) -> Result<PromptParts> {
53 let (system_template, user_template) = split_prompt_template(template_content);
54
55 let system = if let Some(system_template) = system_template {
56 ensure_static_system_prompt(system_template, template_name)?;
57 system_template.trim().to_string()
58 } else {
59 String::new()
60 };
61
62 let mut tera = TERA.lock();
63 let rendered_user = tera.render_str(user_template, context).map_err(|e| {
64 CommitGenError::Other(format!("Failed to render {template_name} prompt template: {e}"))
65 })?;
66
67 Ok(PromptParts { system, user: rendered_user.trim().to_string() })
68}
69
70#[derive(Default)]
72pub struct AnalysisParams<'a> {
73 pub variant: &'a str,
74 pub stat: &'a str,
75 pub diff: &'a str,
76 pub scope_candidates: &'a str,
77 pub recent_commits: Option<&'a str>,
78 pub common_scopes: Option<&'a str>,
79 pub types_description: Option<&'a str>,
80 pub project_context: Option<&'a str>,
81}
82
83#[derive(RustEmbed)]
85#[folder = "prompts/"]
86#[include = "**/*.md"]
87struct Prompts;
88
89static TERA: LazyLock<Mutex<Tera>> = LazyLock::new(|| {
92 if let Err(e) = ensure_prompts_dir() {
94 eprintln!("Warning: Failed to initialize prompts directory: {e}");
95 }
96
97 let mut tera = Tera::default();
98
99 if let Some(prompts_dir) = get_user_prompts_dir() {
101 if let Err(e) =
102 register_directory_templates(&mut tera, &prompts_dir.join("analysis"), "analysis")
103 {
104 eprintln!("Warning: {e}");
105 }
106 if let Err(e) =
107 register_directory_templates(&mut tera, &prompts_dir.join("summary"), "summary")
108 {
109 eprintln!("Warning: {e}");
110 }
111 if let Err(e) =
112 register_directory_templates(&mut tera, &prompts_dir.join("changelog"), "changelog")
113 {
114 eprintln!("Warning: {e}");
115 }
116 if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("map"), "map") {
117 eprintln!("Warning: {e}");
118 }
119 if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("reduce"), "reduce")
120 {
121 eprintln!("Warning: {e}");
122 }
123 if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("fast"), "fast") {
124 eprintln!("Warning: {e}");
125 }
126 }
127
128 for file in Prompts::iter() {
130 if tera.get_template_names().any(|name| name == file.as_ref()) {
131 continue;
132 }
133
134 if let Some(embedded_file) = Prompts::get(file.as_ref()) {
135 match std::str::from_utf8(embedded_file.data.as_ref()) {
136 Ok(content) => {
137 if let Err(e) = tera.add_raw_template(file.as_ref(), content) {
138 eprintln!(
139 "Warning: Failed to register embedded template {}: {}",
140 file.as_ref(),
141 e
142 );
143 }
144 },
145 Err(e) => {
146 eprintln!("Warning: Embedded template {} is not valid UTF-8: {}", file.as_ref(), e);
147 },
148 }
149 }
150 }
151
152 tera.autoescape_on(vec![]);
154
155 Mutex::new(tera)
156});
157
158fn get_user_prompts_dir() -> Option<PathBuf> {
160 std::env::var("HOME")
161 .or_else(|_| std::env::var("USERPROFILE"))
162 .ok()
163 .map(|home| PathBuf::from(home).join(".llm-git").join("prompts"))
164}
165
166pub fn ensure_prompts_dir() -> Result<()> {
168 let Some(user_prompts_dir) = get_user_prompts_dir() else {
169 return Ok(());
172 };
173
174 let user_llm_git_dir = user_prompts_dir
176 .parent()
177 .ok_or_else(|| CommitGenError::Other("Invalid prompts directory path".to_string()))?;
178
179 if !user_llm_git_dir.exists() {
181 std::fs::create_dir_all(user_llm_git_dir).map_err(|e| {
182 CommitGenError::Other(format!(
183 "Failed to create directory {}: {}",
184 user_llm_git_dir.display(),
185 e
186 ))
187 })?;
188 }
189
190 if !user_prompts_dir.exists() {
192 std::fs::create_dir_all(&user_prompts_dir).map_err(|e| {
193 CommitGenError::Other(format!(
194 "Failed to create directory {}: {}",
195 user_prompts_dir.display(),
196 e
197 ))
198 })?;
199 }
200
201 for file in Prompts::iter() {
203 let file_path = user_prompts_dir.join(file.as_ref());
204
205 if let Some(parent) = file_path.parent() {
207 std::fs::create_dir_all(parent).map_err(|e| {
208 CommitGenError::Other(format!("Failed to create directory {}: {}", parent.display(), e))
209 })?;
210 }
211
212 if let Some(embedded_file) = Prompts::get(file.as_ref()) {
213 let embedded_content = embedded_file.data;
214
215 let should_write = if file_path.exists() {
217 match std::fs::read(&file_path) {
218 Ok(existing_content) => existing_content != embedded_content.as_ref(),
219 Err(_) => true, }
221 } else {
222 true };
224
225 if should_write {
226 std::fs::write(&file_path, embedded_content.as_ref()).map_err(|e| {
227 CommitGenError::Other(format!("Failed to write file {}: {}", file_path.display(), e))
228 })?;
229 }
230 }
231 }
232
233 Ok(())
234}
235
236fn register_directory_templates(tera: &mut Tera, directory: &Path, category: &str) -> Result<()> {
237 if !directory.exists() {
238 return Ok(());
239 }
240
241 for entry in std::fs::read_dir(directory).map_err(|e| {
242 CommitGenError::Other(format!(
243 "Failed to read {} templates directory {}: {}",
244 category,
245 directory.display(),
246 e
247 ))
248 })? {
249 let entry = match entry {
250 Ok(entry) => entry,
251 Err(e) => {
252 eprintln!(
253 "Warning: Failed to iterate template entry in {}: {}",
254 directory.display(),
255 e
256 );
257 continue;
258 },
259 };
260
261 let path = entry.path();
262 if path.extension().and_then(|s| s.to_str()) != Some("md") {
263 continue;
264 }
265
266 let template_name = format!(
267 "{}/{}",
268 category,
269 path
270 .file_name()
271 .and_then(|s| s.to_str())
272 .unwrap_or_default()
273 );
274
275 if let Err(e) = tera.add_template_file(&path, Some(&template_name)) {
278 eprintln!("Warning: Failed to load template file {}: {}", path.display(), e);
279 }
280 }
281
282 Ok(())
283}
284
285fn load_template_file(category: &str, variant: &str) -> Result<String> {
287 if let Some(prompts_dir) = get_user_prompts_dir() {
289 let template_path = prompts_dir.join(category).join(format!("{variant}.md"));
290 if template_path.exists() {
291 return std::fs::read_to_string(&template_path).map_err(|e| {
292 CommitGenError::Other(format!(
293 "Failed to read template file {}: {}",
294 template_path.display(),
295 e
296 ))
297 });
298 }
299 }
300
301 let embedded_key = format!("{category}/{variant}.md");
303 if let Some(bytes) = Prompts::get(&embedded_key) {
304 return std::str::from_utf8(bytes.data.as_ref())
305 .map(|s| s.to_string())
306 .map_err(|e| {
307 CommitGenError::Other(format!(
308 "Embedded template {embedded_key} is not valid UTF-8: {e}"
309 ))
310 });
311 }
312
313 Err(CommitGenError::Other(format!(
314 "Template variant '{variant}' in category '{category}' not found as user override or \
315 embedded default"
316 )))
317}
318
319pub fn render_analysis_prompt(p: &AnalysisParams<'_>) -> Result<PromptParts> {
321 let template_content = load_template_file("analysis", p.variant)?;
323
324 let mut context = Context::new();
326 context.insert("stat", p.stat);
327 context.insert("diff", p.diff);
328 context.insert("scope_candidates", p.scope_candidates);
329 if let Some(commits) = p.recent_commits {
330 context.insert("recent_commits", commits);
331 }
332 if let Some(scopes) = p.common_scopes {
333 context.insert("common_scopes", scopes);
334 }
335 if let Some(types) = p.types_description {
336 context.insert("types_description", types);
337 }
338 if let Some(ctx) = p.project_context {
339 context.insert("project_context", ctx);
340 }
341
342 render_prompt_parts(&format!("analysis/{}.md", p.variant), &template_content, &context)
343}
344
345pub fn render_summary_prompt(
347 variant: &str,
348 commit_type: &str,
349 scope: &str,
350 chars: &str,
351 details: &str,
352 stat: &str,
353 user_context: Option<&str>,
354) -> Result<PromptParts> {
355 let template_content = load_template_file("summary", variant)?;
357
358 let mut context = Context::new();
360 context.insert("commit_type", commit_type);
361 context.insert("scope", scope);
362 context.insert("chars", chars);
363 context.insert("details", details);
364 context.insert("stat", stat);
365 if let Some(ctx) = user_context {
366 context.insert("user_context", ctx);
367 }
368
369 render_prompt_parts(&format!("summary/{variant}.md"), &template_content, &context)
370}
371
372pub fn render_changelog_prompt(
374 variant: &str,
375 changelog_path: &str,
376 is_package_changelog: bool,
377 stat: &str,
378 diff: &str,
379 existing_entries: Option<&str>,
380) -> Result<PromptParts> {
381 let template_content = load_template_file("changelog", variant)?;
383
384 let mut context = Context::new();
386 context.insert("changelog_path", changelog_path);
387 context.insert("is_package_changelog", &is_package_changelog);
388 context.insert("stat", stat);
389 context.insert("diff", diff);
390 if let Some(entries) = existing_entries {
391 context.insert("existing_entries", entries);
392 }
393
394 render_prompt_parts(&format!("changelog/{variant}.md"), &template_content, &context)
395}
396
397pub fn render_map_prompt(
399 variant: &str,
400 filename: &str,
401 diff: &str,
402 context_header: &str,
403) -> Result<PromptParts> {
404 let template_content = load_template_file("map", variant)?;
405
406 let mut context = Context::new();
407 context.insert("filename", filename);
408 context.insert("diff", diff);
409 if !context_header.is_empty() {
410 context.insert("context_header", context_header);
411 }
412
413 render_prompt_parts(&format!("map/{variant}.md"), &template_content, &context)
414}
415
416pub fn render_reduce_prompt(
418 variant: &str,
419 observations: &str,
420 stat: &str,
421 scope_candidates: &str,
422 types_description: Option<&str>,
423) -> Result<PromptParts> {
424 let template_content = load_template_file("reduce", variant)?;
425
426 let mut context = Context::new();
427 context.insert("observations", observations);
428 context.insert("stat", stat);
429 context.insert("scope_candidates", scope_candidates);
430 if let Some(types_desc) = types_description {
431 context.insert("types_description", types_desc);
432 }
433
434 render_prompt_parts(&format!("reduce/{variant}.md"), &template_content, &context)
435}
436
437pub struct FastPromptParams<'a> {
439 pub variant: &'a str,
440 pub stat: &'a str,
441 pub diff: &'a str,
442 pub scope_candidates: &'a str,
443 pub user_context: Option<&'a str>,
444}
445
446pub fn render_fast_prompt(p: &FastPromptParams<'_>) -> Result<PromptParts> {
448 let template_content = load_template_file("fast", p.variant)?;
449
450 let mut context = Context::new();
451 context.insert("stat", p.stat);
452 context.insert("diff", p.diff);
453 context.insert("scope_candidates", p.scope_candidates);
454 if let Some(ctx) = p.user_context {
455 context.insert("user_context", ctx);
456 }
457
458 render_prompt_parts(&format!("fast/{}.md", p.variant), &template_content, &context)
459}