1use anyhow::Result;
2use async_openai::config::OpenAIConfig;
3use async_openai::types::{ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
4use async_openai::Client;
5use serde_json::Value;
6use futures::future::join_all;
7
8use crate::multi_step_analysis::{
9 create_analyze_function_tool, create_generate_function_tool, create_score_function_tool, FileDataForScoring, FileWithScore
10};
11use crate::function_calling::{create_commit_function_tool, CommitFunctionArgs};
12use crate::debug_output;
13
14#[derive(Debug)]
16pub struct ParsedFile {
17 pub path: String,
18 pub operation: String,
19 pub diff_content: String
20}
21
22pub async fn generate_commit_message_multi_step(
24 client: &Client<OpenAIConfig>, model: &str, diff_content: &str, max_length: Option<usize>
25) -> Result<String> {
26 log::info!("Starting multi-step commit message generation");
27
28 if let Some(session) = debug_output::debug_session() {
30 session.init_multi_step_debug();
31 }
32
33 let parsed_files = parse_diff(diff_content)?;
35 log::info!("Parsed {} files from diff", parsed_files.len());
36
37 if let Some(session) = debug_output::debug_session() {
39 session.set_total_files_parsed(parsed_files.len());
40 }
41
42 log::debug!("Analyzing {} files in parallel", parsed_files.len());
44
45 let analysis_futures: Vec<_> = parsed_files
47 .iter()
48 .map(|file| {
49 let file_path = file.path.clone();
50 let operation = file.operation.clone();
51 async move {
52 log::debug!("Analyzing file: {file_path}");
53 let start_time = std::time::Instant::now();
54 let payload = format!("{{\"file_path\": \"{file_path}\", \"operation_type\": \"{operation}\", \"diff_content\": \"...\"}}");
55
56 let result = call_analyze_function(client, model, file).await;
57 let duration = start_time.elapsed();
58 (file, result, duration, payload)
59 }
60 })
61 .collect();
62
63 let analysis_results = join_all(analysis_futures).await;
65
66 let mut file_analyses = Vec::new();
68 for (i, (file, result, duration, payload)) in analysis_results.into_iter().enumerate() {
69 match result {
70 Ok(analysis) => {
71 log::debug!("Successfully analyzed file {}: {}", i, file.path);
72
73 let analysis_result = crate::multi_step_analysis::FileAnalysisResult {
75 lines_added: analysis["lines_added"].as_u64().unwrap_or(0) as u32,
76 lines_removed: analysis["lines_removed"].as_u64().unwrap_or(0) as u32,
77 file_category: analysis["file_category"]
78 .as_str()
79 .unwrap_or("source")
80 .to_string(),
81 summary: analysis["summary"].as_str().unwrap_or("").to_string()
82 };
83
84 if let Some(session) = debug_output::debug_session() {
86 session.add_file_analysis_debug(file.path.clone(), file.operation.clone(), analysis_result.clone(), duration, payload);
87 }
88
89 file_analyses.push((file, analysis));
90 }
91 Err(e) => {
92 log::warn!("Failed to analyze file {}: {}", file.path, e);
93 }
95 }
96 }
97
98 if file_analyses.is_empty() {
99 anyhow::bail!("Failed to analyze any files");
100 }
101
102 let files_data: Vec<FileDataForScoring> = file_analyses
104 .iter()
105 .map(|(file, analysis)| {
106 FileDataForScoring {
107 file_path: file.path.clone(),
108 operation_type: file.operation.clone(),
109 lines_added: analysis["lines_added"].as_u64().unwrap_or(0) as u32,
110 lines_removed: analysis["lines_removed"].as_u64().unwrap_or(0) as u32,
111 file_category: analysis["file_category"]
112 .as_str()
113 .unwrap_or("source")
114 .to_string(),
115 summary: analysis["summary"].as_str().unwrap_or("").to_string()
116 }
117 })
118 .collect();
119
120 let score_start_time = std::time::Instant::now();
122 let score_payload = format!(
123 "{{\"files_data\": [{{\"{}\", ...}}, ...]}}",
124 if !files_data.is_empty() {
125 &files_data[0].file_path
126 } else {
127 "no files"
128 }
129 );
130
131 let score_future = call_score_function(client, model, files_data);
134
135 let scored_files = score_future.await?;
137 let score_duration = score_start_time.elapsed();
138
139 if let Some(session) = debug_output::debug_session() {
141 session.set_score_debug(scored_files.clone(), score_duration, score_payload);
142 }
143
144 let generate_start_time = std::time::Instant::now();
146 let generate_payload = format!("{{\"files_with_scores\": [...], \"max_length\": {}}}", max_length.unwrap_or(72));
147
148 let generate_future = call_generate_function(client, model, scored_files.clone(), max_length.unwrap_or(72));
150
151 let candidates = generate_future.await?;
152 let generate_duration = generate_start_time.elapsed();
153
154 if let Some(session) = debug_output::debug_session() {
156 session.set_generate_debug(candidates.clone(), generate_duration, generate_payload);
157 }
158
159 let final_message_start_time = std::time::Instant::now();
161 let final_message = select_best_candidate(client, model, &candidates, &scored_files, diff_content).await?;
162 let final_message_duration = final_message_start_time.elapsed();
163
164 if let Some(session) = debug_output::debug_session() {
166 session.set_final_message_debug(final_message_duration);
167 session.set_commit_result(final_message.clone(), candidates["reasoning"].as_str().unwrap_or("").to_string());
168 }
169
170 Ok(final_message)
171}
172
173pub fn parse_diff(diff_content: &str) -> Result<Vec<ParsedFile>> {
175 let mut files = Vec::new();
176 let mut current_file: Option<ParsedFile> = None;
177 let mut current_diff = String::new();
178
179 log::debug!("Parsing diff with {} lines", diff_content.lines().count());
181
182 if log::log_enabled!(log::Level::Debug) && !diff_content.is_empty() {
184 let preview = if diff_content.len() > 500 {
186 let truncated_index = diff_content
187 .char_indices()
188 .take_while(|(i, _)| *i < 500)
189 .last()
190 .map(|(i, c)| i + c.len_utf8())
191 .unwrap_or(0);
192
193 format!("{}... (truncated)", &diff_content[..truncated_index])
194 } else {
195 diff_content.to_string()
196 };
197 log::debug!("Diff content preview: \n{preview}");
198 }
199
200 let mut in_diff_section = false;
202 let mut _commit_hash_line: Option<&str> = None;
203
204 for line in diff_content.lines().take(3) {
206 if line.len() >= 40 && line.chars().take(40).all(|c| c.is_ascii_hexdigit()) {
207 _commit_hash_line = Some(line);
208 break;
209 }
210 }
211
212 for line in diff_content.lines() {
214 if line.starts_with("commit ") || (line.len() >= 40 && line.chars().take(40).all(|c| c.is_ascii_hexdigit())) || line.is_empty() {
216 continue;
217 }
218
219 if line.starts_with("diff --git") {
221 in_diff_section = true;
222 if let Some(mut file) = current_file.take() {
224 file.diff_content = current_diff.clone();
225 log::debug!("Adding file to results: {} ({})", file.path, file.operation);
226 files.push(file);
227 current_diff.clear();
228 }
229
230 let parts: Vec<&str> = line.split_whitespace().collect();
232 if parts.len() >= 4 {
233 let a_path = parts[2].trim_start_matches("a/");
234 let b_path = parts[3].trim_start_matches("b/");
235
236 let path = if !b_path.is_empty() {
238 b_path
239 } else {
240 a_path
241 };
242 log::debug!("Found new file in diff: {path}");
243 current_file = Some(ParsedFile {
244 path: path.to_string(),
245 operation: "modified".to_string(), diff_content: String::new()
247 });
248 }
249
250 current_diff.push_str(line);
252 current_diff.push('\n');
253 } else if line.starts_with("new file mode") {
254 if let Some(ref mut file) = current_file {
255 log::debug!("File {} is newly added", file.path);
256 file.operation = "added".to_string();
257 }
258 current_diff.push_str(line);
259 current_diff.push('\n');
260 } else if line.starts_with("deleted file mode") {
261 if let Some(ref mut file) = current_file {
262 log::debug!("File {} is deleted", file.path);
263 file.operation = "deleted".to_string();
264 }
265 current_diff.push_str(line);
266 current_diff.push('\n');
267 } else if line.starts_with("rename from") || line.starts_with("rename to") {
268 if let Some(ref mut file) = current_file {
269 log::debug!("File {} is renamed", file.path);
270 file.operation = "renamed".to_string();
271 }
272 current_diff.push_str(line);
273 current_diff.push('\n');
274 } else if line.starts_with("Binary files") {
275 if let Some(ref mut file) = current_file {
276 log::debug!("File {} is binary", file.path);
277 file.operation = "binary".to_string();
278 }
279 current_diff.push_str(line);
280 current_diff.push('\n');
281 } else if line.starts_with("index ") || line.starts_with("--- ") || line.starts_with("+++ ") || line.starts_with("@@ ") {
282 current_diff.push_str(line);
284 current_diff.push('\n');
285 } else if in_diff_section {
286 current_diff.push_str(line);
287 current_diff.push('\n');
288 }
289 }
290
291 if let Some(mut file) = current_file {
293 file.diff_content = current_diff;
294 log::debug!("Adding final file to results: {} ({})", file.path, file.operation);
295 files.push(file);
296 }
297
298 if files.is_empty() && !diff_content.trim().is_empty() {
301 log::debug!("Trying to parse as raw git diff output with commit info");
302
303 let sections: Vec<&str> = diff_content.split("diff --git").skip(1).collect();
305
306 if !sections.is_empty() {
307 for (i, section) in sections.iter().enumerate() {
308 let full_section = format!("diff --git{section}");
310
311 let mut path = "unknown";
313 let mut found_path = false;
314
315 for section_line in full_section.lines().take(3) {
317 if section_line.starts_with("diff --git") {
318 let parts: Vec<&str> = section_line.split_whitespace().collect();
319 if parts.len() >= 4 {
320 path = parts[3].trim_start_matches("b/");
321 found_path = true;
322 break;
323 }
324 }
325 }
326
327 if found_path {
328 log::debug!("Found file in section {i}: {path}");
329 files.push(ParsedFile {
330 path: path.to_string(),
331 operation: "modified".to_string(), diff_content: full_section
333 });
334 }
335 }
336 }
337 }
338
339 if files.is_empty() && !diff_content.trim().is_empty() {
341 log::debug!("No standard diff format found, treating as single file change");
342 files.push(ParsedFile {
343 path: "unknown".to_string(),
344 operation: "modified".to_string(),
345 diff_content: diff_content.to_string()
346 });
347 }
348
349 log::debug!("Parsed {} files from diff", files.len());
350
351 if log::log_enabled!(log::Level::Debug) {
353 for (i, file) in files.iter().enumerate() {
354 let content_preview = if file.diff_content.len() > 200 {
355 let truncated_index = file
357 .diff_content
358 .char_indices()
359 .take_while(|(i, _)| *i < 200)
360 .last()
361 .map(|(i, c)| i + c.len_utf8())
362 .unwrap_or(0);
363
364 format!("{}... (truncated)", &file.diff_content[..truncated_index])
365 } else {
366 file.diff_content.clone()
367 };
368 log::debug!("File {}: {} ({})\nContent preview:\n{}", i, file.path, file.operation, content_preview);
369 }
370 }
371
372 Ok(files)
373}
374
375async fn call_analyze_function(client: &Client<OpenAIConfig>, model: &str, file: &ParsedFile) -> Result<Value> {
377 let tools = vec![create_analyze_function_tool()?];
378
379 let system_message = ChatCompletionRequestSystemMessageArgs::default()
380 .content("You are a git diff analyzer. Analyze the provided file changes and return structured data.")
381 .build()?
382 .into();
383
384 let user_message = ChatCompletionRequestUserMessageArgs::default()
385 .content(format!(
386 "Analyze this file change:\nPath: {}\nOperation: {}\nDiff:\n{}",
387 file.path, file.operation, file.diff_content
388 ))
389 .build()?
390 .into();
391
392 let request = CreateChatCompletionRequestArgs::default()
393 .model(model)
394 .messages(vec![system_message, user_message])
395 .tools(tools)
396 .tool_choice("analyze")
397 .build()?;
398
399 let response = client.chat().create(request).await?;
400
401 if let Some(tool_call) = response.choices[0]
402 .message
403 .tool_calls
404 .as_ref()
405 .and_then(|calls| calls.first())
406 {
407 let args: Value = serde_json::from_str(&tool_call.function.arguments)?;
408 Ok(args)
409 } else {
410 anyhow::bail!("No tool call in response")
411 }
412}
413
414async fn call_score_function(
416 client: &Client<OpenAIConfig>, model: &str, files_data: Vec<FileDataForScoring>
417) -> Result<Vec<FileWithScore>> {
418 let tools = vec![create_score_function_tool()?];
419
420 let system_message = ChatCompletionRequestSystemMessageArgs::default()
421 .content("You are a git commit impact scorer. Calculate impact scores for the provided file changes.")
422 .build()?
423 .into();
424
425 let user_message = ChatCompletionRequestUserMessageArgs::default()
426 .content(format!(
427 "Calculate impact scores for these {} file changes:\n{}",
428 files_data.len(),
429 serde_json::to_string_pretty(&files_data)?
430 ))
431 .build()?
432 .into();
433
434 let request = CreateChatCompletionRequestArgs::default()
435 .model(model)
436 .messages(vec![system_message, user_message])
437 .tools(tools)
438 .tool_choice("score")
439 .build()?;
440
441 let response = client.chat().create(request).await?;
442
443 if let Some(tool_call) = response.choices[0]
444 .message
445 .tool_calls
446 .as_ref()
447 .and_then(|calls| calls.first())
448 {
449 let args: Value = serde_json::from_str(&tool_call.function.arguments)?;
450 let files_with_scores: Vec<FileWithScore> = if args["files_with_scores"].is_null() {
451 Vec::new() } else {
453 serde_json::from_value(args["files_with_scores"].clone())?
454 };
455 Ok(files_with_scores)
456 } else {
457 anyhow::bail!("No tool call in response")
458 }
459}
460
461async fn call_generate_function(
463 client: &Client<OpenAIConfig>, model: &str, files_with_scores: Vec<FileWithScore>, max_length: usize
464) -> Result<Value> {
465 let tools = vec![create_generate_function_tool()?];
466
467 let system_message = ChatCompletionRequestSystemMessageArgs::default()
468 .content("You are a git commit message generator. Generate concise, descriptive commit messages.")
469 .build()?
470 .into();
471
472 let user_message = ChatCompletionRequestUserMessageArgs::default()
473 .content(format!(
474 "Generate commit message candidates (max {} chars) for these scored changes:\n{}",
475 max_length,
476 serde_json::to_string_pretty(&files_with_scores)?
477 ))
478 .build()?
479 .into();
480
481 let request = CreateChatCompletionRequestArgs::default()
482 .model(model)
483 .messages(vec![system_message, user_message])
484 .tools(tools)
485 .tool_choice("generate")
486 .build()?;
487
488 let response = client.chat().create(request).await?;
489
490 if let Some(tool_call) = response.choices[0]
491 .message
492 .tool_calls
493 .as_ref()
494 .and_then(|calls| calls.first())
495 {
496 let args: Value = serde_json::from_str(&tool_call.function.arguments)?;
497 Ok(args)
498 } else {
499 anyhow::bail!("No tool call in response")
500 }
501}
502
503async fn select_best_candidate(
505 client: &Client<OpenAIConfig>, model: &str, candidates: &Value, scored_files: &[FileWithScore], original_diff: &str
506) -> Result<String> {
507 let tools = vec![create_commit_function_tool(Some(72))?];
509
510 let system_message = ChatCompletionRequestSystemMessageArgs::default()
511 .content(
512 "You are a git commit message expert. Based on the multi-step analysis, \
513 select the best commit message and provide the final formatted response."
514 )
515 .build()?
516 .into();
517
518 let user_message = ChatCompletionRequestUserMessageArgs::default()
519 .content(format!(
520 "Based on this multi-step analysis:\n\n\
521 Candidates: {}\n\
522 Reasoning: {}\n\n\
523 Scored files: {}\n\n\
524 Original diff:\n{}\n\n\
525 Select the best commit message and format the response using the commit function.",
526 serde_json::to_string_pretty(&candidates["candidates"])?,
527 candidates["reasoning"].as_str().unwrap_or(""),
528 serde_json::to_string_pretty(&scored_files)?,
529 original_diff
530 ))
531 .build()?
532 .into();
533
534 let request = CreateChatCompletionRequestArgs::default()
535 .model(model)
536 .messages(vec![system_message, user_message])
537 .tools(tools)
538 .tool_choice("commit")
539 .build()?;
540
541 let response = client.chat().create(request).await?;
542
543 if let Some(tool_call) = response.choices[0]
544 .message
545 .tool_calls
546 .as_ref()
547 .and_then(|calls| calls.first())
548 {
549 let raw_args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)?;
551
552 if let Some(message) = raw_args.get("message").and_then(|m| m.as_str()) {
554 return Ok(message.to_string());
555 }
556
557 let args: CommitFunctionArgs = serde_json::from_str(&tool_call.function.arguments)?;
559 Ok(args.message)
560 } else {
561 anyhow::bail!("No tool call in response")
562 }
563}
564
565pub fn generate_commit_message_local(diff_content: &str, max_length: Option<usize>) -> Result<String> {
567 use crate::multi_step_analysis::{analyze_file, calculate_impact_scores, generate_commit_messages};
568
569 log::info!("Starting local multi-step commit message generation");
570
571 let parsed_files = parse_diff(diff_content)?;
573
574 if let Some(session) = debug_output::debug_session() {
576 session.set_total_files_parsed(parsed_files.len());
577 }
578
579 let mut files_data = Vec::new();
581 for file in parsed_files {
582 let analysis = analyze_file(&file.path, &file.diff_content, &file.operation);
583 files_data.push(FileDataForScoring {
584 file_path: file.path,
585 operation_type: file.operation,
586 lines_added: analysis.lines_added,
587 lines_removed: analysis.lines_removed,
588 file_category: analysis.file_category,
589 summary: analysis.summary
590 });
591 }
592
593 let score_result = calculate_impact_scores(files_data);
595
596 let generate_result = generate_commit_messages(score_result.files_with_scores, max_length.unwrap_or(72));
598
599 Ok(
601 generate_result
602 .candidates
603 .first()
604 .cloned()
605 .unwrap_or_else(|| "Update files".to_string())
606 )
607}
608
609#[cfg(test)]
610mod tests {
611 use super::*;
612
613 #[test]
614 fn test_parse_diff() {
615 let diff = r#"diff --git a/src/main.rs b/src/main.rs
616index 1234567..abcdefg 100644
617--- a/src/main.rs
618+++ b/src/main.rs
619@@ -1,5 +1,6 @@
620 fn main() {
621- println!("Hello");
622+ println!("Hello, world!");
623+ println!("New line");
624 }
625diff --git a/Cargo.toml b/Cargo.toml
626new file mode 100644
627index 0000000..1111111
628--- /dev/null
629+++ b/Cargo.toml
630@@ -0,0 +1,8 @@
631+[package]
632+name = "test"
633+version = "0.1.0"
634"#;
635
636 let files = parse_diff(diff).unwrap();
637 assert_eq!(files.len(), 2);
638 assert_eq!(files[0].path, "src/main.rs");
639 assert_eq!(files[0].operation, "modified");
640 assert_eq!(files[1].path, "Cargo.toml");
641 assert_eq!(files[1].operation, "added");
642
643 assert!(!files[0].diff_content.is_empty());
645 assert!(!files[1].diff_content.is_empty());
646 }
647
648 #[test]
649 fn test_parse_diff_with_commit_hash() {
650 let diff = r#"0472ffa1665c4c5573fb8f7698c9965122eda675 Update files
652diff --git a/src/openai.rs b/src/openai.rs
653index a67ebbe..da223be 100644
654--- a/src/openai.rs
655+++ b/src/openai.rs
656@@ -15,11 +15,6 @@ use crate::multi_step_integration::{generate_commit_message_local, generate_comm
657
658 const MAX_ATTEMPTS: usize = 3;
659
660-#[derive(Debug, Clone, PartialEq)]
661-pub struct Response {
662- pub response: String
663-}
664-
665 #[derive(Debug, Clone, PartialEq)]
666 pub struct Request {
667 pub prompt: String,
668@@ -28,6 +23,11 @@ pub struct Request {
669 pub model: Model
670 }
671
672+#[derive(Debug, Clone, PartialEq)]
673+pub struct Response {
674+ pub response: String
675+}
676+
677 /// Generates an improved commit message using the provided prompt and diff
678 /// Now uses the multi-step approach by default
679 pub async fn generate_commit_message(diff: &str) -> Result<String> {
680"#;
681
682 let files = parse_diff(diff).unwrap();
683 assert_eq!(files.len(), 1);
684 assert_eq!(files[0].path, "src/openai.rs");
685 assert_eq!(files[0].operation, "modified");
686
687 assert!(files[0].diff_content.contains("pub struct Response"));
689
690 assert!(!files[0]
692 .diff_content
693 .contains("0472ffa1665c4c5573fb8f7698c9965122eda675"));
694 }
695
696 #[test]
697 fn test_local_generation() {
698 let diff = r#"diff --git a/src/auth.rs b/src/auth.rs
699index 1234567..abcdefg 100644
700--- a/src/auth.rs
701+++ b/src/auth.rs
702@@ -10,7 +10,15 @@ pub fn authenticate(user: &str, pass: &str) -> Result<Token> {
703- if user == "admin" && pass == "password" {
704- Ok(Token::new())
705- } else {
706- Err(AuthError::InvalidCredentials)
707- }
708+ // Validate input
709+ if user.is_empty() || pass.is_empty() {
710+ return Err(AuthError::EmptyCredentials);
711+ }
712+
713+ // Check credentials against database
714+ let hashed = hash_password(pass);
715+ if validate_user(user, &hashed)? {
716+ Ok(Token::generate(user))
717+ } else {
718+ Err(AuthError::InvalidCredentials)
719+ }
720 }"#;
721
722 let message = generate_commit_message_local(diff, Some(72)).unwrap();
723 assert!(!message.is_empty());
724 assert!(message.len() <= 72);
725 }
726}