use anyhow::{anyhow, Result};
use crate::{
cl100k_base,
model::get_context_size,
p50k_base, p50k_edit, r50k_base,
tokenizer::{get_tokenizer, Tokenizer},
CoreBPE,
};
pub fn get_completion_max_tokens(model: &str, prompt: &str) -> Result<usize> {
let context_size = get_context_size(model);
let bpe = get_bpe_from_model(model)?;
let prompt_tokens = bpe.encode_with_special_tokens(prompt).len();
Ok(context_size.saturating_sub(prompt_tokens))
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct ChatCompletionRequestMessage {
pub role: String,
pub content: Option<String>,
pub name: Option<String>,
pub function_call: Option<FunctionCall>,
}
pub fn num_tokens_from_messages(
model: &str,
messages: &[ChatCompletionRequestMessage],
) -> Result<usize> {
let tokenizer =
get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
if tokenizer != Tokenizer::Cl100kBase {
anyhow::bail!("Chat completion is only supported chat models")
}
let bpe = get_bpe_from_tokenizer(tokenizer)?;
let (tokens_per_message, tokens_per_name) = if model.starts_with("gpt-3.5") {
(
4, -1, )
} else {
(3, 1)
};
let mut num_tokens: i32 = 0;
for message in messages {
num_tokens += tokens_per_message;
num_tokens += bpe
.encode_with_special_tokens(&message.role.to_string())
.len() as i32;
num_tokens += bpe
.encode_with_special_tokens(&message.content.clone().unwrap_or_default())
.len() as i32;
if let Some(name) = &message.name {
num_tokens += bpe.encode_with_special_tokens(name).len() as i32;
num_tokens += tokens_per_name;
}
}
num_tokens += 3; Ok(num_tokens as usize)
}
pub fn get_chat_completion_max_tokens(
model: &str,
messages: &[ChatCompletionRequestMessage],
) -> Result<usize> {
let context_size = get_context_size(model);
let prompt_tokens = num_tokens_from_messages(model, messages)?;
Ok(context_size.saturating_sub(prompt_tokens))
}
pub fn get_bpe_from_model(model: &str) -> Result<CoreBPE> {
let tokenizer =
get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
let bpe = get_bpe_from_tokenizer(tokenizer)?;
Ok(bpe)
}
pub fn get_bpe_from_tokenizer(tokenizer: Tokenizer) -> Result<CoreBPE> {
match tokenizer {
Tokenizer::Cl100kBase => cl100k_base(),
Tokenizer::R50kBase => r50k_base(),
Tokenizer::P50kBase => p50k_base(),
Tokenizer::P50kEdit => p50k_edit(),
Tokenizer::Gpt2 => r50k_base(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_bpe_from_tokenizer() {
let bpe = get_bpe_from_tokenizer(Tokenizer::Cl100kBase).unwrap();
assert_eq!(bpe.decode(vec!(15339)).unwrap(), "hello");
}
#[test]
fn test_num_tokens_from_messages() {
let messages = vec![
ChatCompletionRequestMessage {
role: "system".to_string(),
name: None,
content: Some("You are a helpful, pattern-following assistant that translates corporate jargon into plain English.".to_string()),
function_call: None,
},
ChatCompletionRequestMessage {
role: "system".to_string(),
name: Some("example_user".to_string()),
content: Some("New synergies will help drive top-line growth.".to_string()),
function_call: None,
},
ChatCompletionRequestMessage {
role: "system".to_string(),
name: Some("example_assistant".to_string()),
content: Some("Things working well together will increase revenue.".to_string()),
function_call: None,
},
ChatCompletionRequestMessage {
role: "system".to_string(),
name: Some("example_user".to_string()),
content: Some("Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage.".to_string()),
function_call: None,
},
ChatCompletionRequestMessage {
role: "system".to_string(),
name: Some("example_assistant".to_string()),
content: Some("Let's talk later when we're less busy about how to do better.".to_string()),
function_call: None,
},
ChatCompletionRequestMessage {
role: "user".to_string(),
name: None,
content: Some("This late pivot means we don't have time to boil the ocean for the client deliverable.".to_string()),
function_call: None,
},
];
let num_tokens = num_tokens_from_messages("gpt-3.5-turbo-0301", &messages).unwrap();
assert_eq!(num_tokens, 127);
let num_tokens = num_tokens_from_messages("gpt-4-0314", &messages).unwrap();
assert_eq!(num_tokens, 129);
}
#[test]
fn test_get_chat_completion_max_tokens() {
let model = "gpt-3.5-turbo";
let messages = vec![
ChatCompletionRequestMessage {
content: Some("You are a helpful assistant that only speaks French.".to_string()),
role: "system".to_string(),
name: None,
function_call: None,
},
ChatCompletionRequestMessage {
content: Some("Hello, how are you?".to_string()),
role: "user".to_string(),
name: None,
function_call: None,
},
ChatCompletionRequestMessage {
content: Some("Parlez-vous francais?".to_string()),
role: "system".to_string(),
name: None,
function_call: None,
},
];
let max_tokens = get_chat_completion_max_tokens(model, &messages).unwrap();
assert!(max_tokens > 0);
}
#[test]
fn test_get_completion_max_tokens() {
let model = "gpt-3.5-turbo";
let prompt = "Translate the following English text to French: '";
let max_tokens = get_completion_max_tokens(model, prompt).unwrap();
assert!(max_tokens > 0);
}
}
#[cfg(feature = "async-openai")]
pub mod async_openai {
use anyhow::Result;
impl From<&async_openai::types::FunctionCall> for super::FunctionCall {
fn from(f: &async_openai::types::FunctionCall) -> Self {
Self {
name: f.name.clone(),
arguments: f.arguments.clone(),
}
}
}
impl From<&async_openai::types::ChatCompletionRequestMessage>
for super::ChatCompletionRequestMessage
{
fn from(m: &async_openai::types::ChatCompletionRequestMessage) -> Self {
Self {
role: m.role.to_string(),
name: m.name.clone(),
content: m.content.clone(),
function_call: m.function_call.as_ref().map(|f| f.into()),
}
}
}
pub fn num_tokens_from_messages(
model: &str,
messages: &[async_openai::types::ChatCompletionRequestMessage],
) -> Result<usize> {
let messages = messages.iter().map(|m| m.into()).collect::<Vec<_>>();
super::num_tokens_from_messages(model, &messages)
}
pub fn get_chat_completion_max_tokens(
model: &str,
messages: &[async_openai::types::ChatCompletionRequestMessage],
) -> Result<usize> {
let messages = messages.iter().map(|m| m.into()).collect::<Vec<_>>();
super::get_chat_completion_max_tokens(model, &messages)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_num_tokens_from_messages() {
let model = "gpt-3.5-turbo-0301";
let messages = &[async_openai::types::ChatCompletionRequestMessage {
role: async_openai::types::Role::System,
name: None,
content: Some("You are a helpful, pattern-following assistant that translates corporate jargon into plain English.".to_string()),
function_call: None,
}];
let num_tokens = num_tokens_from_messages(model, messages).unwrap();
assert!(num_tokens > 0);
}
#[test]
fn test_get_chat_completion_max_tokens() {
let model = "gpt-3.5-turbo";
let messages = &[async_openai::types::ChatCompletionRequestMessage {
content: Some("You are a helpful assistant that only speaks French.".to_string()),
role: async_openai::types::Role::System,
name: None,
function_call: None,
}];
let max_tokens = get_chat_completion_max_tokens(model, messages).unwrap();
assert!(max_tokens > 0);
}
}
}