ai/
openai.rs

1use std::time::{Duration, Instant};
2
3use async_openai::types::{ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
4use async_openai::config::OpenAIConfig;
5use async_openai::Client;
6use async_openai::error::OpenAIError;
7use anyhow::{anyhow, Context, Result};
8use reqwest;
9use futures::future::join_all;
10
11use crate::{commit, config, debug_output, function_calling, profile};
12use crate::model::Model;
13use crate::config::App as Settings;
14use crate::multi_step_integration::generate_commit_message_multi_step;
15
16const MAX_ATTEMPTS: usize = 3;
17
18#[derive(Debug, Clone, PartialEq)]
19pub struct Response {
20  pub response: String
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct Request {
25  pub prompt:     String,
26  pub system:     String,
27  pub max_tokens: u16,
28  pub model:      Model
29}
30
31/// Generates an improved commit message using the provided prompt and diff
32/// Now uses a simplified approach that doesn't require parsing the diff
33pub async fn generate_commit_message(diff: &str) -> Result<String> {
34  profile!("Generate commit message (simplified)");
35
36  // Try to use the simplified approach with OpenAI
37  if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
38    if !api_key.is_empty() {
39      // Use the commit function directly without parsing
40      match commit::generate(diff.to_string(), 256, Model::GPT4oMini, None).await {
41        Ok(response) => return Ok(response.response.trim().to_string()),
42        Err(e) => {
43          log::warn!("Direct generation failed, falling back to local: {e}");
44        }
45      }
46    }
47  }
48
49  // Fallback to local generation (simplified version)
50  // Count basic statistics from the diff
51  let mut lines_added = 0;
52  let mut lines_removed = 0;
53  let mut files_mentioned = std::collections::HashSet::new();
54
55  for line in diff.lines() {
56    if line.starts_with("diff --git") {
57      // Extract file path from diff --git line
58      let parts: Vec<&str> = line.split_whitespace().collect();
59      if parts.len() >= 4 {
60        let path = parts[3].trim_start_matches("b/");
61        files_mentioned.insert(path);
62      }
63    } else if line.starts_with("+++") || line.starts_with("---") {
64      if let Some(file) = line.split_whitespace().nth(1) {
65        let cleaned = file.trim_start_matches("a/").trim_start_matches("b/");
66        if cleaned != "/dev/null" {
67          files_mentioned.insert(cleaned);
68        }
69      }
70    } else if line.starts_with('+') && !line.starts_with("+++") {
71      lines_added += 1;
72    } else if line.starts_with('-') && !line.starts_with("---") {
73      lines_removed += 1;
74    }
75  }
76
77  // Track in debug session
78  if let Some(session) = debug_output::debug_session() {
79    session.set_total_files_parsed(files_mentioned.len());
80  }
81
82  // Generate a simple commit message based on the diff
83  let message = match files_mentioned.len().cmp(&1) {
84    std::cmp::Ordering::Equal => {
85      let file = files_mentioned
86        .iter()
87        .next()
88        .ok_or_else(|| anyhow::anyhow!("No files mentioned in commit message"))?;
89      if lines_added > 0 && lines_removed == 0 {
90        format!(
91          "Add {} to {}",
92          if lines_added == 1 {
93            "content"
94          } else {
95            "new content"
96          },
97          file
98        )
99      } else if lines_removed > 0 && lines_added == 0 {
100        format!("Remove content from {file}")
101      } else {
102        format!("Update {file}")
103      }
104    }
105    std::cmp::Ordering::Greater => format!("Update {} files", files_mentioned.len()),
106    std::cmp::Ordering::Less => "Update files".to_string()
107  };
108
109  Ok(message.trim().to_string())
110}
111
112/// Creates an OpenAI configuration from application settings
113pub fn create_openai_config(settings: &Settings) -> Result<OpenAIConfig> {
114  let api_key = settings
115    .openai_api_key
116    .as_ref()
117    .ok_or_else(|| anyhow!("OpenAI API key not configured"))?;
118
119  if api_key.is_empty() || api_key == "<PLACE HOLDER FOR YOUR API KEY>" {
120    return Err(anyhow!("Invalid OpenAI API key"));
121  }
122
123  let config = OpenAIConfig::new().with_api_key(api_key);
124
125  Ok(config)
126}
127
128/// Truncates text to fit within token limits
129fn truncate_to_fit(text: &str, max_tokens: usize, model: &Model) -> Result<String> {
130  profile!("Truncate to fit");
131
132  // Fast path: if text is small, just return it
133  if text.len() < 1000 {
134    return Ok(text.to_string());
135  }
136
137  let token_count = model.count_tokens(text)?;
138  if token_count <= max_tokens {
139    return Ok(text.to_string());
140  }
141
142  // Collect character indices to ensure we slice at valid UTF-8 boundaries
143  let char_indices: Vec<(usize, char)> = text.char_indices().collect();
144  if char_indices.is_empty() {
145    return Ok(String::new());
146  }
147
148  // Binary search for the right truncation point
149  let mut low = 0;
150  let mut high = char_indices.len();
151  let mut best_fit = String::new();
152
153  while low < high {
154    let mid = (low + high) / 2;
155
156    // Get the byte index for this character position
157    let byte_index = if mid < char_indices.len() {
158      char_indices[mid].0
159    } else {
160      text.len()
161    };
162
163    let truncated = &text[..byte_index];
164
165    // Find the last complete line
166    if let Some(last_newline_pos) = truncated.rfind('\n') {
167      // Ensure we're at a valid UTF-8 boundary for the newline position
168      let candidate = &text[..last_newline_pos];
169      let candidate_tokens = model.count_tokens(candidate)?;
170
171      if candidate_tokens <= max_tokens {
172        best_fit = candidate.to_string();
173        // Find the character index after the newline
174        let next_char_idx = char_indices
175          .iter()
176          .position(|(idx, _)| *idx > last_newline_pos)
177          .unwrap_or(char_indices.len());
178        low = next_char_idx;
179      } else {
180        // Find the character index of the newline
181        let newline_char_idx = char_indices
182          .iter()
183          .rposition(|(idx, _)| *idx <= last_newline_pos)
184          .unwrap_or(0);
185        high = newline_char_idx;
186      }
187    } else {
188      high = mid;
189    }
190  }
191
192  if best_fit.is_empty() {
193    // If we couldn't find a good truncation point, just take what we can
194    model.truncate(text, max_tokens)
195  } else {
196    Ok(best_fit)
197  }
198}
199
200/// Calls the OpenAI API with the provided configuration
201pub async fn call_with_config(request: Request, config: OpenAIConfig) -> Result<Response> {
202  profile!("OpenAI API call with custom config");
203
204  // Always try multi-step approach first (it's now the default)
205  let client = Client::with_config(config.clone());
206  let model = request.model.to_string();
207
208  match generate_commit_message_multi_step(&client, &model, &request.prompt, config::APP.max_commit_length).await {
209    Ok(message) => return Ok(Response { response: message }),
210    Err(e) => {
211      log::warn!("Multi-step approach failed, falling back to single-step: {e}");
212    }
213  }
214
215  // Original single-step implementation as fallback
216  // Create client with timeout if specified
217  let client = if let Some(timeout) = config::APP.timeout {
218    let http_client = reqwest::ClientBuilder::new()
219      .timeout(Duration::from_secs(timeout as u64))
220      .build()?;
221    Client::with_config(config).with_http_client(http_client)
222  } else {
223    Client::with_config(config)
224  };
225
226  // Calculate available tokens using model's context size
227  let system_tokens = request.model.count_tokens(&request.system)?;
228  let model_context_size = request.model.context_size();
229  let available_tokens = model_context_size.saturating_sub(system_tokens + request.max_tokens as usize);
230
231  // Truncate prompt if needed
232  let truncated_prompt = truncate_to_fit(&request.prompt, available_tokens, &request.model)?;
233
234  // Create the commit function tool
235  let commit_tool = function_calling::create_commit_function_tool(config::APP.max_commit_length)?;
236
237  let chat_request = CreateChatCompletionRequestArgs::default()
238    .max_tokens(request.max_tokens)
239    .model(request.model.to_string())
240    .messages([
241      ChatCompletionRequestSystemMessageArgs::default()
242        .content(request.system)
243        .build()?
244        .into(),
245      ChatCompletionRequestUserMessageArgs::default()
246        .content(truncated_prompt)
247        .build()?
248        .into()
249    ])
250    .tools(vec![commit_tool])
251    .tool_choice("commit")
252    .build()?;
253
254  let mut last_error = None;
255
256  for attempt in 1..=MAX_ATTEMPTS {
257    log::debug!("OpenAI API attempt {attempt} of {MAX_ATTEMPTS}");
258
259    // Track API call duration
260    let api_start = Instant::now();
261
262    match client.chat().create(chat_request.clone()).await {
263      Ok(response) => {
264        let api_duration = api_start.elapsed();
265
266        // Record API duration in debug session
267        if let Some(session) = debug_output::debug_session() {
268          session.set_api_duration(api_duration);
269        }
270
271        log::debug!("OpenAI API call successful on attempt {attempt}");
272
273        // Extract the response
274        let choice = response
275          .choices
276          .into_iter()
277          .next()
278          .context("No response choices available")?;
279
280        // Check if the model used function calling
281        if let Some(tool_calls) = &choice.message.tool_calls {
282          // Process multiple tool calls in parallel
283          let tool_futures: Vec<_> = tool_calls
284            .iter()
285            .filter(|tool_call| tool_call.function.name == "commit")
286            .map(|tool_call| {
287              let args = tool_call.function.arguments.clone();
288              async move { function_calling::parse_commit_function_response(&args) }
289            })
290            .collect();
291
292          // Execute all tool calls in parallel
293          let results = join_all(tool_futures).await;
294
295          // Process results and handle errors
296          let mut commit_messages = Vec::new();
297          for (i, result) in results.into_iter().enumerate() {
298            match result {
299              Ok(commit_args) => {
300                // Record commit results in debug session
301                if let Some(session) = debug_output::debug_session() {
302                  session.set_commit_result(commit_args.message.clone(), commit_args.reasoning.clone());
303                  session.set_files_analyzed(commit_args.clone());
304                }
305                commit_messages.push(commit_args.message);
306              }
307              Err(e) => {
308                log::warn!("Failed to parse tool call {i}: {e}");
309              }
310            }
311          }
312
313          // Return the first successful commit message or combine them if multiple
314          if !commit_messages.is_empty() {
315            // For now, return the first message. You could also combine them if needed
316            return Ok(Response {
317              response: commit_messages
318                .into_iter()
319                .next()
320                .ok_or_else(|| anyhow::anyhow!("No commit messages generated"))?
321            });
322          }
323        }
324
325        // Fallback to regular message content if no tool call
326        let content = choice
327          .message
328          .content
329          .clone()
330          .context("No response content available")?;
331
332        return Ok(Response { response: content });
333      }
334      Err(e) => {
335        last_error = Some(e);
336        log::warn!("OpenAI API attempt {attempt} failed");
337
338        if attempt < MAX_ATTEMPTS {
339          let delay = Duration::from_millis(500 * attempt as u64);
340          log::debug!("Retrying after {delay:?}");
341          tokio::time::sleep(delay).await;
342        }
343      }
344    }
345  }
346
347  // All attempts failed
348  match last_error {
349    Some(OpenAIError::ApiError(api_err)) => {
350      let error_msg = format!(
351        "OpenAI API error: {} (type: {:?}, code: {:?})",
352        api_err.message,
353        api_err.r#type.as_deref().unwrap_or("unknown"),
354        api_err.code.as_deref().unwrap_or("unknown")
355      );
356      log::error!("{error_msg}");
357      Err(anyhow!(error_msg))
358    }
359    Some(e) => {
360      log::error!("OpenAI request failed: {e}");
361      Err(anyhow!("OpenAI request failed: {}", e))
362    }
363    None => Err(anyhow!("OpenAI request failed after {} attempts", MAX_ATTEMPTS))
364  }
365}
366
367/// Calls the OpenAI API with default configuration from settings
368pub async fn call(request: Request) -> Result<Response> {
369  profile!("OpenAI API call");
370
371  // Create OpenAI configuration using our settings
372  let config = create_openai_config(&config::APP)?;
373
374  // Use the call_with_config function with the default config
375  call_with_config(request, config).await
376}