jutella/chat_client/
context.rs

1// Copyright (c) 2024 Dmitry Markin
2//
3// SPDX-License-Identifier: MIT
4//
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11//
12// The above copyright notice and this permission notice shall be included in all
13// copies or substantial portions of the Software.
14//
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22
23//! Chatbot context.
24
25use crate::chat_client::openai_api::message::{
26    AssistantMessage, Message, SystemMessage, UserMessage,
27};
28use iter_accumulate::IterAccumulate;
29
30/// Chatbot context.
31#[derive(Debug, Default, Clone)]
32pub struct Context {
33    system_message: Option<String>,
34    conversation: Vec<(String, String)>,
35    tokenizer: Option<tiktoken_rs::CoreBPE>,
36    min_history_tokens: Option<usize>,
37    max_history_tokens: Option<usize>,
38}
39
40impl Context {
41    /// Create a new chat context.
42    pub fn new(system_message: Option<String>) -> Self {
43        Self {
44            system_message,
45            conversation: Vec::new(),
46            tokenizer: None,
47            min_history_tokens: None,
48            max_history_tokens: None,
49        }
50    }
51
52    /// Create a new chat context wth tokenizer.
53    pub fn new_with_rolling_window(
54        system_message: Option<String>,
55        tokenizer: tiktoken_rs::CoreBPE,
56        min_history_tokens: Option<usize>,
57        max_history_tokens: Option<usize>,
58    ) -> Self {
59        debug_assert!(min_history_tokens.is_some() || max_history_tokens.is_some());
60
61        Self {
62            system_message,
63            conversation: Vec::new(),
64            tokenizer: Some(tokenizer),
65            min_history_tokens,
66            max_history_tokens,
67        }
68    }
69
70    /// Context so far with a new request message.
71    pub fn with_request(&self, request: String) -> impl Iterator<Item = Message> + '_ {
72        self.system_message
73            .iter()
74            .map(|system_message| SystemMessage::new(system_message.clone()).into())
75            .chain(self.conversation.iter().flat_map(|(request, response)| {
76                [
77                    UserMessage::new(request.clone()).into(),
78                    AssistantMessage::new(response.clone()).into(),
79                ]
80                .into_iter()
81            }))
82            .chain(std::iter::once(UserMessage::new(request).into()))
83    }
84
85    /// Extend the context with a new pair of request and response.
86    pub fn push(&mut self, request: String, response: String) {
87        self.conversation.push((request, response));
88        self.keep_recent();
89    }
90
91    /// Discard old records to keep the context within the limits.
92    fn keep_recent(&mut self) {
93        let Some(ref tokenizer) = self.tokenizer else {
94            return;
95        };
96
97        // At least one of the numbers is limited if tokenizer is set.
98        debug_assert!(self.min_history_tokens.is_some() || self.max_history_tokens.is_some());
99        let min_tokens = self.min_history_tokens.unwrap_or(usize::MAX);
100        let max_tokens = self.max_history_tokens.unwrap_or(usize::MAX);
101
102        let num_tokens = |m| tokenizer.encode_with_special_tokens(m).len();
103
104        let system_tokens = self
105            .system_message
106            .as_ref()
107            .map(|m| num_tokens(m))
108            .unwrap_or_default();
109
110        let keep = self
111            .conversation
112            .iter()
113            .rev()
114            .map(|transaction| num_tokens(&transaction.0) + num_tokens(&transaction.1))
115            .accumulate((0, system_tokens), |(_, acc), x| (acc, acc + x))
116            .map_while(|(prev, current)| (prev < min_tokens).then_some(current))
117            .take_while(|current| *current <= max_tokens)
118            .count();
119
120        let discard = self.conversation.len() - keep;
121        self.conversation.drain(0..discard);
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn empty() {
131        let context = Context::default();
132
133        assert_eq!(
134            context
135                .with_request(String::from("req"))
136                .collect::<Vec<_>>(),
137            vec![UserMessage::new(String::from("req")).into()],
138        );
139    }
140
141    #[test]
142    fn non_empty() {
143        let mut context = Context::default();
144        context.push(String::from("req1"), String::from("resp1"));
145
146        assert_eq!(
147            context
148                .with_request(String::from("req2"))
149                .collect::<Vec<_>>(),
150            vec![
151                UserMessage::new(String::from("req1")).into(),
152                AssistantMessage::new(String::from("resp1")).into(),
153                UserMessage::new(String::from("req2")).into(),
154            ],
155        );
156    }
157
158    #[test]
159    fn empty_with_system_message() {
160        let context = Context::new(Some(String::from("system")));
161
162        assert_eq!(
163            context
164                .with_request(String::from("req"))
165                .collect::<Vec<_>>(),
166            vec![
167                SystemMessage::new(String::from("system")).into(),
168                UserMessage::new(String::from("req")).into(),
169            ]
170        );
171    }
172
173    #[test]
174    fn non_empty_with_system_message() {
175        let mut context = Context::new(Some(String::from("system")));
176        context.push(String::from("req1"), String::from("resp1"));
177
178        assert_eq!(
179            context
180                .with_request(String::from("req2"))
181                .collect::<Vec<_>>(),
182            vec![
183                SystemMessage::new(String::from("system")).into(),
184                UserMessage::new(String::from("req1")).into(),
185                AssistantMessage::new(String::from("resp1")).into(),
186                UserMessage::new(String::from("req2")).into(),
187            ]
188        );
189    }
190
191    #[test]
192    fn min_history_tokens() {
193        let tokenizer = tiktoken_rs::o200k_base().unwrap();
194        let num_tokens = |m| tokenizer.encode_with_special_tokens(m).len();
195        let system = "to to to to to".to_string();
196        let request = "do do do do do".to_string();
197        let response = "be be be be be".to_string();
198        assert_eq!(num_tokens(&system), 5);
199        assert_eq!(num_tokens(&request), 5);
200        assert_eq!(num_tokens(&response), 5);
201
202        let mut context = Context::new_with_rolling_window(
203            Some(system.to_string()),
204            tokenizer.clone(),
205            Some(20),
206            None,
207        );
208        assert!(context.conversation.is_empty());
209
210        // 15 tokens
211        context.push(request.clone(), response.clone());
212        assert_eq!(context.conversation.len(), 1);
213
214        // 25 tokens
215        context.push(request.clone(), response.clone());
216        assert_eq!(context.conversation.len(), 2);
217
218        // 25 tokens again: one transaction was discarded
219        context.push(request.clone(), response.clone());
220        assert_eq!(context.conversation.len(), 2);
221    }
222
223    #[test]
224    fn min_history_tokens_exact() {
225        let tokenizer = tiktoken_rs::o200k_base().unwrap();
226        let num_tokens = |m| tokenizer.encode_with_special_tokens(m).len();
227        let request = "do do do do do".to_string();
228        let response = "be be be be be".to_string();
229        assert_eq!(num_tokens(&request), 5);
230        assert_eq!(num_tokens(&response), 5);
231
232        let mut context = Context::new_with_rolling_window(None, tokenizer.clone(), Some(20), None);
233        assert!(context.conversation.is_empty());
234
235        // 10 tokens
236        context.push(request.clone(), response.clone());
237        assert_eq!(context.conversation.len(), 1);
238
239        // 20 tokens
240        context.push(request.clone(), response.clone());
241        assert_eq!(context.conversation.len(), 2);
242
243        // 20 tokens again: one transaction was discarded
244        context.push(request.clone(), response.clone());
245        assert_eq!(context.conversation.len(), 2);
246    }
247
248    #[test]
249    fn max_history_tokens() {
250        let tokenizer = tiktoken_rs::o200k_base().unwrap();
251        let num_tokens = |m| tokenizer.encode_with_special_tokens(m).len();
252        let system = "to to to to to".to_string();
253        let request = "do do do do do".to_string();
254        let response = "be be be be be".to_string();
255        assert_eq!(num_tokens(&system), 5);
256        assert_eq!(num_tokens(&request), 5);
257        assert_eq!(num_tokens(&response), 5);
258
259        let mut context = Context::new_with_rolling_window(
260            Some(system.to_string()),
261            tokenizer.clone(),
262            None,
263            Some(30),
264        );
265        assert!(context.conversation.is_empty());
266
267        // 15 tokens
268        context.push(request.clone(), response.clone());
269        assert_eq!(context.conversation.len(), 1);
270
271        // 25 tokens
272        context.push(request.clone(), response.clone());
273        assert_eq!(context.conversation.len(), 2);
274
275        // 25 tokens again: one transaction was discarded
276        context.push(request.clone(), response.clone());
277        assert_eq!(context.conversation.len(), 2);
278    }
279
280    #[test]
281    fn max_history_tokens_exact() {
282        let tokenizer = tiktoken_rs::o200k_base().unwrap();
283        let num_tokens = |m| tokenizer.encode_with_special_tokens(m).len();
284        let request = "do do do do do".to_string();
285        let response = "be be be be be".to_string();
286        assert_eq!(num_tokens(&request), 5);
287        assert_eq!(num_tokens(&response), 5);
288
289        let mut context = Context::new_with_rolling_window(None, tokenizer.clone(), None, Some(30));
290        assert!(context.conversation.is_empty());
291
292        // 10 tokens
293        context.push(request.clone(), response.clone());
294        assert_eq!(context.conversation.len(), 1);
295
296        // 20 tokens
297        context.push(request.clone(), response.clone());
298        assert_eq!(context.conversation.len(), 2);
299
300        // 30 tokens
301        context.push(request.clone(), response.clone());
302        assert_eq!(context.conversation.len(), 3);
303
304        // 30 tokens again: one transaction was discarded
305        context.push(request.clone(), response.clone());
306        assert_eq!(context.conversation.len(), 3);
307    }
308}