llm_weaver/lib.rs
1//! Flexible library developed for creating and managing coherent narratives which leverage LLMs
2//! (Large Language Models) to generate dynamic responses.
3//!
4//! Built based on [OpenAI's recommended tactics](https://platform.openai.com/docs/guides/gpt-best-practices/tactic-for-dialogue-applications-that-require-very-long-conversations-summarize-or-filter-previous-dialogue),
5//! LLM Weaver facilitates extended interactions with any LLM, seamlessly handling conversations
6//! that exceed a model's maximum context token limitation.
7//!
8//! [`Loom`] is the core of this library. It prompts the configured LLM and stores the message
9//! history as [`TapestryFragment`] instances. This trait is highly configurable through the
10//! [`Config`] trait to support a wide range of use cases.
11//!
12//! # Nomenclature
13//!
14//! - **Tapestry**: A collection of [`TapestryFragment`] instances.
15//! - **TapestryFragment**: A single part of a conversation containing a list of messages along with
16//! other metadata.
17//! - **ContextMessage**: Represents a single message in a [`TapestryFragment`] instance.
18//! - **Loom**: The machine that drives all of the core methods that should be used across any
19//! service that needs to prompt LLM and receive a response.
20//! - **LLM**: Language Model.
21//!
22//! # Architecture
23//!
24//! Please refer to the [`architecture::Diagram`] for a visual representation of the core
25//! components of this library.
26//!
27//! # Usage
28//!
29//! You must implement the [`Config`] trait, which defines the necessary types and methods needed by
30//! [`Loom`].
31//!
32//! This library uses Redis as the default storage backend for storing [`TapestryFragment`]. It is
33//! expected that a Redis instance is running and that the following environment variables are set:
34//!
35//! - `REDIS_PROTOCOL`
36//! - `REDIS_HOST`
37//! - `REDIS_PORT`
38//! - `REDIS_PASSWORD`
39//!
40//! Should there be a need to integrate a distinct storage backend, you have the flexibility to
41//! create a custom handler by implementing the [`TapestryChestHandler`] trait and injecting it
42//! into the [`Config::Chest`] associated type.
43#![feature(async_closure)]
44#![feature(associated_type_defaults)]
45#![feature(more_qualified_paths)]
46#![feature(const_option)]
47#![feature(anonymous_lifetime_in_impl_trait)]
48#![feature(once_cell_try)]
49
50use std::{
51 fmt::{Debug, Display},
52 marker::PhantomData,
53 str::FromStr,
54};
55
56use async_trait::async_trait;
57pub use bounded_integer::BoundedU8;
58use num_traits::{
59 CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive, SaturatingAdd, SaturatingMul,
60 SaturatingSub, ToPrimitive, Unsigned,
61};
62use serde::{de::DeserializeOwned, Deserialize, Serialize};
63use tracing::trace;
64
65pub mod architecture;
66pub mod loom;
67pub mod storage;
68pub mod types;
69
70#[cfg(test)]
71mod mock;
72#[cfg(test)]
73mod tests;
74
75pub use storage::TapestryChestHandler;
76use types::{LoomError, SummaryModelTokens, WeaveError};
77
78use crate::types::{PromptModelTokens, WrapperRole};
79
80pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
81
82/// Represents a unique identifier for any arbitrary entity.
83///
84/// This trait provides a method for generating a standardized key, which can be utilized across
85/// various implementations in the library, such as the [`TapestryChest`] implementation for storing
86/// keys in redis using the `base_key` method.
87///
88/// ```ignore
89/// use loom::{TapestryId, Get};
90/// use std::fmt::{Debug, Display};
91///
92/// struct MyTapestryId {
93/// id: String,
94/// sub_id: String,
95/// // ...
96/// }
97///
98/// impl TapestryId for MyTapestryId {
99/// fn base_key(&self) -> String {
100/// format!("{}:{}", self.id, self.sub_id)
101/// }
102/// }
103/// ```
104pub trait TapestryId: Debug + Clone + Send + Sync + 'static {
105 /// Returns the base key.
106 ///
107 /// This method should produce a unique string identifier, that will serve as a key for
108 /// associated objects or data within [`TapestryChestHandler`] implementations.
109 fn base_key(&self) -> String;
110}
111
112#[derive(Debug)]
113pub struct LlmConfig<T: Config, L: Llm<T>> {
114 pub model: L,
115 pub params: L::Parameters,
116}
117
118#[async_trait]
119pub trait Llm<T: Config>:
120 Default + Sized + PartialEq + Eq + Clone + Debug + Copy + Send + Sync
121{
122 /// Token to word ratio.
123 ///
124 /// Defaults to `75%`
125 const TOKEN_WORD_RATIO: BoundedU8<0, 100> = BoundedU8::new(75).unwrap();
126
127 /// Tokens are an LLM concept which represents pieces of words. For example, each ChatGPT token
128 /// represents roughly 75% of a word.
129 ///
130 /// This type is used primarily for tracking the number of tokens in a [`TapestryFragment`] and
131 /// counting the number of tokens in a string.
132 ///
133 /// This type is configurable to allow for different types of tokens to be used. For example,
134 /// [`u16`] can be used to represent the number of tokens in a string.
135 type Tokens: Copy
136 + ToString
137 + FromStr
138 + Display
139 + Debug
140 + ToString
141 + Serialize
142 + DeserializeOwned
143 + Default
144 + TryFrom<usize>
145 + Unsigned
146 + FromPrimitive
147 + ToPrimitive
148 + std::iter::Sum
149 + CheckedAdd
150 + CheckedSub
151 + SaturatingAdd
152 + SaturatingSub
153 + SaturatingMul
154 + CheckedDiv
155 + CheckedMul
156 + Ord
157 + Sync
158 + Send;
159 /// Type representing the prompt request.
160 type Request: Clone + From<ContextMessage<T>> + Display + Send;
161 /// Type representing the response to a prompt.
162 type Response: Clone + Into<Option<String>> + Send;
163 /// Type representing the parameters for a prompt.
164 type Parameters: Debug + Clone + Send + Sync;
165
166 /// The maximum number of tokens that can be processed at once by an LLM model.
167 fn max_context_length(&self) -> Self::Tokens;
168 /// Get the model name.
169 ///
170 /// This is used for logging purposes but also can be used to fetch a specific model based on
171 /// `&self`. For example, the model passed to [`Loom::weave`] can be represented as an enum with
172 /// a multitude of variants, each representing a different model.
173 fn name(&self) -> &'static str;
174 /// Alias for the model.
175 ///
176 /// Can be used for any unforseen use cases where the model name is not sufficient.
177 fn alias(&self) -> &'static str;
178 /// Calculates the number of tokens in a string.
179 ///
180 /// This may vary depending on the type of tokens used by the LLM. In the case of ChatGPT, can be calculated using the [tiktoken-rs](https://github.com/zurawiki/tiktoken-rs#counting-token-length) crate.
181 fn count_tokens(content: &str) -> Result<Self::Tokens>;
182 /// Prompt LLM with the supplied messages and parameters.
183 async fn prompt(
184 &self,
185 is_summarizing: bool,
186 prompt_tokens: Self::Tokens,
187 msgs: Vec<Self::Request>,
188 params: &Self::Parameters,
189 max_tokens: Self::Tokens,
190 ) -> Result<Self::Response>;
191 /// Compute cost of a message based on model.
192 fn compute_cost(&self, prompt_tokens: Self::Tokens, response_tokens: Self::Tokens) -> f64;
193 /// Calculate the upperbound of tokens allowed for the current [`Config::PromptModel`] before a
194 /// summary is generated.
195 ///
196 /// This is calculated by multiplying the maximum context length (tokens) for the current
197 /// [`Config::PromptModel`] by the [`Config::TOKEN_THRESHOLD_PERCENTILE`] and dividing by 100.
198 fn get_max_prompt_token_limit(&self) -> Self::Tokens {
199 let max_context_length = self.max_context_length();
200 let token_threshold = Self::Tokens::from_u8(T::TOKEN_THRESHOLD_PERCENTILE.get()).unwrap();
201 let tokens = match max_context_length.checked_mul(&token_threshold) {
202 Some(tokens) => tokens,
203 None => max_context_length,
204 };
205
206 tokens.checked_div(&Self::Tokens::from_u8(100).unwrap()).unwrap()
207 }
208 /// Get optional max completion token limit.
209 fn get_max_completion_token_limit(&self) -> Option<Self::Tokens> {
210 None
211 }
212 /// [`ContextMessage`]s to [`Llm::Request`] conversion.
213 fn ctx_msgs_to_prompt_requests(&self, msgs: &[ContextMessage<T>]) -> Vec<Self::Request> {
214 msgs.iter().map(|m| m.clone().into()).collect()
215 }
216 /// Convert tokens to words.
217 ///
218 /// In the case of ChatGPT, each token represents roughly 75% of a word.
219 fn convert_tokens_to_words(&self, tokens: Self::Tokens) -> Self::Tokens {
220 tokens.saturating_mul(&Self::Tokens::from_u8(Self::TOKEN_WORD_RATIO.get()).unwrap()) /
221 Self::Tokens::from_u8(100).unwrap()
222 }
223}
224
225/// A trait consisting of the main configuration needed to implement [`Loom`].
226#[async_trait]
227pub trait Config: Debug + Sized + Clone + Default + Send + Sync + 'static {
228 /// Number between 0 and 100. Represents the percentile of the maximum number of tokens allowed
229 /// for the current [`Config::PromptModel`] before a summary is generated.
230 ///
231 /// Defaults to `85%`
232 const TOKEN_THRESHOLD_PERCENTILE: BoundedU8<0, 100> = BoundedU8::new(85).unwrap();
233 /// Ensures that the maximum completion tokens is at least the minimum response length.
234 ///
235 /// If the maximum completion tokens is less than the minimum response length, a summary
236 /// will be generated and a new tapestry fragment will be created.
237 const MINIMUM_RESPONSE_LENGTH: u64;
238
239 /// The LLM to use for generating responses to prompts.
240 type PromptModel: Llm<Self>;
241 /// The LLM to use for generating summaries of the current [`TapestryFragment`] instance.
242 ///
243 /// This is separate from [`Config::PromptModel`] to allow for a larger model to be used for
244 /// generating summaries.
245 type SummaryModel: Llm<Self>;
246 /// Storage handler interface for storing and retrieving tapestry fragments.
247 ///
248 /// You can optionally enable the `redis` or `rocksdb` features to use the default storage
249 /// implementations for these storage backends.
250 type Chest: TapestryChestHandler<Self>;
251
252 /// Convert [`Config::PromptModel`] to [`Config::SummaryModel`] tokens.
253 fn convert_prompt_tokens_to_summary_model_tokens(
254 tokens: PromptModelTokens<Self>,
255 ) -> SummaryModelTokens<Self>;
256}
257
258/// Context message that represent a single message in a [`TapestryFragment`] instance.
259#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
260pub struct ContextMessage<T: Config> {
261 pub role: WrapperRole,
262 pub content: String,
263 pub account_id: Option<String>,
264 pub timestamp: String,
265
266 _phantom: PhantomData<T>,
267}
268
269impl<T: Config> ContextMessage<T> {
270 /// Create a new `ContextMessage` instance.
271 pub fn new(
272 role: WrapperRole,
273 content: String,
274 account_id: Option<String>,
275 timestamp: String,
276 ) -> Self {
277 Self { role, content, account_id, timestamp, _phantom: PhantomData }
278 }
279}
280
281/// Represents a single part of a conversation containing a list of messages along with other
282/// metadata.
283///
284/// LLM can only hold a limited amount of tokens in a the entire message history/context.
285/// The total number of `context_tokens` is tracked when [`Loom::weave`] is executed and if it
286/// exceeds the maximum number of tokens allowed for the current GPT [`Config::PromptModel`], then a
287/// summary is generated and a new [`TapestryFragment`] instance is created.
288#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
289pub struct TapestryFragment<T: Config> {
290 /// Total number of _GPT tokens_ in the `context_messages`.
291 pub context_tokens: <T::PromptModel as Llm<T>>::Tokens,
292 /// List of [`ContextMessage`]s that represents the message history.
293 pub context_messages: Vec<ContextMessage<T>>,
294}
295
296impl<T: Config> TapestryFragment<T> {
297 fn new() -> Self {
298 Self::default()
299 }
300
301 /// Add a [`ContextMessage`] to the `context_messages` list.
302 ///
303 /// Also increments the `context_tokens` by the number of tokens in the message.
304 fn push_message(&mut self, msg: ContextMessage<T>) -> Result<()> {
305 let tokens = T::PromptModel::count_tokens(&msg.content)?;
306 let new_token_count = self.context_tokens.checked_add(&tokens).ok_or_else(|| {
307 LoomError::from(WeaveError::BadConfig(
308 "Number of tokens exceeds max tokens for model".to_string(),
309 ))
310 })?;
311
312 trace!("Pushing message: {:?}, new token count: {}", msg, new_token_count);
313
314 self.context_tokens = new_token_count;
315 self.context_messages.push(msg);
316 Ok(())
317 }
318
319 /// Add a [`ContextMessage`] to the `context_messages` list.
320 ///
321 /// Also increments the `context_tokens` by the number of tokens in the message.
322 fn extend_messages(&mut self, msgs: Vec<ContextMessage<T>>) -> Result<()> {
323 let total_new_tokens = msgs
324 .iter()
325 .map(|m| T::PromptModel::count_tokens(&m.content).unwrap())
326 .collect::<Vec<_>>();
327
328 let sum: PromptModelTokens<T> = total_new_tokens
329 .iter()
330 .fold(PromptModelTokens::<T>::default(), |acc, x| acc.saturating_add(x));
331
332 trace!("Extending messages with token sum: {}", sum);
333
334 let new_token_count = self.context_tokens.checked_add(&sum).ok_or_else(|| {
335 LoomError::from(WeaveError::BadConfig(
336 "Number of tokens exceeds max tokens for model".to_string(),
337 ))
338 })?;
339
340 // Update the token count and messages only if all checks pass
341 self.context_tokens = new_token_count;
342 for m in msgs {
343 self.context_messages.push(m);
344 }
345
346 Ok(())
347 }
348}