use std::default::Default;
use std::fmt::{self, Display};
use std::str::FromStr;
use std::sync::OnceLock;
use anyhow::{bail, Result};
use serde::{Deserialize, Serialize};
use tiktoken_rs::CoreBPE;
use tiktoken_rs::model::get_context_size;
use async_openai::types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
use colored::Colorize;
use crate::profile;
use crate::config::AppConfig;
static TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
const MODEL_GPT4_1: &str = "gpt-4.1";
const MODEL_GPT4_1_MINI: &str = "gpt-4.1-mini";
const MODEL_GPT4_1_NANO: &str = "gpt-4.1-nano";
const MODEL_GPT4_5: &str = "gpt-4.5";
const DEFAULT_MODEL_NAME: &str = "gpt-4.1";
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, Serialize, Deserialize, Default)]
pub enum Model {
#[default]
GPT41,
GPT41Mini,
GPT41Nano,
GPT45
}
impl Model {
pub fn count_tokens(&self, text: &str) -> Result<usize> {
profile!("Count tokens");
if text.is_empty() {
return Ok(0);
}
let tokenizer = TOKENIZER.get_or_init(|| get_tokenizer(self.as_ref()));
let tokens = tokenizer.encode_ordinary(text);
Ok(tokens.len())
}
pub fn context_size(&self) -> usize {
profile!("Get context size");
get_context_size(self.as_ref())
}
pub(crate) fn truncate(&self, text: &str, max_tokens: usize) -> Result<String> {
profile!("Truncate text");
self.walk_truncate(text, max_tokens, usize::MAX)
}
pub(crate) fn walk_truncate(&self, text: &str, max_tokens: usize, within: usize) -> Result<String> {
profile!("Walk truncate iteration");
log::debug!("max_tokens: {max_tokens}, within: {within}");
let current_tokens = self.count_tokens(text)?;
if current_tokens <= max_tokens {
return Ok(text.to_string());
}
let words: Vec<&str> = text.split_whitespace().collect();
let mut left = 0;
let mut right = words.len();
let mut best_fit = String::new();
let mut best_tokens = 0;
while left < right {
let mid = (left + right).div_ceil(2);
let candidate = words[..mid].join(" ");
let tokens = self.count_tokens(&candidate)?;
if tokens <= max_tokens {
best_fit = candidate;
best_tokens = tokens;
left = mid;
} else {
right = mid - 1;
}
if best_tokens > 0 && max_tokens.saturating_sub(best_tokens) <= within {
break;
}
}
if best_fit.is_empty() && !words.is_empty() {
best_fit = words[0].to_string();
let tokens = self.count_tokens(&best_fit)?;
if tokens > max_tokens {
let char_limit = max_tokens * 3;
best_fit = text.chars().take(char_limit).collect();
while self.count_tokens(&best_fit)? > max_tokens && !best_fit.is_empty() {
let new_len = (best_fit.len() * 9) / 10;
best_fit = best_fit.chars().take(new_len).collect();
}
}
}
Ok(best_fit)
}
}
impl AsRef<str> for Model {
fn as_ref(&self) -> &str {
match self {
Model::GPT41 => MODEL_GPT4_1,
Model::GPT41Mini => MODEL_GPT4_1_MINI,
Model::GPT41Nano => MODEL_GPT4_1_NANO,
Model::GPT45 => MODEL_GPT4_5
}
}
}
impl From<&Model> for String {
fn from(model: &Model) -> Self {
model.as_ref().to_string()
}
}
impl Model {
pub fn as_str(&self) -> &str {
self.as_ref()
}
}
impl FromStr for Model {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
let normalized = s.trim().to_lowercase();
match normalized.as_str() {
"gpt-4.1" => Ok(Model::GPT41),
"gpt-4.1-mini" => Ok(Model::GPT41Mini),
"gpt-4.1-nano" => Ok(Model::GPT41Nano),
"gpt-4.5" => Ok(Model::GPT45),
"gpt-4" | "gpt-4o" => {
log::warn!(
"Model '{}' is deprecated. Mapping to 'gpt-4.1'. \
Please update your configuration with: git ai config set model gpt-4.1",
s
);
Ok(Model::GPT41)
}
"gpt-4o-mini" | "gpt-3.5-turbo" => {
log::warn!(
"Model '{}' is deprecated. Mapping to 'gpt-4.1-mini'. \
Please update your configuration with: git ai config set model gpt-4.1-mini",
s
);
Ok(Model::GPT41Mini)
}
model =>
bail!(
"Invalid model name: '{}'. Supported models: gpt-4.1, gpt-4.1-mini, gpt-4.1-nano, gpt-4.5",
model
),
}
}
}
impl Display for Model {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_ref())
}
}
impl From<&str> for Model {
fn from(s: &str) -> Self {
s.parse().unwrap_or_else(|e| {
log::error!("Failed to parse model '{}': {}. Falling back to default model 'gpt-4.1'.", s, e);
Model::default()
})
}
}
impl From<String> for Model {
fn from(s: String) -> Self {
s.as_str().into()
}
}
fn get_tokenizer(_model_str: &str) -> CoreBPE {
tiktoken_rs::cl100k_base().expect("Failed to create tokenizer")
}
pub async fn run(settings: AppConfig, content: String) -> Result<String> {
let model_str = settings.model.as_deref().unwrap_or(DEFAULT_MODEL_NAME);
let client = async_openai::Client::new();
let prompt = content; let model: Model = settings
.model
.as_deref()
.unwrap_or(DEFAULT_MODEL_NAME)
.into();
let tokens = model.count_tokens(&prompt)?;
if tokens > model.context_size() {
bail!(
"Input too large: {} tokens. Max {} tokens for {}",
tokens.to_string().red(),
model.context_size().to_string().green(),
model_str.yellow()
);
}
let temperature_value = 0.7;
log::info!(
"Using model: {}, Tokens: {}, Max tokens: {}, Temperature: {}",
model_str.yellow(),
tokens.to_string().green(),
(model.context_size() - tokens).to_string().green(),
temperature_value.to_string().blue() );
let request = CreateChatCompletionRequestArgs::default()
.model(model_str)
.messages([ChatCompletionRequestUserMessageArgs::default()
.content(prompt)
.build()?
.into()])
.temperature(temperature_value) .max_tokens((model.context_size() - tokens) as u16)
.build()?;
profile!("OpenAI API call");
let response = client.chat().create(request).await?;
let result = response.choices[0]
.message
.content
.clone()
.unwrap_or_default();
if result.is_empty() {
bail!("No response from OpenAI");
}
Ok(result.trim().to_string())
}