lm_studio_api/chat/
chat.rs

1use crate::prelude::*;
2use super::*;
3
4use reqwest::Client;
5use futures_util::StreamExt;
6use tokio::sync::mpsc;
7use tokio_stream::wrappers::UnboundedReceiverStream;
8
9/// The LM Studio chat
10pub struct Chat {
11    pub(crate) model: Model,
12    pub(crate) context: Context,
13    pub(crate) host: String,
14    pub(crate) client: Client,
15    pub(crate) reader: Option<ResponseReader>,
16}
17
18impl Chat {
19    /// Creates a new simple chat
20    pub async fn new<M: Into<Model>, C: Into<Context>>(model: M, context: C, port: u16) -> Result<Self> {
21        let mut chat = Self {
22            model: model.into(),
23            context: context.into(),
24            host: fmt!("http://localhost:{port}"),
25            client: Client::new(),
26            reader: None,
27        };
28
29        // loading model:
30        chat.load_model(chat.model.clone()).await?;
31
32        Ok(chat)
33    }
34
35    /// Loads AI model to memory
36    pub async fn load_model<M: Into<Model>>(&mut self, model: M) -> Result<()> {
37        let request = Prompt {
38            model: model.into(),
39            prompt: str!("Hello"),
40            stream: false,
41            ..Default::default()
42        }.into();
43
44        let _ = self.send(request).await?;
45
46        Ok(())
47    }
48
49    /// Send request to chat
50    pub async fn send(&mut self, request: Request) -> Result<Option<Response>> {
51        self.context.update_system_info().await;
52        
53        match request {
54            Request::Messages(request) => self.handle_messages(request).await,
55            Request::Prompt(request) => self.handle_prompt(request).await,
56            Request::Embeddings(request) => self.handle_embeddings(request).await,
57        }
58    }
59
60    /// Handle messages request
61    async fn handle_messages(&mut self, mut request: Messages) -> Result<Option<Response>> {
62        let url = fmt!("{}/v1/chat/completions", self.host);
63
64        // choose AI model:
65        if let Model::Other(s) = &request.model {
66            if s.is_empty() {
67                request.model = self.model.clone();
68            }
69        }
70        
71        // add request to context:
72        request.messages = if request.context {
73            for msg in request.messages {
74                self.context.add(msg);
75            }
76
77            self.context.get()
78        } else {
79            let mut context = self.context.clone();
80            for msg in request.messages {
81                context.add(msg);
82            }
83
84            context.get()
85        };
86
87        // handle request:
88        if !request.stream {
89            let mut response = self.client.post(&url)
90                .json(&request)
91                .send()
92                .await?
93                .error_for_status()?
94                .json::<Response>()
95                .await?;
96
97            // filtering <think>..</think> block:
98            if request.skip_think {
99                let re = re!(r"(?s)<think>.*?</think>");
100                
101                for choice in &mut response.choices {
102                    let message = choice.message.as_mut().unwrap();
103                    message.content = re.replace_all(&message.content, "").trim().to_string();
104                }
105            }
106
107            // add response to context:
108            if request.context {
109                if let Some(choice) = response.choices.get(0) {
110                    let message = choice.message.as_ref().unwrap();
111                    let answer = Message::new(Role::Assistant, message.content.clone());
112                    self.context.add(answer);
113                }
114            }
115
116            Ok(Some(response))
117        } else {
118            // spawning stream reader:
119            self.spawn_reader(url.clone(), request.clone(), request.context, request.skip_think).await?;
120
121            Ok(None)
122        }
123    }
124
125    /// Handle prompt request
126    async fn handle_prompt(&mut self, mut request: Prompt) -> Result<Option<Response>> {
127        let url = fmt!("{}/v1/completions", self.host);
128
129        // choose AI model:
130        if let Model::Other(s) = &request.model {
131            if s.is_empty() {
132                request.model = self.model.clone();
133            }
134        }
135        
136        // add request to context:
137        request.prompt = if request.context {
138            self.context.add(request.prompt);
139            self.context.get_as_string()
140        } else {
141            let mut context = self.context.clone();
142            context.add(request.prompt);
143            context.get_as_string()
144        };
145
146        // handle request:
147        if !request.stream {
148            let mut response = self.client.post(&url)
149                .json(&request)
150                .send()
151                .await?
152                .error_for_status()?
153                .json::<Response>()
154                .await?;
155
156            if request.skip_think {
157                let re = re!(r"(?s)<think>.*?</think>");
158                
159                for choice in &mut response.choices {
160                    let text = choice.text.as_mut().unwrap();
161                    *text = re.replace_all(&text, "").trim().to_string();
162                }
163            }
164
165            Ok(Some(response))
166        } else {
167            // spawning stream reader:
168            self.spawn_reader(url.clone(), request.clone(), false, request.skip_think).await?;
169
170            Ok(None)
171        }
172    }
173
174    /// Handle embeddings request
175    async fn handle_embeddings(&mut self, mut request: Embeddings) -> Result<Option<Response>> {
176        let url = fmt!("{}/v1/embeddings", self.host);
177
178        // choose AI model:
179        if let Model::Other(s) = &request.model {
180            if s.is_empty() {
181                request.model = self.model.clone();
182            }
183        }
184        
185        // handle request:
186        let response = self.client.post(&url)
187            .json(&request)
188            .send()
189            .await?
190            .error_for_status()?
191            .json::<Response>()
192            .await?;
193
194        Ok(Some(response))
195    }
196
197    /// Spawns stream reader
198    async fn spawn_reader<J>(&mut self, url: String, request: J, context: bool, skip_think: bool) -> Result<()>
199    where J: Serialize + Send + Sync + 'static,
200    {
201        let (tx, rx) = mpsc::unbounded_channel::<Result<StreamChoice>>();
202        let client = self.client.clone();
203
204        self.reader = Some( ResponseReader::new(UnboundedReceiverStream::new(rx), context) );
205        
206        tokio::spawn(async move {
207            let mut is_thinking = false;
208            let mut is_after_thinking = false;
209            
210            let response = client.post(&url)
211                .json(&request)
212                .send()
213                .await;
214
215            match response {
216                Ok(response) => {
217                    let mut stream = response.bytes_stream();
218
219                    while let Some(item) = stream.next().await {
220                        match item {
221                            Ok(chunk) => {
222                                let chunk = String::from_utf8_lossy(&chunk);
223
224                                for line in chunk.lines() {
225                                    // parsing response line:
226                                    if line.starts_with("data: ") {
227                                        let data = &line[6..];
228                                        if data == "[DONE]" {
229                                            break;
230                                        }
231
232                                        let stream: Result<Stream> = json::from_str(data).map_err(Into::into);
233                                        let stream = if let Ok(r) = stream { r }else{ continue };
234
235                                        for mut choice in stream.choices {
236                                            if let Some(text) = choice.text_mut() {
237                                                // filtering <think>..</think> block:
238                                                if skip_think {
239                                                    if is_thinking {
240                                                        if text.contains("</think>") {
241                                                            is_thinking = false;
242                                                            is_after_thinking = true;
243                                                        }
244                                                        continue;
245                                                    }
246                                                    else if text.contains("<think>") {
247                                                        is_thinking = true;
248                                                        continue;
249                                                    }
250
251                                                    // trim extra spaces:
252                                                    if is_after_thinking {
253                                                        *text = text.trim_start().to_string();
254                                                        is_after_thinking = false;
255                                                    }
256                                                }
257                                                
258                                                // send answer part to channel:
259                                                if tx.send(Ok(choice)).is_err() {
260                                                    break;
261                                                }
262                                            } else {
263                                                continue;
264                                            }
265                                        }
266                                    }
267                                }
268                            },
269
270                            Err(e) => {
271                                let _ = tx.send(Err(e.into()));
272                                break;
273                            }
274                        }
275                    }
276                },
277                
278                Err(e) => {
279                    let _ = tx.send(Err(e.into()));
280                }
281            }
282        });
283
284        Ok(())
285    }
286
287    /// Read next stream choice
288    pub async fn next(&mut self) -> Option<Result<StreamChoice>> {
289        if let Some(reader) = &mut self.reader {
290            let result = reader.next().await;
291
292            if reader.context && reader.is_ready {
293                self.context.add(reader.message.clone())
294            }
295            
296            result
297        } else {
298            self.reader = None;
299            None
300        }
301    }
302}