chat_splitter/lib.rs
1//! [![Build Status]][actions] [![Latest Version]][crates.io]
2//!
3//! [Build Status]: https://github.com/schneiderfelipe/chat-splitter/actions/workflows/rust.yml/badge.svg
4//! [actions]: https://github.com/schneiderfelipe/chat-splitter/actions/workflows/rust.yml
5//! [Latest Version]: https://img.shields.io/crates/v/chat_splitter.svg
6//! [crates.io]: https://crates.io/crates/chat_splitter
7//!
8//! > For more information,
9//! > please refer to the [blog announcement](https://schneiderfelipe.github.io/posts/chat-splitter-first-release/).
10//!
11//! When utilizing the [`async_openai`](https://github.com/64bit/async-openai) [Rust](https://www.rust-lang.org/) crate,
12//! it is crucial to ensure that you do not exceed
13//! the [maximum number of tokens](https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them) specified by [OpenAI](https://openai.com/)'s [chat models](https://platform.openai.com/docs/api-reference/chat).
14//!
15//! [`chat-splitter`](https://crates.io/crates/chat_splitter) categorizes chat messages into 'outdated' and 'recent' messages,
16//! allowing you to split them based on both the maximum
17//! message count and the maximum chat completion token count.
18//! The token counting functionality is provided by
19//! [`tiktoken_rs`](https://github.com/zurawiki/tiktoken-rs).
20//!
21//! # Usage
22//!
23//! Here's a basic example:
24//!
25//! ```ignore
26//! // Get all your previously stored chat messages...
27//! let mut stored_messages = /* get_stored_messages()? */;
28//!
29//! // ...and split into 'outdated' and 'recent',
30//! // where 'recent' always fits the context size.
31//! let (outdated_messages, recent_messages) =
32//! ChatSplitter::default().split(&stored_messages);
33//! ```
34//!
35//! For a more detailed example,
36//! see [`examples/chat.rs`](https://github.com/schneiderfelipe/chat-splitter/blob/main/examples/chat.rs).
37//!
38//! # Contributing
39//!
40//! Contributions to `chat-splitter` are welcome!
41//! If you find a bug or have a feature request,
42//! please [submit an issue](https://github.com/schneiderfelipe/chat-splitter/issues).
43//! If you'd like to contribute code,
44//! please feel free to [submit a pull request](https://github.com/schneiderfelipe/chat-splitter/pulls).
45
46use std::cmp::Ordering;
47
48use indxvec::Search;
49use tiktoken_rs::get_chat_completion_max_tokens;
50use tiktoken_rs::model::get_context_size;
51
52/// Chat splitter for [OpenAI](https://openai.com/)'s [chat models](https://platform.openai.com/docs/api-reference/chat) when using [`async_openai`].
53///
54/// For more detailed information,
55/// see the [crate documentation](`crate`).
56#[derive(Clone, Debug)]
57pub struct ChatSplitter {
58 /// The model to use for tokenization,
59 /// e.g.,
60 /// `gpt-3.5-turbo`.
61 ///
62 /// It is passed to [`tiktoken_rs`] to select the correct tokenizer.
63 model: String,
64
65 /// The maximum number of tokens to leave for chat completion.
66 ///
67 /// This is the same as in the [official API](https://platform.openai.com/docs/api-reference/chat#completions/create-prompt) and given to [`async_openai`].
68 /// The total length of input tokens and generated tokens is limited by the
69 /// model's context size.
70 /// Splits will have at least that many tokens
71 /// available for chat completion,
72 /// never less.
73 max_tokens: u16,
74
75 /// The maximum number of messages to have in the chat.
76 ///
77 /// Splits will have at most that many messages,
78 /// never more.
79 max_messages: usize,
80}
81
82/// Hard limit that seems to be imposed by the `OpenAI` API.
83const MAX_MESSAGES_LIMIT: usize = 2_048;
84
85/// Recommended minimum for maximum chat completion tokens.
86const RECOMMENDED_MIN_MAX_TOKENS: u16 = 256;
87
88impl Default for ChatSplitter {
89 #[inline]
90 fn default() -> Self {
91 Self::new("gpt-3.5-turbo")
92 }
93}
94
95impl ChatSplitter {
96 /// Create a new [`ChatSplitter`] for the given model.
97 ///
98 /// # Panics
99 ///
100 /// If for some reason [`tiktoken_rs`] gives a context size twice as large
101 /// as what would fit in a [`u16`].
102 /// If this happens,
103 /// it should be considered a bug,
104 /// but this behaviour might change in the future,
105 /// as models with larger context sizes are released.
106 #[inline]
107 pub fn new(model: impl Into<String>) -> Self {
108 let model = model.into();
109 let max_tokens = u16::try_from(get_context_size(&model) / 2).unwrap();
110
111 let max_messages = MAX_MESSAGES_LIMIT / 2;
112
113 Self {
114 model,
115 max_tokens,
116 max_messages,
117 }
118 }
119
120 /// Set the maximum number of messages to have in the chat.
121 ///
122 /// Splits will have at most that many messages,
123 /// never more.
124 #[inline]
125 #[must_use]
126 pub fn max_messages(mut self, max_messages: impl Into<usize>) -> Self {
127 self.max_messages = max_messages.into();
128 if self.max_messages > MAX_MESSAGES_LIMIT {
129 log::warn!(
130 "max_messages = {} > {MAX_MESSAGES_LIMIT}",
131 self.max_messages
132 );
133 }
134 self
135 }
136
137 /// Set the maximum number of tokens to leave for chat completion.
138 ///
139 /// This is the same as in the [official API](https://platform.openai.com/docs/api-reference/chat#completions/create-prompt) and given to [`async_openai`].
140 /// The total length of input tokens and generated tokens is limited by the
141 /// model's context size.
142 /// Splits will have at least that many tokens
143 /// available for chat completion,
144 /// never less.
145 #[inline]
146 #[must_use]
147 pub fn max_tokens(mut self, max_tokens: impl Into<u16>) -> Self {
148 self.max_tokens = max_tokens.into();
149 if self.max_tokens < RECOMMENDED_MIN_MAX_TOKENS {
150 log::warn!(
151 "max_tokens = {} < {RECOMMENDED_MIN_MAX_TOKENS}",
152 self.max_tokens
153 );
154 }
155 self
156 }
157
158 /// Set the model to use for tokenization,
159 /// e.g.,
160 /// `gpt-3.5-turbo`.
161 ///
162 /// It is passed to [`tiktoken_rs`] to select the correct tokenizer.
163 #[inline]
164 #[must_use]
165 pub fn model(mut self, model: impl Into<String>) -> Self {
166 self.model = model.into();
167 self
168 }
169
170 /// Get a split position by only considering `max_messages`.
171 #[inline]
172 fn position_by_max_messages<M>(&self, messages: &[M]) -> usize {
173 let upper_limit = self.max_messages.min(MAX_MESSAGES_LIMIT);
174
175 let n = messages.len();
176 let n = if n <= upper_limit { 0 } else { n - upper_limit };
177 debug_assert!(messages[n..].len() <= upper_limit);
178 n
179 }
180
181 /// Get a split position by only considering `max_tokens`.
182 ///
183 /// # Panics
184 ///
185 /// If tokenizer for the specified model is not found or is not a supported
186 /// chat model.
187 #[inline]
188 fn position_by_max_tokens<M>(&self, messages: &[M]) -> usize
189 where
190 M: IntoChatCompletionRequestMessage + Clone,
191 {
192 let max_tokens = self.max_tokens as usize;
193 let lower_limit = max_tokens.min(get_context_size(&self.model));
194
195 let messages: Vec<_> = messages
196 .iter()
197 .cloned()
198 .map(IntoChatCompletionRequestMessage::into_tiktoken_rs)
199 .collect();
200
201 let (n, _range) = (0..=messages.len()).binary_any(|n| {
202 debug_assert!(n < messages.len());
203
204 let tokens = get_chat_completion_max_tokens(&self.model, &messages[n..])
205 .expect("tokenizer should be available");
206
207 let cmp = tokens.cmp(&lower_limit);
208 debug_assert_ne!(cmp, Ordering::Equal);
209 cmp
210 });
211
212 debug_assert!(
213 get_chat_completion_max_tokens(&self.model, &messages[n..])
214 .expect("tokenizer should be available")
215 >= lower_limit
216 );
217 n
218 }
219
220 /// Get a split position by first considering the `max_messages` limit,
221 /// then
222 /// the `max_tokens` limit.
223 ///
224 /// # Panics
225 ///
226 /// If tokenizer for the specified model is not found or is not a supported
227 /// chat model.
228 #[inline]
229 fn position<M>(&self, messages: &[M]) -> usize
230 where
231 M: IntoChatCompletionRequestMessage + Clone,
232 {
233 let n = self.position_by_max_messages(messages);
234 n + self.position_by_max_tokens(&messages[n..])
235 }
236
237 /// Split the chat into two groups of messages,
238 /// the 'outdated' and the
239 /// 'recent' ones.
240 ///
241 /// The 'recent' messages are guaranteed to satisfy the given limits,
242 /// while
243 /// the 'outdated' ones contain all the ones before 'recent'.
244 ///
245 /// For a detailed usage example,
246 /// see [`examples/chat.rs`](https://github.com/schneiderfelipe/chat-splitter/blob/main/examples/chat.rs).
247 ///
248 /// # Panics
249 ///
250 /// If tokenizer for the specified model is not found or is not a supported
251 /// chat model.
252 #[inline]
253 pub fn split<'a, M>(&self, messages: &'a [M]) -> (&'a [M], &'a [M])
254 where
255 M: IntoChatCompletionRequestMessage + Clone,
256 {
257 messages.split_at(self.position(messages))
258 }
259}
260
261/// Extension trait for converting between different chat completion request
262/// message types.
263///
264/// For a usage example,
265/// see [`examples/chat.rs`](https://github.com/schneiderfelipe/chat-splitter/blob/736f4fceb57bc12adb2b70deb990030a266a95a5/examples/chat.rs#L44-L55).
266pub trait IntoChatCompletionRequestMessage {
267 /// Convert to [`tiktoken_rs` chat completion request message
268 /// type](`tiktoken_rs::ChatCompletionRequestMessage`).
269 fn into_tiktoken_rs(self) -> tiktoken_rs::ChatCompletionRequestMessage;
270
271 /// Convert to [`async_openai` chat completion request message
272 /// type](`async_openai::types::ChatCompletionRequestMessage`).
273 fn into_async_openai(self) -> async_openai::types::ChatCompletionRequestMessage;
274}
275
276impl IntoChatCompletionRequestMessage for tiktoken_rs::ChatCompletionRequestMessage {
277 #[inline]
278 fn into_tiktoken_rs(self) -> tiktoken_rs::ChatCompletionRequestMessage {
279 self
280 }
281
282 #[inline]
283 fn into_async_openai(self) -> async_openai::types::ChatCompletionRequestMessage {
284 async_openai::types::ChatCompletionRequestMessage {
285 role: match self.role.as_ref() {
286 "user" => async_openai::types::Role::User,
287 "system" => async_openai::types::Role::System,
288 "assistant" => async_openai::types::Role::Assistant,
289 "function" => async_openai::types::Role::Function,
290 role => panic!("unknown role '{role}'"),
291 },
292 content: self.content,
293 function_call: self.function_call.map(|fc| {
294 async_openai::types::FunctionCall {
295 name: fc.name,
296 arguments: fc.arguments,
297 }
298 }),
299
300 name: self.name,
301 }
302 }
303}
304
305impl IntoChatCompletionRequestMessage for async_openai::types::ChatCompletionRequestMessage {
306 #[inline]
307 fn into_tiktoken_rs(self) -> tiktoken_rs::ChatCompletionRequestMessage {
308 tiktoken_rs::ChatCompletionRequestMessage {
309 role: self.role.to_string(),
310 content: self.content,
311 function_call: self.function_call.map(|fc| {
312 tiktoken_rs::FunctionCall {
313 name: fc.name,
314 arguments: fc.arguments,
315 }
316 }),
317
318 name: self.name,
319 }
320 }
321
322 #[inline]
323 fn into_async_openai(self) -> async_openai::types::ChatCompletionRequestMessage {
324 self
325 }
326}
327
328impl IntoChatCompletionRequestMessage for async_openai::types::ChatCompletionResponseMessage {
329 #[inline]
330 fn into_tiktoken_rs(self) -> tiktoken_rs::ChatCompletionRequestMessage {
331 tiktoken_rs::ChatCompletionRequestMessage {
332 role: self.role.to_string(),
333 content: self.content,
334 function_call: self.function_call.map(|fc| {
335 tiktoken_rs::FunctionCall {
336 name: fc.name,
337 arguments: fc.arguments,
338 }
339 }),
340
341 name: None,
342 }
343 }
344
345 #[inline]
346 fn into_async_openai(self) -> async_openai::types::ChatCompletionRequestMessage {
347 async_openai::types::ChatCompletionRequestMessage {
348 role: self.role,
349 content: self.content,
350 function_call: self.function_call,
351
352 name: None,
353 }
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn it_works() {
363 let messages: Vec<async_openai::types::ChatCompletionRequestMessage> = Vec::new();
364
365 assert_eq!(ChatSplitter::default().split(&messages).0, &[]);
366 assert_eq!(ChatSplitter::default().split(&messages).1, &[]);
367 }
368}