1use crate::chat_client::openai_api::message::{
26 AssistantMessage, Message, SystemMessage, UserMessage,
27};
28use iter_accumulate::IterAccumulate;
29
30#[derive(Debug, Default, Clone)]
32pub struct Context {
33 system_message: Option<String>,
34 conversation: Vec<(String, String)>,
35 tokenizer: Option<tiktoken_rs::CoreBPE>,
36 min_history_tokens: Option<usize>,
37 max_history_tokens: Option<usize>,
38}
39
40impl Context {
41 pub fn new(system_message: Option<String>) -> Self {
43 Self {
44 system_message,
45 conversation: Vec::new(),
46 tokenizer: None,
47 min_history_tokens: None,
48 max_history_tokens: None,
49 }
50 }
51
52 pub fn new_with_rolling_window(
54 system_message: Option<String>,
55 tokenizer: tiktoken_rs::CoreBPE,
56 min_history_tokens: Option<usize>,
57 max_history_tokens: Option<usize>,
58 ) -> Self {
59 debug_assert!(min_history_tokens.is_some() || max_history_tokens.is_some());
60
61 Self {
62 system_message,
63 conversation: Vec::new(),
64 tokenizer: Some(tokenizer),
65 min_history_tokens,
66 max_history_tokens,
67 }
68 }
69
70 pub fn with_request(&self, request: String) -> impl Iterator<Item = Message> + '_ {
72 self.system_message
73 .iter()
74 .map(|system_message| SystemMessage::new(system_message.clone()).into())
75 .chain(self.conversation.iter().flat_map(|(request, response)| {
76 [
77 UserMessage::new(request.clone()).into(),
78 AssistantMessage::new(response.clone()).into(),
79 ]
80 .into_iter()
81 }))
82 .chain(std::iter::once(UserMessage::new(request).into()))
83 }
84
85 pub fn push(&mut self, request: String, response: String) {
87 self.conversation.push((request, response));
88 self.keep_recent();
89 }
90
91 fn keep_recent(&mut self) {
93 let Some(ref tokenizer) = self.tokenizer else {
94 return;
95 };
96
97 debug_assert!(self.min_history_tokens.is_some() || self.max_history_tokens.is_some());
99 let min_tokens = self.min_history_tokens.unwrap_or(usize::MAX);
100 let max_tokens = self.max_history_tokens.unwrap_or(usize::MAX);
101
102 let num_tokens = |m| tokenizer.encode_with_special_tokens(m).len();
103
104 let system_tokens = self
105 .system_message
106 .as_ref()
107 .map(|m| num_tokens(m))
108 .unwrap_or_default();
109
110 let keep = self
111 .conversation
112 .iter()
113 .rev()
114 .map(|transaction| num_tokens(&transaction.0) + num_tokens(&transaction.1))
115 .accumulate((0, system_tokens), |(_, acc), x| (acc, acc + x))
116 .map_while(|(prev, current)| (prev < min_tokens).then_some(current))
117 .take_while(|current| *current <= max_tokens)
118 .count();
119
120 let discard = self.conversation.len() - keep;
121 self.conversation.drain(0..discard);
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn empty() {
131 let context = Context::default();
132
133 assert_eq!(
134 context
135 .with_request(String::from("req"))
136 .collect::<Vec<_>>(),
137 vec![UserMessage::new(String::from("req")).into()],
138 );
139 }
140
141 #[test]
142 fn non_empty() {
143 let mut context = Context::default();
144 context.push(String::from("req1"), String::from("resp1"));
145
146 assert_eq!(
147 context
148 .with_request(String::from("req2"))
149 .collect::<Vec<_>>(),
150 vec![
151 UserMessage::new(String::from("req1")).into(),
152 AssistantMessage::new(String::from("resp1")).into(),
153 UserMessage::new(String::from("req2")).into(),
154 ],
155 );
156 }
157
158 #[test]
159 fn empty_with_system_message() {
160 let context = Context::new(Some(String::from("system")));
161
162 assert_eq!(
163 context
164 .with_request(String::from("req"))
165 .collect::<Vec<_>>(),
166 vec![
167 SystemMessage::new(String::from("system")).into(),
168 UserMessage::new(String::from("req")).into(),
169 ]
170 );
171 }
172
173 #[test]
174 fn non_empty_with_system_message() {
175 let mut context = Context::new(Some(String::from("system")));
176 context.push(String::from("req1"), String::from("resp1"));
177
178 assert_eq!(
179 context
180 .with_request(String::from("req2"))
181 .collect::<Vec<_>>(),
182 vec![
183 SystemMessage::new(String::from("system")).into(),
184 UserMessage::new(String::from("req1")).into(),
185 AssistantMessage::new(String::from("resp1")).into(),
186 UserMessage::new(String::from("req2")).into(),
187 ]
188 );
189 }
190
191 #[test]
192 fn min_history_tokens() {
193 let tokenizer = tiktoken_rs::o200k_base().unwrap();
194 let num_tokens = |m| tokenizer.encode_with_special_tokens(m).len();
195 let system = "to to to to to".to_string();
196 let request = "do do do do do".to_string();
197 let response = "be be be be be".to_string();
198 assert_eq!(num_tokens(&system), 5);
199 assert_eq!(num_tokens(&request), 5);
200 assert_eq!(num_tokens(&response), 5);
201
202 let mut context = Context::new_with_rolling_window(
203 Some(system.to_string()),
204 tokenizer.clone(),
205 Some(20),
206 None,
207 );
208 assert!(context.conversation.is_empty());
209
210 context.push(request.clone(), response.clone());
212 assert_eq!(context.conversation.len(), 1);
213
214 context.push(request.clone(), response.clone());
216 assert_eq!(context.conversation.len(), 2);
217
218 context.push(request.clone(), response.clone());
220 assert_eq!(context.conversation.len(), 2);
221 }
222
223 #[test]
224 fn min_history_tokens_exact() {
225 let tokenizer = tiktoken_rs::o200k_base().unwrap();
226 let num_tokens = |m| tokenizer.encode_with_special_tokens(m).len();
227 let request = "do do do do do".to_string();
228 let response = "be be be be be".to_string();
229 assert_eq!(num_tokens(&request), 5);
230 assert_eq!(num_tokens(&response), 5);
231
232 let mut context = Context::new_with_rolling_window(None, tokenizer.clone(), Some(20), None);
233 assert!(context.conversation.is_empty());
234
235 context.push(request.clone(), response.clone());
237 assert_eq!(context.conversation.len(), 1);
238
239 context.push(request.clone(), response.clone());
241 assert_eq!(context.conversation.len(), 2);
242
243 context.push(request.clone(), response.clone());
245 assert_eq!(context.conversation.len(), 2);
246 }
247
248 #[test]
249 fn max_history_tokens() {
250 let tokenizer = tiktoken_rs::o200k_base().unwrap();
251 let num_tokens = |m| tokenizer.encode_with_special_tokens(m).len();
252 let system = "to to to to to".to_string();
253 let request = "do do do do do".to_string();
254 let response = "be be be be be".to_string();
255 assert_eq!(num_tokens(&system), 5);
256 assert_eq!(num_tokens(&request), 5);
257 assert_eq!(num_tokens(&response), 5);
258
259 let mut context = Context::new_with_rolling_window(
260 Some(system.to_string()),
261 tokenizer.clone(),
262 None,
263 Some(30),
264 );
265 assert!(context.conversation.is_empty());
266
267 context.push(request.clone(), response.clone());
269 assert_eq!(context.conversation.len(), 1);
270
271 context.push(request.clone(), response.clone());
273 assert_eq!(context.conversation.len(), 2);
274
275 context.push(request.clone(), response.clone());
277 assert_eq!(context.conversation.len(), 2);
278 }
279
280 #[test]
281 fn max_history_tokens_exact() {
282 let tokenizer = tiktoken_rs::o200k_base().unwrap();
283 let num_tokens = |m| tokenizer.encode_with_special_tokens(m).len();
284 let request = "do do do do do".to_string();
285 let response = "be be be be be".to_string();
286 assert_eq!(num_tokens(&request), 5);
287 assert_eq!(num_tokens(&response), 5);
288
289 let mut context = Context::new_with_rolling_window(None, tokenizer.clone(), None, Some(30));
290 assert!(context.conversation.is_empty());
291
292 context.push(request.clone(), response.clone());
294 assert_eq!(context.conversation.len(), 1);
295
296 context.push(request.clone(), response.clone());
298 assert_eq!(context.conversation.len(), 2);
299
300 context.push(request.clone(), response.clone());
302 assert_eq!(context.conversation.len(), 3);
303
304 context.push(request.clone(), response.clone());
306 assert_eq!(context.conversation.len(), 3);
307 }
308}