ai/
model.rs

1use std::default::Default;
2use std::fmt::{self, Display};
3use std::str::FromStr;
4use std::sync::OnceLock;
5
6use anyhow::{bail, Result};
7use serde::{Deserialize, Serialize};
8use tiktoken_rs::CoreBPE;
9use tiktoken_rs::model::get_context_size;
10use async_openai::types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
11use colored::Colorize;
12
13use crate::profile;
14// use crate::config::format_prompt; // Temporarily comment out
15use crate::config::App as Settings; // Use App as Settings
16
17// Cached tokenizer for performance
18static TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
19
20// Model identifiers - using screaming case for constants
21const MODEL_GPT4: &str = "gpt-4";
22const MODEL_GPT4_OPTIMIZED: &str = "gpt-4o";
23const MODEL_GPT4_MINI: &str = "gpt-4o-mini";
24const MODEL_GPT4_1: &str = "gpt-4.1";
25// TODO: Get this from config.rs or a shared constants module
26const DEFAULT_MODEL_NAME: &str = "gpt-4.1";
27
28/// Represents the available AI models for commit message generation.
29/// Each model has different capabilities and token limits.
30#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, Serialize, Deserialize, Default)]
31pub enum Model {
32  /// Standard GPT-4 model
33  GPT4,
34  /// Optimized GPT-4 model for better performance
35  GPT4o,
36  /// Mini version of optimized GPT-4 for faster processing
37  GPT4oMini,
38  /// Default model - GPT-4.1 latest version
39  #[default]
40  GPT41
41}
42
43impl Model {
44  /// Counts the number of tokens in the given text for the current model.
45  /// This is used to ensure we stay within the model's token limits.
46  ///
47  /// # Arguments
48  /// * `text` - The text to count tokens for
49  ///
50  /// # Returns
51  /// * `Result<usize>` - The number of tokens or an error
52  pub fn count_tokens(&self, text: &str) -> Result<usize> {
53    profile!("Count tokens");
54
55    // Fast path for empty text
56    if text.is_empty() {
57      return Ok(0);
58    }
59
60    // Always use the proper tokenizer for accurate counts
61    // We cannot afford to underestimate tokens as it may cause API failures
62    let tokenizer = TOKENIZER.get_or_init(|| {
63      let model_str: &str = self.into();
64      get_tokenizer(model_str)
65    });
66
67    // Use direct tokenization for accurate token count
68    let tokens = tokenizer.encode_ordinary(text);
69    Ok(tokens.len())
70  }
71
72  /// Gets the maximum context size for the current model.
73  ///
74  /// # Returns
75  /// * `usize` - The maximum number of tokens the model can process
76  pub fn context_size(&self) -> usize {
77    profile!("Get context size");
78    let model_str: &str = self.into();
79    get_context_size(model_str)
80  }
81
82  /// Truncates the given text to fit within the specified token limit.
83  ///
84  /// # Arguments
85  /// * `text` - The text to truncate
86  /// * `max_tokens` - The maximum number of tokens allowed
87  ///
88  /// # Returns
89  /// * `Result<String>` - The truncated text or an error
90  pub(crate) fn truncate(&self, text: &str, max_tokens: usize) -> Result<String> {
91    profile!("Truncate text");
92    self.walk_truncate(text, max_tokens, usize::MAX)
93  }
94
95  /// Recursively truncates text to fit within token limits while maintaining coherence.
96  /// Uses a binary search-like approach to find the optimal truncation point.
97  ///
98  /// # Arguments
99  /// * `text` - The text to truncate
100  /// * `max_tokens` - The maximum number of tokens allowed
101  /// * `within` - The maximum allowed deviation from target token count
102  ///
103  /// # Returns
104  /// * `Result<String>` - The truncated text or an error
105  pub(crate) fn walk_truncate(&self, text: &str, max_tokens: usize, within: usize) -> Result<String> {
106    profile!("Walk truncate iteration");
107    log::debug!("max_tokens: {max_tokens}, within: {within}");
108
109    // Check if text already fits within token limit
110    let current_tokens = self.count_tokens(text)?;
111    if current_tokens <= max_tokens {
112      return Ok(text.to_string());
113    }
114
115    // Binary search approach to find the right truncation point
116    let words: Vec<&str> = text.split_whitespace().collect();
117    let mut left = 0;
118    let mut right = words.len();
119    let mut best_fit = String::new();
120    let mut best_tokens = 0;
121
122    // Perform binary search to find optimal word count
123    while left < right {
124      let mid = (left + right).div_ceil(2);
125      let candidate = words[..mid].join(" ");
126      let tokens = self.count_tokens(&candidate)?;
127
128      if tokens <= max_tokens {
129        // This fits, try to find a longer text that still fits
130        best_fit = candidate;
131        best_tokens = tokens;
132        left = mid;
133      } else {
134        // Too many tokens, try shorter text
135        right = mid - 1;
136      }
137
138      // If we're close enough to the target, we can stop
139      if best_tokens > 0 && max_tokens.saturating_sub(best_tokens) <= within {
140        break;
141      }
142    }
143
144    // If we couldn't find any fitting text, truncate more aggressively
145    if best_fit.is_empty() && !words.is_empty() {
146      // Try with just one word
147      best_fit = words[0].to_string();
148      let tokens = self.count_tokens(&best_fit)?;
149
150      // If even one word is too long, truncate at character level
151      if tokens > max_tokens {
152        // Estimate character limit based on token limit
153        // Conservative estimate: ~3 chars per token
154        let char_limit = max_tokens * 3;
155        best_fit = text.chars().take(char_limit).collect();
156
157        // Ensure we don't exceed token limit
158        while self.count_tokens(&best_fit)? > max_tokens && !best_fit.is_empty() {
159          // Remove last 10% of characters
160          let new_len = (best_fit.len() * 9) / 10;
161          best_fit = best_fit.chars().take(new_len).collect();
162        }
163      }
164    }
165
166    Ok(best_fit)
167  }
168}
169
170impl From<&Model> for &str {
171  fn from(model: &Model) -> Self {
172    match model {
173      Model::GPT4o => MODEL_GPT4_OPTIMIZED,
174      Model::GPT4 => MODEL_GPT4,
175      Model::GPT4oMini => MODEL_GPT4_MINI,
176      Model::GPT41 => MODEL_GPT4_1
177    }
178  }
179}
180
181impl FromStr for Model {
182  type Err = anyhow::Error;
183
184  fn from_str(s: &str) -> Result<Self> {
185    match s.trim().to_lowercase().as_str() {
186      "gpt-4o" => Ok(Model::GPT4o),
187      "gpt-4" => Ok(Model::GPT4),
188      "gpt-4o-mini" => Ok(Model::GPT4oMini),
189      "gpt-4.1" => Ok(Model::GPT41),
190      model => bail!("Invalid model name: {}", model)
191    }
192  }
193}
194
195impl Display for Model {
196  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197    write!(f, "{}", <&str>::from(self))
198  }
199}
200
201// Implement conversion from string types to Model with fallback to default
202impl From<&str> for Model {
203  fn from(s: &str) -> Self {
204    s.parse().unwrap_or_default()
205  }
206}
207
208impl From<String> for Model {
209  fn from(s: String) -> Self {
210    s.as_str().into()
211  }
212}
213
214fn get_tokenizer(_model_str: &str) -> CoreBPE {
215  // TODO: This should be based on the model string, but for now we'll just use cl100k_base
216  // which is used by gpt-3.5-turbo and gpt-4
217  tiktoken_rs::cl100k_base().expect("Failed to create tokenizer")
218}
219
220pub async fn run(settings: Settings, content: String) -> Result<String> {
221  let model_str = settings.model.as_deref().unwrap_or(DEFAULT_MODEL_NAME);
222
223  let client = async_openai::Client::new();
224  // let prompt = format_prompt(&content, &settings.prompt(), settings.template())?; // Temporarily comment out
225  let prompt = content; // Use raw content as prompt for now
226  let model: Model = settings
227    .model
228    .as_deref()
229    .unwrap_or(DEFAULT_MODEL_NAME)
230    .into();
231  let tokens = model.count_tokens(&prompt)?;
232
233  if tokens > model.context_size() {
234    bail!(
235      "Input too large: {} tokens. Max {} tokens for {}",
236      tokens.to_string().red(),
237      model.context_size().to_string().green(),
238      model_str.yellow()
239    );
240  }
241
242  // TODO: Make temperature configurable
243  let temperature_value = 0.7;
244
245  log::info!(
246    "Using model: {}, Tokens: {}, Max tokens: {}, Temperature: {}",
247    model_str.yellow(),
248    tokens.to_string().green(),
249    // TODO: Make max_tokens configurable
250    (model.context_size() - tokens).to_string().green(),
251    temperature_value.to_string().blue() // Use temperature_value
252  );
253
254  let request = CreateChatCompletionRequestArgs::default()
255    .model(model_str)
256    .messages([ChatCompletionRequestUserMessageArgs::default()
257      .content(prompt)
258      .build()?
259      .into()])
260    .temperature(temperature_value) // Use temperature_value
261    // TODO: Make max_tokens configurable
262    .max_tokens((model.context_size() - tokens) as u16)
263    .build()?;
264
265  profile!("OpenAI API call");
266  let response = client.chat().create(request).await?;
267  let result = response.choices[0]
268    .message
269    .content
270    .clone()
271    .unwrap_or_default();
272
273  if result.is_empty() {
274    bail!("No response from OpenAI");
275  }
276
277  Ok(result.trim().to_string())
278}