1use std::{
2 path::{Path, PathBuf},
3 sync::LazyLock,
4};
5
6use parking_lot::Mutex;
7use rust_embed::RustEmbed;
8use tera::{Context, Tera};
9
10use crate::error::{CommitGenError, Result};
11
12pub struct PromptParts {
14 pub system: String,
15 pub user: String,
16}
17
18fn split_prompt_parts(rendered: &str) -> PromptParts {
20 const SEPARATOR: &str = "\n======USER=======\n";
21 if let Some(pos) = rendered.find(SEPARATOR) {
22 PromptParts {
23 system: rendered[..pos].trim().to_string(),
24 user: rendered[pos + SEPARATOR.len()..].trim().to_string(),
25 }
26 } else {
27 PromptParts { system: String::new(), user: rendered.trim().to_string() }
29 }
30}
31
32#[derive(Default)]
34pub struct AnalysisParams<'a> {
35 pub variant: &'a str,
36 pub stat: &'a str,
37 pub diff: &'a str,
38 pub scope_candidates: &'a str,
39 pub recent_commits: Option<&'a str>,
40 pub common_scopes: Option<&'a str>,
41 pub types_description: Option<&'a str>,
42 pub project_context: Option<&'a str>,
43}
44
45#[derive(RustEmbed)]
47#[folder = "prompts/"]
48struct Prompts;
49
50static TERA: LazyLock<Mutex<Tera>> = LazyLock::new(|| {
53 if let Err(e) = ensure_prompts_dir() {
55 eprintln!("Warning: Failed to initialize prompts directory: {e}");
56 }
57
58 let mut tera = Tera::default();
59
60 if let Some(prompts_dir) = get_user_prompts_dir() {
62 if let Err(e) =
63 register_directory_templates(&mut tera, &prompts_dir.join("analysis"), "analysis")
64 {
65 eprintln!("Warning: {e}");
66 }
67 if let Err(e) =
68 register_directory_templates(&mut tera, &prompts_dir.join("summary"), "summary")
69 {
70 eprintln!("Warning: {e}");
71 }
72 if let Err(e) =
73 register_directory_templates(&mut tera, &prompts_dir.join("changelog"), "changelog")
74 {
75 eprintln!("Warning: {e}");
76 }
77 if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("map"), "map") {
78 eprintln!("Warning: {e}");
79 }
80 if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("reduce"), "reduce")
81 {
82 eprintln!("Warning: {e}");
83 }
84 }
85
86 for file in Prompts::iter() {
88 if tera.get_template_names().any(|name| name == file.as_ref()) {
89 continue;
90 }
91
92 if let Some(embedded_file) = Prompts::get(file.as_ref()) {
93 match std::str::from_utf8(embedded_file.data.as_ref()) {
94 Ok(content) => {
95 if let Err(e) = tera.add_raw_template(file.as_ref(), content) {
96 eprintln!(
97 "Warning: Failed to register embedded template {}: {}",
98 file.as_ref(),
99 e
100 );
101 }
102 },
103 Err(e) => {
104 eprintln!("Warning: Embedded template {} is not valid UTF-8: {}", file.as_ref(), e);
105 },
106 }
107 }
108 }
109
110 tera.autoescape_on(vec![]);
112
113 Mutex::new(tera)
114});
115
116fn get_user_prompts_dir() -> Option<PathBuf> {
118 std::env::var("HOME")
119 .or_else(|_| std::env::var("USERPROFILE"))
120 .ok()
121 .map(|home| PathBuf::from(home).join(".llm-git").join("prompts"))
122}
123
124pub fn ensure_prompts_dir() -> Result<()> {
126 let Some(user_prompts_dir) = get_user_prompts_dir() else {
127 return Ok(());
130 };
131
132 let user_llm_git_dir = user_prompts_dir
134 .parent()
135 .ok_or_else(|| CommitGenError::Other("Invalid prompts directory path".to_string()))?;
136
137 if !user_llm_git_dir.exists() {
139 std::fs::create_dir_all(user_llm_git_dir).map_err(|e| {
140 CommitGenError::Other(format!(
141 "Failed to create directory {}: {}",
142 user_llm_git_dir.display(),
143 e
144 ))
145 })?;
146 }
147
148 if !user_prompts_dir.exists() {
150 std::fs::create_dir_all(&user_prompts_dir).map_err(|e| {
151 CommitGenError::Other(format!(
152 "Failed to create directory {}: {}",
153 user_prompts_dir.display(),
154 e
155 ))
156 })?;
157 }
158
159 for file in Prompts::iter() {
161 let file_path = user_prompts_dir.join(file.as_ref());
162
163 if let Some(parent) = file_path.parent() {
165 std::fs::create_dir_all(parent).map_err(|e| {
166 CommitGenError::Other(format!("Failed to create directory {}: {}", parent.display(), e))
167 })?;
168 }
169
170 if let Some(embedded_file) = Prompts::get(file.as_ref()) {
171 let embedded_content = embedded_file.data;
172
173 let should_write = if file_path.exists() {
175 match std::fs::read(&file_path) {
176 Ok(existing_content) => existing_content != embedded_content.as_ref(),
177 Err(_) => true, }
179 } else {
180 true };
182
183 if should_write {
184 std::fs::write(&file_path, embedded_content.as_ref()).map_err(|e| {
185 CommitGenError::Other(format!("Failed to write file {}: {}", file_path.display(), e))
186 })?;
187 }
188 }
189 }
190
191 Ok(())
192}
193
194fn register_directory_templates(tera: &mut Tera, directory: &Path, category: &str) -> Result<()> {
195 if !directory.exists() {
196 return Ok(());
197 }
198
199 for entry in std::fs::read_dir(directory).map_err(|e| {
200 CommitGenError::Other(format!(
201 "Failed to read {} templates directory {}: {}",
202 category,
203 directory.display(),
204 e
205 ))
206 })? {
207 let entry = match entry {
208 Ok(entry) => entry,
209 Err(e) => {
210 eprintln!(
211 "Warning: Failed to iterate template entry in {}: {}",
212 directory.display(),
213 e
214 );
215 continue;
216 },
217 };
218
219 let path = entry.path();
220 if path.extension().and_then(|s| s.to_str()) != Some("md") {
221 continue;
222 }
223
224 let template_name = format!(
225 "{}/{}",
226 category,
227 path
228 .file_name()
229 .and_then(|s| s.to_str())
230 .unwrap_or_default()
231 );
232
233 if let Err(e) = tera.add_template_file(&path, Some(&template_name)) {
236 eprintln!("Warning: Failed to load template file {}: {}", path.display(), e);
237 }
238 }
239
240 Ok(())
241}
242
243fn load_template_file(category: &str, variant: &str) -> Result<String> {
245 if let Some(prompts_dir) = get_user_prompts_dir() {
247 let template_path = prompts_dir.join(category).join(format!("{variant}.md"));
248 if template_path.exists() {
249 return std::fs::read_to_string(&template_path).map_err(|e| {
250 CommitGenError::Other(format!(
251 "Failed to read template file {}: {}",
252 template_path.display(),
253 e
254 ))
255 });
256 }
257 }
258
259 let embedded_key = format!("{category}/{variant}.md");
261 if let Some(bytes) = Prompts::get(&embedded_key) {
262 return std::str::from_utf8(bytes.data.as_ref())
263 .map(|s| s.to_string())
264 .map_err(|e| {
265 CommitGenError::Other(format!(
266 "Embedded template {embedded_key} is not valid UTF-8: {e}"
267 ))
268 });
269 }
270
271 Err(CommitGenError::Other(format!(
272 "Template variant '{variant}' in category '{category}' not found as user override or \
273 embedded default"
274 )))
275}
276
277pub fn render_analysis_prompt(p: &AnalysisParams<'_>) -> Result<PromptParts> {
279 let template_content = load_template_file("analysis", p.variant)?;
281
282 let mut context = Context::new();
284 context.insert("stat", p.stat);
285 context.insert("diff", p.diff);
286 context.insert("scope_candidates", p.scope_candidates);
287 if let Some(commits) = p.recent_commits {
288 context.insert("recent_commits", commits);
289 }
290 if let Some(scopes) = p.common_scopes {
291 context.insert("common_scopes", scopes);
292 }
293 if let Some(types) = p.types_description {
294 context.insert("types_description", types);
295 }
296 if let Some(ctx) = p.project_context {
297 context.insert("project_context", ctx);
298 }
299
300 let mut tera = TERA.lock();
302
303 let rendered = tera.render_str(&template_content, &context).map_err(|e| {
304 CommitGenError::Other(format!(
305 "Failed to render analysis prompt template '{}': {e}",
306 p.variant
307 ))
308 })?;
309 Ok(split_prompt_parts(&rendered))
310}
311
312pub fn render_summary_prompt(
314 variant: &str,
315 commit_type: &str,
316 scope: &str,
317 chars: &str,
318 details: &str,
319 stat: &str,
320 user_context: Option<&str>,
321) -> Result<PromptParts> {
322 let template_content = load_template_file("summary", variant)?;
324
325 let mut context = Context::new();
327 context.insert("commit_type", commit_type);
328 context.insert("scope", scope);
329 context.insert("chars", chars);
330 context.insert("details", details);
331 context.insert("stat", stat);
332 if let Some(ctx) = user_context {
333 context.insert("user_context", ctx);
334 }
335
336 let mut tera = TERA.lock();
338 let rendered = tera.render_str(&template_content, &context).map_err(|e| {
339 CommitGenError::Other(format!("Failed to render summary prompt template '{variant}': {e}"))
340 })?;
341 Ok(split_prompt_parts(&rendered))
342}
343
344pub fn render_changelog_prompt(
346 variant: &str,
347 changelog_path: &str,
348 is_package_changelog: bool,
349 stat: &str,
350 diff: &str,
351 existing_entries: Option<&str>,
352) -> Result<PromptParts> {
353 let template_content = load_template_file("changelog", variant)?;
355
356 let mut context = Context::new();
358 context.insert("changelog_path", changelog_path);
359 context.insert("is_package_changelog", &is_package_changelog);
360 context.insert("stat", stat);
361 context.insert("diff", diff);
362 if let Some(entries) = existing_entries {
363 context.insert("existing_entries", entries);
364 }
365
366 let mut tera = TERA.lock();
368 let rendered = tera.render_str(&template_content, &context).map_err(|e| {
369 CommitGenError::Other(format!("Failed to render changelog prompt template '{variant}': {e}"))
370 })?;
371 Ok(split_prompt_parts(&rendered))
372}
373
374pub fn render_map_prompt(
376 variant: &str,
377 filename: &str,
378 diff: &str,
379 context_header: &str,
380) -> Result<PromptParts> {
381 let template_content = load_template_file("map", variant)?;
382
383 let mut context = Context::new();
384 context.insert("filename", filename);
385 context.insert("diff", diff);
386 if !context_header.is_empty() {
387 context.insert("context_header", context_header);
388 }
389
390 let mut tera = TERA.lock();
391 let rendered = tera.render_str(&template_content, &context).map_err(|e| {
392 CommitGenError::Other(format!("Failed to render map prompt template '{variant}': {e}"))
393 })?;
394 Ok(split_prompt_parts(&rendered))
395}
396
397pub fn render_reduce_prompt(
399 variant: &str,
400 observations: &str,
401 stat: &str,
402 scope_candidates: &str,
403 types_description: Option<&str>,
404) -> Result<PromptParts> {
405 let template_content = load_template_file("reduce", variant)?;
406
407 let mut context = Context::new();
408 context.insert("observations", observations);
409 context.insert("stat", stat);
410 context.insert("scope_candidates", scope_candidates);
411 if let Some(types_desc) = types_description {
412 context.insert("types_description", types_desc);
413 }
414
415 let mut tera = TERA.lock();
416 let rendered = tera.render_str(&template_content, &context).map_err(|e| {
417 CommitGenError::Other(format!("Failed to render reduce prompt template '{variant}': {e}"))
418 })?;
419 Ok(split_prompt_parts(&rendered))
420}