ai/
commit.rs

1use anyhow::{anyhow, bail, Result};
2use maplit::hashmap;
3use mustache;
4use async_openai::Client;
5
6use crate::{config, debug_output, openai, profile};
7use crate::model::Model;
8use crate::config::AppConfig;
9use crate::multi_step_integration::{generate_commit_message_local, generate_commit_message_multi_step};
10
11/// The instruction template included at compile time
12const INSTRUCTION_TEMPLATE: &str = include_str!("../resources/prompt.md");
13
14/// Returns the instruction template for the AI model.
15/// This template guides the model in generating appropriate commit messages.
16///
17/// # Returns
18/// * `Result<String>` - The rendered template or an error
19///
20/// Note: This function is public only for testing purposes
21#[doc(hidden)]
22pub fn get_instruction_template() -> Result<String> {
23  profile!("Generate instruction template");
24  let max_length = config::APP_CONFIG
25    .max_commit_length
26    .unwrap_or(72)
27    .to_string();
28  let template = mustache::compile_str(INSTRUCTION_TEMPLATE)
29    .map_err(|e| anyhow!("Template compilation error: {}", e))?
30    .render_to_string(&hashmap! {
31      "max_length" => max_length
32    })
33    .map_err(|e| anyhow!("Template rendering error: {}", e))?;
34  Ok(template)
35}
36
37/// Creates an OpenAI request for commit message generation.
38///
39/// # Arguments
40/// * `diff` - The git diff to generate a commit message for
41/// * `max_tokens` - Maximum number of tokens allowed for the response
42/// * `model` - The AI model to use for generation
43///
44/// # Returns
45/// * `Result<openai::Request>` - The prepared request
46///
47/// Note: This function is public only for testing purposes
48#[doc(hidden)]
49pub fn create_commit_request(diff: String, max_tokens: usize, model: Model) -> Result<openai::Request> {
50  profile!("Prepare OpenAI request");
51  let template = get_instruction_template()?;
52  Ok(openai::Request {
53    system: template,
54    prompt: diff,
55    max_tokens: max_tokens.try_into().unwrap_or(u16::MAX),
56    model
57  })
58}
59
60/// Generates a commit message using the AI model.
61/// Now uses the multi-step approach by default with fallback to single-step.
62///
63/// # Arguments
64/// * `diff` - The git diff to generate a commit message for
65/// * `max_tokens` - Maximum number of tokens allowed for the response
66/// * `model` - The AI model to use for generation
67/// * `settings` - Optional application settings to customize the request
68///
69/// # Returns
70/// * `Result<openai::Response>` - The generated commit message or an error
71///
72/// # Errors
73/// Returns an error if:
74/// - max_tokens is 0
75/// - OpenAI API call fails
76pub async fn generate(patch: String, remaining_tokens: usize, model: Model, settings: Option<&AppConfig>) -> Result<openai::Response> {
77  profile!("Generate commit message");
78
79  if remaining_tokens == 0 {
80    bail!("Maximum token count must be greater than zero")
81  }
82
83  // Try multi-step approach first
84  let max_length = settings
85    .and_then(|s| s.max_commit_length)
86    .or(config::APP_CONFIG.max_commit_length);
87
88  // Check if we have a valid API key configuration
89  let has_valid_api_key = if let Some(custom_settings) = settings {
90    custom_settings
91      .openai_api_key
92      .as_ref()
93      .map(|key| !key.is_empty() && key != "<PLACE HOLDER FOR YOUR API KEY>")
94      .unwrap_or(false)
95  } else {
96    // Check environment variable or config
97    config::APP_CONFIG
98      .openai_api_key
99      .as_ref()
100      .map(|key| !key.is_empty() && key != "<PLACE HOLDER FOR YOUR API KEY>")
101      .unwrap_or(false)
102      || std::env::var("OPENAI_API_KEY")
103        .map(|key| !key.is_empty())
104        .unwrap_or(false)
105  };
106
107  if !has_valid_api_key {
108    bail!("OpenAI API key not configured. Please set your API key using:\n  git-ai config set openai-api-key <your-key>\nor set the OPENAI_API_KEY environment variable.");
109  }
110
111  // Use custom settings if provided
112  if let Some(custom_settings) = settings {
113    if let Some(api_key) = &custom_settings.openai_api_key {
114      if !api_key.is_empty() && api_key != "<PLACE HOLDER FOR YOUR API KEY>" {
115        match openai::create_openai_config(custom_settings) {
116          Ok(config) => {
117            let client = Client::with_config(config);
118            let model_str = model.to_string();
119
120            match generate_commit_message_multi_step(&client, &model_str, &patch, max_length).await {
121              Ok(message) => return Ok(openai::Response { response: message }),
122              Err(e) => {
123                // Check if it's an API key error
124                if e.to_string().contains("invalid_api_key") || e.to_string().contains("Incorrect API key") {
125                  bail!("Invalid OpenAI API key. Please check your API key configuration.");
126                }
127                log::warn!("Multi-step generation with custom settings failed: {e}");
128                if let Some(session) = debug_output::debug_session() {
129                  session.set_multi_step_error(e.to_string());
130                }
131              }
132            }
133          }
134          Err(e) => {
135            // If config creation fails due to API key, propagate the error
136            return Err(e);
137          }
138        }
139      }
140    }
141  } else {
142    // Try with default settings
143    if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
144      if !api_key.is_empty() {
145        let client = Client::new();
146        let model_str = model.to_string();
147
148        match generate_commit_message_multi_step(&client, &model_str, &patch, max_length).await {
149          Ok(message) => return Ok(openai::Response { response: message }),
150          Err(e) => {
151            // Check if it's an API key error
152            if e.to_string().contains("invalid_api_key") || e.to_string().contains("Incorrect API key") {
153              bail!("Invalid OpenAI API key. Please check your API key configuration.");
154            }
155            log::warn!("Multi-step generation failed: {e}");
156            if let Some(session) = debug_output::debug_session() {
157              session.set_multi_step_error(e.to_string());
158            }
159          }
160        }
161      }
162    }
163  }
164
165  // Try local multi-step generation
166  match generate_commit_message_local(&patch, max_length) {
167    Ok(message) => return Ok(openai::Response { response: message }),
168    Err(e) => {
169      log::warn!("Local multi-step generation failed: {e}");
170    }
171  }
172
173  // Mark that we're using single-step fallback
174  if let Some(session) = debug_output::debug_session() {
175    session.set_single_step_success(true);
176  }
177
178  // Fallback to original single-step approach
179  let request = create_commit_request(patch, remaining_tokens, model)?;
180
181  // Use custom settings if provided, otherwise use global config
182  match settings {
183    Some(custom_settings) => {
184      // Create a client with custom settings
185      match openai::create_openai_config(custom_settings) {
186        Ok(config) => openai::call_with_config(request, config).await,
187        Err(e) => Err(e)
188      }
189    }
190    None => {
191      // Use the default global config
192      openai::call(request).await
193    }
194  }
195}
196
197pub fn token_used(model: &Model) -> Result<usize> {
198  get_instruction_token_count(model)
199}
200
201/// Calculates the number of tokens used by the instruction template.
202///
203/// # Arguments
204/// * `model` - The AI model to use for token counting
205///
206/// # Returns
207/// * `Result<usize>` - The number of tokens used or an error
208pub fn get_instruction_token_count(model: &Model) -> Result<usize> {
209  profile!("Calculate instruction tokens");
210  let template = get_instruction_template()?;
211  model.count_tokens(&template)
212}
213
214#[cfg(test)]
215mod tests {
216  use super::*;
217
218  #[tokio::test]
219  async fn test_missing_api_key_error() {
220    // Create settings with no API key
221    let settings = AppConfig {
222      openai_api_key:    None,
223      model:             Some("gpt-4o-mini".to_string()),
224      max_tokens:        Some(1024),
225      max_commit_length: Some(72),
226      timeout:           Some(30)
227    };
228
229    // Temporarily clear the environment variable
230    let original_key = std::env::var("OPENAI_API_KEY").ok();
231    std::env::remove_var("OPENAI_API_KEY");
232
233    // Test that generate returns an error for missing API key
234    let result = generate(
235      "diff --git a/test.txt b/test.txt\n+Hello World".to_string(),
236      1024,
237      Model::GPT41Mini,
238      Some(&settings)
239    )
240    .await;
241
242    // Restore original environment variable if it existed
243    if let Some(key) = original_key {
244      std::env::set_var("OPENAI_API_KEY", key);
245    }
246
247    assert!(result.is_err());
248    let error_message = result.unwrap_err().to_string();
249    assert!(
250      error_message.contains("OpenAI API key not configured"),
251      "Expected error message about missing API key, got: {}",
252      error_message
253    );
254  }
255
256  #[tokio::test]
257  async fn test_invalid_api_key_error() {
258    // Create settings with invalid API key
259    let settings = AppConfig {
260      openai_api_key:    Some("<PLACE HOLDER FOR YOUR API KEY>".to_string()),
261      model:             Some("gpt-4o-mini".to_string()),
262      max_tokens:        Some(1024),
263      max_commit_length: Some(72),
264      timeout:           Some(30)
265    };
266
267    // Test that generate returns an error for invalid API key
268    let result = generate(
269      "diff --git a/test.txt b/test.txt\n+Hello World".to_string(),
270      1024,
271      Model::GPT41Mini,
272      Some(&settings)
273    )
274    .await;
275
276    assert!(result.is_err());
277    let error_message = result.unwrap_err().to_string();
278    assert!(
279      error_message.contains("OpenAI API key not configured"),
280      "Expected error message about invalid API key, got: {}",
281      error_message
282    );
283  }
284}