1use std::time::{Duration, Instant};
2
3use async_openai::types::{ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
4use async_openai::config::OpenAIConfig;
5use async_openai::Client;
6use async_openai::error::OpenAIError;
7use anyhow::{anyhow, Context, Result};
8use reqwest;
9use futures::future::join_all;
10
11use crate::{commit, config, debug_output, function_calling, profile};
12use crate::model::Model;
13use crate::config::App as Settings;
14use crate::multi_step_integration::generate_commit_message_multi_step;
15
16const MAX_ATTEMPTS: usize = 3;
17
18#[derive(Debug, Clone, PartialEq)]
19pub struct Response {
20 pub response: String
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct Request {
25 pub prompt: String,
26 pub system: String,
27 pub max_tokens: u16,
28 pub model: Model
29}
30
31pub async fn generate_commit_message(diff: &str) -> Result<String> {
34 profile!("Generate commit message (simplified)");
35
36 if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
38 if !api_key.is_empty() {
39 match commit::generate(diff.to_string(), 256, Model::GPT4oMini, None).await {
41 Ok(response) => return Ok(response.response.trim().to_string()),
42 Err(e) => {
43 log::warn!("Direct generation failed, falling back to local: {e}");
44 }
45 }
46 }
47 }
48
49 let mut lines_added = 0;
52 let mut lines_removed = 0;
53 let mut files_mentioned = std::collections::HashSet::new();
54
55 for line in diff.lines() {
56 if line.starts_with("diff --git") {
57 let parts: Vec<&str> = line.split_whitespace().collect();
59 if parts.len() >= 4 {
60 let path = parts[3].trim_start_matches("b/");
61 files_mentioned.insert(path);
62 }
63 } else if line.starts_with("+++") || line.starts_with("---") {
64 if let Some(file) = line.split_whitespace().nth(1) {
65 let cleaned = file.trim_start_matches("a/").trim_start_matches("b/");
66 if cleaned != "/dev/null" {
67 files_mentioned.insert(cleaned);
68 }
69 }
70 } else if line.starts_with('+') && !line.starts_with("+++") {
71 lines_added += 1;
72 } else if line.starts_with('-') && !line.starts_with("---") {
73 lines_removed += 1;
74 }
75 }
76
77 if let Some(session) = debug_output::debug_session() {
79 session.set_total_files_parsed(files_mentioned.len());
80 }
81
82 let message = match files_mentioned.len().cmp(&1) {
84 std::cmp::Ordering::Equal => {
85 let file = files_mentioned
86 .iter()
87 .next()
88 .ok_or_else(|| anyhow::anyhow!("No files mentioned in commit message"))?;
89 if lines_added > 0 && lines_removed == 0 {
90 format!(
91 "Add {} to {}",
92 if lines_added == 1 {
93 "content"
94 } else {
95 "new content"
96 },
97 file
98 )
99 } else if lines_removed > 0 && lines_added == 0 {
100 format!("Remove content from {file}")
101 } else {
102 format!("Update {file}")
103 }
104 }
105 std::cmp::Ordering::Greater => format!("Update {} files", files_mentioned.len()),
106 std::cmp::Ordering::Less => "Update files".to_string()
107 };
108
109 Ok(message.trim().to_string())
110}
111
112pub fn create_openai_config(settings: &Settings) -> Result<OpenAIConfig> {
114 let api_key = settings
115 .openai_api_key
116 .as_ref()
117 .ok_or_else(|| anyhow!("OpenAI API key not configured"))?;
118
119 if api_key.is_empty() || api_key == "<PLACE HOLDER FOR YOUR API KEY>" {
120 return Err(anyhow!("Invalid OpenAI API key"));
121 }
122
123 let config = OpenAIConfig::new().with_api_key(api_key);
124
125 Ok(config)
126}
127
128fn truncate_to_fit(text: &str, max_tokens: usize, model: &Model) -> Result<String> {
130 profile!("Truncate to fit");
131
132 if text.len() < 1000 {
134 return Ok(text.to_string());
135 }
136
137 let token_count = model.count_tokens(text)?;
138 if token_count <= max_tokens {
139 return Ok(text.to_string());
140 }
141
142 let char_indices: Vec<(usize, char)> = text.char_indices().collect();
144 if char_indices.is_empty() {
145 return Ok(String::new());
146 }
147
148 let mut low = 0;
150 let mut high = char_indices.len();
151 let mut best_fit = String::new();
152
153 while low < high {
154 let mid = (low + high) / 2;
155
156 let byte_index = if mid < char_indices.len() {
158 char_indices[mid].0
159 } else {
160 text.len()
161 };
162
163 let truncated = &text[..byte_index];
164
165 if let Some(last_newline_pos) = truncated.rfind('\n') {
167 let candidate = &text[..last_newline_pos];
169 let candidate_tokens = model.count_tokens(candidate)?;
170
171 if candidate_tokens <= max_tokens {
172 best_fit = candidate.to_string();
173 let next_char_idx = char_indices
175 .iter()
176 .position(|(idx, _)| *idx > last_newline_pos)
177 .unwrap_or(char_indices.len());
178 low = next_char_idx;
179 } else {
180 let newline_char_idx = char_indices
182 .iter()
183 .rposition(|(idx, _)| *idx <= last_newline_pos)
184 .unwrap_or(0);
185 high = newline_char_idx;
186 }
187 } else {
188 high = mid;
189 }
190 }
191
192 if best_fit.is_empty() {
193 model.truncate(text, max_tokens)
195 } else {
196 Ok(best_fit)
197 }
198}
199
200pub async fn call_with_config(request: Request, config: OpenAIConfig) -> Result<Response> {
202 profile!("OpenAI API call with custom config");
203
204 let client = Client::with_config(config.clone());
206 let model = request.model.to_string();
207
208 match generate_commit_message_multi_step(&client, &model, &request.prompt, config::APP.max_commit_length).await {
209 Ok(message) => return Ok(Response { response: message }),
210 Err(e) => {
211 log::warn!("Multi-step approach failed, falling back to single-step: {e}");
212 }
213 }
214
215 let client = if let Some(timeout) = config::APP.timeout {
218 let http_client = reqwest::ClientBuilder::new()
219 .timeout(Duration::from_secs(timeout as u64))
220 .build()?;
221 Client::with_config(config).with_http_client(http_client)
222 } else {
223 Client::with_config(config)
224 };
225
226 let system_tokens = request.model.count_tokens(&request.system)?;
228 let model_context_size = request.model.context_size();
229 let available_tokens = model_context_size.saturating_sub(system_tokens + request.max_tokens as usize);
230
231 let truncated_prompt = truncate_to_fit(&request.prompt, available_tokens, &request.model)?;
233
234 let commit_tool = function_calling::create_commit_function_tool(config::APP.max_commit_length)?;
236
237 let chat_request = CreateChatCompletionRequestArgs::default()
238 .max_tokens(request.max_tokens)
239 .model(request.model.to_string())
240 .messages([
241 ChatCompletionRequestSystemMessageArgs::default()
242 .content(request.system)
243 .build()?
244 .into(),
245 ChatCompletionRequestUserMessageArgs::default()
246 .content(truncated_prompt)
247 .build()?
248 .into()
249 ])
250 .tools(vec![commit_tool])
251 .tool_choice("commit")
252 .build()?;
253
254 let mut last_error = None;
255
256 for attempt in 1..=MAX_ATTEMPTS {
257 log::debug!("OpenAI API attempt {attempt} of {MAX_ATTEMPTS}");
258
259 let api_start = Instant::now();
261
262 match client.chat().create(chat_request.clone()).await {
263 Ok(response) => {
264 let api_duration = api_start.elapsed();
265
266 if let Some(session) = debug_output::debug_session() {
268 session.set_api_duration(api_duration);
269 }
270
271 log::debug!("OpenAI API call successful on attempt {attempt}");
272
273 let choice = response
275 .choices
276 .into_iter()
277 .next()
278 .context("No response choices available")?;
279
280 if let Some(tool_calls) = &choice.message.tool_calls {
282 let tool_futures: Vec<_> = tool_calls
284 .iter()
285 .filter(|tool_call| tool_call.function.name == "commit")
286 .map(|tool_call| {
287 let args = tool_call.function.arguments.clone();
288 async move { function_calling::parse_commit_function_response(&args) }
289 })
290 .collect();
291
292 let results = join_all(tool_futures).await;
294
295 let mut commit_messages = Vec::new();
297 for (i, result) in results.into_iter().enumerate() {
298 match result {
299 Ok(commit_args) => {
300 if let Some(session) = debug_output::debug_session() {
302 session.set_commit_result(commit_args.message.clone(), commit_args.reasoning.clone());
303 session.set_files_analyzed(commit_args.clone());
304 }
305 commit_messages.push(commit_args.message);
306 }
307 Err(e) => {
308 log::warn!("Failed to parse tool call {i}: {e}");
309 }
310 }
311 }
312
313 if !commit_messages.is_empty() {
315 return Ok(Response {
317 response: commit_messages
318 .into_iter()
319 .next()
320 .ok_or_else(|| anyhow::anyhow!("No commit messages generated"))?
321 });
322 }
323 }
324
325 let content = choice
327 .message
328 .content
329 .clone()
330 .context("No response content available")?;
331
332 return Ok(Response { response: content });
333 }
334 Err(e) => {
335 last_error = Some(e);
336 log::warn!("OpenAI API attempt {attempt} failed");
337
338 if attempt < MAX_ATTEMPTS {
339 let delay = Duration::from_millis(500 * attempt as u64);
340 log::debug!("Retrying after {delay:?}");
341 tokio::time::sleep(delay).await;
342 }
343 }
344 }
345 }
346
347 match last_error {
349 Some(OpenAIError::ApiError(api_err)) => {
350 let error_msg = format!(
351 "OpenAI API error: {} (type: {:?}, code: {:?})",
352 api_err.message,
353 api_err.r#type.as_deref().unwrap_or("unknown"),
354 api_err.code.as_deref().unwrap_or("unknown")
355 );
356 log::error!("{error_msg}");
357 Err(anyhow!(error_msg))
358 }
359 Some(e) => {
360 log::error!("OpenAI request failed: {e}");
361 Err(anyhow!("OpenAI request failed: {}", e))
362 }
363 None => Err(anyhow!("OpenAI request failed after {} attempts", MAX_ATTEMPTS))
364 }
365}
366
367pub async fn call(request: Request) -> Result<Response> {
369 profile!("OpenAI API call");
370
371 let config = create_openai_config(&config::APP)?;
373
374 call_with_config(request, config).await
376}