lm_studio_api/chat/
chat.rs1use crate::prelude::*;
2use super::*;
3
4use reqwest::Client;
5use futures_util::StreamExt;
6use tokio::sync::mpsc;
7use tokio_stream::wrappers::UnboundedReceiverStream;
8
9pub struct Chat {
11 pub(crate) model: Model,
12 pub(crate) context: Context,
13 pub(crate) host: String,
14 pub(crate) client: Client,
15 pub(crate) reader: Option<ResponseReader>,
16}
17
18impl Chat {
19 pub async fn new<M: Into<Model>, C: Into<Context>>(model: M, context: C, port: u16) -> Result<Self> {
21 let mut chat = Self {
22 model: model.into(),
23 context: context.into(),
24 host: fmt!("http://localhost:{port}"),
25 client: Client::new(),
26 reader: None,
27 };
28
29 chat.load_model(chat.model.clone()).await?;
31
32 Ok(chat)
33 }
34
35 pub async fn load_model<M: Into<Model>>(&mut self, model: M) -> Result<()> {
37 let request = Prompt {
38 model: model.into(),
39 prompt: str!("Hello"),
40 stream: false,
41 ..Default::default()
42 }.into();
43
44 let _ = self.send(request).await?;
45
46 Ok(())
47 }
48
49 pub async fn send(&mut self, request: Request) -> Result<Option<Response>> {
51 self.context.update_system_info().await;
52
53 match request {
54 Request::Messages(request) => self.handle_messages(request).await,
55 Request::Prompt(request) => self.handle_prompt(request).await,
56 Request::Embeddings(request) => self.handle_embeddings(request).await,
57 }
58 }
59
60 async fn handle_messages(&mut self, mut request: Messages) -> Result<Option<Response>> {
62 let url = fmt!("{}/v1/chat/completions", self.host);
63
64 if let Model::Other(s) = &request.model {
66 if s.is_empty() {
67 request.model = self.model.clone();
68 }
69 }
70
71 request.messages = if request.context {
73 for msg in request.messages {
74 self.context.add(msg);
75 }
76
77 self.context.get()
78 } else {
79 let mut context = self.context.clone();
80 for msg in request.messages {
81 context.add(msg);
82 }
83
84 context.get()
85 };
86
87 if !request.stream {
89 let mut response = self.client.post(&url)
90 .json(&request)
91 .send()
92 .await?
93 .error_for_status()?
94 .json::<Response>()
95 .await?;
96
97 if request.skip_think {
99 let re = re!(r"(?s)<think>.*?</think>");
100
101 for choice in &mut response.choices {
102 let message = choice.message.as_mut().unwrap();
103 message.content = re.replace_all(&message.content, "").trim().to_string();
104 }
105 }
106
107 if request.context {
109 if let Some(choice) = response.choices.get(0) {
110 let message = choice.message.as_ref().unwrap();
111 let answer = Message::new(Role::Assistant, message.content.clone());
112 self.context.add(answer);
113 }
114 }
115
116 Ok(Some(response))
117 } else {
118 self.spawn_reader(url.clone(), request.clone(), request.context, request.skip_think).await?;
120
121 Ok(None)
122 }
123 }
124
125 async fn handle_prompt(&mut self, mut request: Prompt) -> Result<Option<Response>> {
127 let url = fmt!("{}/v1/completions", self.host);
128
129 if let Model::Other(s) = &request.model {
131 if s.is_empty() {
132 request.model = self.model.clone();
133 }
134 }
135
136 request.prompt = if request.context {
138 self.context.add(request.prompt);
139 self.context.get_as_string()
140 } else {
141 let mut context = self.context.clone();
142 context.add(request.prompt);
143 context.get_as_string()
144 };
145
146 if !request.stream {
148 let mut response = self.client.post(&url)
149 .json(&request)
150 .send()
151 .await?
152 .error_for_status()?
153 .json::<Response>()
154 .await?;
155
156 if request.skip_think {
157 let re = re!(r"(?s)<think>.*?</think>");
158
159 for choice in &mut response.choices {
160 let text = choice.text.as_mut().unwrap();
161 *text = re.replace_all(&text, "").trim().to_string();
162 }
163 }
164
165 Ok(Some(response))
166 } else {
167 self.spawn_reader(url.clone(), request.clone(), false, request.skip_think).await?;
169
170 Ok(None)
171 }
172 }
173
174 async fn handle_embeddings(&mut self, mut request: Embeddings) -> Result<Option<Response>> {
176 let url = fmt!("{}/v1/embeddings", self.host);
177
178 if let Model::Other(s) = &request.model {
180 if s.is_empty() {
181 request.model = self.model.clone();
182 }
183 }
184
185 let response = self.client.post(&url)
187 .json(&request)
188 .send()
189 .await?
190 .error_for_status()?
191 .json::<Response>()
192 .await?;
193
194 Ok(Some(response))
195 }
196
197 async fn spawn_reader<J>(&mut self, url: String, request: J, context: bool, skip_think: bool) -> Result<()>
199 where J: Serialize + Send + Sync + 'static,
200 {
201 let (tx, rx) = mpsc::unbounded_channel::<Result<StreamChoice>>();
202 let client = self.client.clone();
203
204 self.reader = Some( ResponseReader::new(UnboundedReceiverStream::new(rx), context) );
205
206 tokio::spawn(async move {
207 let mut is_thinking = false;
208 let mut is_after_thinking = false;
209
210 let response = client.post(&url)
211 .json(&request)
212 .send()
213 .await;
214
215 match response {
216 Ok(response) => {
217 let mut stream = response.bytes_stream();
218
219 while let Some(item) = stream.next().await {
220 match item {
221 Ok(chunk) => {
222 let chunk = String::from_utf8_lossy(&chunk);
223
224 for line in chunk.lines() {
225 if line.starts_with("data: ") {
227 let data = &line[6..];
228 if data == "[DONE]" {
229 break;
230 }
231
232 let stream: Result<Stream> = json::from_str(data).map_err(Into::into);
233 let stream = if let Ok(r) = stream { r }else{ continue };
234
235 for mut choice in stream.choices {
236 if let Some(text) = choice.text_mut() {
237 if skip_think {
239 if is_thinking {
240 if text.contains("</think>") {
241 is_thinking = false;
242 is_after_thinking = true;
243 }
244 continue;
245 }
246 else if text.contains("<think>") {
247 is_thinking = true;
248 continue;
249 }
250
251 if is_after_thinking {
253 *text = text.trim_start().to_string();
254 is_after_thinking = false;
255 }
256 }
257
258 if tx.send(Ok(choice)).is_err() {
260 break;
261 }
262 } else {
263 continue;
264 }
265 }
266 }
267 }
268 },
269
270 Err(e) => {
271 let _ = tx.send(Err(e.into()));
272 break;
273 }
274 }
275 }
276 },
277
278 Err(e) => {
279 let _ = tx.send(Err(e.into()));
280 }
281 }
282 });
283
284 Ok(())
285 }
286
287 pub async fn next(&mut self) -> Option<Result<StreamChoice>> {
289 if let Some(reader) = &mut self.reader {
290 let result = reader.next().await;
291
292 if reader.context && reader.is_ready {
293 self.context.add(reader.message.clone())
294 }
295
296 result
297 } else {
298 self.reader = None;
299 None
300 }
301 }
302}