ai_chain/
traits.rs

1//! # Traits Module
2//!
3//! Welcome to the `traits` module! This is where ai-chain houses its public traits, which define the essential behavior of steps and executors. These traits are the backbone of our library, and they provide the foundation for creating and working with different models in ai-chain.
4//!
5//! Here's a brief overview of the key concepts:
6//! - **Steps**: These are the building blocks that make up the chains. Steps define the parameters, including the prompt that is sent to the LLM (Large Language Model).
7//! - **Executors**: These are responsible for performing the steps. They take the output of a step, invoke the model with that input, and return the resulting output.
8//!
9//! By implementing these traits, you can set up a new model and use it in your application. Your step defines the input to the model, and your executor invokes the model and returns the output. The output of the executor is then passed to the next step in the chain, and so on.
10//!
11
12use std::{error::Error, fmt::Debug};
13
14use crate::{
15    options::Options,
16    output::Output,
17    prompt::Prompt,
18    schema::{Document, EmptyMetadata},
19    tokens::{PromptTokensError, TokenCount, Tokenizer, TokenizerError},
20};
21use async_trait::async_trait;
22
23#[derive(thiserror::Error, Debug)]
24#[error("unable to create executor")]
25pub enum ExecutorCreationError {
26    #[error("unable to create executor: {0}")]
27    InnerError(#[from] Box<dyn Error + Send + Sync>),
28    #[error("Field must be set: {0}")]
29    FieldRequiredError(String),
30    #[error("Invalid value. {0}")]
31    InvalidValue(String),
32}
33
34#[derive(thiserror::Error, Debug)]
35/// An error indicating that the model was not succesfully run.
36pub enum ExecutorError {
37    #[error("Unable to run model: {0}")]
38    /// An error occuring in the underlying executor code that doesn't fit any other category.
39    InnerError(#[from] Box<dyn Error + Send + Sync>),
40    #[error("Invalid options when calling the executor")]
41    /// An error indicating that the model was invoked with invalid options
42    InvalidOptions,
43    #[error(transparent)]
44    /// An error tokenizing the prompt.
45    PromptTokens(PromptTokensError),
46    #[error("the context was to small to fit your input")]
47    ContextTooSmall,
48}
49
50#[async_trait]
51/// The `Executor` trait represents an executor that performs a single step in a chain. It takes a
52/// step, executes it, and returns the output.
53pub trait Executor {
54    type StepTokenizer<'a>: Tokenizer
55    where
56        Self: 'a;
57
58    /// Create a new executor with the given options. If you don't need to set any options, you can use the `new` method instead.
59    /// # Parameters
60    /// * `options`: The options to set.
61    fn new_with_options(options: Options) -> Result<Self, ExecutorCreationError>
62    where
63        Self: Sized;
64
65    fn new() -> Result<Self, ExecutorCreationError>
66    where
67        Self: Sized,
68    {
69        Self::new_with_options(Options::empty().clone())
70    }
71
72    async fn execute(&self, options: &Options, prompt: &Prompt) -> Result<Output, ExecutorError>;
73
74    /// Calculates the number of tokens used by the step given a set of parameters.
75    ///
76    /// The step and the parameters together are used to form full prompt, which is then tokenized
77    /// and the token count is returned.
78    ///
79    /// # Parameters
80    ///
81    /// * `options`: The per-invocation options that affect the token allowance.
82    /// * `prompt`: The prompt passed into step
83    ///
84    /// # Returns
85    ///
86    /// A `Result` containing the token count, or an error if there was a problem.
87    fn tokens_used(
88        &self,
89        options: &Options,
90        prompt: &Prompt,
91    ) -> Result<TokenCount, PromptTokensError>;
92
93    /// Returns the maximum number of input tokens allowed by the model used.
94    ///
95    /// # Parameters
96    ///
97    /// * `options`: The per-invocation options that affect the token allowance.
98    ///
99    /// # Returns
100    /// The max token count for the step
101    fn max_tokens_allowed(&self, options: &Options) -> i32;
102
103    /// Returns a possible answer prefix inserted by the model, during a certain prompt mode
104    ///
105    /// # Parameters
106    ///
107    /// * `prompt`: The prompt passed into step
108    ///
109    /// # Returns
110    ///
111    /// A `Option` containing a String if  prefix exists, or none if there is no prefix
112    fn answer_prefix(&self, prompt: &Prompt) -> Option<String>;
113
114    /// Creates a tokenizer, depending on the model used by `step`.
115    ///
116    /// # Parameters
117    ///
118    /// * `step`: The step to get an associated tokenizer for.
119    ///
120    /// # Returns
121    ///
122    /// A `Result` containing a tokenizer, or an error if there was a problem.
123    fn get_tokenizer(&self, options: &Options) -> Result<Self::StepTokenizer<'_>, TokenizerError>;
124}
125
126/// This marker trait is needed so the concrete VectorStore::Error can have a derived From<Embeddings::Error>
127pub trait EmbeddingsError {}
128
129#[async_trait]
130pub trait Embeddings {
131    type Error: Send + Debug + Error + EmbeddingsError;
132    async fn embed_texts(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, Self::Error>;
133    async fn embed_query(&self, query: String) -> Result<Vec<f32>, Self::Error>;
134}
135
136#[derive(thiserror::Error, Debug)]
137#[error("unable to create embeddings")]
138pub enum EmbeddingsCreationError {
139    #[error("unable to create embeddings: {0}")]
140    InnerError(#[from] Box<dyn Error + Send + Sync>),
141    #[error("Field must be set: {0}")]
142    FieldRequiredError(String),
143}
144
145/// This marker trait is needed so users of VectorStore can derive From<VectorStore::Error>
146pub trait VectorStoreError {}
147
148#[async_trait]
149pub trait VectorStore<E, M = EmptyMetadata>
150where
151    E: Embeddings,
152    M: serde::Serialize + serde::de::DeserializeOwned,
153{
154    type Error: Debug + Error + VectorStoreError;
155    async fn add_texts(&self, texts: Vec<String>) -> Result<Vec<String>, Self::Error>;
156    async fn add_documents(&self, documents: Vec<Document<M>>) -> Result<Vec<String>, Self::Error>;
157    async fn similarity_search(
158        &self,
159        query: String,
160        limit: u32,
161    ) -> Result<Vec<Document<M>>, Self::Error>;
162}