another_tiktoken_rs/
api.rs

1use anyhow::{anyhow, Result};
2
3use crate::{
4    cl100k_base,
5    model::get_context_size,
6    p50k_base, p50k_edit, r50k_base,
7    tokenizer::{get_tokenizer, Tokenizer},
8    CoreBPE,
9};
10
11/// Calculates the maximum number of tokens available for completion based on the model and prompt provided.
12///
13/// This function determines the number of tokens left for a completion task, given the model and a prompt string.
14/// It first retrieves the context size for the given model and the `CoreBPE` instance for tokenization.
15/// Then, it calculates the number of tokens in the prompt using the appropriate tokenizer.
16///
17/// # Arguments
18///
19/// * `model` - A string slice representing the model name, e.g., "gpt-3.5-turbo".
20/// * `prompt` - A string slice containing the prompt text.
21///
22/// # Errors
23///
24/// This function returns an error in the following cases:
25///
26/// * If there is a failure in creating a `CoreBPE` instance for the specified model.
27///
28/// # Examples
29///
30/// ```
31/// use another_tiktoken_rs::get_completion_max_tokens;
32///
33/// let model = "gpt-3.5-turbo";
34/// let prompt = "Translate the following English text to French: '";
35/// let max_tokens = get_completion_max_tokens(model, prompt).unwrap();
36/// ```
37///
38/// # Returns
39///
40/// If successful, the function returns a `Result` containing the maximum number of tokens available for completion,
41/// based on the given model and prompt.
42pub fn get_completion_max_tokens(model: &str, prompt: &str) -> Result<usize> {
43    let context_size = get_context_size(model);
44    let bpe = get_bpe_from_model(model)?;
45    let prompt_tokens = bpe.encode_with_special_tokens(prompt).len();
46    Ok(context_size.saturating_sub(prompt_tokens))
47}
48
49/// The name and arguments of a function that should be called, as generated by the model.
50#[derive(Debug, Default, Clone, PartialEq, Eq)]
51pub struct FunctionCall {
52    /// The name of the function to call.
53    pub name: String,
54    /// The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.
55    pub arguments: String,
56}
57
58#[derive(Debug, Default, Clone, PartialEq, Eq)]
59pub struct ChatCompletionRequestMessage {
60    /// The role of the messages author. One of `system`, `user`, `assistant`, or `function`.
61    pub role: String,
62    /// The contents of the message.
63    /// `content` is required for all messages except assistant messages with function calls.
64    pub content: Option<String>,
65    /// The name of the author of this message. `name` is required if role is function,
66    /// and it should be the name of the function whose response is in the `content`.
67    /// May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
68    pub name: Option<String>,
69    /// The name and arguments of a function that should be called, as generated by the model.
70    pub function_call: Option<FunctionCall>,
71}
72
73/// Based on <https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb>
74///
75/// num_tokens_from_messages returns the number of tokens required to encode the given messages into
76/// the given model. This is used to estimate the number of tokens that will be used for chat
77/// completion.
78///
79/// # Arguments
80///
81/// * model: A string slice containing the model name (e.g. "gpt-3.5").
82/// * messages: A slice of ChatCompletionRequestMessage structs representing chat messages.
83///
84/// # Returns
85///
86/// * `Result<usize>`: A Result containing the total number of tokens needed to encode the messages
87/// for the specified model, or an error if the tokenizer for the model is not found or not supported.
88///
89/// # Errors
90///
91/// This function will return an error if:
92///
93/// * The tokenizer for the specified model is not found.
94/// * The tokenizer is not a supported chat model (i.e., not Tokenizer::Cl100kBase).
95///
96pub fn num_tokens_from_messages(
97    model: &str,
98    messages: &[ChatCompletionRequestMessage],
99) -> Result<usize> {
100    let tokenizer =
101        get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
102    if tokenizer != Tokenizer::Cl100kBase {
103        anyhow::bail!("Chat completion is only supported chat models")
104    }
105    let bpe = get_bpe_from_tokenizer(tokenizer)?;
106
107    let (tokens_per_message, tokens_per_name) = if model.starts_with("gpt-3.5") {
108        (
109            4,  // every message follows <im_start>{role/name}\n{content}<im_end>\n
110            -1, // if there's a name, the role is omitted
111        )
112    } else {
113        (3, 1)
114    };
115
116    let mut num_tokens: i32 = 0;
117    for message in messages {
118        num_tokens += tokens_per_message;
119        num_tokens += bpe
120            .encode_with_special_tokens(&message.role.to_string())
121            .len() as i32;
122        num_tokens += bpe
123            .encode_with_special_tokens(&message.content.clone().unwrap_or_default())
124            .len() as i32;
125        if let Some(name) = &message.name {
126            num_tokens += bpe.encode_with_special_tokens(name).len() as i32;
127            num_tokens += tokens_per_name;
128        }
129    }
130    num_tokens += 3; // every reply is primed with <|start|>assistant<|message|>
131    Ok(num_tokens as usize)
132}
133
134/// Calculates the maximum number of tokens available for chat completion based on the model and messages provided.
135///
136/// This function determines the number of tokens left for a chat completion task, given the model and a slice of
137/// chat completion request messages. It first retrieves the tokenizer for the given model and checks if chat completion
138/// is supported. Then, it calculates the number of tokens in the existing messages using the appropriate tokenizer.
139///
140/// # Arguments
141///
142/// * `model` - A string slice representing the model name, e.g., "gpt-3.5-turbo".
143/// * `messages` - A slice of `ChatCompletionRequestMessage` instances containing the chat context.
144///
145/// # Errors
146///
147/// This function returns an error in the following cases:
148///
149/// * If there is no tokenizer found for the specified model.
150/// * If chat completion is not supported for the specified model.
151/// * If there is a failure in creating a `CoreBPE` instance for the specified tokenizer.
152///
153/// # Example
154///
155/// ```
156/// use another_tiktoken_rs::{get_chat_completion_max_tokens, ChatCompletionRequestMessage};
157///
158/// let model = "gpt-3.5-turbo";
159/// let messages = vec![
160///     ChatCompletionRequestMessage {
161///         content: Some("You are a helpful assistant that only speaks French.".to_string()),
162///         role: "system".to_string(),
163///         name: None,
164///         function_call: None,
165///     },
166///     ChatCompletionRequestMessage {
167///         content: Some("Hello, how are you?".to_string()),
168///         role: "user".to_string(),
169///         name: None,
170///         function_call: None,
171///     },
172///     ChatCompletionRequestMessage {
173///         content: Some("Parlez-vous francais?".to_string()),
174///         role: "system".to_string(),
175///         name: None,
176///         function_call: None,
177///     },
178/// ];
179/// let max_tokens = get_chat_completion_max_tokens(model, &messages).unwrap();
180/// ```
181///
182/// # Returns
183///
184/// If successful, the function returns a `Result` containing the maximum number of tokens available for chat completion,
185/// based on the given model and messages.
186pub fn get_chat_completion_max_tokens(
187    model: &str,
188    messages: &[ChatCompletionRequestMessage],
189) -> Result<usize> {
190    let context_size = get_context_size(model);
191    let prompt_tokens = num_tokens_from_messages(model, messages)?;
192    Ok(context_size.saturating_sub(prompt_tokens))
193}
194
195/// Returns a `CoreBPE` instance corresponding to the tokenizer used by the given model.
196///
197/// This function first retrieves the tokenizer associated with the specified model name
198/// and then maps the tokenizer to the appropriate `CoreBPE` instance, which is used for
199/// tokenization in different models.
200///
201/// # Arguments
202///
203/// * `model` - A string slice representing the model name for which a `CoreBPE` instance should be retrieved.
204///
205/// # Errors
206///
207/// This function returns an error if:
208/// * No tokenizer is found for the given model.
209/// * There is a failure in creating a `CoreBPE` instance for the tokenizer.
210///
211/// # Examples
212///
213/// ```
214/// use another_tiktoken_rs::get_bpe_from_model;
215///
216/// let model = "gpt-4-0314";
217/// let bpe = get_bpe_from_model(model).unwrap();
218/// ```
219///
220/// # Returns
221///
222/// If successful, the function returns a `Result` containing the `CoreBPE` instance corresponding to the tokenizer used by the given model.
223pub fn get_bpe_from_model(model: &str) -> Result<CoreBPE> {
224    let tokenizer =
225        get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
226    let bpe = get_bpe_from_tokenizer(tokenizer)?;
227    Ok(bpe)
228}
229
230/// Returns a `CoreBPE` instance corresponding to the given tokenizer.
231///
232/// This function is responsible for mapping a `Tokenizer` enum variant to the appropriate
233/// `CoreBPE` instance, which is used for tokenization in different models.
234///
235/// # Arguments
236///
237/// * `tokenizer` - A `Tokenizer` enum variant representing the tokenizer for which a `CoreBPE` instance should be retrieved.
238///
239/// # Errors
240///
241/// This function returns an error if there is a failure in creating a `CoreBPE` instance for the specified tokenizer.
242///
243/// # Examples
244///
245/// ```
246/// use another_tiktoken_rs::get_bpe_from_tokenizer;
247/// use another_tiktoken_rs::tokenizer::Tokenizer;
248///
249/// let tokenizer = Tokenizer::Cl100kBase;
250/// let bpe = get_bpe_from_tokenizer(tokenizer).unwrap();
251/// ```
252///
253/// # Returns
254///
255/// If successful, the function returns a `Result` containing the `CoreBPE` instance corresponding to the given tokenizer.
256pub fn get_bpe_from_tokenizer(tokenizer: Tokenizer) -> Result<CoreBPE> {
257    match tokenizer {
258        Tokenizer::Cl100kBase => cl100k_base(),
259        Tokenizer::R50kBase => r50k_base(),
260        Tokenizer::P50kBase => p50k_base(),
261        Tokenizer::P50kEdit => p50k_edit(),
262        Tokenizer::Gpt2 => r50k_base(),
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_get_bpe_from_tokenizer() {
272        let bpe = get_bpe_from_tokenizer(Tokenizer::Cl100kBase).unwrap();
273        assert_eq!(bpe.decode(vec!(15339)).unwrap(), "hello");
274    }
275
276    #[test]
277    fn test_num_tokens_from_messages() {
278        let messages = vec![
279            ChatCompletionRequestMessage {
280                role: "system".to_string(),
281                name: None,
282                content: Some("You are a helpful, pattern-following assistant that translates corporate jargon into plain English.".to_string()),
283                function_call: None,
284            },
285            ChatCompletionRequestMessage {
286                role: "system".to_string(),
287                name: Some("example_user".to_string()),
288                content: Some("New synergies will help drive top-line growth.".to_string()),
289                function_call: None,
290            },
291            ChatCompletionRequestMessage {
292                role: "system".to_string(),
293                name: Some("example_assistant".to_string()),
294                content: Some("Things working well together will increase revenue.".to_string()),
295                function_call: None,
296            },
297            ChatCompletionRequestMessage {
298                role: "system".to_string(),
299                name: Some("example_user".to_string()),
300                content: Some("Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage.".to_string()),
301                function_call: None,
302            },
303            ChatCompletionRequestMessage {
304                role: "system".to_string(),
305                name: Some("example_assistant".to_string()),
306                content: Some("Let's talk later when we're less busy about how to do better.".to_string()),
307                function_call: None,
308            },
309            ChatCompletionRequestMessage {
310                role: "user".to_string(),
311                name: None,
312                content: Some("This late pivot means we don't have time to boil the ocean for the client deliverable.".to_string()),
313                function_call: None,
314            },
315        ];
316        let num_tokens = num_tokens_from_messages("gpt-3.5-turbo-0301", &messages).unwrap();
317        assert_eq!(num_tokens, 127);
318
319        let num_tokens = num_tokens_from_messages("gpt-4-0314", &messages).unwrap();
320        assert_eq!(num_tokens, 129);
321    }
322
323    #[test]
324    fn test_get_chat_completion_max_tokens() {
325        let model = "gpt-3.5-turbo";
326        let messages = vec![
327            ChatCompletionRequestMessage {
328                content: Some("You are a helpful assistant that only speaks French.".to_string()),
329                role: "system".to_string(),
330                name: None,
331                function_call: None,
332            },
333            ChatCompletionRequestMessage {
334                content: Some("Hello, how are you?".to_string()),
335                role: "user".to_string(),
336                name: None,
337                function_call: None,
338            },
339            ChatCompletionRequestMessage {
340                content: Some("Parlez-vous francais?".to_string()),
341                role: "system".to_string(),
342                name: None,
343                function_call: None,
344            },
345        ];
346        let max_tokens = get_chat_completion_max_tokens(model, &messages).unwrap();
347        assert!(max_tokens > 0);
348    }
349
350    #[test]
351    fn test_get_completion_max_tokens() {
352        let model = "gpt-3.5-turbo";
353        let prompt = "Translate the following English text to French: '";
354        let max_tokens = get_completion_max_tokens(model, prompt).unwrap();
355        assert!(max_tokens > 0);
356    }
357}
358
359/// This module provide support for working with the `async_openai` crate.
360#[cfg(feature = "async-openai")]
361pub mod async_openai {
362    use anyhow::Result;
363
364    impl From<&async_openai::types::FunctionCall> for super::FunctionCall {
365        fn from(f: &async_openai::types::FunctionCall) -> Self {
366            Self {
367                name: f.name.clone(),
368                arguments: f.arguments.clone(),
369            }
370        }
371    }
372
373    impl From<&async_openai::types::ChatCompletionRequestMessage>
374        for super::ChatCompletionRequestMessage
375    {
376        fn from(m: &async_openai::types::ChatCompletionRequestMessage) -> Self {
377            Self {
378                role: m.role.to_string(),
379                name: m.name.clone(),
380                content: m.content.clone(),
381                function_call: m.function_call.as_ref().map(|f| f.into()),
382            }
383        }
384    }
385
386    /// Calculates the total number of tokens for the given list of messages.
387    ///
388    /// # Arguments
389    ///
390    /// * `model` - A string slice representing the name of the model.
391    /// * `messages` - A slice of `async_openai::types::ChatCompletionRequestMessage` instances.
392    ///
393    /// # Returns
394    ///
395    /// * A `Result` containing the total number of tokens (`usize`) or an error if the calculation fails.
396    pub fn num_tokens_from_messages(
397        model: &str,
398        messages: &[async_openai::types::ChatCompletionRequestMessage],
399    ) -> Result<usize> {
400        let messages = messages.iter().map(|m| m.into()).collect::<Vec<_>>();
401        super::num_tokens_from_messages(model, &messages)
402    }
403
404    /// Retrieves the maximum token limit for chat completions.
405    ///
406    /// # Arguments
407    ///
408    /// * `model` - A string slice representing the name of the model.
409    /// * `messages` - A slice of `async_openai::types::ChatCompletionRequestMessage` instances.
410    ///
411    /// # Returns
412    ///
413    /// * A `Result` containing the maximum number of tokens (`usize`) allowed for chat completions or an error if the retrieval fails.
414    pub fn get_chat_completion_max_tokens(
415        model: &str,
416        messages: &[async_openai::types::ChatCompletionRequestMessage],
417    ) -> Result<usize> {
418        let messages = messages.iter().map(|m| m.into()).collect::<Vec<_>>();
419        super::get_chat_completion_max_tokens(model, &messages)
420    }
421
422    #[cfg(test)]
423    mod tests {
424        use super::*;
425
426        #[test]
427        fn test_num_tokens_from_messages() {
428            let model = "gpt-3.5-turbo-0301";
429            let messages = &[async_openai::types::ChatCompletionRequestMessage {
430                role: async_openai::types::Role::System,
431                name: None,
432                content: Some("You are a helpful, pattern-following assistant that translates corporate jargon into plain English.".to_string()),
433                function_call: None,
434            }];
435            let num_tokens = num_tokens_from_messages(model, messages).unwrap();
436            assert!(num_tokens > 0);
437        }
438
439        #[test]
440        fn test_get_chat_completion_max_tokens() {
441            let model = "gpt-3.5-turbo";
442            let messages = &[async_openai::types::ChatCompletionRequestMessage {
443                content: Some("You are a helpful assistant that only speaks French.".to_string()),
444                role: async_openai::types::Role::System,
445                name: None,
446                function_call: None,
447            }];
448            let max_tokens = get_chat_completion_max_tokens(model, messages).unwrap();
449            assert!(max_tokens > 0);
450        }
451    }
452}