use std::{collections::VecDeque, marker::PhantomData};
use num_traits::{CheckedAdd, FromPrimitive, SaturatingAdd, SaturatingSub, Zero};
use tracing::{debug, error, instrument, trace};
use crate::{
types::{
LoomError, PromptModelRequest, PromptModelTokens, SummaryModelTokens, VecPromptMsgsDeque,
WeaveError, WrapperRole, ASSISTANT_ROLE, SYSTEM_ROLE,
},
Config, ContextMessage, Llm, LlmConfig, TapestryChestHandler, TapestryFragment, TapestryId,
};
#[derive(Debug)]
pub struct Loom<T: Config> {
pub chest: T::Chest,
_phantom: PhantomData<T>,
}
impl<T: Config> Loom<T> {
pub fn new() -> Self {
Self { chest: <T::Chest as TapestryChestHandler<T>>::new(), _phantom: PhantomData }
}
#[instrument(skip(self, instructions, msgs))]
pub async fn weave<TID: TapestryId>(
&self,
prompt_llm_config: LlmConfig<T, T::PromptModel>,
summary_llm_config: LlmConfig<T, T::SummaryModel>,
tapestry_id: TID,
instructions: String,
mut msgs: Vec<ContextMessage<T>>,
) -> Result<
(<<T as Config>::PromptModel as Llm<T>>::Response, u64, bool),
Box<dyn std::error::Error + Send + Sync>,
> {
let instructions_ctx_msg =
Self::build_context_message(SYSTEM_ROLE.into(), instructions, None);
let instructions_req_msg: PromptModelRequest<T> = instructions_ctx_msg.clone().into();
trace!("Fetching current tapestry fragment for ID: {:?}", tapestry_id);
let current_tapestry_fragment = self
.chest
.get_tapestry_fragment(tapestry_id.clone(), None)
.await?
.unwrap_or_default();
let max_prompt_tokens_limit = prompt_llm_config.model.get_max_prompt_token_limit();
let mut req_msgs = VecPromptMsgsDeque::<T, T::PromptModel>::with_capacity(
current_tapestry_fragment.context_messages.len() + 1,
);
req_msgs.push_front(instructions_req_msg);
let mut ctx_msgs = VecDeque::from(
prompt_llm_config
.model
.ctx_msgs_to_prompt_requests(¤t_tapestry_fragment.context_messages),
);
req_msgs.append(&mut ctx_msgs);
let msgs_tokens = Self::count_tokens_in_messages(msgs.iter());
trace!(
"Total tokens after adding new messages: {:?}, maximum allowed: {:?}",
req_msgs.tokens.saturating_add(&msgs_tokens),
max_prompt_tokens_limit
);
let does_exceeding_max_token_limit = max_prompt_tokens_limit <=
req_msgs.tokens.saturating_add(&msgs_tokens).saturating_add(
&PromptModelTokens::<T>::from_u64(T::MINIMUM_RESPONSE_LENGTH).unwrap(),
);
let (mut tapestry_fragment_to_persist, was_summary_generated) =
if does_exceeding_max_token_limit {
trace!("Generating summary as the token limit exceeded");
let summary_max_tokens: PromptModelTokens<T> =
prompt_llm_config.model.max_context_length() - max_prompt_tokens_limit;
let summary = Self::generate_summary(
&summary_llm_config,
¤t_tapestry_fragment,
T::convert_prompt_tokens_to_summary_model_tokens(summary_max_tokens),
)
.await?;
let summary_ctx_msg = Self::build_context_message(
SYSTEM_ROLE.into(),
format!("\n\"\"\"\nSummary\n {}", summary),
None,
);
req_msgs.truncate(1);
req_msgs.push_back(summary_ctx_msg.clone().into());
let mut new_tapestry_fragment = TapestryFragment::new();
new_tapestry_fragment.push_message(summary_ctx_msg)?;
(new_tapestry_fragment, true)
} else {
(current_tapestry_fragment, false)
};
req_msgs.extend(msgs.iter().map(|m| m.clone().into()).collect::<Vec<_>>());
let max_completion_tokens = max_prompt_tokens_limit.saturating_sub(&req_msgs.tokens);
trace!("Max completion tokens available: {:?}", max_completion_tokens);
if max_completion_tokens.is_zero() {
return Err(LoomError::from(WeaveError::MaxCompletionTokensIsZero).into());
}
trace!("Prompting LLM with request messages");
let response = prompt_llm_config
.model
.prompt(
false,
req_msgs.tokens,
req_msgs.into_vec(),
&prompt_llm_config.params,
max_completion_tokens,
)
.await
.map_err(|e| {
error!("Failed to prompt LLM: {}", e);
e
})?;
msgs.push(Self::build_context_message(
ASSISTANT_ROLE.into(),
response.clone().into().unwrap_or_default(),
None,
));
tapestry_fragment_to_persist.extend_messages(msgs)?;
debug!("Saving tapestry fragment: {:?}", tapestry_fragment_to_persist);
let tapestry_fragment_id = self
.chest
.save_tapestry_fragment(
&tapestry_id,
tapestry_fragment_to_persist,
was_summary_generated,
)
.await
.map_err(|e| {
error!("Failed to save tapestry fragment: {}", e);
e
})?;
Ok((response, tapestry_fragment_id, was_summary_generated))
}
async fn generate_summary(
summary_model_config: &LlmConfig<T, T::SummaryModel>,
tapestry_fragment: &TapestryFragment<T>,
summary_max_tokens: SummaryModelTokens<T>,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
trace!(
"Generating summary with max tokens: {:?}, for tapestry fragment: {:?}",
summary_max_tokens,
tapestry_fragment
);
let mut summary_generation_prompt = VecPromptMsgsDeque::<T, T::SummaryModel>::new();
summary_generation_prompt.extend(
summary_model_config
.model
.ctx_msgs_to_prompt_requests(tapestry_fragment.context_messages.as_slice()),
);
let res = summary_model_config
.model
.prompt(
true,
summary_generation_prompt.tokens,
summary_generation_prompt.into_vec(),
&summary_model_config.params,
summary_max_tokens,
)
.await
.map_err(|e| {
error!("Failed to prompt LLM: {}", e);
e
})?;
let summary_response_content = res.into();
trace!("Generated summary: {:?}", summary_response_content);
Ok(summary_response_content.unwrap_or_default())
}
pub fn build_context_message(
role: WrapperRole,
content: String,
account_id: Option<String>,
) -> ContextMessage<T> {
trace!("Building context message for role: {:?}, content: {}", role, content);
ContextMessage {
role,
content,
account_id,
timestamp: chrono::Utc::now().to_rfc3339(),
_phantom: PhantomData,
}
}
fn count_tokens_in_messages(
msgs: impl Iterator<Item = &ContextMessage<T>>,
) -> <T::PromptModel as Llm<T>>::Tokens {
msgs.fold(<T::PromptModel as Llm<T>>::Tokens::from_u8(0).unwrap(), |acc, m| {
let tokens = T::PromptModel::count_tokens(&m.content).unwrap_or_default();
match acc.checked_add(&tokens) {
Some(v) => v,
None => {
error!("Token overflow");
acc
},
}
})
}
}