chat_splitter/
lib.rs

1//! [![Build Status]][actions] [![Latest Version]][crates.io]
2//!
3//! [Build Status]: https://github.com/schneiderfelipe/chat-splitter/actions/workflows/rust.yml/badge.svg
4//! [actions]: https://github.com/schneiderfelipe/chat-splitter/actions/workflows/rust.yml
5//! [Latest Version]: https://img.shields.io/crates/v/chat_splitter.svg
6//! [crates.io]: https://crates.io/crates/chat_splitter
7//!
8//! > For more information,
9//! > please refer to the [blog announcement](https://schneiderfelipe.github.io/posts/chat-splitter-first-release/).
10//!
11//! When utilizing the [`async_openai`](https://github.com/64bit/async-openai) [Rust](https://www.rust-lang.org/) crate,
12//! it is crucial to ensure that you do not exceed
13//! the [maximum number of tokens](https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them) specified by [OpenAI](https://openai.com/)'s [chat models](https://platform.openai.com/docs/api-reference/chat).
14//!
15//! [`chat-splitter`](https://crates.io/crates/chat_splitter) categorizes chat messages into 'outdated' and 'recent' messages,
16//! allowing you to split them based on both the maximum
17//! message count and the maximum chat completion token count.
18//! The token counting functionality is provided by
19//! [`tiktoken_rs`](https://github.com/zurawiki/tiktoken-rs).
20//!
21//! # Usage
22//!
23//! Here's a basic example:
24//!
25//! ```ignore
26//! // Get all your previously stored chat messages...
27//! let mut stored_messages = /* get_stored_messages()? */;
28//!
29//! // ...and split into 'outdated' and 'recent',
30//! // where 'recent' always fits the context size.
31//! let (outdated_messages, recent_messages) =
32//!     ChatSplitter::default().split(&stored_messages);
33//! ```
34//!
35//! For a more detailed example,
36//! see [`examples/chat.rs`](https://github.com/schneiderfelipe/chat-splitter/blob/main/examples/chat.rs).
37//!
38//! # Contributing
39//!
40//! Contributions to `chat-splitter` are welcome!
41//! If you find a bug or have a feature request,
42//! please [submit an issue](https://github.com/schneiderfelipe/chat-splitter/issues).
43//! If you'd like to contribute code,
44//! please feel free to [submit a pull request](https://github.com/schneiderfelipe/chat-splitter/pulls).
45
46use std::cmp::Ordering;
47
48use indxvec::Search;
49use tiktoken_rs::get_chat_completion_max_tokens;
50use tiktoken_rs::model::get_context_size;
51
52/// Chat splitter for [OpenAI](https://openai.com/)'s [chat models](https://platform.openai.com/docs/api-reference/chat) when using [`async_openai`].
53///
54/// For more detailed information,
55/// see the [crate documentation](`crate`).
56#[derive(Clone, Debug)]
57pub struct ChatSplitter {
58    /// The model to use for tokenization,
59    /// e.g.,
60    /// `gpt-3.5-turbo`.
61    ///
62    /// It is passed to [`tiktoken_rs`] to select the correct tokenizer.
63    model: String,
64
65    /// The maximum number of tokens to leave for chat completion.
66    ///
67    /// This is the same as in the [official API](https://platform.openai.com/docs/api-reference/chat#completions/create-prompt) and given to [`async_openai`].
68    /// The total length of input tokens and generated tokens is limited by the
69    /// model's context size.
70    /// Splits will have at least that many tokens
71    /// available for chat completion,
72    /// never less.
73    max_tokens: u16,
74
75    /// The maximum number of messages to have in the chat.
76    ///
77    /// Splits will have at most that many messages,
78    /// never more.
79    max_messages: usize,
80}
81
82/// Hard limit that seems to be imposed by the `OpenAI` API.
83const MAX_MESSAGES_LIMIT: usize = 2_048;
84
85/// Recommended minimum for maximum chat completion tokens.
86const RECOMMENDED_MIN_MAX_TOKENS: u16 = 256;
87
88impl Default for ChatSplitter {
89    #[inline]
90    fn default() -> Self {
91        Self::new("gpt-3.5-turbo")
92    }
93}
94
95impl ChatSplitter {
96    /// Create a new [`ChatSplitter`] for the given model.
97    ///
98    /// # Panics
99    ///
100    /// If for some reason [`tiktoken_rs`] gives a context size twice as large
101    /// as what would fit in a [`u16`].
102    /// If this happens,
103    /// it should be considered a bug,
104    /// but this behaviour might change in the future,
105    /// as models with larger context sizes are released.
106    #[inline]
107    pub fn new(model: impl Into<String>) -> Self {
108        let model = model.into();
109        let max_tokens = u16::try_from(get_context_size(&model) / 2).unwrap();
110
111        let max_messages = MAX_MESSAGES_LIMIT / 2;
112
113        Self {
114            model,
115            max_tokens,
116            max_messages,
117        }
118    }
119
120    /// Set the maximum number of messages to have in the chat.
121    ///
122    /// Splits will have at most that many messages,
123    /// never more.
124    #[inline]
125    #[must_use]
126    pub fn max_messages(mut self, max_messages: impl Into<usize>) -> Self {
127        self.max_messages = max_messages.into();
128        if self.max_messages > MAX_MESSAGES_LIMIT {
129            log::warn!(
130                "max_messages = {} > {MAX_MESSAGES_LIMIT}",
131                self.max_messages
132            );
133        }
134        self
135    }
136
137    /// Set the maximum number of tokens to leave for chat completion.
138    ///
139    /// This is the same as in the [official API](https://platform.openai.com/docs/api-reference/chat#completions/create-prompt) and given to [`async_openai`].
140    /// The total length of input tokens and generated tokens is limited by the
141    /// model's context size.
142    /// Splits will have at least that many tokens
143    /// available for chat completion,
144    /// never less.
145    #[inline]
146    #[must_use]
147    pub fn max_tokens(mut self, max_tokens: impl Into<u16>) -> Self {
148        self.max_tokens = max_tokens.into();
149        if self.max_tokens < RECOMMENDED_MIN_MAX_TOKENS {
150            log::warn!(
151                "max_tokens = {} < {RECOMMENDED_MIN_MAX_TOKENS}",
152                self.max_tokens
153            );
154        }
155        self
156    }
157
158    /// Set the model to use for tokenization,
159    /// e.g.,
160    /// `gpt-3.5-turbo`.
161    ///
162    /// It is passed to [`tiktoken_rs`] to select the correct tokenizer.
163    #[inline]
164    #[must_use]
165    pub fn model(mut self, model: impl Into<String>) -> Self {
166        self.model = model.into();
167        self
168    }
169
170    /// Get a split position by only considering `max_messages`.
171    #[inline]
172    fn position_by_max_messages<M>(&self, messages: &[M]) -> usize {
173        let upper_limit = self.max_messages.min(MAX_MESSAGES_LIMIT);
174
175        let n = messages.len();
176        let n = if n <= upper_limit { 0 } else { n - upper_limit };
177        debug_assert!(messages[n..].len() <= upper_limit);
178        n
179    }
180
181    /// Get a split position by only considering `max_tokens`.
182    ///
183    /// # Panics
184    ///
185    /// If tokenizer for the specified model is not found or is not a supported
186    /// chat model.
187    #[inline]
188    fn position_by_max_tokens<M>(&self, messages: &[M]) -> usize
189    where
190        M: IntoChatCompletionRequestMessage + Clone,
191    {
192        let max_tokens = self.max_tokens as usize;
193        let lower_limit = max_tokens.min(get_context_size(&self.model));
194
195        let messages: Vec<_> = messages
196            .iter()
197            .cloned()
198            .map(IntoChatCompletionRequestMessage::into_tiktoken_rs)
199            .collect();
200
201        let (n, _range) = (0..=messages.len()).binary_any(|n| {
202            debug_assert!(n < messages.len());
203
204            let tokens = get_chat_completion_max_tokens(&self.model, &messages[n..])
205                .expect("tokenizer should be available");
206
207            let cmp = tokens.cmp(&lower_limit);
208            debug_assert_ne!(cmp, Ordering::Equal);
209            cmp
210        });
211
212        debug_assert!(
213            get_chat_completion_max_tokens(&self.model, &messages[n..])
214                .expect("tokenizer should be available")
215                >= lower_limit
216        );
217        n
218    }
219
220    /// Get a split position by first considering the `max_messages` limit,
221    /// then
222    /// the `max_tokens` limit.
223    ///
224    /// # Panics
225    ///
226    /// If tokenizer for the specified model is not found or is not a supported
227    /// chat model.
228    #[inline]
229    fn position<M>(&self, messages: &[M]) -> usize
230    where
231        M: IntoChatCompletionRequestMessage + Clone,
232    {
233        let n = self.position_by_max_messages(messages);
234        n + self.position_by_max_tokens(&messages[n..])
235    }
236
237    /// Split the chat into two groups of messages,
238    /// the 'outdated' and the
239    /// 'recent' ones.
240    ///
241    /// The 'recent' messages are guaranteed to satisfy the given limits,
242    /// while
243    /// the 'outdated' ones contain all the ones before 'recent'.
244    ///
245    /// For a detailed usage example,
246    /// see [`examples/chat.rs`](https://github.com/schneiderfelipe/chat-splitter/blob/main/examples/chat.rs).
247    ///
248    /// # Panics
249    ///
250    /// If tokenizer for the specified model is not found or is not a supported
251    /// chat model.
252    #[inline]
253    pub fn split<'a, M>(&self, messages: &'a [M]) -> (&'a [M], &'a [M])
254    where
255        M: IntoChatCompletionRequestMessage + Clone,
256    {
257        messages.split_at(self.position(messages))
258    }
259}
260
261/// Extension trait for converting between different chat completion request
262/// message types.
263///
264/// For a usage example,
265/// see [`examples/chat.rs`](https://github.com/schneiderfelipe/chat-splitter/blob/736f4fceb57bc12adb2b70deb990030a266a95a5/examples/chat.rs#L44-L55).
266pub trait IntoChatCompletionRequestMessage {
267    /// Convert to [`tiktoken_rs` chat completion request message
268    /// type](`tiktoken_rs::ChatCompletionRequestMessage`).
269    fn into_tiktoken_rs(self) -> tiktoken_rs::ChatCompletionRequestMessage;
270
271    /// Convert to [`async_openai` chat completion request message
272    /// type](`async_openai::types::ChatCompletionRequestMessage`).
273    fn into_async_openai(self) -> async_openai::types::ChatCompletionRequestMessage;
274}
275
276impl IntoChatCompletionRequestMessage for tiktoken_rs::ChatCompletionRequestMessage {
277    #[inline]
278    fn into_tiktoken_rs(self) -> tiktoken_rs::ChatCompletionRequestMessage {
279        self
280    }
281
282    #[inline]
283    fn into_async_openai(self) -> async_openai::types::ChatCompletionRequestMessage {
284        async_openai::types::ChatCompletionRequestMessage {
285            role: match self.role.as_ref() {
286                "user" => async_openai::types::Role::User,
287                "system" => async_openai::types::Role::System,
288                "assistant" => async_openai::types::Role::Assistant,
289                "function" => async_openai::types::Role::Function,
290                role => panic!("unknown role '{role}'"),
291            },
292            content: self.content,
293            function_call: self.function_call.map(|fc| {
294                async_openai::types::FunctionCall {
295                    name: fc.name,
296                    arguments: fc.arguments,
297                }
298            }),
299
300            name: self.name,
301        }
302    }
303}
304
305impl IntoChatCompletionRequestMessage for async_openai::types::ChatCompletionRequestMessage {
306    #[inline]
307    fn into_tiktoken_rs(self) -> tiktoken_rs::ChatCompletionRequestMessage {
308        tiktoken_rs::ChatCompletionRequestMessage {
309            role: self.role.to_string(),
310            content: self.content,
311            function_call: self.function_call.map(|fc| {
312                tiktoken_rs::FunctionCall {
313                    name: fc.name,
314                    arguments: fc.arguments,
315                }
316            }),
317
318            name: self.name,
319        }
320    }
321
322    #[inline]
323    fn into_async_openai(self) -> async_openai::types::ChatCompletionRequestMessage {
324        self
325    }
326}
327
328impl IntoChatCompletionRequestMessage for async_openai::types::ChatCompletionResponseMessage {
329    #[inline]
330    fn into_tiktoken_rs(self) -> tiktoken_rs::ChatCompletionRequestMessage {
331        tiktoken_rs::ChatCompletionRequestMessage {
332            role: self.role.to_string(),
333            content: self.content,
334            function_call: self.function_call.map(|fc| {
335                tiktoken_rs::FunctionCall {
336                    name: fc.name,
337                    arguments: fc.arguments,
338                }
339            }),
340
341            name: None,
342        }
343    }
344
345    #[inline]
346    fn into_async_openai(self) -> async_openai::types::ChatCompletionRequestMessage {
347        async_openai::types::ChatCompletionRequestMessage {
348            role: self.role,
349            content: self.content,
350            function_call: self.function_call,
351
352            name: None,
353        }
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn it_works() {
363        let messages: Vec<async_openai::types::ChatCompletionRequestMessage> = Vec::new();
364
365        assert_eq!(ChatSplitter::default().split(&messages).0, &[]);
366        assert_eq!(ChatSplitter::default().split(&messages).1, &[]);
367    }
368}