1use std::collections::VecDeque;
2
3use async_openai::types::Role;
4use num_traits::{CheckedAdd, FromPrimitive, SaturatingAdd};
5use serde::{Deserialize, Serialize};
6use tracing::error;
7
8use crate::{Config, Llm};
9
10pub type PromptModelTokens<T> = <<T as Config>::PromptModel as Llm<T>>::Tokens;
11pub type SummaryModelTokens<T> = <<T as Config>::SummaryModel as Llm<T>>::Tokens;
12pub type PromptModelRequest<T> = <<T as Config>::PromptModel as Llm<T>>::Request;
13
14pub type F32 = f32;
16
17pub const SYSTEM_ROLE: &str = "system";
18pub const ASSISTANT_ROLE: &str = "assistant";
19pub const USER_ROLE: &str = "user";
20pub const FUNCTION_ROLE: &str = "function";
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
24pub enum WrapperRole {
25 Role(Role),
26}
27
28impl Default for WrapperRole {
29 fn default() -> Self {
30 Self::Role(Role::User)
31 }
32}
33
34impl From<&str> for WrapperRole {
35 fn from(role: &str) -> Self {
36 match role {
37 SYSTEM_ROLE => Self::Role(Role::System),
38 ASSISTANT_ROLE => Self::Role(Role::Assistant),
39 USER_ROLE => Self::Role(Role::User),
40 FUNCTION_ROLE => Self::Role(Role::Function),
41 _ => panic!(
42 "Invalid role: {} \n Valid roles: {} | {} | {} | {}",
43 role, SYSTEM_ROLE, ASSISTANT_ROLE, USER_ROLE, FUNCTION_ROLE
44 ),
45 }
46 }
47}
48
49impl From<WrapperRole> for String {
50 fn from(role: WrapperRole) -> Self {
51 match role {
52 WrapperRole::Role(Role::System) => SYSTEM_ROLE.to_string(),
53 WrapperRole::Role(Role::Assistant) => ASSISTANT_ROLE.to_string(),
54 WrapperRole::Role(Role::User) => USER_ROLE.to_string(),
55 WrapperRole::Role(Role::Function) => FUNCTION_ROLE.to_string(),
56 WrapperRole::Role(_) => panic!("Invalid role"),
57 }
58 }
59}
60
61#[derive(Debug, thiserror::Error)]
62pub enum LoomError {
63 #[error("Weave error: {0}")]
64 Weave(#[from] WeaveError),
65 #[error("Storage error: {0}")]
66 Storage(#[from] StorageError),
67 #[error("Error: {0}")]
68 Error(String),
69}
70
71impl From<Box<dyn std::error::Error + Send + Sync>> for LoomError {
72 fn from(error: Box<dyn std::error::Error + Send + Sync>) -> Self {
73 match error.downcast::<LoomError>() {
74 Ok(loom_error) => *loom_error,
75 Err(error) => LoomError::Error(error.to_string()),
76 }
77 }
78}
79
80#[derive(Debug, thiserror::Error)]
81pub enum WeaveError {
82 #[error("Exceeds max prompt tokens")]
83 MaxCompletionTokensIsZero,
84 #[error("Bad configuration: {0}")]
85 BadConfig(String),
86}
87
88#[derive(Debug, thiserror::Error)]
89pub enum StorageError {
90 #[cfg(feature = "rocksdb")]
91 #[error("RocksDb error: {0}")]
92 RocksDb(rocksdb::Error),
93 #[error("Parsing error")]
94 Parsing,
95 #[error("Not found")]
96 NotFound,
97 #[error("Failed to read instance count")]
98 FailedToReadInstanceCount,
99 #[error("Database error: {0}")]
100 DatabaseError(String),
101 #[error("Serialization error: {0}")]
102 SerializationError(String),
103 #[error("Deserialization error: {0}")]
104 DeserializationError(String),
105 #[error("Internal error: {0}")]
106 InternalError(String),
107}
108
109pub struct VecPromptMsgsDeque<T: Config, L: Llm<T>> {
112 pub tokens: <L as Llm<T>>::Tokens,
113 pub inner: VecDeque<<L as Llm<T>>::Request>,
114}
115
116impl<T: Config, L: Llm<T>> VecPromptMsgsDeque<T, L> {
117 pub fn new() -> Self {
118 Self { tokens: L::Tokens::from_u8(0).unwrap(), inner: VecDeque::new() }
119 }
120
121 pub fn with_capacity(capacity: usize) -> Self {
122 Self { tokens: L::Tokens::from_u8(0).unwrap(), inner: VecDeque::with_capacity(capacity) }
123 }
124
125 pub fn push_front(&mut self, msg_reqs: L::Request) {
126 let tokens = L::count_tokens(&msg_reqs.to_string()).unwrap_or_default();
127 self.tokens = self.tokens.saturating_add(&tokens);
128 self.inner.push_front(msg_reqs);
129 }
130
131 pub fn push_back(&mut self, msg_reqs: L::Request) {
132 let tokens = L::count_tokens(&msg_reqs.to_string()).unwrap_or_default();
133 self.tokens = self.tokens.saturating_add(&tokens);
134 self.inner.push_back(msg_reqs);
135 }
136
137 pub fn append(&mut self, msg_reqs: &mut VecDeque<L::Request>) {
138 msg_reqs.iter().for_each(|msg_req| {
139 let msg_tokens = L::count_tokens(&msg_req.to_string()).unwrap_or_default();
140 self.tokens = self.tokens.saturating_add(&msg_tokens);
141 });
142 self.inner.append(msg_reqs);
143 }
144
145 pub fn truncate(&mut self, len: usize) {
146 let mut tokens = L::Tokens::from_u8(0).unwrap();
147 for msg_req in self.inner.iter().take(len) {
148 let msg_tokens = L::count_tokens(&msg_req.to_string()).unwrap_or_default();
149 tokens = tokens.saturating_add(&msg_tokens);
150 }
151 self.inner.truncate(len);
152 self.tokens = tokens;
153 }
154
155 pub fn extend(&mut self, msg_reqs: Vec<L::Request>) {
156 let mut tokens = L::Tokens::from_u8(0).unwrap();
157 for msg_req in &msg_reqs {
158 let msg_tokens = L::count_tokens(&msg_req.to_string()).unwrap_or_default();
159 tokens = tokens.saturating_add(&msg_tokens);
160 }
161 self.inner.extend(msg_reqs);
162 match self.tokens.checked_add(&tokens) {
163 Some(v) => self.tokens = v,
164 None => {
165 error!("Token overflow");
166 },
167 }
168 }
169
170 pub fn into_vec(self) -> Vec<L::Request> {
171 self.inner.into()
172 }
173}