alith_prompt/
token_count.rs1use crate::PromptTokenizer;
2use std::collections::HashMap;
3use std::sync::Arc;
4use thiserror::Error;
5
6pub const DEFAULT_SAFETY_TOKENS: u64 = 10;
7
8pub fn check_and_get_max_tokens(
32 ctx_size: u64,
33 inference_ctx_size: Option<u64>,
34 total_prompt_tokens: u64,
35 safety_tokens: Option<u64>,
36 requested_tokens: Option<u64>,
37) -> Result<u64, RequestTokenLimitError> {
38 let available_tokens = available_tokens(
39 ctx_size,
40 inference_ctx_size,
41 total_prompt_tokens,
42 safety_tokens,
43 )?;
44 let requested_tokens = if let Some(requested_tokens) = requested_tokens {
45 if requested_tokens > available_tokens {
46 eprintln!(
47 "requested_tokens ({requested_tokens}) is greater than available_tokens ({}). Using available_tokens for request.",
48 available_tokens
49 );
50 available_tokens
51 } else {
52 requested_tokens
53 }
54 } else {
55 available_tokens
56 };
57
58 if total_prompt_tokens + requested_tokens as u64 >= ctx_size {
59 panic!(
60 "total_prompt_tokens ({total_prompt_tokens}) + requested_tokens ({requested_tokens}) >= ctx_size ({ctx_size}). This should never happen.",
61 );
62 }
63 Ok(requested_tokens)
64}
65
66fn available_tokens(
67 ctx_size: u64,
68 inference_ctx_size: Option<u64>,
69 total_prompt_tokens: u64,
70 safety_tokens: Option<u64>,
71) -> Result<u64, RequestTokenLimitError> {
72 let safety_tokens = safety_tokens.unwrap_or(DEFAULT_SAFETY_TOKENS);
73
74 if total_prompt_tokens >= ctx_size - safety_tokens {
75 return Err(RequestTokenLimitError::PromptTokensExceeds {
76 total_prompt_tokens,
77 ctx_size: ctx_size - safety_tokens,
78 });
79 }
80
81 let available_tokens = if let Some(inference_ctx_size) = inference_ctx_size {
82 std::cmp::min(ctx_size - total_prompt_tokens, inference_ctx_size) - safety_tokens
83 } else {
84 ctx_size - total_prompt_tokens - safety_tokens
85 };
86 if available_tokens == 0 {
87 panic!("available_tokens == 0. This should never happen.",);
88 }
89 Ok(available_tokens - safety_tokens)
90}
91
92pub(crate) fn total_prompt_tokens_openai_format(
93 prompt: &[HashMap<String, String>],
94 tokens_per_message: Option<u32>,
95 tokens_per_name: Option<i32>,
96 tokenizer: &Arc<dyn PromptTokenizer>,
97) -> u64 {
98 let tokens_per_message = tokens_per_message.unwrap_or(0);
99 let mut num_tokens: u64 = 0;
100 for message in prompt {
101 num_tokens += tokens_per_message as u64;
102
103 for (key, value) in message.iter() {
104 num_tokens += tokenizer.count_tokens(value) as u64;
105 if let Some(tokens_per_name) = tokens_per_name {
106 if key == "name" {
107 if tokens_per_name < 0 {
108 num_tokens -= tokens_per_name.unsigned_abs() as u64;
110 } else {
111 num_tokens += tokens_per_name as u64;
112 }
113 }
114 }
115 }
116 }
117 num_tokens += 3; num_tokens
119}
120
121#[derive(Debug, Clone)]
122pub struct MaxTokenState {
123 pub actual_request: u64,
124 pub requested_response: u64,
125}
126
127#[derive(Error, Debug, Clone)]
128pub enum RequestTokenLimitError {
129 #[error("total_prompt_tokens ({total_prompt_tokens}) exceeds ctx_size ({ctx_size})")]
130 PromptTokensExceeds {
131 total_prompt_tokens: u64,
132 ctx_size: u64,
133 },
134 #[error("GenericPromptError: {e}")]
135 GenericPromptError { e: String },
136 #[error("PromptTokensNotSet: Prompt tokens not set.")]
137 PromptTokensNotSet,
138 #[error(
139 "TokenLimitIncreaseError: initial_state: {:?}, new_state: {:?}",
140 initial_state,
141 new_state
142 )]
143 TokenLimitIncreaseError {
144 initial_state: MaxTokenState,
145 new_state: MaxTokenState,
146 },
147}