llm_weaver/
lib.rs

1//! Flexible library developed for creating and managing coherent narratives which leverage LLMs
2//! (Large Language Models) to generate dynamic responses.
3//!
4//! Built based on [OpenAI's recommended tactics](https://platform.openai.com/docs/guides/gpt-best-practices/tactic-for-dialogue-applications-that-require-very-long-conversations-summarize-or-filter-previous-dialogue),
5//! LLM Weaver facilitates extended interactions with any LLM, seamlessly handling conversations
6//! that exceed a model's maximum context token limitation.
7//!
8//! [`Loom`] is the core of this library. It prompts the configured LLM and stores the message
9//! history as [`TapestryFragment`] instances. This trait is highly configurable through the
10//! [`Config`] trait to support a wide range of use cases.
11//!
12//! # Nomenclature
13//!
14//! - **Tapestry**: A collection of [`TapestryFragment`] instances.
15//! - **TapestryFragment**: A single part of a conversation containing a list of messages along with
16//!   other metadata.
17//! - **ContextMessage**: Represents a single message in a [`TapestryFragment`] instance.
18//! - **Loom**: The machine that drives all of the core methods that should be used across any
19//!   service that needs to prompt LLM and receive a response.
20//! - **LLM**: Language Model.
21//!
22//! # Architecture
23//!
24//! Please refer to the [`architecture::Diagram`] for a visual representation of the core
25//! components of this library.
26//!
27//! # Usage
28//!
29//! You must implement the [`Config`] trait, which defines the necessary types and methods needed by
30//! [`Loom`].
31//!
32//! This library uses Redis as the default storage backend for storing [`TapestryFragment`]. It is
33//! expected that a Redis instance is running and that the following environment variables are set:
34//!
35//! - `REDIS_PROTOCOL`
36//! - `REDIS_HOST`
37//! - `REDIS_PORT`
38//! - `REDIS_PASSWORD`
39//!
40//! Should there be a need to integrate a distinct storage backend, you have the flexibility to
41//! create a custom handler by implementing the [`TapestryChestHandler`] trait and injecting it
42//! into the [`Config::Chest`] associated type.
43#![feature(async_closure)]
44#![feature(associated_type_defaults)]
45#![feature(more_qualified_paths)]
46#![feature(const_option)]
47#![feature(anonymous_lifetime_in_impl_trait)]
48#![feature(once_cell_try)]
49
50use std::{
51	fmt::{Debug, Display},
52	marker::PhantomData,
53	str::FromStr,
54};
55
56use async_trait::async_trait;
57pub use bounded_integer::BoundedU8;
58use num_traits::{
59	CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive, SaturatingAdd, SaturatingMul,
60	SaturatingSub, ToPrimitive, Unsigned,
61};
62use serde::{de::DeserializeOwned, Deserialize, Serialize};
63use tracing::trace;
64
65pub mod architecture;
66pub mod loom;
67pub mod storage;
68pub mod types;
69
70#[cfg(test)]
71mod mock;
72#[cfg(test)]
73mod tests;
74
75pub use storage::TapestryChestHandler;
76use types::{LoomError, SummaryModelTokens, WeaveError};
77
78use crate::types::{PromptModelTokens, WrapperRole};
79
80pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
81
82/// Represents a unique identifier for any arbitrary entity.
83///
84/// This trait provides a method for generating a standardized key, which can be utilized across
85/// various implementations in the library, such as the [`TapestryChest`] implementation for storing
86/// keys in redis using the `base_key` method.
87///
88/// ```ignore
89/// use loom::{TapestryId, Get};
90/// use std::fmt::{Debug, Display};
91///
92/// struct MyTapestryId {
93///     id: String,
94///     sub_id: String,
95///     // ...
96/// }
97///
98/// impl TapestryId for MyTapestryId {
99///     fn base_key(&self) -> String {
100///         format!("{}:{}", self.id, self.sub_id)
101///     }
102/// }
103/// ```
104pub trait TapestryId: Debug + Clone + Send + Sync + 'static {
105	/// Returns the base key.
106	///
107	/// This method should produce a unique string identifier, that will serve as a key for
108	/// associated objects or data within [`TapestryChestHandler`] implementations.
109	fn base_key(&self) -> String;
110}
111
112#[derive(Debug)]
113pub struct LlmConfig<T: Config, L: Llm<T>> {
114	pub model: L,
115	pub params: L::Parameters,
116}
117
118#[async_trait]
119pub trait Llm<T: Config>:
120	Default + Sized + PartialEq + Eq + Clone + Debug + Copy + Send + Sync
121{
122	/// Token to word ratio.
123	///
124	/// Defaults to `75%`
125	const TOKEN_WORD_RATIO: BoundedU8<0, 100> = BoundedU8::new(75).unwrap();
126
127	/// Tokens are an LLM concept which represents pieces of words. For example, each ChatGPT token
128	/// represents roughly 75% of a word.
129	///
130	/// This type is used primarily for tracking the number of tokens in a [`TapestryFragment`] and
131	/// counting the number of tokens in a string.
132	///
133	/// This type is configurable to allow for different types of tokens to be used. For example,
134	/// [`u16`] can be used to represent the number of tokens in a string.
135	type Tokens: Copy
136		+ ToString
137		+ FromStr
138		+ Display
139		+ Debug
140		+ ToString
141		+ Serialize
142		+ DeserializeOwned
143		+ Default
144		+ TryFrom<usize>
145		+ Unsigned
146		+ FromPrimitive
147		+ ToPrimitive
148		+ std::iter::Sum
149		+ CheckedAdd
150		+ CheckedSub
151		+ SaturatingAdd
152		+ SaturatingSub
153		+ SaturatingMul
154		+ CheckedDiv
155		+ CheckedMul
156		+ Ord
157		+ Sync
158		+ Send;
159	/// Type representing the prompt request.
160	type Request: Clone + From<ContextMessage<T>> + Display + Send;
161	/// Type representing the response to a prompt.
162	type Response: Clone + Into<Option<String>> + Send;
163	/// Type representing the parameters for a prompt.
164	type Parameters: Debug + Clone + Send + Sync;
165
166	/// The maximum number of tokens that can be processed at once by an LLM model.
167	fn max_context_length(&self) -> Self::Tokens;
168	/// Get the model name.
169	///
170	/// This is used for logging purposes but also can be used to fetch a specific model based on
171	/// `&self`. For example, the model passed to [`Loom::weave`] can be represented as an enum with
172	/// a multitude of variants, each representing a different model.
173	fn name(&self) -> &'static str;
174	/// Alias for the model.
175	///
176	/// Can be used for any unforseen use cases where the model name is not sufficient.
177	fn alias(&self) -> &'static str;
178	/// Calculates the number of tokens in a string.
179	///
180	/// This may vary depending on the type of tokens used by the LLM. In the case of ChatGPT, can be calculated using the [tiktoken-rs](https://github.com/zurawiki/tiktoken-rs#counting-token-length) crate.
181	fn count_tokens(content: &str) -> Result<Self::Tokens>;
182	/// Prompt LLM with the supplied messages and parameters.
183	async fn prompt(
184		&self,
185		is_summarizing: bool,
186		prompt_tokens: Self::Tokens,
187		msgs: Vec<Self::Request>,
188		params: &Self::Parameters,
189		max_tokens: Self::Tokens,
190	) -> Result<Self::Response>;
191	/// Compute cost of a message based on model.
192	fn compute_cost(&self, prompt_tokens: Self::Tokens, response_tokens: Self::Tokens) -> f64;
193	/// Calculate the upperbound of tokens allowed for the current [`Config::PromptModel`] before a
194	/// summary is generated.
195	///
196	/// This is calculated by multiplying the maximum context length (tokens) for the current
197	/// [`Config::PromptModel`] by the [`Config::TOKEN_THRESHOLD_PERCENTILE`] and dividing by 100.
198	fn get_max_prompt_token_limit(&self) -> Self::Tokens {
199		let max_context_length = self.max_context_length();
200		let token_threshold = Self::Tokens::from_u8(T::TOKEN_THRESHOLD_PERCENTILE.get()).unwrap();
201		let tokens = match max_context_length.checked_mul(&token_threshold) {
202			Some(tokens) => tokens,
203			None => max_context_length,
204		};
205
206		tokens.checked_div(&Self::Tokens::from_u8(100).unwrap()).unwrap()
207	}
208	/// Get optional max completion token limit.
209	fn get_max_completion_token_limit(&self) -> Option<Self::Tokens> {
210		None
211	}
212	/// [`ContextMessage`]s to [`Llm::Request`] conversion.
213	fn ctx_msgs_to_prompt_requests(&self, msgs: &[ContextMessage<T>]) -> Vec<Self::Request> {
214		msgs.iter().map(|m| m.clone().into()).collect()
215	}
216	/// Convert tokens to words.
217	///
218	/// In the case of ChatGPT, each token represents roughly 75% of a word.
219	fn convert_tokens_to_words(&self, tokens: Self::Tokens) -> Self::Tokens {
220		tokens.saturating_mul(&Self::Tokens::from_u8(Self::TOKEN_WORD_RATIO.get()).unwrap()) /
221			Self::Tokens::from_u8(100).unwrap()
222	}
223}
224
225/// A trait consisting of the main configuration needed to implement [`Loom`].
226#[async_trait]
227pub trait Config: Debug + Sized + Clone + Default + Send + Sync + 'static {
228	/// Number between 0 and 100. Represents the percentile of the maximum number of tokens allowed
229	/// for the current [`Config::PromptModel`] before a summary is generated.
230	///
231	/// Defaults to `85%`
232	const TOKEN_THRESHOLD_PERCENTILE: BoundedU8<0, 100> = BoundedU8::new(85).unwrap();
233	/// Ensures that the maximum completion tokens is at least the minimum response length.
234	///
235	/// If the maximum completion tokens is less than the minimum response length, a summary
236	/// will be generated and a new tapestry fragment will be created.
237	const MINIMUM_RESPONSE_LENGTH: u64;
238
239	/// The LLM to use for generating responses to prompts.
240	type PromptModel: Llm<Self>;
241	/// The LLM to use for generating summaries of the current [`TapestryFragment`] instance.
242	///
243	/// This is separate from [`Config::PromptModel`] to allow for a larger model to be used for
244	/// generating summaries.
245	type SummaryModel: Llm<Self>;
246	/// Storage handler interface for storing and retrieving tapestry fragments.
247	///
248	/// You can optionally enable the `redis` or `rocksdb` features to use the default storage
249	/// implementations for these storage backends.
250	type Chest: TapestryChestHandler<Self>;
251
252	/// Convert [`Config::PromptModel`] to [`Config::SummaryModel`] tokens.
253	fn convert_prompt_tokens_to_summary_model_tokens(
254		tokens: PromptModelTokens<Self>,
255	) -> SummaryModelTokens<Self>;
256}
257
258/// Context message that represent a single message in a [`TapestryFragment`] instance.
259#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
260pub struct ContextMessage<T: Config> {
261	pub role: WrapperRole,
262	pub content: String,
263	pub account_id: Option<String>,
264	pub timestamp: String,
265
266	_phantom: PhantomData<T>,
267}
268
269impl<T: Config> ContextMessage<T> {
270	/// Create a new `ContextMessage` instance.
271	pub fn new(
272		role: WrapperRole,
273		content: String,
274		account_id: Option<String>,
275		timestamp: String,
276	) -> Self {
277		Self { role, content, account_id, timestamp, _phantom: PhantomData }
278	}
279}
280
281/// Represents a single part of a conversation containing a list of messages along with other
282/// metadata.
283///
284/// LLM can only hold a limited amount of tokens in a the entire message history/context.
285/// The total number of `context_tokens` is tracked when [`Loom::weave`] is executed and if it
286/// exceeds the maximum number of tokens allowed for the current GPT [`Config::PromptModel`], then a
287/// summary is generated and a new [`TapestryFragment`] instance is created.
288#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
289pub struct TapestryFragment<T: Config> {
290	/// Total number of _GPT tokens_ in the `context_messages`.
291	pub context_tokens: <T::PromptModel as Llm<T>>::Tokens,
292	/// List of [`ContextMessage`]s that represents the message history.
293	pub context_messages: Vec<ContextMessage<T>>,
294}
295
296impl<T: Config> TapestryFragment<T> {
297	fn new() -> Self {
298		Self::default()
299	}
300
301	/// Add a [`ContextMessage`] to the `context_messages` list.
302	///
303	/// Also increments the `context_tokens` by the number of tokens in the message.
304	fn push_message(&mut self, msg: ContextMessage<T>) -> Result<()> {
305		let tokens = T::PromptModel::count_tokens(&msg.content)?;
306		let new_token_count = self.context_tokens.checked_add(&tokens).ok_or_else(|| {
307			LoomError::from(WeaveError::BadConfig(
308				"Number of tokens exceeds max tokens for model".to_string(),
309			))
310		})?;
311
312		trace!("Pushing message: {:?}, new token count: {}", msg, new_token_count);
313
314		self.context_tokens = new_token_count;
315		self.context_messages.push(msg);
316		Ok(())
317	}
318
319	/// Add a [`ContextMessage`] to the `context_messages` list.
320	///
321	/// Also increments the `context_tokens` by the number of tokens in the message.
322	fn extend_messages(&mut self, msgs: Vec<ContextMessage<T>>) -> Result<()> {
323		let total_new_tokens = msgs
324			.iter()
325			.map(|m| T::PromptModel::count_tokens(&m.content).unwrap())
326			.collect::<Vec<_>>();
327
328		let sum: PromptModelTokens<T> = total_new_tokens
329			.iter()
330			.fold(PromptModelTokens::<T>::default(), |acc, x| acc.saturating_add(x));
331
332		trace!("Extending messages with token sum: {}", sum);
333
334		let new_token_count = self.context_tokens.checked_add(&sum).ok_or_else(|| {
335			LoomError::from(WeaveError::BadConfig(
336				"Number of tokens exceeds max tokens for model".to_string(),
337			))
338		})?;
339
340		// Update the token count and messages only if all checks pass
341		self.context_tokens = new_token_count;
342		for m in msgs {
343			self.context_messages.push(m);
344		}
345
346		Ok(())
347	}
348}