1use std::default::Default;
2use std::fmt::{self, Display};
3use std::str::FromStr;
4use std::sync::OnceLock;
5
6use anyhow::{bail, Result};
7use serde::{Deserialize, Serialize};
8use tiktoken_rs::CoreBPE;
9use tiktoken_rs::model::get_context_size;
10use async_openai::types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
11use colored::Colorize;
12
13use crate::profile;
14use crate::config::App as Settings; static TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
19
20const MODEL_GPT4: &str = "gpt-4";
22const MODEL_GPT4_OPTIMIZED: &str = "gpt-4o";
23const MODEL_GPT4_MINI: &str = "gpt-4o-mini";
24const MODEL_GPT4_1: &str = "gpt-4.1";
25const DEFAULT_MODEL_NAME: &str = "gpt-4.1";
27
28#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, Serialize, Deserialize, Default)]
31pub enum Model {
32 GPT4,
34 GPT4o,
36 GPT4oMini,
38 #[default]
40 GPT41
41}
42
43impl Model {
44 pub fn count_tokens(&self, text: &str) -> Result<usize> {
53 profile!("Count tokens");
54
55 if text.is_empty() {
57 return Ok(0);
58 }
59
60 let tokenizer = TOKENIZER.get_or_init(|| {
63 let model_str: &str = self.into();
64 get_tokenizer(model_str)
65 });
66
67 let tokens = tokenizer.encode_ordinary(text);
69 Ok(tokens.len())
70 }
71
72 pub fn context_size(&self) -> usize {
77 profile!("Get context size");
78 let model_str: &str = self.into();
79 get_context_size(model_str)
80 }
81
82 pub(crate) fn truncate(&self, text: &str, max_tokens: usize) -> Result<String> {
91 profile!("Truncate text");
92 self.walk_truncate(text, max_tokens, usize::MAX)
93 }
94
95 pub(crate) fn walk_truncate(&self, text: &str, max_tokens: usize, within: usize) -> Result<String> {
106 profile!("Walk truncate iteration");
107 log::debug!("max_tokens: {max_tokens}, within: {within}");
108
109 let current_tokens = self.count_tokens(text)?;
111 if current_tokens <= max_tokens {
112 return Ok(text.to_string());
113 }
114
115 let words: Vec<&str> = text.split_whitespace().collect();
117 let mut left = 0;
118 let mut right = words.len();
119 let mut best_fit = String::new();
120 let mut best_tokens = 0;
121
122 while left < right {
124 let mid = (left + right).div_ceil(2);
125 let candidate = words[..mid].join(" ");
126 let tokens = self.count_tokens(&candidate)?;
127
128 if tokens <= max_tokens {
129 best_fit = candidate;
131 best_tokens = tokens;
132 left = mid;
133 } else {
134 right = mid - 1;
136 }
137
138 if best_tokens > 0 && max_tokens.saturating_sub(best_tokens) <= within {
140 break;
141 }
142 }
143
144 if best_fit.is_empty() && !words.is_empty() {
146 best_fit = words[0].to_string();
148 let tokens = self.count_tokens(&best_fit)?;
149
150 if tokens > max_tokens {
152 let char_limit = max_tokens * 3;
155 best_fit = text.chars().take(char_limit).collect();
156
157 while self.count_tokens(&best_fit)? > max_tokens && !best_fit.is_empty() {
159 let new_len = (best_fit.len() * 9) / 10;
161 best_fit = best_fit.chars().take(new_len).collect();
162 }
163 }
164 }
165
166 Ok(best_fit)
167 }
168}
169
170impl From<&Model> for &str {
171 fn from(model: &Model) -> Self {
172 match model {
173 Model::GPT4o => MODEL_GPT4_OPTIMIZED,
174 Model::GPT4 => MODEL_GPT4,
175 Model::GPT4oMini => MODEL_GPT4_MINI,
176 Model::GPT41 => MODEL_GPT4_1
177 }
178 }
179}
180
181impl FromStr for Model {
182 type Err = anyhow::Error;
183
184 fn from_str(s: &str) -> Result<Self> {
185 match s.trim().to_lowercase().as_str() {
186 "gpt-4o" => Ok(Model::GPT4o),
187 "gpt-4" => Ok(Model::GPT4),
188 "gpt-4o-mini" => Ok(Model::GPT4oMini),
189 "gpt-4.1" => Ok(Model::GPT41),
190 model => bail!("Invalid model name: {}", model)
191 }
192 }
193}
194
195impl Display for Model {
196 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197 write!(f, "{}", <&str>::from(self))
198 }
199}
200
201impl From<&str> for Model {
203 fn from(s: &str) -> Self {
204 s.parse().unwrap_or_default()
205 }
206}
207
208impl From<String> for Model {
209 fn from(s: String) -> Self {
210 s.as_str().into()
211 }
212}
213
214fn get_tokenizer(_model_str: &str) -> CoreBPE {
215 tiktoken_rs::cl100k_base().expect("Failed to create tokenizer")
218}
219
220pub async fn run(settings: Settings, content: String) -> Result<String> {
221 let model_str = settings.model.as_deref().unwrap_or(DEFAULT_MODEL_NAME);
222
223 let client = async_openai::Client::new();
224 let prompt = content; let model: Model = settings
227 .model
228 .as_deref()
229 .unwrap_or(DEFAULT_MODEL_NAME)
230 .into();
231 let tokens = model.count_tokens(&prompt)?;
232
233 if tokens > model.context_size() {
234 bail!(
235 "Input too large: {} tokens. Max {} tokens for {}",
236 tokens.to_string().red(),
237 model.context_size().to_string().green(),
238 model_str.yellow()
239 );
240 }
241
242 let temperature_value = 0.7;
244
245 log::info!(
246 "Using model: {}, Tokens: {}, Max tokens: {}, Temperature: {}",
247 model_str.yellow(),
248 tokens.to_string().green(),
249 (model.context_size() - tokens).to_string().green(),
251 temperature_value.to_string().blue() );
253
254 let request = CreateChatCompletionRequestArgs::default()
255 .model(model_str)
256 .messages([ChatCompletionRequestUserMessageArgs::default()
257 .content(prompt)
258 .build()?
259 .into()])
260 .temperature(temperature_value) .max_tokens((model.context_size() - tokens) as u16)
263 .build()?;
264
265 profile!("OpenAI API call");
266 let response = client.chat().create(request).await?;
267 let result = response.choices[0]
268 .message
269 .content
270 .clone()
271 .unwrap_or_default();
272
273 if result.is_empty() {
274 bail!("No response from OpenAI");
275 }
276
277 Ok(result.trim().to_string())
278}