1use std::{collections::VecDeque, marker::PhantomData};
2
3use num_traits::{CheckedAdd, FromPrimitive, SaturatingAdd, SaturatingSub, Zero};
4use tracing::{debug, error, instrument, trace};
5
6use crate::{
7 types::{
8 LoomError, PromptModelRequest, PromptModelTokens, SummaryModelTokens, VecPromptMsgsDeque,
9 WeaveError, WrapperRole, ASSISTANT_ROLE, SYSTEM_ROLE,
10 },
11 Config, ContextMessage, Llm, LlmConfig, TapestryChestHandler, TapestryFragment, TapestryId,
12};
13
14#[derive(Debug)]
19pub struct Loom<T: Config> {
20 pub chest: T::Chest,
21 _phantom: PhantomData<T>,
22}
23
24impl<T: Config> Loom<T> {
25 pub fn new() -> Self {
27 Self { chest: <T::Chest as TapestryChestHandler<T>>::new(), _phantom: PhantomData }
28 }
29
30 #[instrument(skip(self, instructions, msgs))]
48 pub async fn weave<TID: TapestryId>(
49 &self,
50 prompt_llm_config: LlmConfig<T, T::PromptModel>,
51 summary_llm_config: LlmConfig<T, T::SummaryModel>,
52 tapestry_id: TID,
53 instructions: String,
54 mut msgs: Vec<ContextMessage<T>>,
55 ) -> Result<
56 (<<T as Config>::PromptModel as Llm<T>>::Response, u64, bool),
57 Box<dyn std::error::Error + Send + Sync>,
58 > {
59 let instructions_ctx_msg =
60 Self::build_context_message(SYSTEM_ROLE.into(), instructions, None);
61 let instructions_req_msg: PromptModelRequest<T> = instructions_ctx_msg.clone().into();
62
63 trace!("Fetching current tapestry fragment for ID: {:?}", tapestry_id);
64
65 let current_tapestry_fragment = self
66 .chest
67 .get_tapestry_fragment(tapestry_id.clone(), None)
68 .await?
69 .unwrap_or_default();
70
71 let max_prompt_tokens_limit = prompt_llm_config.model.get_max_prompt_token_limit();
73
74 let mut req_msgs = VecPromptMsgsDeque::<T, T::PromptModel>::with_capacity(
76 current_tapestry_fragment.context_messages.len() + 1,
77 );
78
79 req_msgs.push_front(instructions_req_msg);
81
82 let mut ctx_msgs = VecDeque::from(
84 prompt_llm_config
85 .model
86 .ctx_msgs_to_prompt_requests(¤t_tapestry_fragment.context_messages),
87 );
88 req_msgs.append(&mut ctx_msgs);
89
90 let msgs_tokens = Self::count_tokens_in_messages(msgs.iter());
97
98 trace!(
99 "Total tokens after adding new messages: {:?}, maximum allowed: {:?}",
100 req_msgs.tokens.saturating_add(&msgs_tokens),
101 max_prompt_tokens_limit
102 );
103
104 let does_exceeding_max_token_limit = max_prompt_tokens_limit <=
107 req_msgs.tokens.saturating_add(&msgs_tokens).saturating_add(
108 &PromptModelTokens::<T>::from_u64(T::MINIMUM_RESPONSE_LENGTH).unwrap(),
109 );
110
111 let (mut tapestry_fragment_to_persist, was_summary_generated) =
112 if does_exceeding_max_token_limit {
113 trace!("Generating summary as the token limit exceeded");
114
115 let summary_max_tokens: PromptModelTokens<T> =
118 prompt_llm_config.model.max_context_length() - max_prompt_tokens_limit;
119
120 let summary = Self::generate_summary(
121 &summary_llm_config,
122 ¤t_tapestry_fragment,
123 T::convert_prompt_tokens_to_summary_model_tokens(summary_max_tokens),
124 )
125 .await?;
126
127 let summary_ctx_msg = Self::build_context_message(
128 SYSTEM_ROLE.into(),
129 format!("\n\"\"\"\nSummary\n {}", summary),
130 None,
131 );
132
133 req_msgs.truncate(1);
136 req_msgs.push_back(summary_ctx_msg.clone().into());
137
138 let mut new_tapestry_fragment = TapestryFragment::new();
140 new_tapestry_fragment.push_message(summary_ctx_msg)?;
141
142 (new_tapestry_fragment, true)
143 } else {
144 (current_tapestry_fragment, false)
145 };
146
147 req_msgs.extend(msgs.iter().map(|m| m.clone().into()).collect::<Vec<_>>());
149
150 let max_completion_tokens = max_prompt_tokens_limit.saturating_sub(&req_msgs.tokens);
152
153 trace!("Max completion tokens available: {:?}", max_completion_tokens);
154
155 if max_completion_tokens.is_zero() {
156 return Err(LoomError::from(WeaveError::MaxCompletionTokensIsZero).into());
157 }
158
159 trace!("Prompting LLM with request messages");
160
161 let response = prompt_llm_config
162 .model
163 .prompt(
164 false,
165 req_msgs.tokens,
166 req_msgs.into_vec(),
167 &prompt_llm_config.params,
168 max_completion_tokens,
169 )
170 .await
171 .map_err(|e| {
172 error!("Failed to prompt LLM: {}", e);
173 e
174 })?;
175
176 msgs.push(Self::build_context_message(
178 ASSISTANT_ROLE.into(),
179 response.clone().into().unwrap_or_default(),
180 None,
181 ));
182
183 tapestry_fragment_to_persist.extend_messages(msgs)?;
186
187 debug!("Saving tapestry fragment: {:?}", tapestry_fragment_to_persist);
188
189 let tapestry_fragment_id = self
192 .chest
193 .save_tapestry_fragment(
194 &tapestry_id,
195 tapestry_fragment_to_persist,
196 was_summary_generated,
197 )
198 .await
199 .map_err(|e| {
200 error!("Failed to save tapestry fragment: {}", e);
201 e
202 })?;
203
204 Ok((response, tapestry_fragment_id, was_summary_generated))
205 }
206
207 async fn generate_summary(
211 summary_model_config: &LlmConfig<T, T::SummaryModel>,
212 tapestry_fragment: &TapestryFragment<T>,
213 summary_max_tokens: SummaryModelTokens<T>,
214 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
215 trace!(
216 "Generating summary with max tokens: {:?}, for tapestry fragment: {:?}",
217 summary_max_tokens,
218 tapestry_fragment
219 );
220
221 let mut summary_generation_prompt = VecPromptMsgsDeque::<T, T::SummaryModel>::new();
222
223 summary_generation_prompt.extend(
224 summary_model_config
225 .model
226 .ctx_msgs_to_prompt_requests(tapestry_fragment.context_messages.as_slice()),
227 );
228
229 let res = summary_model_config
230 .model
231 .prompt(
232 true,
233 summary_generation_prompt.tokens,
234 summary_generation_prompt.into_vec(),
235 &summary_model_config.params,
236 summary_max_tokens,
237 )
238 .await
239 .map_err(|e| {
240 error!("Failed to prompt LLM: {}", e);
241 e
242 })?;
243
244 let summary_response_content = res.into();
245
246 trace!("Generated summary: {:?}", summary_response_content);
247
248 Ok(summary_response_content.unwrap_or_default())
249 }
250
251 pub fn build_context_message(
253 role: WrapperRole,
254 content: String,
255 account_id: Option<String>,
256 ) -> ContextMessage<T> {
257 trace!("Building context message for role: {:?}, content: {}", role, content);
258
259 ContextMessage {
260 role,
261 content,
262 account_id,
263 timestamp: chrono::Utc::now().to_rfc3339(),
264 _phantom: PhantomData,
265 }
266 }
267
268 fn count_tokens_in_messages(
269 msgs: impl Iterator<Item = &ContextMessage<T>>,
270 ) -> <T::PromptModel as Llm<T>>::Tokens {
271 msgs.fold(<T::PromptModel as Llm<T>>::Tokens::from_u8(0).unwrap(), |acc, m| {
272 let tokens = T::PromptModel::count_tokens(&m.content).unwrap_or_default();
273 match acc.checked_add(&tokens) {
274 Some(v) => v,
275 None => {
276 error!("Token overflow");
277 acc
278 },
279 }
280 })
281 }
282}