alith_prompt/
token_count.rs

1use crate::PromptTokenizer;
2use std::collections::HashMap;
3use std::sync::Arc;
4use thiserror::Error;
5
6pub const DEFAULT_SAFETY_TOKENS: u64 = 10;
7
8/// Sets and validates the 'max_tokens' or 'n_ctx' or 'n_predict' parameter for a request.
9/// First, it checks that the total_prompt_tokens is less than the ctx_size - safety_tokens.
10/// Then returns 'available_tokens' as the lower of either:
11/// ctx_size - total_prompt_tokens - safety_tokens or if it's provided, inference_ctx_size.
12/// If 'requested_tokens' is provided, 'requested_tokens' is returned if less than 'available_tokens'.
13/// If 'requested_tokens' is 'None' or 'requested_tokens' is greater than 'available_tokens',
14/// 'available_tokens' is returned.
15///
16/// # Arguments
17///
18/// * `ctx_size` - The total context length for the for the model or system.
19/// * `inference_ctx_size` - Optional output size for models with output generation limits. Defaults to None.
20/// * `total_prompt_tokens` - The total prompt tokens as an unsigned 32-bit integer.
21/// * `safety_tokens` - Optional padding. Defaults to 10.
22/// * `requested_tokens` - Optional 'max_tokens' for the response. Defaults to 'available_tokens'.
23///
24/// # Returns
25///
26/// A u32 to be used for the 'max_tokens' or 'n_ctx' parameter for inference requests.
27///
28/// # Errors
29///
30/// Returns an error if any of the validation checks fail.
31pub fn check_and_get_max_tokens(
32    ctx_size: u64,
33    inference_ctx_size: Option<u64>,
34    total_prompt_tokens: u64,
35    safety_tokens: Option<u64>,
36    requested_tokens: Option<u64>,
37) -> Result<u64, RequestTokenLimitError> {
38    let available_tokens = available_tokens(
39        ctx_size,
40        inference_ctx_size,
41        total_prompt_tokens,
42        safety_tokens,
43    )?;
44    let requested_tokens = if let Some(requested_tokens) = requested_tokens {
45        if requested_tokens > available_tokens {
46            eprintln!(
47                "requested_tokens ({requested_tokens}) is greater than available_tokens ({}). Using available_tokens for request.",
48                available_tokens
49            );
50            available_tokens
51        } else {
52            requested_tokens
53        }
54    } else {
55        available_tokens
56    };
57
58    if total_prompt_tokens + requested_tokens as u64 >= ctx_size {
59        panic!(
60            "total_prompt_tokens ({total_prompt_tokens}) + requested_tokens ({requested_tokens}) >= ctx_size ({ctx_size}). This should never happen.",
61        );
62    }
63    Ok(requested_tokens)
64}
65
66fn available_tokens(
67    ctx_size: u64,
68    inference_ctx_size: Option<u64>,
69    total_prompt_tokens: u64,
70    safety_tokens: Option<u64>,
71) -> Result<u64, RequestTokenLimitError> {
72    let safety_tokens = safety_tokens.unwrap_or(DEFAULT_SAFETY_TOKENS);
73
74    if total_prompt_tokens >= ctx_size - safety_tokens {
75        return Err(RequestTokenLimitError::PromptTokensExceeds {
76            total_prompt_tokens,
77            ctx_size: ctx_size - safety_tokens,
78        });
79    }
80
81    let available_tokens = if let Some(inference_ctx_size) = inference_ctx_size {
82        std::cmp::min(ctx_size - total_prompt_tokens, inference_ctx_size) - safety_tokens
83    } else {
84        ctx_size - total_prompt_tokens - safety_tokens
85    };
86    if available_tokens == 0 {
87        panic!("available_tokens == 0. This should never happen.",);
88    }
89    Ok(available_tokens - safety_tokens)
90}
91
92pub(crate) fn total_prompt_tokens_openai_format(
93    prompt: &[HashMap<String, String>],
94    tokens_per_message: Option<u32>,
95    tokens_per_name: Option<i32>,
96    tokenizer: &Arc<dyn PromptTokenizer>,
97) -> u64 {
98    let tokens_per_message = tokens_per_message.unwrap_or(0);
99    let mut num_tokens: u64 = 0;
100    for message in prompt {
101        num_tokens += tokens_per_message as u64;
102
103        for (key, value) in message.iter() {
104            num_tokens += tokenizer.count_tokens(value) as u64;
105            if let Some(tokens_per_name) = tokens_per_name {
106                if key == "name" {
107                    if tokens_per_name < 0 {
108                        // Handles cases for certain models where name doesn't count towards token count
109                        num_tokens -= tokens_per_name.unsigned_abs() as u64;
110                    } else {
111                        num_tokens += tokens_per_name as u64;
112                    }
113                }
114            }
115        }
116    }
117    num_tokens += 3; // every reply is primed with <|start|>assistant<|message|>
118    num_tokens
119}
120
121#[derive(Debug, Clone)]
122pub struct MaxTokenState {
123    pub actual_request: u64,
124    pub requested_response: u64,
125}
126
127#[derive(Error, Debug, Clone)]
128pub enum RequestTokenLimitError {
129    #[error("total_prompt_tokens ({total_prompt_tokens}) exceeds ctx_size ({ctx_size})")]
130    PromptTokensExceeds {
131        total_prompt_tokens: u64,
132        ctx_size: u64,
133    },
134    #[error("GenericPromptError: {e}")]
135    GenericPromptError { e: String },
136    #[error("PromptTokensNotSet: Prompt tokens not set.")]
137    PromptTokensNotSet,
138    #[error(
139        "TokenLimitIncreaseError: initial_state: {:?}, new_state: {:?}",
140        initial_state,
141        new_state
142    )]
143    TokenLimitIncreaseError {
144        initial_state: MaxTokenState,
145        new_state: MaxTokenState,
146    },
147}