use anyhow::{anyhow, Result};
use crate::{
cl100k_base_singleton,
model::get_context_size,
o200k_base_singleton, o200k_harmony_singleton, p50k_base_singleton, p50k_edit_singleton,
r50k_base_singleton,
tokenizer::{get_tokenizer, Tokenizer},
CoreBPE,
};
pub fn get_text_completion_max_tokens(model: &str, prompt: &str) -> Result<usize> {
let context_size = get_context_size(model)
.ok_or_else(|| anyhow!("Unknown context size for model {}", model))?;
let tokenizer =
get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
let bpe = bpe_singleton(tokenizer);
let prompt_tokens = bpe.count_with_special_tokens(prompt);
Ok(context_size.saturating_sub(prompt_tokens))
}
#[deprecated(since = "0.10.0", note = "renamed to `get_text_completion_max_tokens`")]
pub fn get_completion_max_tokens(model: &str, prompt: &str) -> Result<usize> {
get_text_completion_max_tokens(model, prompt)
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct ChatCompletionRequestMessage {
pub role: String,
pub content: Option<String>,
pub name: Option<String>,
pub function_call: Option<FunctionCall>,
pub tool_calls: Vec<FunctionCall>,
pub refusal: Option<String>,
}
pub fn num_tokens_from_messages(
model: &str,
messages: &[ChatCompletionRequestMessage],
) -> Result<usize> {
let tokenizer =
get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
if tokenizer != Tokenizer::Cl100kBase
&& tokenizer != Tokenizer::O200kBase
&& tokenizer != Tokenizer::O200kHarmony
{
anyhow::bail!(
"Chat token counting is not supported for model {:?} (tokenizer {:?}). \
Supported tokenizers: Cl100kBase, O200kBase, O200kHarmony.",
model,
tokenizer
)
}
let bpe = bpe_singleton(tokenizer);
const FUNCTION_CALL_OVERHEAD: i32 = 1;
const REPLY_PRIMING: i32 = 3;
let (tokens_per_message, tokens_per_name) = if model == "gpt-3.5-turbo-0301" {
(4, -1)
} else {
(3, 1)
};
let mut num_tokens: i32 = 0;
for message in messages {
num_tokens += tokens_per_message;
num_tokens += bpe.count_with_special_tokens(&message.role) as i32;
if let Some(content) = &message.content {
num_tokens += bpe.count_with_special_tokens(content) as i32;
}
if let Some(name) = &message.name {
num_tokens += bpe.count_with_special_tokens(name) as i32;
num_tokens += tokens_per_name;
}
if let Some(function_call) = &message.function_call {
num_tokens += bpe.count_with_special_tokens(&function_call.name) as i32;
num_tokens += bpe.count_with_special_tokens(&function_call.arguments) as i32;
num_tokens += FUNCTION_CALL_OVERHEAD;
}
for tool_call in &message.tool_calls {
num_tokens += bpe.count_with_special_tokens(&tool_call.name) as i32;
num_tokens += bpe.count_with_special_tokens(&tool_call.arguments) as i32;
num_tokens += FUNCTION_CALL_OVERHEAD;
}
if let Some(refusal) = &message.refusal {
num_tokens += bpe.count_with_special_tokens(refusal) as i32;
}
}
num_tokens += REPLY_PRIMING;
Ok(num_tokens as usize)
}
pub fn get_chat_completion_max_tokens(
model: &str,
messages: &[ChatCompletionRequestMessage],
) -> Result<usize> {
let context_size = get_context_size(model)
.ok_or_else(|| anyhow!("Unknown context size for model {}", model))?;
let prompt_tokens = num_tokens_from_messages(model, messages)?;
Ok(context_size.saturating_sub(prompt_tokens))
}
fn bpe_singleton(tokenizer: Tokenizer) -> &'static CoreBPE {
match tokenizer {
Tokenizer::O200kHarmony => o200k_harmony_singleton(),
Tokenizer::O200kBase => o200k_base_singleton(),
Tokenizer::Cl100kBase => cl100k_base_singleton(),
Tokenizer::R50kBase => r50k_base_singleton(),
Tokenizer::P50kBase => p50k_base_singleton(),
Tokenizer::P50kEdit => p50k_edit_singleton(),
Tokenizer::Gpt2 => r50k_base_singleton(),
}
}
pub fn bpe_for_model(model: &str) -> Result<&'static CoreBPE> {
let tokenizer =
get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
bpe_for_tokenizer(tokenizer)
}
#[deprecated(since = "0.10.0", note = "renamed to `bpe_for_model`")]
pub fn get_bpe_from_model(model: &str) -> Result<&'static CoreBPE> {
bpe_for_model(model)
}
pub fn bpe_for_tokenizer(tokenizer: Tokenizer) -> Result<&'static CoreBPE> {
Ok(bpe_singleton(tokenizer))
}
#[deprecated(since = "0.10.0", note = "renamed to `bpe_for_tokenizer`")]
pub fn get_bpe_from_tokenizer(tokenizer: Tokenizer) -> Result<&'static CoreBPE> {
bpe_for_tokenizer(tokenizer)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bpe_for_tokenizer() {
let bpe = bpe_for_tokenizer(Tokenizer::Cl100kBase).unwrap();
assert_eq!(bpe.decode(&[15339]).unwrap(), "hello");
}
#[test]
fn test_num_tokens_from_messages() {
let messages = vec![
ChatCompletionRequestMessage {
role: "system".to_string(),
name: None,
content: Some("You are a helpful, pattern-following assistant that translates corporate jargon into plain English.".to_string()),
..Default::default()
},
ChatCompletionRequestMessage {
role: "system".to_string(),
name: Some("example_user".to_string()),
content: Some("New synergies will help drive top-line growth.".to_string()),
..Default::default()
},
ChatCompletionRequestMessage {
role: "system".to_string(),
name: Some("example_assistant".to_string()),
content: Some("Things working well together will increase revenue.".to_string()),
..Default::default()
},
ChatCompletionRequestMessage {
role: "system".to_string(),
name: Some("example_user".to_string()),
content: Some("Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage.".to_string()),
..Default::default()
},
ChatCompletionRequestMessage {
role: "system".to_string(),
name: Some("example_assistant".to_string()),
content: Some("Let's talk later when we're less busy about how to do better.".to_string()),
..Default::default()
},
ChatCompletionRequestMessage {
role: "user".to_string(),
name: None,
content: Some("This late pivot means we don't have time to boil the ocean for the client deliverable.".to_string()),
..Default::default()
},
];
let num_tokens = num_tokens_from_messages("gpt-3.5-turbo-0301", &messages).unwrap();
assert_eq!(num_tokens, 127);
let num_tokens = num_tokens_from_messages("gpt-4-0314", &messages).unwrap();
assert_eq!(num_tokens, 129);
let num_tokens = num_tokens_from_messages("gpt-4o-2024-05-13", &messages).unwrap();
assert_eq!(num_tokens, 124);
let num_tokens = num_tokens_from_messages("gpt-3.5-turbo-0125", &messages).unwrap();
assert_eq!(num_tokens, 129);
}
#[test]
fn test_num_tokens_from_messages_with_function_call() {
let messages = vec![
ChatCompletionRequestMessage {
role: "system".to_string(),
content: Some("You are a friendly chatbot.\n".to_string()),
name: None,
..Default::default()
},
ChatCompletionRequestMessage {
role: "assistant".to_string(),
content: Some("Hello, I am a friendly chatbot!\n".to_string()),
name: None,
..Default::default()
},
ChatCompletionRequestMessage {
role: "user".to_string(),
content: Some("What is the weather in New York?".to_string()),
name: None,
..Default::default()
},
ChatCompletionRequestMessage {
role: "assistant".to_string(),
content: Some(String::new()),
function_call: Some(FunctionCall {
name: "get_weather".to_string(),
arguments: "{\n \"city\": \"New York\"\n}".to_string(),
}),
..Default::default()
},
ChatCompletionRequestMessage {
role: "function".to_string(),
content: Some(
"{\"temperature\": 72, \"conditions\": \"partly_cloudy\"}".to_string(),
),
name: Some("get_weather".to_string()),
..Default::default()
},
];
let num_tokens = num_tokens_from_messages("gpt-4-0613", &messages).unwrap();
assert_eq!(num_tokens, 78);
}
#[test]
fn test_num_tokens_from_messages_with_tool_calls() {
let messages_with = vec![ChatCompletionRequestMessage {
role: "assistant".to_string(),
tool_calls: vec![FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"city": "Paris"}"#.to_string(),
}],
..Default::default()
}];
let messages_without = vec![ChatCompletionRequestMessage {
role: "assistant".to_string(),
..Default::default()
}];
let with = num_tokens_from_messages("gpt-4o", &messages_with).unwrap();
let without = num_tokens_from_messages("gpt-4o", &messages_without).unwrap();
assert!(
with > without,
"tool_calls should contribute tokens: {with} vs {without}"
);
}
#[test]
fn test_num_tokens_from_messages_with_multiple_tool_calls() {
let single = vec![ChatCompletionRequestMessage {
role: "assistant".to_string(),
tool_calls: vec![FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"city": "Paris"}"#.to_string(),
}],
..Default::default()
}];
let double = vec![ChatCompletionRequestMessage {
role: "assistant".to_string(),
tool_calls: vec![
FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"city": "Paris"}"#.to_string(),
},
FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"city": "London"}"#.to_string(),
},
],
..Default::default()
}];
let single_tokens = num_tokens_from_messages("gpt-4o", &single).unwrap();
let double_tokens = num_tokens_from_messages("gpt-4o", &double).unwrap();
assert!(
double_tokens > single_tokens,
"multiple tool_calls should each contribute tokens: {double_tokens} vs {single_tokens}"
);
}
#[test]
fn test_num_tokens_from_messages_with_refusal() {
let messages_with = vec![ChatCompletionRequestMessage {
role: "assistant".to_string(),
refusal: Some("I cannot help with that request.".to_string()),
..Default::default()
}];
let messages_without = vec![ChatCompletionRequestMessage {
role: "assistant".to_string(),
..Default::default()
}];
let with = num_tokens_from_messages("gpt-4o", &messages_with).unwrap();
let without = num_tokens_from_messages("gpt-4o", &messages_without).unwrap();
assert!(
with > without,
"refusal should contribute tokens: {with} vs {without}"
);
}
#[test]
fn test_num_tokens_from_messages_repeated_calls_consistent() {
let messages = vec![ChatCompletionRequestMessage {
role: "user".to_string(),
content: Some("Hello, world!".to_string()),
..Default::default()
}];
let first = num_tokens_from_messages("gpt-4o", &messages).unwrap();
for _ in 0..5 {
let result = num_tokens_from_messages("gpt-4o", &messages).unwrap();
assert_eq!(first, result);
}
}
#[test]
fn test_text_completion_max_tokens_repeated_calls_consistent() {
let first = get_text_completion_max_tokens("gpt-4o", "Hello, world!").unwrap();
for _ in 0..5 {
let result = get_text_completion_max_tokens("gpt-4o", "Hello, world!").unwrap();
assert_eq!(first, result);
}
}
#[test]
fn test_bpe_singleton_matches_fresh_bpe() {
let singleton = bpe_singleton(Tokenizer::Cl100kBase);
let fresh = bpe_for_tokenizer(Tokenizer::Cl100kBase).unwrap();
let text = "The quick brown fox jumps over the lazy dog";
assert_eq!(
singleton.encode_with_special_tokens(text),
fresh.encode_with_special_tokens(text),
);
}
#[test]
fn test_get_chat_completion_max_tokens() {
let model = "gpt-3.5-turbo";
let messages = vec![
ChatCompletionRequestMessage {
content: Some("You are a helpful assistant that only speaks French.".to_string()),
role: "system".to_string(),
name: None,
..Default::default()
},
ChatCompletionRequestMessage {
content: Some("Hello, how are you?".to_string()),
role: "user".to_string(),
name: None,
..Default::default()
},
ChatCompletionRequestMessage {
content: Some("Parlez-vous francais?".to_string()),
role: "system".to_string(),
name: None,
..Default::default()
},
];
let max_tokens = get_chat_completion_max_tokens(model, &messages).unwrap();
assert!(max_tokens > 0);
}
#[test]
fn test_text_completion_max_tokens() {
let model = "gpt-3.5-turbo";
let prompt = "Translate the following English text to French: '";
let max_tokens = get_text_completion_max_tokens(model, prompt).unwrap();
assert!(max_tokens > 0);
}
}
#[cfg(feature = "async-openai")]
pub mod async_openai {
use anyhow::Result;
use async_openai::types::chat::{
ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageContent,
ChatCompletionRequestAssistantMessageContentPart,
ChatCompletionRequestDeveloperMessageContent,
ChatCompletionRequestDeveloperMessageContentPart, ChatCompletionRequestMessage,
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestSystemMessageContentPart,
ChatCompletionRequestToolMessageContent, ChatCompletionRequestToolMessageContentPart,
ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
FunctionCall,
};
impl From<&FunctionCall> for super::FunctionCall {
fn from(f: &FunctionCall) -> Self {
Self {
name: f.name.clone(),
arguments: f.arguments.clone(),
}
}
}
fn join_texts(texts: Vec<String>) -> Option<String> {
if texts.is_empty() {
None
} else {
Some(texts.join(""))
}
}
fn system_content_text(content: &ChatCompletionRequestSystemMessageContent) -> Option<String> {
match content {
ChatCompletionRequestSystemMessageContent::Text(s) => Some(s.clone()),
ChatCompletionRequestSystemMessageContent::Array(parts) => join_texts(
parts
.iter()
.map(|ChatCompletionRequestSystemMessageContentPart::Text(t)| t.text.clone())
.collect(),
),
}
}
fn developer_content_text(
content: &ChatCompletionRequestDeveloperMessageContent,
) -> Option<String> {
match content {
ChatCompletionRequestDeveloperMessageContent::Text(s) => Some(s.clone()),
ChatCompletionRequestDeveloperMessageContent::Array(parts) => join_texts(
parts
.iter()
.map(|ChatCompletionRequestDeveloperMessageContentPart::Text(t)| t.text.clone())
.collect(),
),
}
}
fn user_content_text(content: &ChatCompletionRequestUserMessageContent) -> Option<String> {
match content {
ChatCompletionRequestUserMessageContent::Text(s) => Some(s.clone()),
ChatCompletionRequestUserMessageContent::Array(parts) => join_texts(
parts
.iter()
.filter_map(|p| match p {
ChatCompletionRequestUserMessageContentPart::Text(t) => {
Some(t.text.clone())
}
ChatCompletionRequestUserMessageContentPart::ImageUrl(_)
| ChatCompletionRequestUserMessageContentPart::InputAudio(_)
| ChatCompletionRequestUserMessageContentPart::File(_) => None,
})
.collect(),
),
}
}
fn assistant_content_text(
content: &ChatCompletionRequestAssistantMessageContent,
) -> (Option<String>, Option<String>) {
match content {
ChatCompletionRequestAssistantMessageContent::Text(s) => (Some(s.clone()), None),
ChatCompletionRequestAssistantMessageContent::Array(parts) => {
let mut texts = Vec::new();
let mut refusals = Vec::new();
for p in parts {
match p {
ChatCompletionRequestAssistantMessageContentPart::Text(t) => {
texts.push(t.text.clone());
}
ChatCompletionRequestAssistantMessageContentPart::Refusal(r) => {
refusals.push(r.refusal.clone());
}
}
}
(join_texts(texts), join_texts(refusals))
}
}
}
fn tool_content_text(content: &ChatCompletionRequestToolMessageContent) -> Option<String> {
match content {
ChatCompletionRequestToolMessageContent::Text(s) => Some(s.clone()),
ChatCompletionRequestToolMessageContent::Array(parts) => join_texts(
parts
.iter()
.map(|ChatCompletionRequestToolMessageContentPart::Text(t)| t.text.clone())
.collect(),
),
}
}
fn extract_tool_calls(
tool_calls: &Option<Vec<ChatCompletionMessageToolCalls>>,
) -> Vec<super::FunctionCall> {
tool_calls
.as_ref()
.map(|calls| {
calls
.iter()
.map(|tc| match tc {
ChatCompletionMessageToolCalls::Function(f) => (&f.function).into(),
ChatCompletionMessageToolCalls::Custom(c) => super::FunctionCall {
name: c.custom_tool.name.clone(),
arguments: c.custom_tool.input.clone(),
},
})
.collect()
})
.unwrap_or_default()
}
#[allow(deprecated)]
impl From<&ChatCompletionRequestMessage> for super::ChatCompletionRequestMessage {
fn from(m: &ChatCompletionRequestMessage) -> Self {
match m {
ChatCompletionRequestMessage::System(msg) => Self {
role: "system".to_string(),
name: msg.name.clone(),
content: Some(system_content_text(&msg.content).unwrap_or_default()),
..Default::default()
},
ChatCompletionRequestMessage::Developer(msg) => Self {
role: "developer".to_string(),
name: msg.name.clone(),
content: Some(developer_content_text(&msg.content).unwrap_or_default()),
..Default::default()
},
ChatCompletionRequestMessage::User(msg) => Self {
role: "user".to_string(),
name: msg.name.clone(),
content: Some(user_content_text(&msg.content).unwrap_or_default()),
..Default::default()
},
ChatCompletionRequestMessage::Assistant(msg) => {
let (content, refusal) = msg
.content
.as_ref()
.map(assistant_content_text)
.unwrap_or_default();
let refusal = refusal.or_else(|| msg.refusal.clone());
Self {
role: "assistant".to_string(),
name: msg.name.clone(),
content,
function_call: msg.function_call.as_ref().map(|f| f.into()),
tool_calls: extract_tool_calls(&msg.tool_calls),
refusal,
}
}
ChatCompletionRequestMessage::Tool(msg) => Self {
role: "tool".to_string(),
name: Some(msg.tool_call_id.clone()),
content: Some(tool_content_text(&msg.content).unwrap_or_default()),
..Default::default()
},
ChatCompletionRequestMessage::Function(msg) => Self {
role: "function".to_string(),
name: Some(msg.name.clone()),
content: msg.content.clone(),
..Default::default()
},
}
}
}
pub fn num_tokens_from_messages(
model: &str,
messages: &[ChatCompletionRequestMessage],
) -> Result<usize> {
let messages: Vec<super::ChatCompletionRequestMessage> =
messages.iter().map(|m| m.into()).collect();
super::num_tokens_from_messages(model, &messages)
}
pub fn get_chat_completion_max_tokens(
model: &str,
messages: &[ChatCompletionRequestMessage],
) -> Result<usize> {
let messages: Vec<super::ChatCompletionRequestMessage> =
messages.iter().map(|m| m.into()).collect();
super::get_chat_completion_max_tokens(model, &messages)
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use super::*;
use async_openai::types::chat::{
ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessage,
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
};
#[test]
fn test_num_tokens_from_messages_system() {
let model = "gpt-4o";
let messages = &[ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
"You are a helpful assistant.".to_string(),
),
name: None,
},
)];
let num_tokens = num_tokens_from_messages(model, messages).unwrap();
assert!(num_tokens > 0);
}
#[test]
fn test_num_tokens_from_messages_user() {
let model = "gpt-4o";
let messages = &[ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(
"Hello, how are you?".to_string(),
),
name: None,
},
)];
let num_tokens = num_tokens_from_messages(model, messages).unwrap();
assert!(num_tokens > 0);
}
#[test]
fn test_num_tokens_with_tool_calls() {
let model = "gpt-4o";
let messages = &[ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: None,
refusal: None,
name: None,
audio: None,
tool_calls: Some(vec![ChatCompletionMessageToolCalls::Function(
ChatCompletionMessageToolCall {
id: "call_123".to_string(),
function: FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"location": "Paris"}"#.to_string(),
},
},
)]),
function_call: None,
},
)];
let tokens_with = num_tokens_from_messages(model, messages).unwrap();
let empty = &[ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: None,
refusal: None,
name: None,
audio: None,
tool_calls: None,
function_call: None,
},
)];
let tokens_without = num_tokens_from_messages(model, empty).unwrap();
assert!(
tokens_with > tokens_without,
"tool_calls should contribute tokens: {tokens_with} vs {tokens_without}"
);
}
#[test]
fn test_num_tokens_with_refusal() {
let model = "gpt-4o";
let messages = &[ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: None,
refusal: Some("I cannot help with that request.".to_string()),
name: None,
audio: None,
tool_calls: None,
function_call: None,
},
)];
let tokens_with = num_tokens_from_messages(model, messages).unwrap();
let empty = &[ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: None,
refusal: None,
name: None,
audio: None,
tool_calls: None,
function_call: None,
},
)];
let tokens_without = num_tokens_from_messages(model, empty).unwrap();
assert!(
tokens_with > tokens_without,
"refusal should contribute tokens: {tokens_with} vs {tokens_without}"
);
}
#[test]
fn test_get_chat_completion_max_tokens() {
let model = "gpt-4o";
let messages = &[ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
"You are a helpful assistant.".to_string(),
),
name: None,
},
)];
let max_tokens = get_chat_completion_max_tokens(model, messages).unwrap();
assert!(max_tokens > 0);
}
}
}