use std::{
path::{Path, PathBuf},
sync::LazyLock,
};
use parking_lot::Mutex;
use rust_embed::RustEmbed;
use serde::Serialize;
use tera::{Context, Tera};
use crate::error::{CommitGenError, Result};
pub struct PromptParts {
pub system: String,
pub user: String,
}
const USER_SEPARATOR_MARKER: &str = "======USER=======";
fn find_user_separator(content: &str) -> Option<(usize, usize)> {
let marker_pos = content.find(USER_SEPARATOR_MARKER)?;
let system_end = if marker_pos >= 2 && &content[marker_pos - 2..marker_pos] == "\r\n" {
marker_pos - 2
} else if marker_pos >= 1 && &content[marker_pos - 1..marker_pos] == "\n" {
marker_pos - 1
} else {
marker_pos
};
let after_marker = marker_pos + USER_SEPARATOR_MARKER.len();
let user_start = if content.get(after_marker..after_marker + 2) == Some("\r\n") {
after_marker + 2
} else if content.get(after_marker..after_marker + 1) == Some("\n") {
after_marker + 1
} else {
after_marker
};
Some((system_end, user_start))
}
fn split_prompt_template(template_content: &str) -> (Option<&str>, &str) {
if let Some((system_end, user_start)) = find_user_separator(template_content) {
(Some(&template_content[..system_end]), &template_content[user_start..])
} else {
(None, template_content)
}
}
fn ensure_static_system_prompt(system_template: &str, template_name: &str) -> Result<()> {
let has_template_tags = system_template.contains("{{")
|| system_template.contains("{%")
|| system_template.contains("{#");
if has_template_tags {
return Err(CommitGenError::Other(format!(
"Template '{template_name}' contains dynamic tags in system section. Move interpolated \
content below ======USER=======."
)));
}
Ok(())
}
fn render_prompt_parts(
template_name: &str,
template_content: &str,
context: &Context,
) -> Result<PromptParts> {
let (system_template, user_template) = split_prompt_template(template_content);
let system = if let Some(system_template) = system_template {
ensure_static_system_prompt(system_template, template_name)?;
system_template.trim().to_string()
} else {
String::new()
};
let mut tera = TERA.lock();
let rendered_user = tera.render_str(user_template, context).map_err(|e| {
CommitGenError::Other(format!("Failed to render {template_name} prompt template: {e}"))
})?;
Ok(PromptParts { system, user: rendered_user.trim().to_string() })
}
#[derive(Default)]
pub struct AnalysisParams<'a> {
pub variant: &'a str,
pub stat: &'a str,
pub diff: &'a str,
pub scope_candidates: &'a str,
pub recent_commits: Option<&'a str>,
pub common_scopes: Option<&'a str>,
pub types_description: Option<&'a str>,
pub project_context: Option<&'a str>,
}
#[derive(RustEmbed)]
#[folder = "prompts/"]
#[include = "**/*.md"]
struct Prompts;
static TERA: LazyLock<Mutex<Tera>> = LazyLock::new(|| {
if let Err(e) = ensure_prompts_dir() {
eprintln!("Warning: Failed to initialize prompts directory: {e}");
}
let mut tera = Tera::default();
if let Some(prompts_dir) = get_user_prompts_dir() {
if let Err(e) =
register_directory_templates(&mut tera, &prompts_dir.join("analysis"), "analysis")
{
eprintln!("Warning: {e}");
}
if let Err(e) =
register_directory_templates(&mut tera, &prompts_dir.join("summary"), "summary")
{
eprintln!("Warning: {e}");
}
if let Err(e) =
register_directory_templates(&mut tera, &prompts_dir.join("changelog"), "changelog")
{
eprintln!("Warning: {e}");
}
if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("map"), "map") {
eprintln!("Warning: {e}");
}
if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("reduce"), "reduce")
{
eprintln!("Warning: {e}");
}
if let Err(e) = register_directory_templates(&mut tera, &prompts_dir.join("fast"), "fast") {
eprintln!("Warning: {e}");
}
if let Err(e) = register_directory_templates(
&mut tera,
&prompts_dir.join("compose-intent"),
"compose-intent",
) {
eprintln!("Warning: {e}");
}
if let Err(e) =
register_directory_templates(&mut tera, &prompts_dir.join("compose-bind"), "compose-bind")
{
eprintln!("Warning: {e}");
}
}
for file in Prompts::iter() {
if tera.get_template_names().any(|name| name == file.as_ref()) {
continue;
}
if let Some(embedded_file) = Prompts::get(file.as_ref()) {
match std::str::from_utf8(embedded_file.data.as_ref()) {
Ok(content) => {
if let Err(e) = tera.add_raw_template(file.as_ref(), content) {
eprintln!(
"Warning: Failed to register embedded template {}: {}",
file.as_ref(),
e
);
}
},
Err(e) => {
eprintln!("Warning: Embedded template {} is not valid UTF-8: {}", file.as_ref(), e);
},
}
}
}
tera.autoescape_on(vec![]);
Mutex::new(tera)
});
fn get_user_prompts_dir() -> Option<PathBuf> {
std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE"))
.ok()
.map(|home| PathBuf::from(home).join(".llm-git").join("prompts"))
}
pub fn ensure_prompts_dir() -> Result<()> {
let Some(user_prompts_dir) = get_user_prompts_dir() else {
return Ok(());
};
let user_llm_git_dir = user_prompts_dir
.parent()
.ok_or_else(|| CommitGenError::Other("Invalid prompts directory path".to_string()))?;
if !user_llm_git_dir.exists() {
std::fs::create_dir_all(user_llm_git_dir).map_err(|e| {
CommitGenError::Other(format!(
"Failed to create directory {}: {}",
user_llm_git_dir.display(),
e
))
})?;
}
if !user_prompts_dir.exists() {
std::fs::create_dir_all(&user_prompts_dir).map_err(|e| {
CommitGenError::Other(format!(
"Failed to create directory {}: {}",
user_prompts_dir.display(),
e
))
})?;
}
for file in Prompts::iter() {
let file_path = user_prompts_dir.join(file.as_ref());
if let Some(parent) = file_path.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
CommitGenError::Other(format!("Failed to create directory {}: {}", parent.display(), e))
})?;
}
if let Some(embedded_file) = Prompts::get(file.as_ref()) {
let embedded_content = embedded_file.data;
let should_write = if file_path.exists() {
match std::fs::read(&file_path) {
Ok(existing_content) => existing_content != embedded_content.as_ref(),
Err(_) => true, }
} else {
true };
if should_write {
std::fs::write(&file_path, embedded_content.as_ref()).map_err(|e| {
CommitGenError::Other(format!("Failed to write file {}: {}", file_path.display(), e))
})?;
}
}
}
Ok(())
}
fn register_directory_templates(tera: &mut Tera, directory: &Path, category: &str) -> Result<()> {
if !directory.exists() {
return Ok(());
}
for entry in std::fs::read_dir(directory).map_err(|e| {
CommitGenError::Other(format!(
"Failed to read {} templates directory {}: {}",
category,
directory.display(),
e
))
})? {
let entry = match entry {
Ok(entry) => entry,
Err(e) => {
eprintln!(
"Warning: Failed to iterate template entry in {}: {}",
directory.display(),
e
);
continue;
},
};
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) != Some("md") {
continue;
}
let template_name = format!(
"{}/{}",
category,
path
.file_name()
.and_then(|s| s.to_str())
.unwrap_or_default()
);
if let Err(e) = tera.add_template_file(&path, Some(&template_name)) {
eprintln!("Warning: Failed to load template file {}: {}", path.display(), e);
}
}
Ok(())
}
fn load_template_file(category: &str, variant: &str) -> Result<String> {
if let Some(prompts_dir) = get_user_prompts_dir() {
let template_path = prompts_dir.join(category).join(format!("{variant}.md"));
if template_path.exists() {
return std::fs::read_to_string(&template_path).map_err(|e| {
CommitGenError::Other(format!(
"Failed to read template file {}: {}",
template_path.display(),
e
))
});
}
}
let embedded_key = format!("{category}/{variant}.md");
if let Some(bytes) = Prompts::get(&embedded_key) {
return std::str::from_utf8(bytes.data.as_ref())
.map(|s| s.to_string())
.map_err(|e| {
CommitGenError::Other(format!(
"Embedded template {embedded_key} is not valid UTF-8: {e}"
))
});
}
Err(CommitGenError::Other(format!(
"Template variant '{variant}' in category '{category}' not found as user override or \
embedded default"
)))
}
pub fn render_analysis_prompt(p: &AnalysisParams<'_>) -> Result<PromptParts> {
let template_content = load_template_file("analysis", p.variant)?;
let mut context = Context::new();
context.insert("stat", p.stat);
context.insert("diff", p.diff);
context.insert("scope_candidates", p.scope_candidates);
if let Some(commits) = p.recent_commits {
context.insert("recent_commits", commits);
}
if let Some(scopes) = p.common_scopes {
context.insert("common_scopes", scopes);
}
if let Some(types) = p.types_description {
context.insert("types_description", types);
}
if let Some(ctx) = p.project_context {
context.insert("project_context", ctx);
}
render_prompt_parts(&format!("analysis/{}.md", p.variant), &template_content, &context)
}
pub fn render_summary_prompt(
variant: &str,
commit_type: &str,
scope: &str,
chars: &str,
details: &str,
stat: &str,
user_context: Option<&str>,
) -> Result<PromptParts> {
let template_content = load_template_file("summary", variant)?;
let mut context = Context::new();
context.insert("commit_type", commit_type);
context.insert("scope", scope);
context.insert("chars", chars);
context.insert("details", details);
context.insert("stat", stat);
if let Some(ctx) = user_context {
context.insert("user_context", ctx);
}
render_prompt_parts(&format!("summary/{variant}.md"), &template_content, &context)
}
pub fn render_changelog_prompt(
variant: &str,
changelog_path: &str,
is_package_changelog: bool,
stat: &str,
diff: &str,
existing_entries: Option<&str>,
) -> Result<PromptParts> {
let template_content = load_template_file("changelog", variant)?;
let mut context = Context::new();
context.insert("changelog_path", changelog_path);
context.insert("is_package_changelog", &is_package_changelog);
context.insert("stat", stat);
context.insert("diff", diff);
if let Some(entries) = existing_entries {
context.insert("existing_entries", entries);
}
render_prompt_parts(&format!("changelog/{variant}.md"), &template_content, &context)
}
#[derive(Serialize)]
pub struct MapFile<'a> {
pub path: &'a str,
pub diff: &'a str,
}
pub fn render_map_prompt(
variant: &str,
files: &[MapFile<'_>],
context_header: &str,
) -> Result<PromptParts> {
let template_content = load_template_file("map", variant)?;
let mut context = Context::new();
context.insert("files", files);
if !context_header.is_empty() {
context.insert("context_header", context_header);
}
render_prompt_parts(&format!("map/{variant}.md"), &template_content, &context)
}
pub fn render_reduce_prompt(
variant: &str,
observations: &str,
stat: &str,
scope_candidates: &str,
types_description: Option<&str>,
) -> Result<PromptParts> {
let template_content = load_template_file("reduce", variant)?;
let mut context = Context::new();
context.insert("observations", observations);
context.insert("stat", stat);
context.insert("scope_candidates", scope_candidates);
if let Some(types_desc) = types_description {
context.insert("types_description", types_desc);
}
render_prompt_parts(&format!("reduce/{variant}.md"), &template_content, &context)
}
pub struct ComposeIntentPromptParams<'a> {
pub variant: &'a str,
pub max_commits: usize,
pub stat: &'a str,
pub snapshot_summary: &'a str,
pub planning_targets: &'a str,
pub planning_notes: &'a str,
pub split_bias: &'a str,
}
pub fn render_compose_intent_prompt(p: &ComposeIntentPromptParams<'_>) -> Result<PromptParts> {
let template_content = load_template_file("compose-intent", p.variant)?;
let mut context = Context::new();
context.insert("max_commits", &p.max_commits);
context.insert("stat", p.stat);
context.insert("snapshot_summary", p.snapshot_summary);
context.insert("planning_targets", p.planning_targets);
context.insert("planning_notes", p.planning_notes);
context.insert("split_bias", p.split_bias);
render_prompt_parts(&format!("compose-intent/{}.md", p.variant), &template_content, &context)
}
pub struct ComposeBindPromptParams<'a> {
pub variant: &'a str,
pub groups: &'a str,
pub ambiguous_files: &'a str,
}
pub fn render_compose_bind_prompt(p: &ComposeBindPromptParams<'_>) -> Result<PromptParts> {
let template_content = load_template_file("compose-bind", p.variant)?;
let mut context = Context::new();
context.insert("groups", p.groups);
context.insert("ambiguous_files", p.ambiguous_files);
render_prompt_parts(&format!("compose-bind/{}.md", p.variant), &template_content, &context)
}
pub struct FastPromptParams<'a> {
pub variant: &'a str,
pub stat: &'a str,
pub diff: &'a str,
pub scope_candidates: &'a str,
pub user_context: Option<&'a str>,
pub types_description: Option<&'a str>,
}
pub fn render_fast_prompt(p: &FastPromptParams<'_>) -> Result<PromptParts> {
let template_content = load_template_file("fast", p.variant)?;
let mut context = Context::new();
context.insert("stat", p.stat);
context.insert("diff", p.diff);
context.insert("scope_candidates", p.scope_candidates);
if let Some(ctx) = p.user_context {
context.insert("user_context", ctx);
}
if let Some(types_desc) = p.types_description {
context.insert("types_description", types_desc);
}
render_prompt_parts(&format!("fast/{}.md", p.variant), &template_content, &context)
}
#[cfg(test)]
mod tests {
use super::{
AnalysisParams, ComposeBindPromptParams, ComposeIntentPromptParams, FastPromptParams,
ensure_prompts_dir, render_analysis_prompt, render_changelog_prompt, render_compose_bind_prompt,
render_compose_intent_prompt, render_fast_prompt, render_reduce_prompt,
render_summary_prompt, split_prompt_template,
};
#[test]
fn test_split_prompt_template_lf() {
let content = "system text\nmore system\n======USER=======\nuser body\n";
let (system, user) = split_prompt_template(content);
assert_eq!(system, Some("system text\nmore system"));
assert_eq!(user, "user body\n");
}
#[test]
fn test_split_prompt_template_crlf() {
let content = "system text\r\nmore system\r\n======USER=======\r\nuser body\r\n";
let (system, user) = split_prompt_template(content);
assert_eq!(system, Some("system text\r\nmore system"));
assert_eq!(user, "user body\r\n");
}
#[test]
fn test_split_prompt_template_no_separator() {
let content = "no separator here";
let (system, user) = split_prompt_template(content);
assert_eq!(system, None);
assert_eq!(user, content);
}
#[test]
fn test_render_analysis_prompt_requests_holistic_summary() {
ensure_prompts_dir().unwrap();
let parts = render_analysis_prompt(&AnalysisParams {
variant: "default",
stat: "src/api/client.rs | 24 +++++++++++++++---------",
diff: "diff --git a/src/api/client.rs b/src/api/client.rs\n",
scope_candidates: "api",
recent_commits: None,
common_scopes: None,
types_description: None,
project_context: None,
})
.unwrap();
assert!(parts.system.contains("Generate Summary"));
assert!(parts.system.contains("\"summary\""));
assert!(
parts
.system
.contains("umbrella headline for the whole changeset")
);
assert!(parts.system.contains("Does not copy detail #1"));
}
#[test]
fn test_render_changelog_prompt_variants_render() {
ensure_prompts_dir().unwrap();
for variant in ["default", "markdown"] {
let parts = render_changelog_prompt(
variant,
"CHANGELOG.md",
false,
"src/api.rs | 4 ++--",
"diff --git a/src/api.rs b/src/api.rs\n",
Some("- Added existing entry"),
)
.unwrap_or_else(|e| panic!("{variant} changelog prompt failed to render: {e}"));
assert!(parts.user.contains("src/api.rs"), "{variant}: diff missing");
assert!(
parts.user.contains("Added existing entry"),
"{variant}: existing entries missing"
);
match variant {
"markdown" => {
assert!(
parts.system.contains("# Added"),
"markdown variant must advertise markdown sections"
);
assert!(
!parts.system.contains("{\"entries\""),
"markdown variant must not advertise JSON output"
);
},
"default" => assert!(
parts.system.contains("{\"entries\""),
"default variant must advertise JSON output"
),
_ => unreachable!(),
}
}
}
#[test]
fn test_render_fast_prompt_surfaces_type_guidance() {
ensure_prompts_dir().unwrap();
let parts = render_fast_prompt(&FastPromptParams {
variant: "default",
stat: "prompts/analysis/default.md | 5 +++++",
diff: "diff --git a/prompts/analysis/default.md \
b/prompts/analysis/default.md\n",
scope_candidates: "prompts",
user_context: None,
types_description: Some(
"**docs**: Documentation only changes\n Note: Excludes prompt template files.",
),
})
.unwrap();
assert!(parts.user.contains("<commit_types>"));
assert!(parts.user.contains("Excludes prompt template files."));
assert!(parts.system.contains("not `docs`"));
}
#[test]
fn test_render_fast_prompt_omits_commit_types_when_absent() {
ensure_prompts_dir().unwrap();
let parts = render_fast_prompt(&FastPromptParams {
variant: "default",
stat: "src/main.rs | 5 +++++",
diff: "diff --git a/src/main.rs b/src/main.rs\n",
scope_candidates: "",
user_context: None,
types_description: None,
})
.unwrap();
assert!(!parts.user.contains("<commit_types>"));
}
#[test]
fn test_render_reduce_prompt_guides_grouped_synthesis() {
ensure_prompts_dir().unwrap();
let parts = render_reduce_prompt(
"default",
r#"[{"file":"src/a.rs","observations":["Added retry handling."]}]"#,
"src/a.rs | 10 +++++-----",
"api",
None,
)
.unwrap();
assert!(parts.system.contains("3-4 strong grouped details"));
assert!(
parts
.system
.contains("Synthesize repeated file observations")
);
assert!(parts.system.contains("over enumerating files"));
}
#[test]
fn test_render_compose_intent_prompt() {
let parts = render_compose_intent_prompt(&ComposeIntentPromptParams {
variant: "default",
max_commits: 3,
stat: "src/foo.rs | 10 +++++-----",
snapshot_summary: "- F1 src/foo.rs",
planning_targets: "file IDs",
planning_notes: "Prefer conservative grouping over speculative splitting.",
split_bias: "Prefer fewer groups when the split is uncertain.",
})
.unwrap();
assert!(parts.system.contains("create_compose_intent_plan"));
assert!(parts.user.contains("max_commits: 3"));
assert!(parts.user.contains("src/foo.rs"));
}
#[test]
fn test_render_summary_prompt_guides_umbrella_title() {
ensure_prompts_dir().unwrap();
let parts = render_summary_prompt(
"default",
"feat",
"api",
"72",
"Added websocket reconnects.\nUpdated client retry tests.",
"src/api/client.rs | 24 +++++++++++++++---------",
None,
)
.unwrap();
assert!(
parts
.system
.contains("umbrella description for the whole changeset")
);
assert!(parts.system.contains("not as candidate titles to copy"));
assert!(
parts
.system
.contains("does not copy or narrowly paraphrase one detail point")
);
assert!(parts.user.contains("<detail_points>"));
assert!(parts.user.contains("Added websocket reconnects."));
assert!(parts.user.contains("Updated client retry tests."));
}
#[test]
fn test_render_compose_bind_prompt() {
let parts = render_compose_bind_prompt(&ComposeBindPromptParams {
variant: "default",
groups: "- G1 [feat(api)] Added endpoint",
ambiguous_files: "- F2 src/api.rs candidates: G1",
})
.unwrap();
assert!(parts.system.contains("bind_compose_hunks"));
assert!(parts.user.contains("G1"));
assert!(parts.user.contains("src/api.rs"));
}
}