1use anyhow::Result;
2use async_openai::config::OpenAIConfig;
3use async_openai::types::{ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
4use async_openai::Client;
5use serde_json::Value;
6use futures::future::join_all;
7
8use crate::multi_step_analysis::{
9 create_analyze_function_tool, create_generate_function_tool, create_score_function_tool, FileDataForScoring, FileWithScore
10};
11use crate::function_calling::{create_commit_function_tool, CommitFunctionArgs};
12use crate::debug_output;
13
14#[derive(Debug)]
16pub struct ParsedFile {
17 pub path: String,
18 pub operation: String,
19 pub diff_content: String
20}
21
22pub async fn generate_commit_message_multi_step(
24 client: &Client<OpenAIConfig>, model: &str, diff_content: &str, max_length: Option<usize>
25) -> Result<String> {
26 log::info!("Starting multi-step commit message generation");
27
28 if let Some(session) = debug_output::debug_session() {
30 session.init_multi_step_debug();
31 }
32
33 let parsed_files = parse_diff(diff_content)?;
35 log::info!("Parsed {} files from diff", parsed_files.len());
36
37 if let Some(session) = debug_output::debug_session() {
39 session.set_total_files_parsed(parsed_files.len());
40 }
41
42 log::debug!("Analyzing {} files in parallel", parsed_files.len());
44
45 let analysis_futures: Vec<_> = parsed_files
47 .iter()
48 .map(|file| {
49 let file_path = file.path.clone();
50 let operation = file.operation.clone();
51 async move {
52 log::debug!("Analyzing file: {file_path}");
53 let start_time = std::time::Instant::now();
54 let payload = format!("{{\"file_path\": \"{file_path}\", \"operation_type\": \"{operation}\", \"diff_content\": \"...\"}}");
55
56 let result = call_analyze_function(client, model, file).await;
57 let duration = start_time.elapsed();
58 (file, result, duration, payload)
59 }
60 })
61 .collect();
62
63 let analysis_results = join_all(analysis_futures).await;
65
66 let mut file_analyses = Vec::new();
68 for (i, (file, result, duration, payload)) in analysis_results.into_iter().enumerate() {
69 match result {
70 Ok(analysis) => {
71 log::debug!("Successfully analyzed file {}: {}", i, file.path);
72
73 let analysis_result = crate::multi_step_analysis::FileAnalysisResult {
75 lines_added: analysis["lines_added"].as_u64().unwrap_or(0) as u32,
76 lines_removed: analysis["lines_removed"].as_u64().unwrap_or(0) as u32,
77 file_category: analysis["file_category"]
78 .as_str()
79 .unwrap_or("source")
80 .to_string(),
81 summary: analysis["summary"].as_str().unwrap_or("").to_string()
82 };
83
84 if let Some(session) = debug_output::debug_session() {
86 session.add_file_analysis_debug(file.path.clone(), file.operation.clone(), analysis_result.clone(), duration, payload);
87 }
88
89 file_analyses.push((file, analysis));
90 }
91 Err(e) => {
92 let error_str = e.to_string();
94 if error_str.contains("invalid_api_key") || error_str.contains("Incorrect API key") || error_str.contains("Invalid API key") {
95 return Err(e);
96 }
97 log::warn!("Failed to analyze file {}: {}", file.path, e);
98 }
100 }
101 }
102
103 if file_analyses.is_empty() {
104 anyhow::bail!("Failed to analyze any files");
105 }
106
107 let files_data: Vec<FileDataForScoring> = file_analyses
109 .iter()
110 .map(|(file, analysis)| {
111 FileDataForScoring {
112 file_path: file.path.clone(),
113 operation_type: file.operation.clone(),
114 lines_added: analysis["lines_added"].as_u64().unwrap_or(0) as u32,
115 lines_removed: analysis["lines_removed"].as_u64().unwrap_or(0) as u32,
116 file_category: analysis["file_category"]
117 .as_str()
118 .unwrap_or("source")
119 .to_string(),
120 summary: analysis["summary"].as_str().unwrap_or("").to_string()
121 }
122 })
123 .collect();
124
125 let score_start_time = std::time::Instant::now();
127 let score_payload = format!(
128 "{{\"files_data\": [{{\"{}\", ...}}, ...]}}",
129 if !files_data.is_empty() {
130 &files_data[0].file_path
131 } else {
132 "no files"
133 }
134 );
135
136 let score_future = call_score_function(client, model, files_data);
139
140 let scored_files = score_future.await?;
142 let score_duration = score_start_time.elapsed();
143
144 if let Some(session) = debug_output::debug_session() {
146 session.set_score_debug(scored_files.clone(), score_duration, score_payload);
147 }
148
149 let generate_start_time = std::time::Instant::now();
151 let generate_payload = format!("{{\"files_with_scores\": [...], \"max_length\": {}}}", max_length.unwrap_or(72));
152
153 let generate_future = call_generate_function(client, model, scored_files.clone(), max_length.unwrap_or(72));
155
156 let candidates = generate_future.await?;
157 let generate_duration = generate_start_time.elapsed();
158
159 if let Some(session) = debug_output::debug_session() {
161 session.set_generate_debug(candidates.clone(), generate_duration, generate_payload);
162 }
163
164 let final_message_start_time = std::time::Instant::now();
166 let final_message = select_best_candidate(client, model, &candidates, &scored_files, diff_content).await?;
167 let final_message_duration = final_message_start_time.elapsed();
168
169 if let Some(session) = debug_output::debug_session() {
171 session.set_final_message_debug(final_message_duration);
172 session.set_commit_result(final_message.clone(), candidates["reasoning"].as_str().unwrap_or("").to_string());
173 }
174
175 Ok(final_message)
176}
177
178fn extract_file_path_from_diff_parts(parts: &[&str]) -> Option<String> {
187 if parts.len() < 4 {
188 return None;
189 }
190
191 let strip_prefix = |s: &str| {
193 s.trim_start_matches("a/")
194 .trim_start_matches("b/")
195 .trim_start_matches("c/")
196 .trim_start_matches("i/")
197 .to_string()
198 };
199
200 let new_path = strip_prefix(parts[3]);
201 let old_path = strip_prefix(parts[2]);
202
203 Some(if new_path == "/dev/null" || new_path == "dev/null" {
205 old_path
206 } else {
207 new_path
208 })
209}
210
211pub fn parse_diff(diff_content: &str) -> Result<Vec<ParsedFile>> {
213 let mut files = Vec::new();
214 let mut current_file: Option<ParsedFile> = None;
215 let mut current_diff = String::new();
216
217 log::debug!("Parsing diff with {} lines", diff_content.lines().count());
219
220 if log::log_enabled!(log::Level::Debug) && !diff_content.is_empty() {
222 let preview = if diff_content.len() > 500 {
224 let truncated_index = diff_content
225 .char_indices()
226 .take_while(|(i, _)| *i < 500)
227 .last()
228 .map(|(i, c)| i + c.len_utf8())
229 .unwrap_or(0);
230
231 format!("{}... (truncated)", &diff_content[..truncated_index])
232 } else {
233 diff_content.to_string()
234 };
235 log::debug!("Diff content preview: \n{preview}");
236 }
237
238 let mut in_diff_section = false;
240 let mut _commit_hash_line: Option<&str> = None;
241
242 for line in diff_content.lines().take(3) {
244 if line.len() >= 40 && line.chars().take(40).all(|c| c.is_ascii_hexdigit()) {
245 _commit_hash_line = Some(line);
246 break;
247 }
248 }
249
250 for line in diff_content.lines() {
252 if line.starts_with("commit ") || (line.len() >= 40 && line.chars().take(40).all(|c| c.is_ascii_hexdigit())) || line.is_empty() {
254 continue;
255 }
256
257 if line.starts_with("diff --git") {
259 in_diff_section = true;
260 if let Some(mut file) = current_file.take() {
262 file.diff_content = current_diff.clone();
263 log::debug!("Adding file to results: {} ({})", file.path, file.operation);
264 files.push(file);
265 current_diff.clear();
266 }
267
268 let parts: Vec<&str> = line.split_whitespace().collect();
270 if let Some(path) = extract_file_path_from_diff_parts(&parts) {
271 log::debug!("Found new file in diff: {path}");
272 current_file = Some(ParsedFile {
273 path,
274 operation: "modified".to_string(), diff_content: String::new()
276 });
277 }
278
279 current_diff.push_str(line);
281 current_diff.push('\n');
282 } else if line.starts_with("new file mode") {
283 if let Some(ref mut file) = current_file {
284 log::debug!("File {} is newly added", file.path);
285 file.operation = "added".to_string();
286 }
287 current_diff.push_str(line);
288 current_diff.push('\n');
289 } else if line.starts_with("deleted file mode") {
290 if let Some(ref mut file) = current_file {
291 log::debug!("File {} is deleted", file.path);
292 file.operation = "deleted".to_string();
293 }
294 current_diff.push_str(line);
295 current_diff.push('\n');
296 } else if line.starts_with("rename from") || line.starts_with("rename to") {
297 if let Some(ref mut file) = current_file {
298 log::debug!("File {} is renamed", file.path);
299 file.operation = "renamed".to_string();
300 }
301 current_diff.push_str(line);
302 current_diff.push('\n');
303 } else if line.starts_with("Binary files") {
304 if let Some(ref mut file) = current_file {
305 log::debug!("File {} is binary", file.path);
306 file.operation = "binary".to_string();
307 }
308 current_diff.push_str(line);
309 current_diff.push('\n');
310 } else if line.starts_with("index ") || line.starts_with("--- ") || line.starts_with("+++ ") || line.starts_with("@@ ") {
311 current_diff.push_str(line);
313 current_diff.push('\n');
314 } else if in_diff_section {
315 current_diff.push_str(line);
316 current_diff.push('\n');
317 }
318 }
319
320 if let Some(mut file) = current_file {
322 file.diff_content = current_diff;
323 log::debug!("Adding final file to results: {} ({})", file.path, file.operation);
324 files.push(file);
325 }
326
327 if files.is_empty() && !diff_content.trim().is_empty() {
330 log::debug!("Trying to parse as raw git diff output with commit info");
331
332 let sections: Vec<&str> = diff_content.split("diff --git").skip(1).collect();
334
335 if !sections.is_empty() {
336 for (i, section) in sections.iter().enumerate() {
337 let full_section = format!("diff --git{section}");
339
340 let mut found_path = false;
342
343 let mut extracted_path = String::new();
345 for section_line in full_section.lines().take(3) {
346 if section_line.starts_with("diff --git") {
347 let parts: Vec<&str> = section_line.split_whitespace().collect();
348 if let Some(p) = extract_file_path_from_diff_parts(&parts) {
349 extracted_path = p;
350 found_path = true;
351 break;
352 }
353 }
354 }
355
356 if found_path {
357 log::debug!("Found file in section {i}: {extracted_path}");
358 files.push(ParsedFile {
359 path: extracted_path,
360 operation: "modified".to_string(), diff_content: full_section
362 });
363 }
364 }
365 }
366 }
367
368 if files.is_empty() && !diff_content.trim().is_empty() {
370 log::debug!("No standard diff format found, treating as single file change");
371 files.push(ParsedFile {
372 path: "unknown".to_string(),
373 operation: "modified".to_string(),
374 diff_content: diff_content.to_string()
375 });
376 }
377
378 log::debug!("Parsed {} files from diff", files.len());
379
380 if log::log_enabled!(log::Level::Debug) {
382 for (i, file) in files.iter().enumerate() {
383 let content_preview = if file.diff_content.len() > 200 {
384 let truncated_index = file
386 .diff_content
387 .char_indices()
388 .take_while(|(i, _)| *i < 200)
389 .last()
390 .map(|(i, c)| i + c.len_utf8())
391 .unwrap_or(0);
392
393 format!("{}... (truncated)", &file.diff_content[..truncated_index])
394 } else {
395 file.diff_content.clone()
396 };
397 log::debug!("File {}: {} ({})\nContent preview:\n{}", i, file.path, file.operation, content_preview);
398 }
399 }
400
401 Ok(files)
402}
403
404async fn call_analyze_function(client: &Client<OpenAIConfig>, model: &str, file: &ParsedFile) -> Result<Value> {
406 let tools = vec![create_analyze_function_tool()?];
407
408 let system_message = ChatCompletionRequestSystemMessageArgs::default()
409 .content("You are a git diff analyzer. Analyze the provided file changes and return structured data.")
410 .build()?
411 .into();
412
413 let user_message = ChatCompletionRequestUserMessageArgs::default()
414 .content(format!(
415 "Analyze this file change:\nPath: {}\nOperation: {}\nDiff:\n{}",
416 file.path, file.operation, file.diff_content
417 ))
418 .build()?
419 .into();
420
421 let request = CreateChatCompletionRequestArgs::default()
422 .model(model)
423 .messages(vec![system_message, user_message])
424 .tools(tools)
425 .tool_choice("analyze")
426 .build()?;
427
428 let response = client.chat().create(request).await?;
429
430 if let Some(tool_call) = response.choices[0]
431 .message
432 .tool_calls
433 .as_ref()
434 .and_then(|calls| calls.first())
435 {
436 let args: Value = serde_json::from_str(&tool_call.function.arguments)?;
437 Ok(args)
438 } else {
439 anyhow::bail!("No tool call in response")
440 }
441}
442
443async fn call_score_function(
445 client: &Client<OpenAIConfig>, model: &str, files_data: Vec<FileDataForScoring>
446) -> Result<Vec<FileWithScore>> {
447 let tools = vec![create_score_function_tool()?];
448
449 let system_message = ChatCompletionRequestSystemMessageArgs::default()
450 .content("You are a git commit impact scorer. Calculate impact scores for the provided file changes.")
451 .build()?
452 .into();
453
454 let user_message = ChatCompletionRequestUserMessageArgs::default()
455 .content(format!(
456 "Calculate impact scores for these {} file changes:\n{}",
457 files_data.len(),
458 serde_json::to_string_pretty(&files_data)?
459 ))
460 .build()?
461 .into();
462
463 let request = CreateChatCompletionRequestArgs::default()
464 .model(model)
465 .messages(vec![system_message, user_message])
466 .tools(tools)
467 .tool_choice("score")
468 .build()?;
469
470 let response = client.chat().create(request).await?;
471
472 if let Some(tool_call) = response.choices[0]
473 .message
474 .tool_calls
475 .as_ref()
476 .and_then(|calls| calls.first())
477 {
478 let args: Value = serde_json::from_str(&tool_call.function.arguments)?;
479 let files_with_scores: Vec<FileWithScore> = if args["files_with_scores"].is_null() {
480 Vec::new() } else {
482 serde_json::from_value(args["files_with_scores"].clone())?
483 };
484 Ok(files_with_scores)
485 } else {
486 anyhow::bail!("No tool call in response")
487 }
488}
489
490async fn call_generate_function(
492 client: &Client<OpenAIConfig>, model: &str, files_with_scores: Vec<FileWithScore>, max_length: usize
493) -> Result<Value> {
494 let tools = vec![create_generate_function_tool()?];
495
496 let system_message = ChatCompletionRequestSystemMessageArgs::default()
497 .content("You are a git commit message generator. Generate concise, descriptive commit messages.")
498 .build()?
499 .into();
500
501 let user_message = ChatCompletionRequestUserMessageArgs::default()
502 .content(format!(
503 "Generate commit message candidates (max {} chars) for these scored changes:\n{}",
504 max_length,
505 serde_json::to_string_pretty(&files_with_scores)?
506 ))
507 .build()?
508 .into();
509
510 let request = CreateChatCompletionRequestArgs::default()
511 .model(model)
512 .messages(vec![system_message, user_message])
513 .tools(tools)
514 .tool_choice("generate")
515 .build()?;
516
517 let response = client.chat().create(request).await?;
518
519 if let Some(tool_call) = response.choices[0]
520 .message
521 .tool_calls
522 .as_ref()
523 .and_then(|calls| calls.first())
524 {
525 let args: Value = serde_json::from_str(&tool_call.function.arguments)?;
526 Ok(args)
527 } else {
528 anyhow::bail!("No tool call in response")
529 }
530}
531
532async fn select_best_candidate(
534 client: &Client<OpenAIConfig>, model: &str, candidates: &Value, scored_files: &[FileWithScore], original_diff: &str
535) -> Result<String> {
536 let tools = vec![create_commit_function_tool(Some(72))?];
538
539 let system_message = ChatCompletionRequestSystemMessageArgs::default()
540 .content(
541 "You are a git commit message expert. Based on the multi-step analysis, \
542 select the best commit message and provide the final formatted response."
543 )
544 .build()?
545 .into();
546
547 let user_message = ChatCompletionRequestUserMessageArgs::default()
548 .content(format!(
549 "Based on this multi-step analysis:\n\n\
550 Candidates: {}\n\
551 Reasoning: {}\n\n\
552 Scored files: {}\n\n\
553 Original diff:\n{}\n\n\
554 Select the best commit message and format the response using the commit function.",
555 serde_json::to_string_pretty(&candidates["candidates"])?,
556 candidates["reasoning"].as_str().unwrap_or(""),
557 serde_json::to_string_pretty(&scored_files)?,
558 original_diff
559 ))
560 .build()?
561 .into();
562
563 let request = CreateChatCompletionRequestArgs::default()
564 .model(model)
565 .messages(vec![system_message, user_message])
566 .tools(tools)
567 .tool_choice("commit")
568 .build()?;
569
570 let response = client.chat().create(request).await?;
571
572 if let Some(tool_call) = response.choices[0]
573 .message
574 .tool_calls
575 .as_ref()
576 .and_then(|calls| calls.first())
577 {
578 let raw_args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)?;
580
581 if let Some(message) = raw_args.get("message").and_then(|m| m.as_str()) {
583 return Ok(message.to_string());
584 }
585
586 let args: CommitFunctionArgs = serde_json::from_str(&tool_call.function.arguments)?;
588 Ok(args.message)
589 } else {
590 anyhow::bail!("No tool call in response")
591 }
592}
593
594pub fn generate_commit_message_local(diff_content: &str, max_length: Option<usize>) -> Result<String> {
596 use crate::multi_step_analysis::{analyze_file, calculate_impact_scores, generate_commit_messages};
597
598 log::info!("Starting local multi-step commit message generation");
599
600 let parsed_files = parse_diff(diff_content)?;
602
603 if let Some(session) = debug_output::debug_session() {
605 session.set_total_files_parsed(parsed_files.len());
606 }
607
608 let mut files_data = Vec::new();
610 for file in parsed_files {
611 let analysis = analyze_file(&file.path, &file.diff_content, &file.operation);
612 files_data.push(FileDataForScoring {
613 file_path: file.path,
614 operation_type: file.operation,
615 lines_added: analysis.lines_added,
616 lines_removed: analysis.lines_removed,
617 file_category: analysis.file_category,
618 summary: analysis.summary
619 });
620 }
621
622 let score_result = calculate_impact_scores(files_data);
624
625 let generate_result = generate_commit_messages(score_result.files_with_scores, max_length.unwrap_or(72));
627
628 Ok(
630 generate_result
631 .candidates
632 .first()
633 .cloned()
634 .unwrap_or_else(|| "Update files".to_string())
635 )
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641
642 #[test]
643 fn test_parse_diff() {
644 let diff = r#"diff --git a/src/main.rs b/src/main.rs
645index 1234567..abcdefg 100644
646--- a/src/main.rs
647+++ b/src/main.rs
648@@ -1,5 +1,6 @@
649 fn main() {
650- println!("Hello");
651+ println!("Hello, world!");
652+ println!("New line");
653 }
654diff --git a/Cargo.toml b/Cargo.toml
655new file mode 100644
656index 0000000..1111111
657--- /dev/null
658+++ b/Cargo.toml
659@@ -0,0 +1,8 @@
660+[package]
661+name = "test"
662+version = "0.1.0"
663"#;
664
665 let files = parse_diff(diff).unwrap();
666 assert_eq!(files.len(), 2);
667 assert_eq!(files[0].path, "src/main.rs");
668 assert_eq!(files[0].operation, "modified");
669 assert_eq!(files[1].path, "Cargo.toml");
670 assert_eq!(files[1].operation, "added");
671
672 assert!(!files[0].diff_content.is_empty());
674 assert!(!files[1].diff_content.is_empty());
675 }
676
677 #[test]
678 fn test_parse_diff_with_commit_hash() {
679 let diff = r#"0472ffa1665c4c5573fb8f7698c9965122eda675 Update files
681diff --git a/src/openai.rs b/src/openai.rs
682index a67ebbe..da223be 100644
683--- a/src/openai.rs
684+++ b/src/openai.rs
685@@ -15,11 +15,6 @@ use crate::multi_step_integration::{generate_commit_message_local, generate_comm
686
687 const MAX_ATTEMPTS: usize = 3;
688
689-#[derive(Debug, Clone, PartialEq)]
690-pub struct Response {
691- pub response: String
692-}
693-
694 #[derive(Debug, Clone, PartialEq)]
695 pub struct Request {
696 pub prompt: String,
697@@ -28,6 +23,11 @@ pub struct Request {
698 pub model: Model
699 }
700
701+#[derive(Debug, Clone, PartialEq)]
702+pub struct Response {
703+ pub response: String
704+}
705+
706 /// Generates an improved commit message using the provided prompt and diff
707 /// Now uses the multi-step approach by default
708 pub async fn generate_commit_message(diff: &str) -> Result<String> {
709"#;
710
711 let files = parse_diff(diff).unwrap();
712 assert_eq!(files.len(), 1);
713 assert_eq!(files[0].path, "src/openai.rs");
714 assert_eq!(files[0].operation, "modified");
715
716 assert!(files[0].diff_content.contains("pub struct Response"));
718
719 assert!(!files[0]
721 .diff_content
722 .contains("0472ffa1665c4c5573fb8f7698c9965122eda675"));
723 }
724
725 #[test]
726 fn test_parse_diff_with_c_i_prefixes() {
727 let diff = r#"diff --git c/test.md i/test.md
729new file mode 100644
730index 0000000..6c61a60
731--- /dev/null
732+++ i/test.md
733@@ -0,0 +1 @@
734+# Test File
735
736diff --git c/test.js i/test.js
737new file mode 100644
738index 0000000..a730e61
739--- /dev/null
740+++ i/test.js
741@@ -0,0 +1 @@
742+console.log('Hello');
743"#;
744
745 let files = parse_diff(diff).unwrap();
746 assert_eq!(files.len(), 2);
747 assert_eq!(files[0].path, "test.md", "Should extract clean path without i/ prefix");
748 assert_eq!(files[0].operation, "added");
749 assert_eq!(files[1].path, "test.js", "Should extract clean path without i/ prefix");
750 assert_eq!(files[1].operation, "added");
751
752 assert!(files[0].diff_content.contains("# Test File"));
754 assert!(files[1].diff_content.contains("console.log"));
755 }
756
757 #[test]
758 fn test_parse_diff_with_deleted_file() {
759 let diff = r#"diff --git a/deleted.txt b/dev/null
761deleted file mode 100644
762index 1234567..0000000
763--- a/deleted.txt
764+++ /dev/null
765@@ -1,3 +0,0 @@
766-This file
767-will be
768-deleted
769"#;
770
771 let files = parse_diff(diff).unwrap();
772 assert_eq!(files.len(), 1);
773 assert_eq!(files[0].path, "deleted.txt", "Should use a path for deleted files");
774 assert_eq!(files[0].operation, "deleted");
775
776 assert!(files[0].diff_content.contains("This file"));
778 }
779
780 #[test]
781 fn test_local_generation() {
782 let diff = r#"diff --git a/src/auth.rs b/src/auth.rs
783index 1234567..abcdefg 100644
784--- a/src/auth.rs
785+++ b/src/auth.rs
786@@ -10,7 +10,15 @@ pub fn authenticate(user: &str, pass: &str) -> Result<Token> {
787- if user == "admin" && pass == "password" {
788- Ok(Token::new())
789- } else {
790- Err(AuthError::InvalidCredentials)
791- }
792+ // Validate input
793+ if user.is_empty() || pass.is_empty() {
794+ return Err(AuthError::EmptyCredentials);
795+ }
796+
797+ // Check credentials against database
798+ let hashed = hash_password(pass);
799+ if validate_user(user, &hashed)? {
800+ Ok(Token::generate(user))
801+ } else {
802+ Err(AuthError::InvalidCredentials)
803+ }
804 }"#;
805
806 let message = generate_commit_message_local(diff, Some(72)).unwrap();
807 assert!(!message.is_empty());
808 assert!(message.len() <= 72);
809 }
810}