another_tiktoken_rs/api.rs
1use anyhow::{anyhow, Result};
2
3use crate::{
4 cl100k_base,
5 model::get_context_size,
6 p50k_base, p50k_edit, r50k_base,
7 tokenizer::{get_tokenizer, Tokenizer},
8 CoreBPE,
9};
10
11/// Calculates the maximum number of tokens available for completion based on the model and prompt provided.
12///
13/// This function determines the number of tokens left for a completion task, given the model and a prompt string.
14/// It first retrieves the context size for the given model and the `CoreBPE` instance for tokenization.
15/// Then, it calculates the number of tokens in the prompt using the appropriate tokenizer.
16///
17/// # Arguments
18///
19/// * `model` - A string slice representing the model name, e.g., "gpt-3.5-turbo".
20/// * `prompt` - A string slice containing the prompt text.
21///
22/// # Errors
23///
24/// This function returns an error in the following cases:
25///
26/// * If there is a failure in creating a `CoreBPE` instance for the specified model.
27///
28/// # Examples
29///
30/// ```
31/// use another_tiktoken_rs::get_completion_max_tokens;
32///
33/// let model = "gpt-3.5-turbo";
34/// let prompt = "Translate the following English text to French: '";
35/// let max_tokens = get_completion_max_tokens(model, prompt).unwrap();
36/// ```
37///
38/// # Returns
39///
40/// If successful, the function returns a `Result` containing the maximum number of tokens available for completion,
41/// based on the given model and prompt.
42pub fn get_completion_max_tokens(model: &str, prompt: &str) -> Result<usize> {
43 let context_size = get_context_size(model);
44 let bpe = get_bpe_from_model(model)?;
45 let prompt_tokens = bpe.encode_with_special_tokens(prompt).len();
46 Ok(context_size.saturating_sub(prompt_tokens))
47}
48
49/// The name and arguments of a function that should be called, as generated by the model.
50#[derive(Debug, Default, Clone, PartialEq, Eq)]
51pub struct FunctionCall {
52 /// The name of the function to call.
53 pub name: String,
54 /// The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.
55 pub arguments: String,
56}
57
58#[derive(Debug, Default, Clone, PartialEq, Eq)]
59pub struct ChatCompletionRequestMessage {
60 /// The role of the messages author. One of `system`, `user`, `assistant`, or `function`.
61 pub role: String,
62 /// The contents of the message.
63 /// `content` is required for all messages except assistant messages with function calls.
64 pub content: Option<String>,
65 /// The name of the author of this message. `name` is required if role is function,
66 /// and it should be the name of the function whose response is in the `content`.
67 /// May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
68 pub name: Option<String>,
69 /// The name and arguments of a function that should be called, as generated by the model.
70 pub function_call: Option<FunctionCall>,
71}
72
73/// Based on <https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb>
74///
75/// num_tokens_from_messages returns the number of tokens required to encode the given messages into
76/// the given model. This is used to estimate the number of tokens that will be used for chat
77/// completion.
78///
79/// # Arguments
80///
81/// * model: A string slice containing the model name (e.g. "gpt-3.5").
82/// * messages: A slice of ChatCompletionRequestMessage structs representing chat messages.
83///
84/// # Returns
85///
86/// * `Result<usize>`: A Result containing the total number of tokens needed to encode the messages
87/// for the specified model, or an error if the tokenizer for the model is not found or not supported.
88///
89/// # Errors
90///
91/// This function will return an error if:
92///
93/// * The tokenizer for the specified model is not found.
94/// * The tokenizer is not a supported chat model (i.e., not Tokenizer::Cl100kBase).
95///
96pub fn num_tokens_from_messages(
97 model: &str,
98 messages: &[ChatCompletionRequestMessage],
99) -> Result<usize> {
100 let tokenizer =
101 get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
102 if tokenizer != Tokenizer::Cl100kBase {
103 anyhow::bail!("Chat completion is only supported chat models")
104 }
105 let bpe = get_bpe_from_tokenizer(tokenizer)?;
106
107 let (tokens_per_message, tokens_per_name) = if model.starts_with("gpt-3.5") {
108 (
109 4, // every message follows <im_start>{role/name}\n{content}<im_end>\n
110 -1, // if there's a name, the role is omitted
111 )
112 } else {
113 (3, 1)
114 };
115
116 let mut num_tokens: i32 = 0;
117 for message in messages {
118 num_tokens += tokens_per_message;
119 num_tokens += bpe
120 .encode_with_special_tokens(&message.role.to_string())
121 .len() as i32;
122 num_tokens += bpe
123 .encode_with_special_tokens(&message.content.clone().unwrap_or_default())
124 .len() as i32;
125 if let Some(name) = &message.name {
126 num_tokens += bpe.encode_with_special_tokens(name).len() as i32;
127 num_tokens += tokens_per_name;
128 }
129 }
130 num_tokens += 3; // every reply is primed with <|start|>assistant<|message|>
131 Ok(num_tokens as usize)
132}
133
134/// Calculates the maximum number of tokens available for chat completion based on the model and messages provided.
135///
136/// This function determines the number of tokens left for a chat completion task, given the model and a slice of
137/// chat completion request messages. It first retrieves the tokenizer for the given model and checks if chat completion
138/// is supported. Then, it calculates the number of tokens in the existing messages using the appropriate tokenizer.
139///
140/// # Arguments
141///
142/// * `model` - A string slice representing the model name, e.g., "gpt-3.5-turbo".
143/// * `messages` - A slice of `ChatCompletionRequestMessage` instances containing the chat context.
144///
145/// # Errors
146///
147/// This function returns an error in the following cases:
148///
149/// * If there is no tokenizer found for the specified model.
150/// * If chat completion is not supported for the specified model.
151/// * If there is a failure in creating a `CoreBPE` instance for the specified tokenizer.
152///
153/// # Example
154///
155/// ```
156/// use another_tiktoken_rs::{get_chat_completion_max_tokens, ChatCompletionRequestMessage};
157///
158/// let model = "gpt-3.5-turbo";
159/// let messages = vec![
160/// ChatCompletionRequestMessage {
161/// content: Some("You are a helpful assistant that only speaks French.".to_string()),
162/// role: "system".to_string(),
163/// name: None,
164/// function_call: None,
165/// },
166/// ChatCompletionRequestMessage {
167/// content: Some("Hello, how are you?".to_string()),
168/// role: "user".to_string(),
169/// name: None,
170/// function_call: None,
171/// },
172/// ChatCompletionRequestMessage {
173/// content: Some("Parlez-vous francais?".to_string()),
174/// role: "system".to_string(),
175/// name: None,
176/// function_call: None,
177/// },
178/// ];
179/// let max_tokens = get_chat_completion_max_tokens(model, &messages).unwrap();
180/// ```
181///
182/// # Returns
183///
184/// If successful, the function returns a `Result` containing the maximum number of tokens available for chat completion,
185/// based on the given model and messages.
186pub fn get_chat_completion_max_tokens(
187 model: &str,
188 messages: &[ChatCompletionRequestMessage],
189) -> Result<usize> {
190 let context_size = get_context_size(model);
191 let prompt_tokens = num_tokens_from_messages(model, messages)?;
192 Ok(context_size.saturating_sub(prompt_tokens))
193}
194
195/// Returns a `CoreBPE` instance corresponding to the tokenizer used by the given model.
196///
197/// This function first retrieves the tokenizer associated with the specified model name
198/// and then maps the tokenizer to the appropriate `CoreBPE` instance, which is used for
199/// tokenization in different models.
200///
201/// # Arguments
202///
203/// * `model` - A string slice representing the model name for which a `CoreBPE` instance should be retrieved.
204///
205/// # Errors
206///
207/// This function returns an error if:
208/// * No tokenizer is found for the given model.
209/// * There is a failure in creating a `CoreBPE` instance for the tokenizer.
210///
211/// # Examples
212///
213/// ```
214/// use another_tiktoken_rs::get_bpe_from_model;
215///
216/// let model = "gpt-4-0314";
217/// let bpe = get_bpe_from_model(model).unwrap();
218/// ```
219///
220/// # Returns
221///
222/// If successful, the function returns a `Result` containing the `CoreBPE` instance corresponding to the tokenizer used by the given model.
223pub fn get_bpe_from_model(model: &str) -> Result<CoreBPE> {
224 let tokenizer =
225 get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
226 let bpe = get_bpe_from_tokenizer(tokenizer)?;
227 Ok(bpe)
228}
229
230/// Returns a `CoreBPE` instance corresponding to the given tokenizer.
231///
232/// This function is responsible for mapping a `Tokenizer` enum variant to the appropriate
233/// `CoreBPE` instance, which is used for tokenization in different models.
234///
235/// # Arguments
236///
237/// * `tokenizer` - A `Tokenizer` enum variant representing the tokenizer for which a `CoreBPE` instance should be retrieved.
238///
239/// # Errors
240///
241/// This function returns an error if there is a failure in creating a `CoreBPE` instance for the specified tokenizer.
242///
243/// # Examples
244///
245/// ```
246/// use another_tiktoken_rs::get_bpe_from_tokenizer;
247/// use another_tiktoken_rs::tokenizer::Tokenizer;
248///
249/// let tokenizer = Tokenizer::Cl100kBase;
250/// let bpe = get_bpe_from_tokenizer(tokenizer).unwrap();
251/// ```
252///
253/// # Returns
254///
255/// If successful, the function returns a `Result` containing the `CoreBPE` instance corresponding to the given tokenizer.
256pub fn get_bpe_from_tokenizer(tokenizer: Tokenizer) -> Result<CoreBPE> {
257 match tokenizer {
258 Tokenizer::Cl100kBase => cl100k_base(),
259 Tokenizer::R50kBase => r50k_base(),
260 Tokenizer::P50kBase => p50k_base(),
261 Tokenizer::P50kEdit => p50k_edit(),
262 Tokenizer::Gpt2 => r50k_base(),
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_get_bpe_from_tokenizer() {
272 let bpe = get_bpe_from_tokenizer(Tokenizer::Cl100kBase).unwrap();
273 assert_eq!(bpe.decode(vec!(15339)).unwrap(), "hello");
274 }
275
276 #[test]
277 fn test_num_tokens_from_messages() {
278 let messages = vec![
279 ChatCompletionRequestMessage {
280 role: "system".to_string(),
281 name: None,
282 content: Some("You are a helpful, pattern-following assistant that translates corporate jargon into plain English.".to_string()),
283 function_call: None,
284 },
285 ChatCompletionRequestMessage {
286 role: "system".to_string(),
287 name: Some("example_user".to_string()),
288 content: Some("New synergies will help drive top-line growth.".to_string()),
289 function_call: None,
290 },
291 ChatCompletionRequestMessage {
292 role: "system".to_string(),
293 name: Some("example_assistant".to_string()),
294 content: Some("Things working well together will increase revenue.".to_string()),
295 function_call: None,
296 },
297 ChatCompletionRequestMessage {
298 role: "system".to_string(),
299 name: Some("example_user".to_string()),
300 content: Some("Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage.".to_string()),
301 function_call: None,
302 },
303 ChatCompletionRequestMessage {
304 role: "system".to_string(),
305 name: Some("example_assistant".to_string()),
306 content: Some("Let's talk later when we're less busy about how to do better.".to_string()),
307 function_call: None,
308 },
309 ChatCompletionRequestMessage {
310 role: "user".to_string(),
311 name: None,
312 content: Some("This late pivot means we don't have time to boil the ocean for the client deliverable.".to_string()),
313 function_call: None,
314 },
315 ];
316 let num_tokens = num_tokens_from_messages("gpt-3.5-turbo-0301", &messages).unwrap();
317 assert_eq!(num_tokens, 127);
318
319 let num_tokens = num_tokens_from_messages("gpt-4-0314", &messages).unwrap();
320 assert_eq!(num_tokens, 129);
321 }
322
323 #[test]
324 fn test_get_chat_completion_max_tokens() {
325 let model = "gpt-3.5-turbo";
326 let messages = vec![
327 ChatCompletionRequestMessage {
328 content: Some("You are a helpful assistant that only speaks French.".to_string()),
329 role: "system".to_string(),
330 name: None,
331 function_call: None,
332 },
333 ChatCompletionRequestMessage {
334 content: Some("Hello, how are you?".to_string()),
335 role: "user".to_string(),
336 name: None,
337 function_call: None,
338 },
339 ChatCompletionRequestMessage {
340 content: Some("Parlez-vous francais?".to_string()),
341 role: "system".to_string(),
342 name: None,
343 function_call: None,
344 },
345 ];
346 let max_tokens = get_chat_completion_max_tokens(model, &messages).unwrap();
347 assert!(max_tokens > 0);
348 }
349
350 #[test]
351 fn test_get_completion_max_tokens() {
352 let model = "gpt-3.5-turbo";
353 let prompt = "Translate the following English text to French: '";
354 let max_tokens = get_completion_max_tokens(model, prompt).unwrap();
355 assert!(max_tokens > 0);
356 }
357}
358
359/// This module provide support for working with the `async_openai` crate.
360#[cfg(feature = "async-openai")]
361pub mod async_openai {
362 use anyhow::Result;
363
364 impl From<&async_openai::types::FunctionCall> for super::FunctionCall {
365 fn from(f: &async_openai::types::FunctionCall) -> Self {
366 Self {
367 name: f.name.clone(),
368 arguments: f.arguments.clone(),
369 }
370 }
371 }
372
373 impl From<&async_openai::types::ChatCompletionRequestMessage>
374 for super::ChatCompletionRequestMessage
375 {
376 fn from(m: &async_openai::types::ChatCompletionRequestMessage) -> Self {
377 Self {
378 role: m.role.to_string(),
379 name: m.name.clone(),
380 content: m.content.clone(),
381 function_call: m.function_call.as_ref().map(|f| f.into()),
382 }
383 }
384 }
385
386 /// Calculates the total number of tokens for the given list of messages.
387 ///
388 /// # Arguments
389 ///
390 /// * `model` - A string slice representing the name of the model.
391 /// * `messages` - A slice of `async_openai::types::ChatCompletionRequestMessage` instances.
392 ///
393 /// # Returns
394 ///
395 /// * A `Result` containing the total number of tokens (`usize`) or an error if the calculation fails.
396 pub fn num_tokens_from_messages(
397 model: &str,
398 messages: &[async_openai::types::ChatCompletionRequestMessage],
399 ) -> Result<usize> {
400 let messages = messages.iter().map(|m| m.into()).collect::<Vec<_>>();
401 super::num_tokens_from_messages(model, &messages)
402 }
403
404 /// Retrieves the maximum token limit for chat completions.
405 ///
406 /// # Arguments
407 ///
408 /// * `model` - A string slice representing the name of the model.
409 /// * `messages` - A slice of `async_openai::types::ChatCompletionRequestMessage` instances.
410 ///
411 /// # Returns
412 ///
413 /// * A `Result` containing the maximum number of tokens (`usize`) allowed for chat completions or an error if the retrieval fails.
414 pub fn get_chat_completion_max_tokens(
415 model: &str,
416 messages: &[async_openai::types::ChatCompletionRequestMessage],
417 ) -> Result<usize> {
418 let messages = messages.iter().map(|m| m.into()).collect::<Vec<_>>();
419 super::get_chat_completion_max_tokens(model, &messages)
420 }
421
422 #[cfg(test)]
423 mod tests {
424 use super::*;
425
426 #[test]
427 fn test_num_tokens_from_messages() {
428 let model = "gpt-3.5-turbo-0301";
429 let messages = &[async_openai::types::ChatCompletionRequestMessage {
430 role: async_openai::types::Role::System,
431 name: None,
432 content: Some("You are a helpful, pattern-following assistant that translates corporate jargon into plain English.".to_string()),
433 function_call: None,
434 }];
435 let num_tokens = num_tokens_from_messages(model, messages).unwrap();
436 assert!(num_tokens > 0);
437 }
438
439 #[test]
440 fn test_get_chat_completion_max_tokens() {
441 let model = "gpt-3.5-turbo";
442 let messages = &[async_openai::types::ChatCompletionRequestMessage {
443 content: Some("You are a helpful assistant that only speaks French.".to_string()),
444 role: async_openai::types::Role::System,
445 name: None,
446 function_call: None,
447 }];
448 let max_tokens = get_chat_completion_max_tokens(model, messages).unwrap();
449 assert!(max_tokens > 0);
450 }
451 }
452}