use crate::step::Step;
use crate::{traits, Parameters};
use serde::{Deserialize, Serialize};
use std::cmp::max;
use thiserror::Error;
#[derive(Clone, Debug, Error)]
pub enum PromptTokensError {
#[error("The prompt tokens are not accessible for this type of step.")]
NotAvailable,
#[error("The prompt tokens could not be computed.")]
UnableToCompute,
#[error("Formatting prompt failed: {0}")]
PromptFormatFailed(#[from] crate::prompt::StringTemplateError),
#[error("Tokenizer error: {0}")]
TokenizerError(#[from] crate::tokens::TokenizerError),
}
pub trait ExecutorTokenCountExt: traits::Executor {
fn split_to_fit(
&self,
step: &Step,
doc: &Parameters,
base_parameters: &Parameters,
chunk_overlap: Option<usize>,
) -> Result<Vec<Parameters>, PromptTokensError> {
let splitter = self
.get_tokenizer(step.options())
.map_err(|_e| PromptTokensError::UnableToCompute)?;
let text = doc.get_text().ok_or(PromptTokensError::UnableToCompute)?;
let prompt = step.format(&base_parameters.combine(&Parameters::new_with_text("")))?;
let tokens_used = self.tokens_used(step.options(), &prompt)?;
let chunk_overlap = chunk_overlap.unwrap_or(0);
let split_params = splitter
.split_text(
&text,
tokens_used.max_tokens as usize - tokens_used.tokens_used as usize,
chunk_overlap,
)
.map_err(|_e| PromptTokensError::UnableToCompute)?
.into_iter()
.map(Parameters::new_with_text)
.collect();
Ok(split_params)
}
}
impl<E: traits::Executor> ExecutorTokenCountExt for E {}
pub struct TokenCount {
max_tokens: i32,
tokens_used: i32,
}
impl TokenCount {
pub fn new(max_tokens: i32, tokens_used: i32) -> Self {
Self {
max_tokens,
tokens_used,
}
}
pub fn tokens_remaining(&self) -> i32 {
self.max_tokens - self.tokens_used
}
pub fn has_tokens_remaining(&self) -> bool {
self.has_room_for(1)
}
pub fn has_room_for(&self, tokens: i32) -> bool {
self.tokens_remaining() >= tokens
}
}
#[derive(Error, Debug, Clone)]
pub enum TokenizerError {
#[error("Error tokenizing input text")]
TokenizationError,
#[error("Error stringifying tokens to text")]
ToStringError,
#[error("Error creating tokenizer")]
TokenizerCreationError,
#[error("Token Collection type mismatch")]
TokenCollectionTypeMismatch,
}
pub trait Tokenizer {
fn tokenize_str(&self, doc: &str) -> Result<TokenCollection, TokenizerError>;
fn to_string(&self, tokens: TokenCollection) -> Result<String, TokenizerError>;
fn split_text(
&self,
doc: &str,
max_tokens_per_chunk: usize,
chunk_overlap: usize,
) -> Result<Vec<String>, TokenizerError> {
let tokens = self.tokenize_str(doc)?;
let step_size = max(
max_tokens_per_chunk.checked_sub(chunk_overlap).unwrap_or(1),
1,
);
debug_assert_ne!(step_size, 0);
(0..tokens.len())
.step_by(step_size)
.map(|start_idx| {
let end_idx = usize::min(start_idx + max_tokens_per_chunk, tokens.len());
self.to_string(tokens.slice(start_idx, end_idx))
})
.collect()
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(transparent)]
pub struct Token(TokenImpl);
#[derive(Serialize, Deserialize, Clone, Debug)]
enum TokenImpl {
I32(i32),
USize(usize),
}
impl From<i32> for Token {
fn from(value: i32) -> Self {
Token(TokenImpl::I32(value))
}
}
impl From<usize> for Token {
fn from(value: usize) -> Self {
Token(TokenImpl::USize(value))
}
}
impl Token {
pub fn to_i32(&self) -> Option<i32> {
match &self.0 {
TokenImpl::I32(x) => Some(*x),
_ => None,
}
}
pub fn to_usize(&self) -> Option<usize> {
match &self.0 {
TokenImpl::USize(x) => Some(*x),
_ => None,
}
}
}
#[derive(Debug)]
pub struct TokenCollection(TokenCollectionImpl);
#[derive(Debug)]
enum TokenCollectionImpl {
I32(Vec<i32>),
Usize(Vec<usize>),
}
impl TokenCollection {
pub fn as_i32(self) -> Result<Vec<i32>, TokenizerError> {
match self.0 {
TokenCollectionImpl::I32(v) => Ok(v),
_ => Err(TokenizerError::TokenCollectionTypeMismatch),
}
}
pub fn as_usize(self) -> Result<Vec<usize>, TokenizerError> {
match self.0 {
TokenCollectionImpl::Usize(v) => Ok(v),
_ => Err(TokenizerError::TokenCollectionTypeMismatch),
}
}
pub fn len(&self) -> usize {
match &self.0 {
TokenCollectionImpl::I32(x) => x.len(),
TokenCollectionImpl::Usize(x) => x.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn slice(&self, start: usize, end: usize) -> Self {
match &self.0 {
TokenCollectionImpl::I32(v) => Vec::from(&v[start..end]).into(),
TokenCollectionImpl::Usize(v) => Vec::from(&v[start..end]).into(),
}
}
}
impl From<Vec<i32>> for TokenCollection {
fn from(v: Vec<i32>) -> Self {
TokenCollection(TokenCollectionImpl::I32(v))
}
}
impl From<Vec<usize>> for TokenCollection {
fn from(v: Vec<usize>) -> Self {
TokenCollection(TokenCollectionImpl::Usize(v))
}
}