ai/
simple_multi_step.rs

1use anyhow::Result;
2use async_openai::config::OpenAIConfig;
3use async_openai::types::{ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
4use async_openai::Client;
5
6use crate::function_calling::{create_commit_function_tool, CommitFunctionArgs};
7use crate::debug_output;
8
9/// Simplified multi-step commit message generation that works with raw diff
10pub async fn generate_commit_message_simple(
11  client: &Client<OpenAIConfig>, model: &str, diff_content: &str, max_length: Option<usize>
12) -> Result<String> {
13  log::info!("Starting simplified multi-step commit message generation");
14
15  // Initialize multi-step debug session
16  if let Some(session) = debug_output::debug_session() {
17    session.init_multi_step_debug();
18  }
19
20  // Use the commit function tool directly with the full diff
21  let tools = vec![create_commit_function_tool(max_length)?];
22
23  let system_message = ChatCompletionRequestSystemMessageArgs::default()
24    .content(
25      "You are a git commit message expert. Analyze the provided git diff and generate a concise, \
26       descriptive commit message. Focus on the most significant changes and their impact. \
27       The message should explain WHAT changed and WHY it matters."
28    )
29    .build()?
30    .into();
31
32  let user_message = ChatCompletionRequestUserMessageArgs::default()
33    .content(format!("Generate a commit message for the following git diff:\n\n{diff_content}"))
34    .build()?
35    .into();
36
37  let request = CreateChatCompletionRequestArgs::default()
38    .model(model)
39    .messages(vec![system_message, user_message])
40    .tools(tools)
41    .tool_choice("commit")
42    .build()?;
43
44  let response = client.chat().create(request).await?;
45
46  if let Some(tool_call) = response.choices[0]
47    .message
48    .tool_calls
49    .as_ref()
50    .and_then(|calls| calls.first())
51  {
52    let args: CommitFunctionArgs = serde_json::from_str(&tool_call.function.arguments)?;
53
54    // Record in debug session
55    if let Some(session) = debug_output::debug_session() {
56      session.set_commit_result(args.message.clone(), args.reasoning.clone());
57      session.set_files_analyzed(args.clone());
58      // Set a dummy count since we're not parsing files
59      session.set_total_files_parsed(1);
60    }
61
62    Ok(args.message)
63  } else {
64    anyhow::bail!("No tool call in response")
65  }
66}
67
68/// Local version that doesn't require parsing
69pub fn generate_commit_message_simple_local(diff_content: &str, max_length: Option<usize>) -> Result<String> {
70  log::info!("Starting simplified local commit message generation");
71
72  // Count basic statistics from the diff
73  let mut lines_added = 0;
74  let mut lines_removed = 0;
75  let mut files_mentioned = std::collections::HashSet::new();
76
77  for line in diff_content.lines() {
78    if line.starts_with("+++") || line.starts_with("---") {
79      if let Some(file) = line.split_whitespace().nth(1) {
80        files_mentioned.insert(file.trim_start_matches("a/").trim_start_matches("b/"));
81      }
82    } else if line.starts_with('+') && !line.starts_with("+++") {
83      lines_added += 1;
84    } else if line.starts_with('-') && !line.starts_with("---") {
85      lines_removed += 1;
86    }
87  }
88
89  // Track in debug session
90  if let Some(session) = debug_output::debug_session() {
91    session.set_total_files_parsed(files_mentioned.len());
92  }
93
94  // Generate a simple commit message based on the diff
95  let message = match files_mentioned.len().cmp(&1) {
96    std::cmp::Ordering::Equal => {
97      let file = files_mentioned
98        .iter()
99        .next()
100        .ok_or_else(|| anyhow::anyhow!("No files mentioned in commit message"))?;
101      if lines_added > 0 && lines_removed == 0 {
102        format!(
103          "Add {} to {}",
104          if lines_added == 1 {
105            "content"
106          } else {
107            "new content"
108          },
109          file
110        )
111      } else if lines_removed > 0 && lines_added == 0 {
112        format!("Remove content from {file}")
113      } else {
114        format!("Update {file}")
115      }
116    }
117    std::cmp::Ordering::Greater => format!("Update {} files", files_mentioned.len()),
118    std::cmp::Ordering::Less => "Update files".to_string()
119  };
120
121  // Ensure it fits within the length limit
122  let max_len = max_length.unwrap_or(72);
123  if message.len() > max_len {
124    Ok(message.chars().take(max_len - 3).collect::<String>() + "...")
125  } else {
126    Ok(message)
127  }
128}