1use std::time::{Duration, Instant};
2
3use async_openai::types::{ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
4use async_openai::config::OpenAIConfig;
5use async_openai::Client;
6use async_openai::error::OpenAIError;
7use anyhow::{anyhow, Context, Result};
8use reqwest;
9use futures::future::join_all;
10
11use crate::{commit, config, debug_output, function_calling, profile};
12use crate::model::Model;
13use crate::config::AppConfig;
14use crate::multi_step_integration::generate_commit_message_multi_step;
15
16const MAX_ATTEMPTS: usize = 3;
17
18#[derive(Debug, Clone, PartialEq)]
19pub struct Response {
20 pub response: String
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct Request {
25 pub prompt: String,
26 pub system: String,
27 pub max_tokens: u16,
28 pub model: Model
29}
30
31pub async fn generate_commit_message(diff: &str) -> Result<String> {
34 profile!("Generate commit message (simplified)");
35
36 if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
38 if !api_key.is_empty() {
39 match commit::generate(diff.to_string(), 256, Model::GPT41Mini, None).await {
41 Ok(response) => return Ok(response.response.trim().to_string()),
42 Err(e) => {
43 log::warn!("Direct generation failed, falling back to local: {e}");
44 }
45 }
46 }
47 }
48
49 let mut lines_added = 0;
52 let mut lines_removed = 0;
53 let mut files_mentioned = std::collections::HashSet::new();
54
55 for line in diff.lines() {
56 if line.starts_with("diff --git") {
57 let parts: Vec<&str> = line.split_whitespace().collect();
59 if parts.len() >= 4 {
60 let path = parts[3].trim_start_matches("b/");
61 files_mentioned.insert(path);
62 }
63 } else if line.starts_with("+++") || line.starts_with("---") {
64 if let Some(file) = line.split_whitespace().nth(1) {
65 let cleaned = file.trim_start_matches("a/").trim_start_matches("b/");
66 if cleaned != "/dev/null" {
67 files_mentioned.insert(cleaned);
68 }
69 }
70 } else if line.starts_with('+') && !line.starts_with("+++") {
71 lines_added += 1;
72 } else if line.starts_with('-') && !line.starts_with("---") {
73 lines_removed += 1;
74 }
75 }
76
77 if let Some(session) = debug_output::debug_session() {
79 session.set_total_files_parsed(files_mentioned.len());
80 }
81
82 let message = match files_mentioned.len().cmp(&1) {
84 std::cmp::Ordering::Equal => {
85 let file = files_mentioned
86 .iter()
87 .next()
88 .ok_or_else(|| anyhow::anyhow!("No files mentioned in commit message"))?;
89 if lines_added > 0 && lines_removed == 0 {
90 format!(
91 "Add {} to {}",
92 if lines_added == 1 {
93 "content"
94 } else {
95 "new content"
96 },
97 file
98 )
99 } else if lines_removed > 0 && lines_added == 0 {
100 format!("Remove content from {file}")
101 } else {
102 format!("Update {file}")
103 }
104 }
105 std::cmp::Ordering::Greater => format!("Update {} files", files_mentioned.len()),
106 std::cmp::Ordering::Less => "Update files".to_string()
107 };
108
109 Ok(message.trim().to_string())
110}
111
112pub fn create_openai_config(settings: &AppConfig) -> Result<OpenAIConfig> {
114 let api_key = settings
115 .openai_api_key
116 .as_ref()
117 .ok_or_else(|| anyhow!("OpenAI API key not configured"))?;
118
119 if api_key.is_empty() || api_key == "<PLACE HOLDER FOR YOUR API KEY>" {
120 return Err(anyhow!("Invalid OpenAI API key"));
121 }
122
123 let config = OpenAIConfig::new().with_api_key(api_key);
124
125 Ok(config)
126}
127
128fn truncate_to_fit(text: &str, max_tokens: usize, model: &Model) -> Result<String> {
130 profile!("Truncate to fit");
131
132 if text.len() < 1000 {
134 return Ok(text.to_string());
135 }
136
137 let token_count = model.count_tokens(text)?;
138 if token_count <= max_tokens {
139 return Ok(text.to_string());
140 }
141
142 let char_indices: Vec<(usize, char)> = text.char_indices().collect();
144 if char_indices.is_empty() {
145 return Ok(String::new());
146 }
147
148 let mut low = 0;
150 let mut high = char_indices.len();
151 let mut best_fit = String::new();
152
153 while low < high {
154 let mid = (low + high) / 2;
155
156 let byte_index = if mid < char_indices.len() {
158 char_indices[mid].0
159 } else {
160 text.len()
161 };
162
163 let truncated = &text[..byte_index];
164
165 if let Some(last_newline_pos) = truncated.rfind('\n') {
167 let candidate = &text[..last_newline_pos];
169 let candidate_tokens = model.count_tokens(candidate)?;
170
171 if candidate_tokens <= max_tokens {
172 best_fit = candidate.to_string();
173 let next_char_idx = char_indices
175 .iter()
176 .position(|(idx, _)| *idx > last_newline_pos)
177 .unwrap_or(char_indices.len());
178 low = next_char_idx;
179 } else {
180 let newline_char_idx = char_indices
182 .iter()
183 .rposition(|(idx, _)| *idx <= last_newline_pos)
184 .unwrap_or(0);
185 high = newline_char_idx;
186 }
187 } else {
188 high = mid;
189 }
190 }
191
192 if best_fit.is_empty() {
193 model.truncate(text, max_tokens)
195 } else {
196 Ok(best_fit)
197 }
198}
199
200pub async fn call_with_config(request: Request, config: OpenAIConfig) -> Result<Response> {
202 profile!("OpenAI API call with custom config");
203
204 let client = Client::with_config(config.clone());
206 let model = request.model.to_string();
207
208 match generate_commit_message_multi_step(&client, &model, &request.prompt, config::APP_CONFIG.max_commit_length).await {
209 Ok(message) => return Ok(Response { response: message }),
210 Err(e) => {
211 if e.to_string().contains("invalid_api_key") || e.to_string().contains("Incorrect API key") {
213 return Err(e);
214 }
215 log::warn!("Multi-step approach failed, falling back to single-step: {e}");
216 }
217 }
218
219 let client = if let Some(timeout) = config::APP_CONFIG.timeout {
222 let http_client = reqwest::ClientBuilder::new()
223 .timeout(Duration::from_secs(timeout as u64))
224 .build()?;
225 Client::with_config(config).with_http_client(http_client)
226 } else {
227 Client::with_config(config)
228 };
229
230 let system_tokens = request.model.count_tokens(&request.system)?;
232 let model_context_size = request.model.context_size();
233 let available_tokens = model_context_size.saturating_sub(system_tokens + request.max_tokens as usize);
234
235 let truncated_prompt = truncate_to_fit(&request.prompt, available_tokens, &request.model)?;
237
238 let commit_tool = function_calling::create_commit_function_tool(config::APP_CONFIG.max_commit_length)?;
240
241 let chat_request = CreateChatCompletionRequestArgs::default()
242 .max_tokens(request.max_tokens)
243 .model(request.model.to_string())
244 .messages([
245 ChatCompletionRequestSystemMessageArgs::default()
246 .content(request.system)
247 .build()?
248 .into(),
249 ChatCompletionRequestUserMessageArgs::default()
250 .content(truncated_prompt)
251 .build()?
252 .into()
253 ])
254 .tools(vec![commit_tool])
255 .tool_choice("commit")
256 .build()?;
257
258 let mut last_error = None;
259
260 for attempt in 1..=MAX_ATTEMPTS {
261 log::debug!("OpenAI API attempt {attempt} of {MAX_ATTEMPTS}");
262
263 let api_start = Instant::now();
265
266 match client.chat().create(chat_request.clone()).await {
267 Ok(response) => {
268 let api_duration = api_start.elapsed();
269
270 if let Some(session) = debug_output::debug_session() {
272 session.set_api_duration(api_duration);
273 }
274
275 log::debug!("OpenAI API call successful on attempt {attempt}");
276
277 let choice = response
279 .choices
280 .into_iter()
281 .next()
282 .context("No response choices available")?;
283
284 if let Some(tool_calls) = &choice.message.tool_calls {
286 let tool_futures: Vec<_> = tool_calls
288 .iter()
289 .filter(|tool_call| tool_call.function.name == "commit")
290 .map(|tool_call| {
291 let args = tool_call.function.arguments.clone();
292 async move { function_calling::parse_commit_function_response(&args) }
293 })
294 .collect();
295
296 let results = join_all(tool_futures).await;
298
299 let mut commit_messages = Vec::new();
301 for (i, result) in results.into_iter().enumerate() {
302 match result {
303 Ok(commit_args) => {
304 if let Some(session) = debug_output::debug_session() {
306 session.set_commit_result(commit_args.message.clone(), commit_args.reasoning.clone());
307 session.set_files_analyzed(commit_args.clone());
308 }
309 commit_messages.push(commit_args.message);
310 }
311 Err(e) => {
312 log::warn!("Failed to parse tool call {i}: {e}");
313 }
314 }
315 }
316
317 if !commit_messages.is_empty() {
319 return Ok(Response {
321 response: commit_messages
322 .into_iter()
323 .next()
324 .ok_or_else(|| anyhow::anyhow!("No commit messages generated"))?
325 });
326 }
327 }
328
329 let content = choice
331 .message
332 .content
333 .clone()
334 .context("No response content available")?;
335
336 return Ok(Response { response: content });
337 }
338 Err(e) => {
339 last_error = Some(e);
340 log::warn!("OpenAI API attempt {attempt} failed");
341
342 if let OpenAIError::ApiError(ref api_err) = &last_error.as_ref().unwrap() {
344 if api_err.code.as_deref() == Some("invalid_api_key") {
345 let error_msg = format!("Invalid OpenAI API key: {}", api_err.message);
346 log::error!("{error_msg}");
347 return Err(anyhow!(error_msg));
348 }
349 }
350
351 if attempt < MAX_ATTEMPTS {
352 let delay = Duration::from_millis(500 * attempt as u64);
353 log::debug!("Retrying after {delay:?}");
354 tokio::time::sleep(delay).await;
355 }
356 }
357 }
358 }
359
360 match last_error {
362 Some(OpenAIError::ApiError(api_err)) => {
363 let error_msg = format!(
364 "OpenAI API error: {} (type: {:?}, code: {:?})",
365 api_err.message,
366 api_err.r#type.as_deref().unwrap_or("unknown"),
367 api_err.code.as_deref().unwrap_or("unknown")
368 );
369 log::error!("{error_msg}");
370 Err(anyhow!(error_msg))
371 }
372 Some(e) => {
373 log::error!("OpenAI request failed: {e}");
374 Err(anyhow!("OpenAI request failed: {}", e))
375 }
376 None => Err(anyhow!("OpenAI request failed after {} attempts", MAX_ATTEMPTS))
377 }
378}
379
380pub async fn call(request: Request) -> Result<Response> {
382 profile!("OpenAI API call");
383
384 let config = create_openai_config(&config::APP_CONFIG)?;
386
387 call_with_config(request, config).await
389}