llm_weaver/
types.rs

1use std::collections::VecDeque;
2
3use async_openai::types::Role;
4use num_traits::{CheckedAdd, FromPrimitive, SaturatingAdd};
5use serde::{Deserialize, Serialize};
6use tracing::error;
7
8use crate::{Config, Llm};
9
10pub type PromptModelTokens<T> = <<T as Config>::PromptModel as Llm<T>>::Tokens;
11pub type SummaryModelTokens<T> = <<T as Config>::SummaryModel as Llm<T>>::Tokens;
12pub type PromptModelRequest<T> = <<T as Config>::PromptModel as Llm<T>>::Request;
13
14/// Base type for all configuration parameters.
15pub type F32 = f32;
16
17pub const SYSTEM_ROLE: &str = "system";
18pub const ASSISTANT_ROLE: &str = "assistant";
19pub const USER_ROLE: &str = "user";
20pub const FUNCTION_ROLE: &str = "function";
21
22/// Wrapped [`Role`] for custom implementations.
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
24pub enum WrapperRole {
25	Role(Role),
26}
27
28impl Default for WrapperRole {
29	fn default() -> Self {
30		Self::Role(Role::User)
31	}
32}
33
34impl From<&str> for WrapperRole {
35	fn from(role: &str) -> Self {
36		match role {
37			SYSTEM_ROLE => Self::Role(Role::System),
38			ASSISTANT_ROLE => Self::Role(Role::Assistant),
39			USER_ROLE => Self::Role(Role::User),
40			FUNCTION_ROLE => Self::Role(Role::Function),
41			_ => panic!(
42				"Invalid role: {} \n Valid roles: {} | {} | {} | {}",
43				role, SYSTEM_ROLE, ASSISTANT_ROLE, USER_ROLE, FUNCTION_ROLE
44			),
45		}
46	}
47}
48
49impl From<WrapperRole> for String {
50	fn from(role: WrapperRole) -> Self {
51		match role {
52			WrapperRole::Role(Role::System) => SYSTEM_ROLE.to_string(),
53			WrapperRole::Role(Role::Assistant) => ASSISTANT_ROLE.to_string(),
54			WrapperRole::Role(Role::User) => USER_ROLE.to_string(),
55			WrapperRole::Role(Role::Function) => FUNCTION_ROLE.to_string(),
56			WrapperRole::Role(_) => panic!("Invalid role"),
57		}
58	}
59}
60
61#[derive(Debug, thiserror::Error)]
62pub enum LoomError {
63	#[error("Weave error: {0}")]
64	Weave(#[from] WeaveError),
65	#[error("Storage error: {0}")]
66	Storage(#[from] StorageError),
67	#[error("Error: {0}")]
68	Error(String),
69}
70
71impl From<Box<dyn std::error::Error + Send + Sync>> for LoomError {
72	fn from(error: Box<dyn std::error::Error + Send + Sync>) -> Self {
73		match error.downcast::<LoomError>() {
74			Ok(loom_error) => *loom_error,
75			Err(error) => LoomError::Error(error.to_string()),
76		}
77	}
78}
79
80#[derive(Debug, thiserror::Error)]
81pub enum WeaveError {
82	#[error("Exceeds max prompt tokens")]
83	MaxCompletionTokensIsZero,
84	#[error("Bad configuration: {0}")]
85	BadConfig(String),
86}
87
88#[derive(Debug, thiserror::Error)]
89pub enum StorageError {
90	#[cfg(feature = "rocksdb")]
91	#[error("RocksDb error: {0}")]
92	RocksDb(rocksdb::Error),
93	#[error("Parsing error")]
94	Parsing,
95	#[error("Not found")]
96	NotFound,
97	#[error("Failed to read instance count")]
98	FailedToReadInstanceCount,
99	#[error("Database error: {0}")]
100	DatabaseError(String),
101	#[error("Serialization error: {0}")]
102	SerializationError(String),
103	#[error("Deserialization error: {0}")]
104	DeserializationError(String),
105	#[error("Internal error: {0}")]
106	InternalError(String),
107}
108
109/// A helper struct to manage the prompt messages in a deque while keeping track of the tokens
110/// added or removed.
111pub struct VecPromptMsgsDeque<T: Config, L: Llm<T>> {
112	pub tokens: <L as Llm<T>>::Tokens,
113	pub inner: VecDeque<<L as Llm<T>>::Request>,
114}
115
116impl<T: Config, L: Llm<T>> VecPromptMsgsDeque<T, L> {
117	pub fn new() -> Self {
118		Self { tokens: L::Tokens::from_u8(0).unwrap(), inner: VecDeque::new() }
119	}
120
121	pub fn with_capacity(capacity: usize) -> Self {
122		Self { tokens: L::Tokens::from_u8(0).unwrap(), inner: VecDeque::with_capacity(capacity) }
123	}
124
125	pub fn push_front(&mut self, msg_reqs: L::Request) {
126		let tokens = L::count_tokens(&msg_reqs.to_string()).unwrap_or_default();
127		self.tokens = self.tokens.saturating_add(&tokens);
128		self.inner.push_front(msg_reqs);
129	}
130
131	pub fn push_back(&mut self, msg_reqs: L::Request) {
132		let tokens = L::count_tokens(&msg_reqs.to_string()).unwrap_or_default();
133		self.tokens = self.tokens.saturating_add(&tokens);
134		self.inner.push_back(msg_reqs);
135	}
136
137	pub fn append(&mut self, msg_reqs: &mut VecDeque<L::Request>) {
138		msg_reqs.iter().for_each(|msg_req| {
139			let msg_tokens = L::count_tokens(&msg_req.to_string()).unwrap_or_default();
140			self.tokens = self.tokens.saturating_add(&msg_tokens);
141		});
142		self.inner.append(msg_reqs);
143	}
144
145	pub fn truncate(&mut self, len: usize) {
146		let mut tokens = L::Tokens::from_u8(0).unwrap();
147		for msg_req in self.inner.iter().take(len) {
148			let msg_tokens = L::count_tokens(&msg_req.to_string()).unwrap_or_default();
149			tokens = tokens.saturating_add(&msg_tokens);
150		}
151		self.inner.truncate(len);
152		self.tokens = tokens;
153	}
154
155	pub fn extend(&mut self, msg_reqs: Vec<L::Request>) {
156		let mut tokens = L::Tokens::from_u8(0).unwrap();
157		for msg_req in &msg_reqs {
158			let msg_tokens = L::count_tokens(&msg_req.to_string()).unwrap_or_default();
159			tokens = tokens.saturating_add(&msg_tokens);
160		}
161		self.inner.extend(msg_reqs);
162		match self.tokens.checked_add(&tokens) {
163			Some(v) => self.tokens = v,
164			None => {
165				error!("Token overflow");
166			},
167		}
168	}
169
170	pub fn into_vec(self) -> Vec<L::Request> {
171		self.inner.into()
172	}
173}