rustylms/chat.rs
1use anyhow::Result;
2
3use crate::{
4 lmsserver::LMSServer,
5 models::{ChatCompletionsRequest, ChatCompletionsResponse, Message},
6};
7
8#[derive(Debug)]
9pub struct Chat {
10 model: String,
11 messages: Vec<Message>,
12 temperature: f32,
13 max_tokens: i32,
14}
15
16impl Chat {
17 /// Creates a new `Chat` with the selected model.
18 ///
19 /// # Example
20 ///
21 /// ```rust
22 /// use rustylms::{
23 /// chat::Chat,
24 /// lmsserver::LMSServer
25 /// };
26 ///
27 /// let server = LMSServer::new("http://localhost:1234");
28 /// let chat = Chat::new("model-name").system_prompt("You are a helpful assistant.").user_prompt("Why does iron rust?");
29 /// let completion = chat.get_completions(&server).await.expect("Could not get completions");
30 ///
31 /// println!("From assistant: {}", completion.get_message().unwrap().content);
32 /// ```
33 pub fn new<T>(model: T) -> Self
34 where
35 T: Into<String>,
36 {
37 Self {
38 model: model.into(),
39 messages: vec![],
40 temperature: 0.7,
41 max_tokens: -1,
42 }
43 }
44
45 /// Sets the temperature of the model. The default value for this is `0.7`.
46 pub fn temperature(mut self, temperature: f32) -> Self {
47 self.temperature = temperature;
48
49 self
50 }
51
52 /// Sets the maximum tokens a completion can generate. The default value is `-1` meaning no limit.
53 pub fn max_tokens(mut self, max_tokens: i32) -> Self {
54 self.max_tokens = max_tokens;
55
56 self
57 }
58
59 /// This function adds a system prompt to the messages.
60 ///
61 /// **NOTE:** This function doesn't clear the messages array before adding the system prompt!
62 pub fn system_prompt<T>(mut self, system_prompt: T) -> Self
63 where
64 T: Into<String>,
65 {
66 self.add_system_message(system_prompt);
67
68 self
69 }
70
71 /// This function adds a user prompt to the messages.
72 ///
73 /// **NOTE:** This function doesn't clear the messages array before adding the user prompt!
74 pub fn user_prompt<T>(mut self, user_prompt: T) -> Self
75 where
76 T: Into<String>,
77 {
78 self.add_user_message(user_prompt);
79
80 self
81 }
82
83 /// Clears all messages in the chat **including system prompts**.
84 pub fn clear_messages(&mut self) {
85 self.messages.clear()
86 }
87
88 /// Adds the provided `Message` to the chat.
89 pub fn add_message(&mut self, message: Message) {
90 self.messages.push(message)
91 }
92
93 /// Adds the provided message content as a system message.
94 pub fn add_system_message<T>(&mut self, message: T)
95 where
96 T: Into<String>,
97 {
98 self.add_message(Message::system(message))
99 }
100
101 /// Adds the provided message content as a message from the assistant.
102 pub fn add_assistant_message<T>(&mut self, message: T)
103 where
104 T: Into<String>,
105 {
106 self.add_message(Message::assistant(message))
107 }
108
109 /// Adds the provided message content as a message from the user.
110 pub fn add_user_message<T>(&mut self, message: T)
111 where
112 T: Into<String>,
113 {
114 self.add_message(Message::user(message))
115 }
116
117 /// Gets the completion from the server by sending the current `Chat` struct.
118 pub async fn get_completions(&self, server: &LMSServer) -> Result<ChatCompletionsResponse> {
119 let request = ChatCompletionsRequest {
120 max_tokens: self.max_tokens,
121 messages: self.messages.clone(),
122 model: self.model.clone(),
123 temperature: self.temperature,
124 };
125 let response = server.get_chat_completion(request).await?;
126
127 Ok(response)
128 }
129}