use anyhow::Result;
use async_openai::config::OpenAIConfig;
use async_openai::types::{ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
use async_openai::Client;
use serde_json::Value;
use futures::future::join_all;
use crate::multi_step_analysis::{
create_analyze_function_tool, create_generate_function_tool, create_score_function_tool, FileDataForScoring, FileWithScore
};
use crate::function_calling::{create_commit_function_tool, CommitFunctionArgs};
use crate::debug_output;
#[derive(Debug)]
pub struct ParsedFile {
pub path: String,
pub operation: String,
pub diff_content: String
}
pub async fn generate_commit_message_multi_step(
client: &Client<OpenAIConfig>, model: &str, diff_content: &str, max_length: Option<usize>
) -> Result<String> {
log::info!("Starting multi-step commit message generation");
if let Some(session) = debug_output::debug_session() {
session.init_multi_step_debug();
}
let parsed_files = parse_diff(diff_content)?;
log::info!("Parsed {} files from diff", parsed_files.len());
if let Some(session) = debug_output::debug_session() {
session.set_total_files_parsed(parsed_files.len());
}
log::debug!("Analyzing {} files in parallel", parsed_files.len());
let analysis_futures: Vec<_> = parsed_files
.iter()
.map(|file| {
let file_path = file.path.clone();
let operation = file.operation.clone();
async move {
log::debug!("Analyzing file: {file_path}");
let start_time = std::time::Instant::now();
let payload = format!("{{\"file_path\": \"{file_path}\", \"operation_type\": \"{operation}\", \"diff_content\": \"...\"}}");
let result = call_analyze_function(client, model, file).await;
let duration = start_time.elapsed();
(file, result, duration, payload)
}
})
.collect();
let analysis_results = join_all(analysis_futures).await;
let mut file_analyses = Vec::new();
for (i, (file, result, duration, payload)) in analysis_results.into_iter().enumerate() {
match result {
Ok(analysis) => {
log::debug!("Successfully analyzed file {}: {}", i, file.path);
let analysis_result = crate::multi_step_analysis::FileAnalysisResult {
lines_added: analysis["lines_added"].as_u64().unwrap_or(0) as u32,
lines_removed: analysis["lines_removed"].as_u64().unwrap_or(0) as u32,
file_category: analysis["file_category"]
.as_str()
.unwrap_or("source")
.to_string(),
summary: analysis["summary"].as_str().unwrap_or("").to_string()
};
if let Some(session) = debug_output::debug_session() {
session.add_file_analysis_debug(file.path.clone(), file.operation.clone(), analysis_result.clone(), duration, payload);
}
file_analyses.push((file, analysis));
}
Err(e) => {
let error_str = e.to_string();
if error_str.contains("invalid_api_key") || error_str.contains("Incorrect API key") || error_str.contains("Invalid API key") {
return Err(e);
}
log::warn!("Failed to analyze file {}: {}", file.path, e);
}
}
}
if file_analyses.is_empty() {
anyhow::bail!("Failed to analyze any files");
}
let files_data: Vec<FileDataForScoring> = file_analyses
.iter()
.map(|(file, analysis)| {
FileDataForScoring {
file_path: file.path.clone(),
operation_type: file.operation.clone(),
lines_added: analysis["lines_added"].as_u64().unwrap_or(0) as u32,
lines_removed: analysis["lines_removed"].as_u64().unwrap_or(0) as u32,
file_category: analysis["file_category"]
.as_str()
.unwrap_or("source")
.to_string(),
summary: analysis["summary"].as_str().unwrap_or("").to_string()
}
})
.collect();
let score_start_time = std::time::Instant::now();
let score_payload = format!(
"{{\"files_data\": [{{\"{}\", ...}}, ...]}}",
if !files_data.is_empty() {
&files_data[0].file_path
} else {
"no files"
}
);
let score_future = call_score_function(client, model, files_data);
let scored_files = score_future.await?;
let score_duration = score_start_time.elapsed();
if let Some(session) = debug_output::debug_session() {
session.set_score_debug(scored_files.clone(), score_duration, score_payload);
}
let generate_start_time = std::time::Instant::now();
let generate_payload = format!("{{\"files_with_scores\": [...], \"max_length\": {}}}", max_length.unwrap_or(72));
let generate_future = call_generate_function(client, model, scored_files.clone(), max_length.unwrap_or(72));
let candidates = generate_future.await?;
let generate_duration = generate_start_time.elapsed();
if let Some(session) = debug_output::debug_session() {
session.set_generate_debug(candidates.clone(), generate_duration, generate_payload);
}
let final_message_start_time = std::time::Instant::now();
let final_message = select_best_candidate(client, model, &candidates, &scored_files, diff_content).await?;
let final_message_duration = final_message_start_time.elapsed();
if let Some(session) = debug_output::debug_session() {
session.set_final_message_debug(final_message_duration);
session.set_commit_result(final_message.clone(), candidates["reasoning"].as_str().unwrap_or("").to_string());
}
Ok(final_message)
}
fn extract_file_path_from_diff_parts(parts: &[&str]) -> Option<String> {
if parts.len() < 4 {
return None;
}
let strip_prefix = |s: &str| {
s.trim_start_matches("a/")
.trim_start_matches("b/")
.trim_start_matches("c/")
.trim_start_matches("i/")
.to_string()
};
let new_path = strip_prefix(parts[3]);
let old_path = strip_prefix(parts[2]);
Some(if new_path == "/dev/null" || new_path == "dev/null" {
old_path
} else {
new_path
})
}
pub fn parse_diff(diff_content: &str) -> Result<Vec<ParsedFile>> {
let mut files = Vec::new();
let mut current_file: Option<ParsedFile> = None;
let mut current_diff = String::new();
log::debug!("Parsing diff with {} lines", diff_content.lines().count());
if log::log_enabled!(log::Level::Debug) && !diff_content.is_empty() {
let preview = if diff_content.len() > 500 {
let truncated_index = diff_content
.char_indices()
.take_while(|(i, _)| *i < 500)
.last()
.map(|(i, c)| i + c.len_utf8())
.unwrap_or(0);
format!("{}... (truncated)", &diff_content[..truncated_index])
} else {
diff_content.to_string()
};
log::debug!("Diff content preview: \n{preview}");
}
let mut in_diff_section = false;
let mut _commit_hash_line: Option<&str> = None;
for line in diff_content.lines().take(3) {
if line.len() >= 40 && line.chars().take(40).all(|c| c.is_ascii_hexdigit()) {
_commit_hash_line = Some(line);
break;
}
}
for line in diff_content.lines() {
if line.starts_with("commit ") || (line.len() >= 40 && line.chars().take(40).all(|c| c.is_ascii_hexdigit())) || line.is_empty() {
continue;
}
if line.starts_with("diff --git") {
in_diff_section = true;
if let Some(mut file) = current_file.take() {
file.diff_content = current_diff.clone();
log::debug!("Adding file to results: {} ({})", file.path, file.operation);
files.push(file);
current_diff.clear();
}
let parts: Vec<&str> = line.split_whitespace().collect();
if let Some(path) = extract_file_path_from_diff_parts(&parts) {
log::debug!("Found new file in diff: {path}");
current_file = Some(ParsedFile {
path,
operation: "modified".to_string(), diff_content: String::new()
});
}
current_diff.push_str(line);
current_diff.push('\n');
} else if line.starts_with("new file mode") {
if let Some(ref mut file) = current_file {
log::debug!("File {} is newly added", file.path);
file.operation = "added".to_string();
}
current_diff.push_str(line);
current_diff.push('\n');
} else if line.starts_with("deleted file mode") {
if let Some(ref mut file) = current_file {
log::debug!("File {} is deleted", file.path);
file.operation = "deleted".to_string();
}
current_diff.push_str(line);
current_diff.push('\n');
} else if line.starts_with("rename from") || line.starts_with("rename to") {
if let Some(ref mut file) = current_file {
log::debug!("File {} is renamed", file.path);
file.operation = "renamed".to_string();
}
current_diff.push_str(line);
current_diff.push('\n');
} else if line.starts_with("Binary files") {
if let Some(ref mut file) = current_file {
log::debug!("File {} is binary", file.path);
file.operation = "binary".to_string();
}
current_diff.push_str(line);
current_diff.push('\n');
} else if line.starts_with("index ") || line.starts_with("--- ") || line.starts_with("+++ ") || line.starts_with("@@ ") {
current_diff.push_str(line);
current_diff.push('\n');
} else if in_diff_section {
current_diff.push_str(line);
current_diff.push('\n');
}
}
if let Some(mut file) = current_file {
file.diff_content = current_diff;
log::debug!("Adding final file to results: {} ({})", file.path, file.operation);
files.push(file);
}
if files.is_empty() && !diff_content.trim().is_empty() {
log::debug!("Trying to parse as raw git diff output with commit info");
let sections: Vec<&str> = diff_content.split("diff --git").skip(1).collect();
if !sections.is_empty() {
for (i, section) in sections.iter().enumerate() {
let full_section = format!("diff --git{section}");
let mut found_path = false;
let mut extracted_path = String::new();
for section_line in full_section.lines().take(3) {
if section_line.starts_with("diff --git") {
let parts: Vec<&str> = section_line.split_whitespace().collect();
if let Some(p) = extract_file_path_from_diff_parts(&parts) {
extracted_path = p;
found_path = true;
break;
}
}
}
if found_path {
log::debug!("Found file in section {i}: {extracted_path}");
files.push(ParsedFile {
path: extracted_path,
operation: "modified".to_string(), diff_content: full_section
});
}
}
}
}
if files.is_empty() && !diff_content.trim().is_empty() {
log::debug!("No standard diff format found, treating as single file change");
files.push(ParsedFile {
path: "unknown".to_string(),
operation: "modified".to_string(),
diff_content: diff_content.to_string()
});
}
log::debug!("Parsed {} files from diff", files.len());
if log::log_enabled!(log::Level::Debug) {
for (i, file) in files.iter().enumerate() {
let content_preview = if file.diff_content.len() > 200 {
let truncated_index = file
.diff_content
.char_indices()
.take_while(|(i, _)| *i < 200)
.last()
.map(|(i, c)| i + c.len_utf8())
.unwrap_or(0);
format!("{}... (truncated)", &file.diff_content[..truncated_index])
} else {
file.diff_content.clone()
};
log::debug!("File {}: {} ({})\nContent preview:\n{}", i, file.path, file.operation, content_preview);
}
}
Ok(files)
}
async fn call_analyze_function(client: &Client<OpenAIConfig>, model: &str, file: &ParsedFile) -> Result<Value> {
let tools = vec![create_analyze_function_tool()?];
let system_message = ChatCompletionRequestSystemMessageArgs::default()
.content("You are a git diff analyzer. Analyze the provided file changes and return structured data.")
.build()?
.into();
let user_message = ChatCompletionRequestUserMessageArgs::default()
.content(format!(
"Analyze this file change:\nPath: {}\nOperation: {}\nDiff:\n{}",
file.path, file.operation, file.diff_content
))
.build()?
.into();
let request = CreateChatCompletionRequestArgs::default()
.model(model)
.messages(vec![system_message, user_message])
.tools(tools)
.tool_choice("analyze")
.build()?;
let response = client.chat().create(request).await?;
if let Some(tool_call) = response.choices[0]
.message
.tool_calls
.as_ref()
.and_then(|calls| calls.first())
{
let args: Value = serde_json::from_str(&tool_call.function.arguments)?;
Ok(args)
} else {
anyhow::bail!("No tool call in response")
}
}
async fn call_score_function(
client: &Client<OpenAIConfig>, model: &str, files_data: Vec<FileDataForScoring>
) -> Result<Vec<FileWithScore>> {
let tools = vec![create_score_function_tool()?];
let system_message = ChatCompletionRequestSystemMessageArgs::default()
.content("You are a git commit impact scorer. Calculate impact scores for the provided file changes.")
.build()?
.into();
let user_message = ChatCompletionRequestUserMessageArgs::default()
.content(format!(
"Calculate impact scores for these {} file changes:\n{}",
files_data.len(),
serde_json::to_string_pretty(&files_data)?
))
.build()?
.into();
let request = CreateChatCompletionRequestArgs::default()
.model(model)
.messages(vec![system_message, user_message])
.tools(tools)
.tool_choice("score")
.build()?;
let response = client.chat().create(request).await?;
if let Some(tool_call) = response.choices[0]
.message
.tool_calls
.as_ref()
.and_then(|calls| calls.first())
{
let args: Value = serde_json::from_str(&tool_call.function.arguments)?;
let files_with_scores: Vec<FileWithScore> = if args["files_with_scores"].is_null() {
Vec::new() } else {
serde_json::from_value(args["files_with_scores"].clone())?
};
Ok(files_with_scores)
} else {
anyhow::bail!("No tool call in response")
}
}
async fn call_generate_function(
client: &Client<OpenAIConfig>, model: &str, files_with_scores: Vec<FileWithScore>, max_length: usize
) -> Result<Value> {
let tools = vec![create_generate_function_tool()?];
let system_message = ChatCompletionRequestSystemMessageArgs::default()
.content("You are a git commit message generator. Generate concise, descriptive commit messages.")
.build()?
.into();
let user_message = ChatCompletionRequestUserMessageArgs::default()
.content(format!(
"Generate commit message candidates (max {} chars) for these scored changes:\n{}",
max_length,
serde_json::to_string_pretty(&files_with_scores)?
))
.build()?
.into();
let request = CreateChatCompletionRequestArgs::default()
.model(model)
.messages(vec![system_message, user_message])
.tools(tools)
.tool_choice("generate")
.build()?;
let response = client.chat().create(request).await?;
if let Some(tool_call) = response.choices[0]
.message
.tool_calls
.as_ref()
.and_then(|calls| calls.first())
{
let args: Value = serde_json::from_str(&tool_call.function.arguments)?;
Ok(args)
} else {
anyhow::bail!("No tool call in response")
}
}
async fn select_best_candidate(
client: &Client<OpenAIConfig>, model: &str, candidates: &Value, scored_files: &[FileWithScore], original_diff: &str
) -> Result<String> {
let tools = vec![create_commit_function_tool(Some(72))?];
let system_message = ChatCompletionRequestSystemMessageArgs::default()
.content(
"You are a git commit message expert. Based on the multi-step analysis, \
select the best commit message and provide the final formatted response."
)
.build()?
.into();
let user_message = ChatCompletionRequestUserMessageArgs::default()
.content(format!(
"Based on this multi-step analysis:\n\n\
Candidates: {}\n\
Reasoning: {}\n\n\
Scored files: {}\n\n\
Original diff:\n{}\n\n\
Select the best commit message and format the response using the commit function.",
serde_json::to_string_pretty(&candidates["candidates"])?,
candidates["reasoning"].as_str().unwrap_or(""),
serde_json::to_string_pretty(&scored_files)?,
original_diff
))
.build()?
.into();
let request = CreateChatCompletionRequestArgs::default()
.model(model)
.messages(vec![system_message, user_message])
.tools(tools)
.tool_choice("commit")
.build()?;
let response = client.chat().create(request).await?;
if let Some(tool_call) = response.choices[0]
.message
.tool_calls
.as_ref()
.and_then(|calls| calls.first())
{
let raw_args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)?;
if let Some(message) = raw_args.get("message").and_then(|m| m.as_str()) {
return Ok(message.to_string());
}
let args: CommitFunctionArgs = serde_json::from_str(&tool_call.function.arguments)?;
Ok(args.message)
} else {
anyhow::bail!("No tool call in response")
}
}
pub fn generate_commit_message_local(diff_content: &str, max_length: Option<usize>) -> Result<String> {
use crate::multi_step_analysis::{analyze_file, calculate_impact_scores, generate_commit_messages};
log::info!("Starting local multi-step commit message generation");
let parsed_files = parse_diff(diff_content)?;
if let Some(session) = debug_output::debug_session() {
session.set_total_files_parsed(parsed_files.len());
}
let mut files_data = Vec::new();
for file in parsed_files {
let analysis = analyze_file(&file.path, &file.diff_content, &file.operation);
files_data.push(FileDataForScoring {
file_path: file.path,
operation_type: file.operation,
lines_added: analysis.lines_added,
lines_removed: analysis.lines_removed,
file_category: analysis.file_category,
summary: analysis.summary
});
}
let score_result = calculate_impact_scores(files_data);
let generate_result = generate_commit_messages(score_result.files_with_scores, max_length.unwrap_or(72));
Ok(
generate_result
.candidates
.first()
.cloned()
.unwrap_or_else(|| "Update files".to_string())
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_diff() {
let diff = r#"diff --git a/src/main.rs b/src/main.rs
index 1234567..abcdefg 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,5 +1,6 @@
fn main() {
- println!("Hello");
+ println!("Hello, world!");
+ println!("New line");
}
diff --git a/Cargo.toml b/Cargo.toml
new file mode 100644
index 0000000..1111111
--- /dev/null
+++ b/Cargo.toml
@@ -0,0 +1,8 @@
+[package]
+name = "test"
+version = "0.1.0"
"#;
let files = parse_diff(diff).unwrap();
assert_eq!(files.len(), 2);
assert_eq!(files[0].path, "src/main.rs");
assert_eq!(files[0].operation, "modified");
assert_eq!(files[1].path, "Cargo.toml");
assert_eq!(files[1].operation, "added");
assert!(!files[0].diff_content.is_empty());
assert!(!files[1].diff_content.is_empty());
}
#[test]
fn test_parse_diff_with_commit_hash() {
let diff = r#"0472ffa1665c4c5573fb8f7698c9965122eda675 Update files
diff --git a/src/openai.rs b/src/openai.rs
index a67ebbe..da223be 100644
--- a/src/openai.rs
+++ b/src/openai.rs
@@ -15,11 +15,6 @@ use crate::multi_step_integration::{generate_commit_message_local, generate_comm
const MAX_ATTEMPTS: usize = 3;
-#[derive(Debug, Clone, PartialEq)]
-pub struct Response {
- pub response: String
-}
-
#[derive(Debug, Clone, PartialEq)]
pub struct Request {
pub prompt: String,
@@ -28,6 +23,11 @@ pub struct Request {
pub model: Model
}
+#[derive(Debug, Clone, PartialEq)]
+pub struct Response {
+ pub response: String
+}
+
/// Generates an improved commit message using the provided prompt and diff
/// Now uses the multi-step approach by default
pub async fn generate_commit_message(diff: &str) -> Result<String> {
"#;
let files = parse_diff(diff).unwrap();
assert_eq!(files.len(), 1);
assert_eq!(files[0].path, "src/openai.rs");
assert_eq!(files[0].operation, "modified");
assert!(files[0].diff_content.contains("pub struct Response"));
assert!(!files[0]
.diff_content
.contains("0472ffa1665c4c5573fb8f7698c9965122eda675"));
}
#[test]
fn test_parse_diff_with_c_i_prefixes() {
let diff = r#"diff --git c/test.md i/test.md
new file mode 100644
index 0000000..6c61a60
--- /dev/null
+++ i/test.md
@@ -0,0 +1 @@
+# Test File
diff --git c/test.js i/test.js
new file mode 100644
index 0000000..a730e61
--- /dev/null
+++ i/test.js
@@ -0,0 +1 @@
+console.log('Hello');
"#;
let files = parse_diff(diff).unwrap();
assert_eq!(files.len(), 2);
assert_eq!(files[0].path, "test.md", "Should extract clean path without i/ prefix");
assert_eq!(files[0].operation, "added");
assert_eq!(files[1].path, "test.js", "Should extract clean path without i/ prefix");
assert_eq!(files[1].operation, "added");
assert!(files[0].diff_content.contains("# Test File"));
assert!(files[1].diff_content.contains("console.log"));
}
#[test]
fn test_parse_diff_with_deleted_file() {
let diff = r#"diff --git a/deleted.txt b/dev/null
deleted file mode 100644
index 1234567..0000000
--- a/deleted.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-This file
-will be
-deleted
"#;
let files = parse_diff(diff).unwrap();
assert_eq!(files.len(), 1);
assert_eq!(files[0].path, "deleted.txt", "Should use a path for deleted files");
assert_eq!(files[0].operation, "deleted");
assert!(files[0].diff_content.contains("This file"));
}
#[test]
fn test_local_generation() {
let diff = r#"diff --git a/src/auth.rs b/src/auth.rs
index 1234567..abcdefg 100644
--- a/src/auth.rs
+++ b/src/auth.rs
@@ -10,7 +10,15 @@ pub fn authenticate(user: &str, pass: &str) -> Result<Token> {
- if user == "admin" && pass == "password" {
- Ok(Token::new())
- } else {
- Err(AuthError::InvalidCredentials)
- }
+ // Validate input
+ if user.is_empty() || pass.is_empty() {
+ return Err(AuthError::EmptyCredentials);
+ }
+
+ // Check credentials against database
+ let hashed = hash_password(pass);
+ if validate_user(user, &hashed)? {
+ Ok(Token::generate(user))
+ } else {
+ Err(AuthError::InvalidCredentials)
+ }
}"#;
let message = generate_commit_message_local(diff, Some(72)).unwrap();
assert!(!message.is_empty());
assert!(message.len() <= 72);
}
}