1use std::future::Future;
4
5use futures::StreamExt;
6use serde::{Deserialize, Serialize};
7use yaaral::prelude::*;
8
9use crate::prelude::*;
10
11#[derive(Debug)]
13pub struct SimpleApi<RT>
14where
15 RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
16{
17 runtime: RT,
18 baseuri: String,
19 port: u16,
20 model: Option<String>,
21}
22
23#[derive(Debug, Serialize, Deserialize)]
24struct Message
25{
26 role: String,
27 content: String,
28}
29
30#[derive(Debug, Serialize)]
31struct ChatRequestBody
32{
33 #[serde(skip_serializing_if = "Option::is_none")]
34 model: Option<String>,
35 messages: Vec<Message>,
36 temperature: f32,
37 max_tokens: u32,
38 stream: bool,
39 grammar: Option<String>,
40}
41
42#[derive(Debug, Serialize)]
43struct GenerationRequestBody
44{
45 #[serde(skip_serializing_if = "Option::is_none")]
46 model: Option<String>,
47 prompt: String,
48 temperature: f32,
49 max_tokens: u32,
50 stream: bool,
51 grammar: Option<String>,
52}
53
54#[derive(Debug, Deserialize)]
55#[allow(dead_code)]
56struct ChatChunk
57{
58 pub id: String,
59 pub object: String,
60 pub created: u64,
61 pub model: String,
62 pub choices: Vec<Choice>,
63 #[serde(default)]
64 pub timings: Option<Timings>, }
66
67#[derive(Debug, Deserialize)]
68#[allow(dead_code)]
69struct Choice
70{
71 pub index: u32,
72 #[serde(default)]
73 pub finish_reason: Option<String>,
74 #[serde(default)]
75 pub delta: Delta,
76}
77
78#[derive(Debug, Deserialize, Default)]
79#[allow(dead_code)]
80struct Delta
81{
82 #[serde(default)]
83 pub role: Option<String>,
84 #[serde(default)]
85 pub content: Option<String>,
86}
87
88#[derive(Debug, Deserialize)]
89#[allow(dead_code)]
90struct Timings
91{
92 pub prompt_n: Option<i32>,
93 pub prompt_ms: Option<f64>,
94 pub prompt_per_token_ms: Option<f64>,
95 pub prompt_per_second: Option<f64>,
96 pub predicted_n: Option<i32>,
97 pub predicted_ms: Option<f64>,
98 pub predicted_per_token_ms: Option<f64>,
99 pub predicted_per_second: Option<f64>,
100}
101
102#[derive(Debug, Deserialize, Serialize)]
103#[serde(untagged)]
104enum StreamFrame
105{
106 Delta(GenerationDelta),
107 Final(Final),
108}
109
110#[derive(Debug, Deserialize, Serialize)]
111struct GenerationDelta
112{
113 pub content: String,
114 pub stop: bool,
115 #[serde(skip_serializing_if = "Option::is_none")]
116 pub oaicompat_msg_diffs: Option<Vec<OaiDelta>>,
117}
118
119#[derive(Debug, Deserialize, Serialize)]
120struct OaiDelta
121{
122 pub content_delta: String,
123}
124
125#[derive(Debug, Deserialize, Serialize)]
126struct Final
127{
128 pub content: String,
129 pub generated_text: String,
130 pub stop: bool,
131 pub model: String,
132 pub tokens_predicted: u64,
133 pub tokens_evaluated: u64,
134 pub generation_settings: serde_json::Value,
135 pub prompt: String,
136 pub truncated: bool,
137 pub stopped_eos: bool,
138 pub stopped_word: bool,
139 pub stopped_limit: bool,
140 pub tokens_cached: u64,
141 pub timings: serde_json::Value,
142}
143
144#[derive(Debug, Deserialize, Serialize)]
145pub(crate) struct ApiError
146{
147 pub code: Option<u32>,
148 pub message: String,
149 #[serde(rename = "type")]
150 pub typ: String,
151}
152
153impl<RT> SimpleApi<RT>
154where
155 RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
156{
157 pub fn new(
160 runtime: RT,
161 baseuri: impl Into<String>,
162 port: u16,
163 model: Option<String>,
164 ) -> Result<Self>
165 {
166 Ok(Self {
167 baseuri: baseuri.into(),
168 port,
169 model,
170 runtime,
171 })
172 }
173}
174
175trait Data
176{
177 fn content(&self) -> Option<&String>;
178 fn is_finished(&self) -> bool;
179}
180
181impl Data for ChatChunk
182{
183 fn content(&self) -> Option<&String>
184 {
185 self.choices.get(0).and_then(|c| c.delta.content.as_ref())
186 }
187 fn is_finished(&self) -> bool
188 {
189 if let Some(reason) = self.choices.get(0).and_then(|c| c.finish_reason.as_deref())
190 {
191 if reason == "stop"
192 {
193 return true;
194 }
195 }
196 return false;
197 }
198}
199
200impl Data for StreamFrame
201{
202 fn content(&self) -> Option<&String>
203 {
204 match self
205 {
206 StreamFrame::Delta(delta) => Some(&delta.content),
207 StreamFrame::Final(_) => None,
208 }
209 }
210 fn is_finished(&self) -> bool
211 {
212 match self
213 {
214 Self::Delta(_) => false,
215 Self::Final(_) => true,
216 }
217 }
218}
219
220fn response_to_stream<D: Data + for<'de> Deserialize<'de>>(
221 response: impl yaaral::http::Response,
222) -> Result<StringStream>
223{
224 let stream = response.into_stream().map(|chunk_result| {
225 let mut results = vec![];
226 match chunk_result
227 {
228 Ok(chunk) =>
229 {
230 let chunk_str = String::from_utf8_lossy(&chunk);
231 for line in chunk_str.lines()
232 {
233 let line = line.trim();
234 if line.starts_with("data:")
235 {
236 let json_str = line.trim_start_matches("data:");
237 match serde_json::from_str::<D>(json_str)
238 {
239 Ok(chunk) =>
240 {
241 if let Some(content) = chunk.content()
242 {
243 results.push(Ok(content.to_owned()));
244 }
245
246 if chunk.is_finished()
247 {
248 break;
249 }
250 }
251 Err(e) => results.push(Err(e.into())),
252 }
253 }
254 else if line.starts_with("error:")
255 {
256 let json_str = line.trim_start_matches("error:");
257 if let Ok(chunk) = serde_json::from_str::<ApiError>(json_str)
258 {
259 results.push(Err(Error::SimpleApiError {
260 code: chunk.code.unwrap_or_default(),
261 message: chunk.message,
262 error_type: chunk.typ,
263 }));
264 }
265 }
266 else if !line.is_empty()
267 {
268 log::error!("Unhandled line: {}.", line);
269 }
270 }
271 }
272 Err(e) =>
273 {
274 results.push(Err(Error::HttpError(format!("{:?}", e))));
275 }
276 }
277 futures::stream::iter(results)
278 });
279
280 let flat_stream = stream.flatten().boxed();
282
283 Ok(pin_stream(flat_stream))
284}
285
286fn grammar_for(format: crate::Format) -> Option<String>
287{
288 match format
289 {
290 crate::Format::Json => Some(include_str!("../data/grammar/json.gbnf").to_string()),
291 crate::Format::Text => None,
292 }
293}
294
295impl<RT> LargeLanguageModel for SimpleApi<RT>
296where
297 RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
298{
299 fn chat_stream(
300 &self,
301 prompt: ChatPrompt,
302 ) -> Result<impl Future<Output = Result<StringStream>> + Send>
303 {
304 let url = format!("{}:{}/v1/chat/completions", self.baseuri, self.port);
305
306 let messages = prompt
307 .messages
308 .into_iter()
309 .map(|m| Message {
310 role: match m.role
311 {
312 Role::User => "user".to_string(),
313 Role::System => "system".to_string(),
314 Role::Assistant => "assistant".to_string(),
315 Role::Custom(custom) => custom,
316 },
317 content: m.content,
318 })
319 .collect();
320
321 let request_body = ChatRequestBody {
322 model: self.model.to_owned(),
323 messages,
324 temperature: 0.7,
325 max_tokens: 2560,
326 stream: true,
327 grammar: grammar_for(prompt.format),
328 };
329
330 let rt = self.runtime.clone();
331
332 let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;
333
334 Ok(async move {
335 let response = rt.wpost(yreq).await;
336
337 if !response.status().is_success()
338 {
339 return Err(Error::HttpError(format!(
340 "Error code {}",
341 response.status()
342 )));
343 }
344
345 response_to_stream::<ChatChunk>(response)
346 })
347 }
348 fn generate_stream(
349 &self,
350 prompt: GenerationPrompt,
351 ) -> Result<impl Future<Output = Result<StringStream>> + Send>
352 {
353 let rt = self.runtime.clone();
354 Ok(async move {
355 if prompt.system.is_none() && prompt.assistant.is_none()
356 {
357 let url = format!("{}:{}/v1/completions", self.baseuri, self.port);
358
359 let request_body = GenerationRequestBody {
360 model: self.model.to_owned(),
361 prompt: prompt.user,
362 temperature: 0.7,
363 max_tokens: 2560,
364 stream: true,
365 grammar: grammar_for(prompt.format),
366 };
367
368 let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;
369
370 let response = rt.wpost(yreq).await;
371
372 if !response.status().is_success()
373 {
374 return Err(Error::HttpError(format!(
375 "Error code {}",
376 response.status()
377 )));
378 }
379
380 response_to_stream::<StreamFrame>(response)
381 }
382 else
383 {
384 crate::generate_with_chat(self, prompt)?.await
385 }
386 })
387 }
388}