llm_weaver/
loom.rs

1use std::{collections::VecDeque, marker::PhantomData};
2
3use num_traits::{CheckedAdd, FromPrimitive, SaturatingAdd, SaturatingSub, Zero};
4use tracing::{debug, error, instrument, trace};
5
6use crate::{
7	types::{
8		LoomError, PromptModelRequest, PromptModelTokens, SummaryModelTokens, VecPromptMsgsDeque,
9		WeaveError, WrapperRole, ASSISTANT_ROLE, SYSTEM_ROLE,
10	},
11	Config, ContextMessage, Llm, LlmConfig, TapestryChestHandler, TapestryFragment, TapestryId,
12};
13
14/// The machine that drives all of the core methods that should be used across any service
15/// that needs to prompt LLM and receive a response.
16///
17/// This is implemented over the [`Config`] trait.
18#[derive(Debug)]
19pub struct Loom<T: Config> {
20	pub chest: T::Chest,
21	_phantom: PhantomData<T>,
22}
23
24impl<T: Config> Loom<T> {
25	/// Creates a new instance of `Loom`.
26	pub fn new() -> Self {
27		Self { chest: <T::Chest as TapestryChestHandler<T>>::new(), _phantom: PhantomData }
28	}
29
30	/// Prompt LLM Weaver for a response for [`TapestryId`].
31	///
32	/// Prompts LLM with the current [`TapestryFragment`] instance and the new `msgs`.
33	///
34	/// A summary will be generated of the current [`TapestryFragment`] instance if the total number
35	/// of tokens in the `context_messages` exceeds the maximum number of tokens allowed for the
36	/// current [`Config::PromptModel`] or custom max tokens. This threshold is affected by the
37	/// [`Config::TOKEN_THRESHOLD_PERCENTILE`].
38	///
39	/// # Parameters
40	///
41	/// - `prompt_llm_config`: The [`Config::PromptModel`] to use for prompting LLM.
42	/// - `summary_llm_config`: The [`Config::SummaryModel`] to use for generating summaries.
43	/// - `tapestry_id`: The [`TapestryId`] to use for storing the [`TapestryFragment`] instance.
44	/// - `instructions`: The instruction message to be used for the current [`TapestryFragment`]
45	///   instance.
46	/// - `msgs`: The messages to prompt the LLM with.
47	#[instrument(skip(self, instructions, msgs))]
48	pub async fn weave<TID: TapestryId>(
49		&self,
50		prompt_llm_config: LlmConfig<T, T::PromptModel>,
51		summary_llm_config: LlmConfig<T, T::SummaryModel>,
52		tapestry_id: TID,
53		instructions: String,
54		mut msgs: Vec<ContextMessage<T>>,
55	) -> Result<
56		(<<T as Config>::PromptModel as Llm<T>>::Response, u64, bool),
57		Box<dyn std::error::Error + Send + Sync>,
58	> {
59		let instructions_ctx_msg =
60			Self::build_context_message(SYSTEM_ROLE.into(), instructions, None);
61		let instructions_req_msg: PromptModelRequest<T> = instructions_ctx_msg.clone().into();
62
63		trace!("Fetching current tapestry fragment for ID: {:?}", tapestry_id);
64
65		let current_tapestry_fragment = self
66			.chest
67			.get_tapestry_fragment(tapestry_id.clone(), None)
68			.await?
69			.unwrap_or_default();
70
71		// Get max token limit which cannot be exceeded in a tapestry fragment
72		let max_prompt_tokens_limit = prompt_llm_config.model.get_max_prompt_token_limit();
73
74		// Request messages which will be sent as a whole to the LLM
75		let mut req_msgs = VecPromptMsgsDeque::<T, T::PromptModel>::with_capacity(
76			current_tapestry_fragment.context_messages.len() + 1,
77		);
78
79		// Add instructions as the first message
80		req_msgs.push_front(instructions_req_msg);
81
82		// Convert and append all tapestry fragment messages to the request messages.
83		let mut ctx_msgs = VecDeque::from(
84			prompt_llm_config
85				.model
86				.ctx_msgs_to_prompt_requests(&current_tapestry_fragment.context_messages),
87		);
88		req_msgs.append(&mut ctx_msgs);
89
90		// New messages are not added here yet since we first calculate if the new `msgs` would
91		// have the tapestry fragment exceed the maximum token limit and require a summary
92		// generation resulting in a new tapestry fragment.
93		//
94		// Either we are starting a new tapestry fragment with the instruction and summary messages
95		// or we are continuing the current tapestry fragment.
96		let msgs_tokens = Self::count_tokens_in_messages(msgs.iter());
97
98		trace!(
99			"Total tokens after adding new messages: {:?}, maximum allowed: {:?}",
100			req_msgs.tokens.saturating_add(&msgs_tokens),
101			max_prompt_tokens_limit
102		);
103
104		// Check if the total number of tokens in the tapestry fragment exceeds the maximum number
105		// of tokens allowed after adding the new messages and the minimum response length.
106		let does_exceeding_max_token_limit = max_prompt_tokens_limit <=
107			req_msgs.tokens.saturating_add(&msgs_tokens).saturating_add(
108				&PromptModelTokens::<T>::from_u64(T::MINIMUM_RESPONSE_LENGTH).unwrap(),
109			);
110
111		let (mut tapestry_fragment_to_persist, was_summary_generated) =
112			if does_exceeding_max_token_limit {
113				trace!("Generating summary as the token limit exceeded");
114
115				// Summary generation should not exceed the maximum token limit of the prompt model
116				// since it will be added to the tapestry fragment
117				let summary_max_tokens: PromptModelTokens<T> =
118					prompt_llm_config.model.max_context_length() - max_prompt_tokens_limit;
119
120				let summary = Self::generate_summary(
121					&summary_llm_config,
122					&current_tapestry_fragment,
123					T::convert_prompt_tokens_to_summary_model_tokens(summary_max_tokens),
124				)
125				.await?;
126
127				let summary_ctx_msg = Self::build_context_message(
128					SYSTEM_ROLE.into(),
129					format!("\n\"\"\"\nSummary\n {}", summary),
130					None,
131				);
132
133				// Truncate all tapestry fragment messages except for the instructions and add the
134				// summary
135				req_msgs.truncate(1);
136				req_msgs.push_back(summary_ctx_msg.clone().into());
137
138				// Create new tapestry fragment
139				let mut new_tapestry_fragment = TapestryFragment::new();
140				new_tapestry_fragment.push_message(summary_ctx_msg)?;
141
142				(new_tapestry_fragment, true)
143			} else {
144				(current_tapestry_fragment, false)
145			};
146
147		// Add new messages to the request messages
148		req_msgs.extend(msgs.iter().map(|m| m.clone().into()).collect::<Vec<_>>());
149
150		// Tokens available for LLM response which would not exceed maximum token limit
151		let max_completion_tokens = max_prompt_tokens_limit.saturating_sub(&req_msgs.tokens);
152
153		trace!("Max completion tokens available: {:?}", max_completion_tokens);
154
155		if max_completion_tokens.is_zero() {
156			return Err(LoomError::from(WeaveError::MaxCompletionTokensIsZero).into());
157		}
158
159		trace!("Prompting LLM with request messages");
160
161		let response = prompt_llm_config
162			.model
163			.prompt(
164				false,
165				req_msgs.tokens,
166				req_msgs.into_vec(),
167				&prompt_llm_config.params,
168				max_completion_tokens,
169			)
170			.await
171			.map_err(|e| {
172				error!("Failed to prompt LLM: {}", e);
173				e
174			})?;
175
176		// Add LLM response to the tapestry fragment messages to save
177		msgs.push(Self::build_context_message(
178			ASSISTANT_ROLE.into(),
179			response.clone().into().unwrap_or_default(),
180			None,
181		));
182
183		// Add new messages and response to the tapestry fragment which will be persisted in the
184		// database
185		tapestry_fragment_to_persist.extend_messages(msgs)?;
186
187		debug!("Saving tapestry fragment: {:?}", tapestry_fragment_to_persist);
188
189		// Save tapestry fragment to database
190		// When summarized, the tapestry_fragment will be saved under a new instance
191		let tapestry_fragment_id = self
192			.chest
193			.save_tapestry_fragment(
194				&tapestry_id,
195				tapestry_fragment_to_persist,
196				was_summary_generated,
197			)
198			.await
199			.map_err(|e| {
200				error!("Failed to save tapestry fragment: {}", e);
201				e
202			})?;
203
204		Ok((response, tapestry_fragment_id, was_summary_generated))
205	}
206
207	/// Generates the summary of the current [`TapestryFragment`] instance.
208	///
209	/// Returns the summary message as a string.
210	async fn generate_summary(
211		summary_model_config: &LlmConfig<T, T::SummaryModel>,
212		tapestry_fragment: &TapestryFragment<T>,
213		summary_max_tokens: SummaryModelTokens<T>,
214	) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
215		trace!(
216			"Generating summary with max tokens: {:?}, for tapestry fragment: {:?}",
217			summary_max_tokens,
218			tapestry_fragment
219		);
220
221		let mut summary_generation_prompt = VecPromptMsgsDeque::<T, T::SummaryModel>::new();
222
223		summary_generation_prompt.extend(
224			summary_model_config
225				.model
226				.ctx_msgs_to_prompt_requests(tapestry_fragment.context_messages.as_slice()),
227		);
228
229		let res = summary_model_config
230			.model
231			.prompt(
232				true,
233				summary_generation_prompt.tokens,
234				summary_generation_prompt.into_vec(),
235				&summary_model_config.params,
236				summary_max_tokens,
237			)
238			.await
239			.map_err(|e| {
240				error!("Failed to prompt LLM: {}", e);
241				e
242			})?;
243
244		let summary_response_content = res.into();
245
246		trace!("Generated summary: {:?}", summary_response_content);
247
248		Ok(summary_response_content.unwrap_or_default())
249	}
250
251	/// Helper method to build a [`ContextMessage`]
252	pub fn build_context_message(
253		role: WrapperRole,
254		content: String,
255		account_id: Option<String>,
256	) -> ContextMessage<T> {
257		trace!("Building context message for role: {:?}, content: {}", role, content);
258
259		ContextMessage {
260			role,
261			content,
262			account_id,
263			timestamp: chrono::Utc::now().to_rfc3339(),
264			_phantom: PhantomData,
265		}
266	}
267
268	fn count_tokens_in_messages(
269		msgs: impl Iterator<Item = &ContextMessage<T>>,
270	) -> <T::PromptModel as Llm<T>>::Tokens {
271		msgs.fold(<T::PromptModel as Llm<T>>::Tokens::from_u8(0).unwrap(), |acc, m| {
272			let tokens = T::PromptModel::count_tokens(&m.content).unwrap_or_default();
273			match acc.checked_add(&tokens) {
274				Some(v) => v,
275				None => {
276					error!("Token overflow");
277					acc
278				},
279			}
280		})
281	}
282}