1use crate::step::Step;
8use crate::{traits, Parameters};
9use serde::{Deserialize, Serialize};
10use std::cmp::max;
11use thiserror::Error;
12
13#[derive(Clone, Debug, Error)]
15pub enum PromptTokensError {
16 #[error("The prompt tokens are not accessible for this type of step.")]
18 NotAvailable,
19 #[error("The prompt tokens could not be computed.")]
21 UnableToCompute,
22 #[error("Formatting prompt failed: {0}")]
24 PromptFormatFailed(#[from] crate::prompt::StringTemplateError),
25 #[error("Tokenizer error: {0}")]
26 TokenizerError(#[from] crate::tokens::TokenizerError),
27}
28
29pub trait ExecutorTokenCountExt: traits::Executor {
32 fn split_to_fit(
44 &self,
45 step: &Step,
46 doc: &Parameters,
47 base_parameters: &Parameters,
48 chunk_overlap: Option<usize>,
49 ) -> Result<Vec<Parameters>, PromptTokensError> {
50 let splitter = self
51 .get_tokenizer(step.options())
52 .map_err(|_e| PromptTokensError::UnableToCompute)?;
53
54 let text = doc.get_text().ok_or(PromptTokensError::UnableToCompute)?;
55
56 let prompt = step.format(&base_parameters.combine(&Parameters::new_with_text("")))?;
57 let tokens_used = self.tokens_used(step.options(), &prompt)?;
58 let chunk_overlap = chunk_overlap.unwrap_or(0);
59
60 let split_params = splitter
61 .split_text(
62 &text,
63 tokens_used.max_tokens as usize - tokens_used.tokens_used as usize,
64 chunk_overlap,
65 )
66 .map_err(|_e| PromptTokensError::UnableToCompute)?
67 .into_iter()
68 .map(Parameters::new_with_text)
69 .collect();
70 Ok(split_params)
71 }
72}
73
74impl<E: traits::Executor> ExecutorTokenCountExt for E {}
76
77pub struct TokenCount {
80 max_tokens: i32,
82 tokens_used: i32,
84}
85impl TokenCount {
86 pub fn new(max_tokens: i32, tokens_used: i32) -> Self {
93 Self {
94 max_tokens,
95 tokens_used,
96 }
97 }
98
99 pub fn tokens_remaining(&self) -> i32 {
101 self.max_tokens - self.tokens_used
102 }
103
104 pub fn has_tokens_remaining(&self) -> bool {
106 self.has_room_for(1)
107 }
108
109 pub fn has_room_for(&self, tokens: i32) -> bool {
123 self.tokens_remaining() >= tokens
124 }
125}
126
127#[derive(Error, Debug, Clone)]
128pub enum TokenizerError {
129 #[error("Error tokenizing input text")]
130 TokenizationError,
131 #[error("Error stringifying tokens to text")]
132 ToStringError,
133 #[error("Error creating tokenizer")]
134 TokenizerCreationError,
135 #[error("Token Collection type mismatch")]
136 TokenCollectionTypeMismatch,
137}
138
139pub trait Tokenizer {
140 fn tokenize_str(&self, doc: &str) -> Result<TokenCollection, TokenizerError>;
150
151 fn to_string(&self, tokens: TokenCollection) -> Result<String, TokenizerError>;
161
162 fn split_text(
163 &self,
164 doc: &str,
165 max_tokens_per_chunk: usize,
166 chunk_overlap: usize,
167 ) -> Result<Vec<String>, TokenizerError> {
168 let tokens = self.tokenize_str(doc)?;
169 let step_size = max(
170 max_tokens_per_chunk.checked_sub(chunk_overlap).unwrap_or(1),
171 1,
172 );
173
174 debug_assert_ne!(step_size, 0);
175
176 (0..tokens.len())
177 .step_by(step_size)
178 .map(|start_idx| {
179 let end_idx = usize::min(start_idx + max_tokens_per_chunk, tokens.len());
180 self.to_string(tokens.slice(start_idx, end_idx))
181 })
182 .collect()
183 }
184}
185#[derive(Serialize, Deserialize, Clone, Debug)]
187#[serde(transparent)]
188pub struct Token(TokenImpl);
189
190#[derive(Serialize, Deserialize, Clone, Debug)]
191enum TokenImpl {
192 I32(i32),
193 USize(usize),
194}
195
196impl From<i32> for Token {
197 fn from(value: i32) -> Self {
198 Token(TokenImpl::I32(value))
199 }
200}
201
202impl From<usize> for Token {
203 fn from(value: usize) -> Self {
204 Token(TokenImpl::USize(value))
205 }
206}
207
208impl Token {
209 pub fn to_i32(&self) -> Option<i32> {
210 match &self.0 {
211 TokenImpl::I32(x) => Some(*x),
212 _ => None,
213 }
214 }
215
216 pub fn to_usize(&self) -> Option<usize> {
217 match &self.0 {
218 TokenImpl::USize(x) => Some(*x),
219 _ => None,
220 }
221 }
222}
223
224#[derive(Debug)]
229pub struct TokenCollection(TokenCollectionImpl);
230
231#[derive(Debug)]
237enum TokenCollectionImpl {
238 I32(Vec<i32>),
240 Usize(Vec<usize>),
242}
243
244impl TokenCollection {
245 pub fn as_i32(self) -> Result<Vec<i32>, TokenizerError> {
248 match self.0 {
249 TokenCollectionImpl::I32(v) => Ok(v),
250 _ => Err(TokenizerError::TokenCollectionTypeMismatch),
251 }
252 }
253
254 pub fn as_usize(self) -> Result<Vec<usize>, TokenizerError> {
257 match self.0 {
258 TokenCollectionImpl::Usize(v) => Ok(v),
259 _ => Err(TokenizerError::TokenCollectionTypeMismatch),
260 }
261 }
262
263 pub fn len(&self) -> usize {
265 match &self.0 {
266 TokenCollectionImpl::I32(x) => x.len(),
267 TokenCollectionImpl::Usize(x) => x.len(),
268 }
269 }
270
271 pub fn is_empty(&self) -> bool {
273 self.len() == 0
274 }
275
276 pub fn slice(&self, start: usize, end: usize) -> Self {
278 match &self.0 {
279 TokenCollectionImpl::I32(v) => Vec::from(&v[start..end]).into(),
280 TokenCollectionImpl::Usize(v) => Vec::from(&v[start..end]).into(),
281 }
282 }
283}
284
285impl From<Vec<i32>> for TokenCollection {
287 fn from(v: Vec<i32>) -> Self {
288 TokenCollection(TokenCollectionImpl::I32(v))
289 }
290}
291
292impl From<Vec<usize>> for TokenCollection {
294 fn from(v: Vec<usize>) -> Self {
295 TokenCollection(TokenCollectionImpl::Usize(v))
296 }
297}