ai/
model.rs

1use std::default::Default;
2use std::fmt::{self, Display};
3use std::str::FromStr;
4use std::sync::OnceLock;
5
6use anyhow::{bail, Result};
7use serde::{Deserialize, Serialize};
8use tiktoken_rs::CoreBPE;
9use tiktoken_rs::model::get_context_size;
10use async_openai::types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
11use colored::Colorize;
12
13use crate::profile;
14// use crate::config::format_prompt; // Temporarily comment out
15use crate::config::AppConfig;
16
17// Cached tokenizer for performance
18static TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
19
20// Model identifiers - using screaming case for constants
21const MODEL_GPT4_1: &str = "gpt-4.1";
22const MODEL_GPT4_1_MINI: &str = "gpt-4.1-mini";
23const MODEL_GPT4_1_NANO: &str = "gpt-4.1-nano";
24const MODEL_GPT4_5: &str = "gpt-4.5";
25// TODO: Get this from config.rs or a shared constants module
26const DEFAULT_MODEL_NAME: &str = "gpt-4.1";
27
28/// Represents the available AI models for commit message generation.
29/// Each model has different capabilities and token limits.
30#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, Serialize, Deserialize, Default)]
31pub enum Model {
32  /// Default model - GPT-4.1 latest version
33  #[default]
34  GPT41,
35  /// Mini version of GPT-4.1 for faster processing
36  GPT41Mini,
37  /// Nano version of GPT-4.1 for very fast processing
38  GPT41Nano,
39  /// GPT-4.5 model for advanced capabilities
40  GPT45
41}
42
43impl Model {
44  /// Counts the number of tokens in the given text for the current model.
45  /// This is used to ensure we stay within the model's token limits.
46  ///
47  /// # Arguments
48  /// * `text` - The text to count tokens for
49  ///
50  /// # Returns
51  /// * `Result<usize>` - The number of tokens or an error
52  pub fn count_tokens(&self, text: &str) -> Result<usize> {
53    profile!("Count tokens");
54
55    // Fast path for empty text
56    if text.is_empty() {
57      return Ok(0);
58    }
59
60    // Always use the proper tokenizer for accurate counts
61    // We cannot afford to underestimate tokens as it may cause API failures
62    let tokenizer = TOKENIZER.get_or_init(|| get_tokenizer(self.as_ref()));
63
64    // Use direct tokenization for accurate token count
65    let tokens = tokenizer.encode_ordinary(text);
66    Ok(tokens.len())
67  }
68
69  /// Gets the maximum context size for the current model.
70  ///
71  /// # Returns
72  /// * `usize` - The maximum number of tokens the model can process
73  pub fn context_size(&self) -> usize {
74    profile!("Get context size");
75    get_context_size(self.as_ref())
76  }
77
78  /// Truncates the given text to fit within the specified token limit.
79  ///
80  /// # Arguments
81  /// * `text` - The text to truncate
82  /// * `max_tokens` - The maximum number of tokens allowed
83  ///
84  /// # Returns
85  /// * `Result<String>` - The truncated text or an error
86  pub(crate) fn truncate(&self, text: &str, max_tokens: usize) -> Result<String> {
87    profile!("Truncate text");
88    self.walk_truncate(text, max_tokens, usize::MAX)
89  }
90
91  /// Recursively truncates text to fit within token limits while maintaining coherence.
92  /// Uses a binary search-like approach to find the optimal truncation point.
93  ///
94  /// # Arguments
95  /// * `text` - The text to truncate
96  /// * `max_tokens` - The maximum number of tokens allowed
97  /// * `within` - The maximum allowed deviation from target token count
98  ///
99  /// # Returns
100  /// * `Result<String>` - The truncated text or an error
101  pub(crate) fn walk_truncate(&self, text: &str, max_tokens: usize, within: usize) -> Result<String> {
102    profile!("Walk truncate iteration");
103    log::debug!("max_tokens: {max_tokens}, within: {within}");
104
105    // Check if text already fits within token limit
106    let current_tokens = self.count_tokens(text)?;
107    if current_tokens <= max_tokens {
108      return Ok(text.to_string());
109    }
110
111    // Binary search approach to find the right truncation point
112    let words: Vec<&str> = text.split_whitespace().collect();
113    let mut left = 0;
114    let mut right = words.len();
115    let mut best_fit = String::new();
116    let mut best_tokens = 0;
117
118    // Perform binary search to find optimal word count
119    while left < right {
120      let mid = (left + right).div_ceil(2);
121      let candidate = words[..mid].join(" ");
122      let tokens = self.count_tokens(&candidate)?;
123
124      if tokens <= max_tokens {
125        // This fits, try to find a longer text that still fits
126        best_fit = candidate;
127        best_tokens = tokens;
128        left = mid;
129      } else {
130        // Too many tokens, try shorter text
131        right = mid - 1;
132      }
133
134      // If we're close enough to the target, we can stop
135      if best_tokens > 0 && max_tokens.saturating_sub(best_tokens) <= within {
136        break;
137      }
138    }
139
140    // If we couldn't find any fitting text, truncate more aggressively
141    if best_fit.is_empty() && !words.is_empty() {
142      // Try with just one word
143      best_fit = words[0].to_string();
144      let tokens = self.count_tokens(&best_fit)?;
145
146      // If even one word is too long, truncate at character level
147      if tokens > max_tokens {
148        // Estimate character limit based on token limit
149        // Conservative estimate: ~3 chars per token
150        let char_limit = max_tokens * 3;
151        best_fit = text.chars().take(char_limit).collect();
152
153        // Ensure we don't exceed token limit
154        while self.count_tokens(&best_fit)? > max_tokens && !best_fit.is_empty() {
155          // Remove last 10% of characters
156          let new_len = (best_fit.len() * 9) / 10;
157          best_fit = best_fit.chars().take(new_len).collect();
158        }
159      }
160    }
161
162    Ok(best_fit)
163  }
164}
165
166impl AsRef<str> for Model {
167  fn as_ref(&self) -> &str {
168    match self {
169      Model::GPT41 => MODEL_GPT4_1,
170      Model::GPT41Mini => MODEL_GPT4_1_MINI,
171      Model::GPT41Nano => MODEL_GPT4_1_NANO,
172      Model::GPT45 => MODEL_GPT4_5
173    }
174  }
175}
176
177// Keep conversion to String for cases that need owned strings
178impl From<&Model> for String {
179  fn from(model: &Model) -> Self {
180    model.as_ref().to_string()
181  }
182}
183
184// Keep the old impl for backwards compatibility where possible
185impl Model {
186  pub fn as_str(&self) -> &str {
187    self.as_ref()
188  }
189}
190
191impl FromStr for Model {
192  type Err = anyhow::Error;
193
194  fn from_str(s: &str) -> Result<Self> {
195    let normalized = s.trim().to_lowercase();
196    match normalized.as_str() {
197      "gpt-4.1" => Ok(Model::GPT41),
198      "gpt-4.1-mini" => Ok(Model::GPT41Mini),
199      "gpt-4.1-nano" => Ok(Model::GPT41Nano),
200      "gpt-4.5" => Ok(Model::GPT45),
201      // Backward compatibility for deprecated models - map to closest GPT-4.1 equivalent
202      "gpt-4" | "gpt-4o" => {
203        log::warn!(
204          "Model '{}' is deprecated. Mapping to 'gpt-4.1'. \
205          Please update your configuration with: git ai config set model gpt-4.1",
206          s
207        );
208        Ok(Model::GPT41)
209      }
210      "gpt-4o-mini" | "gpt-3.5-turbo" => {
211        log::warn!(
212          "Model '{}' is deprecated. Mapping to 'gpt-4.1-mini'. \
213          Please update your configuration with: git ai config set model gpt-4.1-mini",
214          s
215        );
216        Ok(Model::GPT41Mini)
217      }
218      model =>
219        bail!(
220          "Invalid model name: '{}'. Supported models: gpt-4.1, gpt-4.1-mini, gpt-4.1-nano, gpt-4.5",
221          model
222        ),
223    }
224  }
225}
226
227impl Display for Model {
228  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229    write!(f, "{}", self.as_ref())
230  }
231}
232
233// Implement conversion from string types to Model with fallback to default
234impl From<&str> for Model {
235  fn from(s: &str) -> Self {
236    s.parse().unwrap_or_else(|e| {
237      log::error!("Failed to parse model '{}': {}. Falling back to default model 'gpt-4.1'.", s, e);
238      Model::default()
239    })
240  }
241}
242
243impl From<String> for Model {
244  fn from(s: String) -> Self {
245    s.as_str().into()
246  }
247}
248
249fn get_tokenizer(_model_str: &str) -> CoreBPE {
250  // TODO: This should be based on the model string, but for now we'll just use cl100k_base
251  // which is used by gpt-3.5-turbo and gpt-4
252  tiktoken_rs::cl100k_base().expect("Failed to create tokenizer")
253}
254
255pub async fn run(settings: AppConfig, content: String) -> Result<String> {
256  let model_str = settings.model.as_deref().unwrap_or(DEFAULT_MODEL_NAME);
257
258  let client = async_openai::Client::new();
259  // let prompt = format_prompt(&content, &settings.prompt(), settings.template())?; // Temporarily comment out
260  let prompt = content; // Use raw content as prompt for now
261  let model: Model = settings
262    .model
263    .as_deref()
264    .unwrap_or(DEFAULT_MODEL_NAME)
265    .into();
266  let tokens = model.count_tokens(&prompt)?;
267
268  if tokens > model.context_size() {
269    bail!(
270      "Input too large: {} tokens. Max {} tokens for {}",
271      tokens.to_string().red(),
272      model.context_size().to_string().green(),
273      model_str.yellow()
274    );
275  }
276
277  // TODO: Make temperature configurable
278  let temperature_value = 0.7;
279
280  log::info!(
281    "Using model: {}, Tokens: {}, Max tokens: {}, Temperature: {}",
282    model_str.yellow(),
283    tokens.to_string().green(),
284    // TODO: Make max_tokens configurable
285    (model.context_size() - tokens).to_string().green(),
286    temperature_value.to_string().blue() // Use temperature_value
287  );
288
289  let request = CreateChatCompletionRequestArgs::default()
290    .model(model_str)
291    .messages([ChatCompletionRequestUserMessageArgs::default()
292      .content(prompt)
293      .build()?
294      .into()])
295    .temperature(temperature_value) // Use temperature_value
296    // TODO: Make max_tokens configurable
297    .max_tokens((model.context_size() - tokens) as u16)
298    .build()?;
299
300  profile!("OpenAI API call");
301  let response = client.chat().create(request).await?;
302  let result = response.choices[0]
303    .message
304    .content
305    .clone()
306    .unwrap_or_default();
307
308  if result.is_empty() {
309    bail!("No response from OpenAI");
310  }
311
312  Ok(result.trim().to_string())
313}