lm_studio_api/chat/
context.rs

1use crate::prelude::*;
2use super::{ Role, Message, SystemInfo, };
3use tiktoken_rs::{ CoreBPE, cl100k_base };
4
5/// Chat context
6#[derive(Clone)]
7pub struct Context {
8    tokenizer: CoreBPE,
9    pub messages: Vec<Message>,
10    pub system_prompt: Arc<Mutex<Box<dyn SystemInfo>>>,
11    pub tokens_limit: u32,
12    pub total_tokens: u32,
13}
14
15impl Context {
16    /// Creates a new chat context
17    pub fn new(system_prompt: Box<dyn SystemInfo>, tokens_limit: u32) -> Self {
18        // init tokeninzer:
19        let tokenizer = cl100k_base().expect("Failed to create tokenizer");
20        
21        // creating context:
22        let mut context = Self {
23            tokenizer,
24            messages: vec![
25                Message::new(Role::System, str!()),
26            ],
27            system_prompt: Arc::new(Mutex::new(system_prompt)),
28            tokens_limit,
29            total_tokens: 0,
30        };
31
32        // plus tokens count:
33        context.total_tokens = context.count_tokens(&context.messages[0].text());
34
35        context
36    }
37
38    /// Add a message to context
39    pub fn add<M: Into<Message>>(&mut self, message: M) {
40        let message = message.into();
41        let tokens_count = self.count_tokens(message.text());
42
43        // add message to context:
44        self.messages.push(message);
45        self.total_tokens += tokens_count;
46
47        // remove old messages:
48        while self.messages.len() > 3 && self.total_tokens > self.tokens_limit {
49            let removed_count = self.count_tokens(self.messages[1].text()) as u32;
50
51            self.messages.remove(1);
52            self.total_tokens -= removed_count;
53        }
54    }
55
56    /// Returns all context messages
57    pub fn get(&self) -> Vec<Message> {
58        self.messages.clone()
59    }
60
61    /// Returns all context messages as single string
62    pub fn get_as_string(&self) -> String {
63        self.messages.clone()
64            .into_iter()
65            .map(|msg| msg.text().to_owned())
66            .collect::<Vec<_>>()
67            .join("\n\n")
68    }
69
70    /// Clears context messages
71    pub fn clear(&mut self) {
72        self.messages.truncate(1);
73    }
74
75    /// Tokenizes text
76    pub fn tokenize(&self, text: &str) -> Vec<u32> {
77        self.tokenizer.encode_with_special_tokens(text)
78    }
79    
80    /// Calculates tokens count
81    pub fn count_tokens(&self, text: &str) -> u32 {
82        self.tokenize(text).len() as u32
83    }
84
85    /// Updates the actual system info
86    pub async fn update_system_info(&mut self) {
87        self.messages[0].content = self.system_prompt.lock().await.update();
88    }
89}