use std::cmp::Ordering;
use indxvec::Search;
use tiktoken_rs::get_chat_completion_max_tokens;
use tiktoken_rs::model::get_context_size;
#[derive(Clone, Debug)]
pub struct ChatSplitter {
model: String,
max_tokens: u16,
max_messages: usize,
}
const MAX_MESSAGES_LIMIT: usize = 2_048;
const RECOMMENDED_MIN_MAX_TOKENS: u16 = 256;
impl Default for ChatSplitter {
#[inline]
fn default() -> Self {
Self::new("gpt-3.5-turbo")
}
}
impl ChatSplitter {
#[inline]
pub fn new(model: impl Into<String>) -> Self {
let model = model.into();
let max_tokens = u16::try_from(get_context_size(&model) / 2).unwrap();
let max_messages = MAX_MESSAGES_LIMIT / 2;
Self {
model,
max_tokens,
max_messages,
}
}
#[inline]
#[must_use]
pub fn max_messages(mut self, max_messages: impl Into<usize>) -> Self {
self.max_messages = max_messages.into();
if self.max_messages > MAX_MESSAGES_LIMIT {
log::warn!(
"max_messages = {} > {MAX_MESSAGES_LIMIT}",
self.max_messages
);
}
self
}
#[inline]
#[must_use]
pub fn max_tokens(mut self, max_tokens: impl Into<u16>) -> Self {
self.max_tokens = max_tokens.into();
if self.max_tokens < RECOMMENDED_MIN_MAX_TOKENS {
log::warn!(
"max_tokens = {} < {RECOMMENDED_MIN_MAX_TOKENS}",
self.max_tokens
);
}
self
}
#[inline]
#[must_use]
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
#[inline]
fn position_by_max_messages<M>(&self, messages: &[M]) -> usize {
let upper_limit = self.max_messages.min(MAX_MESSAGES_LIMIT);
let n = messages.len();
let n = if n <= upper_limit { 0 } else { n - upper_limit };
debug_assert!(messages[n..].len() <= upper_limit);
n
}
#[inline]
fn position_by_max_tokens<M>(&self, messages: &[M]) -> usize
where
M: IntoChatCompletionRequestMessage + Clone,
{
let max_tokens = self.max_tokens as usize;
let lower_limit = max_tokens.min(get_context_size(&self.model));
let messages: Vec<_> = messages
.iter()
.cloned()
.map(IntoChatCompletionRequestMessage::into_tiktoken_rs)
.collect();
let (n, _range) = (0..=messages.len()).binary_any(|n| {
debug_assert!(n < messages.len());
let tokens = get_chat_completion_max_tokens(&self.model, &messages[n..])
.expect("tokenizer should be available");
let cmp = tokens.cmp(&lower_limit);
debug_assert_ne!(cmp, Ordering::Equal);
cmp
});
debug_assert!(
get_chat_completion_max_tokens(&self.model, &messages[n..])
.expect("tokenizer should be available")
>= lower_limit
);
n
}
#[inline]
fn position<M>(&self, messages: &[M]) -> usize
where
M: IntoChatCompletionRequestMessage + Clone,
{
let n = self.position_by_max_messages(messages);
n + self.position_by_max_tokens(&messages[n..])
}
#[inline]
pub fn split<'a, M>(&self, messages: &'a [M]) -> (&'a [M], &'a [M])
where
M: IntoChatCompletionRequestMessage + Clone,
{
messages.split_at(self.position(messages))
}
}
pub trait IntoChatCompletionRequestMessage {
fn into_tiktoken_rs(self) -> tiktoken_rs::ChatCompletionRequestMessage;
fn into_async_openai(self) -> async_openai::types::ChatCompletionRequestMessage;
}
impl IntoChatCompletionRequestMessage for tiktoken_rs::ChatCompletionRequestMessage {
#[inline]
fn into_tiktoken_rs(self) -> tiktoken_rs::ChatCompletionRequestMessage {
self
}
#[inline]
fn into_async_openai(self) -> async_openai::types::ChatCompletionRequestMessage {
async_openai::types::ChatCompletionRequestMessage {
role: match self.role.as_ref() {
"user" => async_openai::types::Role::User,
"system" => async_openai::types::Role::System,
"assistant" => async_openai::types::Role::Assistant,
"function" => async_openai::types::Role::Function,
role => panic!("unknown role '{role}'"),
},
content: self.content,
function_call: self.function_call.map(|fc| {
async_openai::types::FunctionCall {
name: fc.name,
arguments: fc.arguments,
}
}),
name: self.name,
}
}
}
impl IntoChatCompletionRequestMessage for async_openai::types::ChatCompletionRequestMessage {
#[inline]
fn into_tiktoken_rs(self) -> tiktoken_rs::ChatCompletionRequestMessage {
tiktoken_rs::ChatCompletionRequestMessage {
role: self.role.to_string(),
content: self.content,
function_call: self.function_call.map(|fc| {
tiktoken_rs::FunctionCall {
name: fc.name,
arguments: fc.arguments,
}
}),
name: self.name,
}
}
#[inline]
fn into_async_openai(self) -> async_openai::types::ChatCompletionRequestMessage {
self
}
}
impl IntoChatCompletionRequestMessage for async_openai::types::ChatCompletionResponseMessage {
#[inline]
fn into_tiktoken_rs(self) -> tiktoken_rs::ChatCompletionRequestMessage {
tiktoken_rs::ChatCompletionRequestMessage {
role: self.role.to_string(),
content: self.content,
function_call: self.function_call.map(|fc| {
tiktoken_rs::FunctionCall {
name: fc.name,
arguments: fc.arguments,
}
}),
name: None,
}
}
#[inline]
fn into_async_openai(self) -> async_openai::types::ChatCompletionRequestMessage {
async_openai::types::ChatCompletionRequestMessage {
role: self.role,
content: self.content,
function_call: self.function_call,
name: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_works() {
let messages: Vec<async_openai::types::ChatCompletionRequestMessage> = Vec::new();
assert_eq!(ChatSplitter::default().split(&messages).0, &[]);
assert_eq!(ChatSplitter::default().split(&messages).1, &[]);
}
}