1use std::default::Default;
2use std::fmt::{self, Display};
3use std::str::FromStr;
4use std::sync::OnceLock;
5
6use anyhow::{bail, Result};
7use serde::{Deserialize, Serialize};
8use tiktoken_rs::CoreBPE;
9use tiktoken_rs::model::get_context_size;
10use async_openai::types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
11use colored::Colorize;
12
13use crate::profile;
14use crate::config::AppConfig;
16
17static TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
19
20const MODEL_GPT4_1: &str = "gpt-4.1";
22const MODEL_GPT4_1_MINI: &str = "gpt-4.1-mini";
23const MODEL_GPT4_1_NANO: &str = "gpt-4.1-nano";
24const MODEL_GPT4_5: &str = "gpt-4.5";
25const DEFAULT_MODEL_NAME: &str = "gpt-4.1";
27
28#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, Serialize, Deserialize, Default)]
31pub enum Model {
32 #[default]
34 GPT41,
35 GPT41Mini,
37 GPT41Nano,
39 GPT45
41}
42
43impl Model {
44 pub fn count_tokens(&self, text: &str) -> Result<usize> {
53 profile!("Count tokens");
54
55 if text.is_empty() {
57 return Ok(0);
58 }
59
60 let tokenizer = TOKENIZER.get_or_init(|| get_tokenizer(self.as_ref()));
63
64 let tokens = tokenizer.encode_ordinary(text);
66 Ok(tokens.len())
67 }
68
69 pub fn context_size(&self) -> usize {
74 profile!("Get context size");
75 get_context_size(self.as_ref())
76 }
77
78 pub(crate) fn truncate(&self, text: &str, max_tokens: usize) -> Result<String> {
87 profile!("Truncate text");
88 self.walk_truncate(text, max_tokens, usize::MAX)
89 }
90
91 pub(crate) fn walk_truncate(&self, text: &str, max_tokens: usize, within: usize) -> Result<String> {
102 profile!("Walk truncate iteration");
103 log::debug!("max_tokens: {max_tokens}, within: {within}");
104
105 let current_tokens = self.count_tokens(text)?;
107 if current_tokens <= max_tokens {
108 return Ok(text.to_string());
109 }
110
111 let words: Vec<&str> = text.split_whitespace().collect();
113 let mut left = 0;
114 let mut right = words.len();
115 let mut best_fit = String::new();
116 let mut best_tokens = 0;
117
118 while left < right {
120 let mid = (left + right).div_ceil(2);
121 let candidate = words[..mid].join(" ");
122 let tokens = self.count_tokens(&candidate)?;
123
124 if tokens <= max_tokens {
125 best_fit = candidate;
127 best_tokens = tokens;
128 left = mid;
129 } else {
130 right = mid - 1;
132 }
133
134 if best_tokens > 0 && max_tokens.saturating_sub(best_tokens) <= within {
136 break;
137 }
138 }
139
140 if best_fit.is_empty() && !words.is_empty() {
142 best_fit = words[0].to_string();
144 let tokens = self.count_tokens(&best_fit)?;
145
146 if tokens > max_tokens {
148 let char_limit = max_tokens * 3;
151 best_fit = text.chars().take(char_limit).collect();
152
153 while self.count_tokens(&best_fit)? > max_tokens && !best_fit.is_empty() {
155 let new_len = (best_fit.len() * 9) / 10;
157 best_fit = best_fit.chars().take(new_len).collect();
158 }
159 }
160 }
161
162 Ok(best_fit)
163 }
164}
165
166impl AsRef<str> for Model {
167 fn as_ref(&self) -> &str {
168 match self {
169 Model::GPT41 => MODEL_GPT4_1,
170 Model::GPT41Mini => MODEL_GPT4_1_MINI,
171 Model::GPT41Nano => MODEL_GPT4_1_NANO,
172 Model::GPT45 => MODEL_GPT4_5
173 }
174 }
175}
176
177impl From<&Model> for String {
179 fn from(model: &Model) -> Self {
180 model.as_ref().to_string()
181 }
182}
183
184impl Model {
186 pub fn as_str(&self) -> &str {
187 self.as_ref()
188 }
189}
190
191impl FromStr for Model {
192 type Err = anyhow::Error;
193
194 fn from_str(s: &str) -> Result<Self> {
195 let normalized = s.trim().to_lowercase();
196 match normalized.as_str() {
197 "gpt-4.1" => Ok(Model::GPT41),
198 "gpt-4.1-mini" => Ok(Model::GPT41Mini),
199 "gpt-4.1-nano" => Ok(Model::GPT41Nano),
200 "gpt-4.5" => Ok(Model::GPT45),
201 "gpt-4" | "gpt-4o" => {
203 log::warn!(
204 "Model '{}' is deprecated. Mapping to 'gpt-4.1'. \
205 Please update your configuration with: git ai config set model gpt-4.1",
206 s
207 );
208 Ok(Model::GPT41)
209 }
210 "gpt-4o-mini" | "gpt-3.5-turbo" => {
211 log::warn!(
212 "Model '{}' is deprecated. Mapping to 'gpt-4.1-mini'. \
213 Please update your configuration with: git ai config set model gpt-4.1-mini",
214 s
215 );
216 Ok(Model::GPT41Mini)
217 }
218 model =>
219 bail!(
220 "Invalid model name: '{}'. Supported models: gpt-4.1, gpt-4.1-mini, gpt-4.1-nano, gpt-4.5",
221 model
222 ),
223 }
224 }
225}
226
227impl Display for Model {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 write!(f, "{}", self.as_ref())
230 }
231}
232
233impl From<&str> for Model {
235 fn from(s: &str) -> Self {
236 s.parse().unwrap_or_else(|e| {
237 log::error!("Failed to parse model '{}': {}. Falling back to default model 'gpt-4.1'.", s, e);
238 Model::default()
239 })
240 }
241}
242
243impl From<String> for Model {
244 fn from(s: String) -> Self {
245 s.as_str().into()
246 }
247}
248
249fn get_tokenizer(_model_str: &str) -> CoreBPE {
250 tiktoken_rs::cl100k_base().expect("Failed to create tokenizer")
253}
254
255pub async fn run(settings: AppConfig, content: String) -> Result<String> {
256 let model_str = settings.model.as_deref().unwrap_or(DEFAULT_MODEL_NAME);
257
258 let client = async_openai::Client::new();
259 let prompt = content; let model: Model = settings
262 .model
263 .as_deref()
264 .unwrap_or(DEFAULT_MODEL_NAME)
265 .into();
266 let tokens = model.count_tokens(&prompt)?;
267
268 if tokens > model.context_size() {
269 bail!(
270 "Input too large: {} tokens. Max {} tokens for {}",
271 tokens.to_string().red(),
272 model.context_size().to_string().green(),
273 model_str.yellow()
274 );
275 }
276
277 let temperature_value = 0.7;
279
280 log::info!(
281 "Using model: {}, Tokens: {}, Max tokens: {}, Temperature: {}",
282 model_str.yellow(),
283 tokens.to_string().green(),
284 (model.context_size() - tokens).to_string().green(),
286 temperature_value.to_string().blue() );
288
289 let request = CreateChatCompletionRequestArgs::default()
290 .model(model_str)
291 .messages([ChatCompletionRequestUserMessageArgs::default()
292 .content(prompt)
293 .build()?
294 .into()])
295 .temperature(temperature_value) .max_tokens((model.context_size() - tokens) as u16)
298 .build()?;
299
300 profile!("OpenAI API call");
301 let response = client.chat().create(request).await?;
302 let result = response.choices[0]
303 .message
304 .content
305 .clone()
306 .unwrap_or_default();
307
308 if result.is_empty() {
309 bail!("No response from OpenAI");
310 }
311
312 Ok(result.trim().to_string())
313}