1use std::future::Future;
4
5use futures::StreamExt;
6use serde::{Deserialize, Serialize};
7use yaaral::prelude::*;
8
9use crate::prelude::*;
10
11#[derive(Debug)]
13pub struct SimpleApi<RT>
14where
15 RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
16{
17 runtime: RT,
18 baseuri: String,
19 port: u16,
20 model: Option<String>,
21}
22
23#[derive(Debug, Serialize, Deserialize)]
24struct Message
25{
26 role: String,
27 content: String,
28}
29
30#[derive(Debug, Serialize)]
31struct ChatRequestBody
32{
33 #[serde(skip_serializing_if = "Option::is_none")]
34 model: Option<String>,
35 messages: Vec<Message>,
36 temperature: f32,
37 max_tokens: u32,
38 stream: bool,
39 grammar: Option<String>,
40}
41
42#[derive(Debug, Serialize)]
43struct GenerationRequestBody
44{
45 #[serde(skip_serializing_if = "Option::is_none")]
46 model: Option<String>,
47 prompt: String,
48 temperature: f32,
49 max_tokens: u32,
50 stream: bool,
51 grammar: Option<String>,
52}
53
54#[derive(Debug, Deserialize)]
55#[allow(dead_code)]
56struct ChatChunk
57{
58 pub id: String,
59 pub object: String,
60 pub created: u64,
61 pub model: String,
62 pub choices: Vec<Choice>,
63 #[serde(default)]
64 pub timings: Option<Timings>, }
66
67#[derive(Debug, Deserialize)]
68#[allow(dead_code)]
69struct Choice
70{
71 pub index: u32,
72 #[serde(default)]
73 pub finish_reason: Option<String>,
74 #[serde(default)]
75 pub delta: Delta,
76}
77
78#[derive(Debug, Deserialize, Default)]
79#[allow(dead_code)]
80struct Delta
81{
82 #[serde(default)]
83 pub role: Option<String>,
84 #[serde(default)]
85 pub content: Option<String>,
86}
87
88#[derive(Debug, Deserialize)]
89#[allow(dead_code)]
90struct Timings
91{
92 pub prompt_n: Option<i32>,
93 pub prompt_ms: Option<f64>,
94 pub prompt_per_token_ms: Option<f64>,
95 pub prompt_per_second: Option<f64>,
96 pub predicted_n: Option<i32>,
97 pub predicted_ms: Option<f64>,
98 pub predicted_per_token_ms: Option<f64>,
99 pub predicted_per_second: Option<f64>,
100}
101
102#[derive(Debug, Deserialize, Serialize)]
103#[serde(untagged)]
104#[allow(clippy::large_enum_variant)]
105enum StreamFrame
106{
107 Delta(GenerationDelta),
108 Final(Final),
109}
110
111#[derive(Debug, Deserialize, Serialize)]
112struct GenerationDelta
113{
114 pub content: String,
115 pub stop: bool,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 pub oaicompat_msg_diffs: Option<Vec<OaiDelta>>,
118}
119
120#[derive(Debug, Deserialize, Serialize)]
121struct OaiDelta
122{
123 pub content_delta: String,
124}
125
126#[derive(Debug, Deserialize, Serialize)]
127struct Final
128{
129 pub content: String,
130 pub generated_text: String,
131 pub stop: bool,
132 pub model: String,
133 pub tokens_predicted: u64,
134 pub tokens_evaluated: u64,
135 pub generation_settings: serde_json::Value,
136 pub prompt: String,
137 pub truncated: bool,
138 pub stopped_eos: bool,
139 pub stopped_word: bool,
140 pub stopped_limit: bool,
141 pub tokens_cached: u64,
142 pub timings: serde_json::Value,
143}
144
145#[derive(Debug, Deserialize, Serialize)]
146pub(crate) struct ApiError
147{
148 pub code: Option<u32>,
149 pub message: String,
150 #[serde(rename = "type")]
151 pub typ: String,
152}
153
154impl<RT> SimpleApi<RT>
155where
156 RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
157{
158 pub fn new(
161 runtime: RT,
162 baseuri: impl Into<String>,
163 port: u16,
164 model: Option<String>,
165 ) -> Result<Self>
166 {
167 Ok(Self {
168 baseuri: baseuri.into(),
169 port,
170 model,
171 runtime,
172 })
173 }
174}
175
176trait Data
177{
178 fn content(&self) -> Option<&String>;
179 fn is_finished(&self) -> bool;
180}
181
182impl Data for ChatChunk
183{
184 fn content(&self) -> Option<&String>
185 {
186 self.choices.first().and_then(|c| c.delta.content.as_ref())
187 }
188 fn is_finished(&self) -> bool
189 {
190 if let Some(reason) = self
191 .choices
192 .first()
193 .and_then(|c| c.finish_reason.as_deref())
194 {
195 if reason == "stop"
196 {
197 return true;
198 }
199 }
200 false
201 }
202}
203
204impl Data for StreamFrame
205{
206 fn content(&self) -> Option<&String>
207 {
208 match self
209 {
210 StreamFrame::Delta(delta) => Some(&delta.content),
211 StreamFrame::Final(_) => None,
212 }
213 }
214 fn is_finished(&self) -> bool
215 {
216 match self
217 {
218 Self::Delta(_) => false,
219 Self::Final(_) => true,
220 }
221 }
222}
223
224fn response_to_stream<D: Data + for<'de> Deserialize<'de>>(
225 response: impl yaaral::http::Response,
226) -> Result<StringStream>
227{
228 let stream = response.into_stream().map(|chunk_result| {
229 let mut results = vec![];
230 match chunk_result
231 {
232 Ok(chunk) =>
233 {
234 let chunk_str = String::from_utf8_lossy(&chunk);
235 for line in chunk_str.lines()
236 {
237 let line = line.trim();
238 if line.starts_with("data:")
239 {
240 let json_str = line.trim_start_matches("data:");
241 match serde_json::from_str::<D>(json_str)
242 {
243 Ok(chunk) =>
244 {
245 if let Some(content) = chunk.content()
246 {
247 results.push(Ok(content.to_owned()));
248 }
249
250 if chunk.is_finished()
251 {
252 break;
253 }
254 }
255 Err(e) => results.push(Err(e.into())),
256 }
257 }
258 else if line.starts_with("error:")
259 {
260 let json_str = line.trim_start_matches("error:");
261 if let Ok(chunk) = serde_json::from_str::<ApiError>(json_str)
262 {
263 results.push(Err(Error::SimpleApiError {
264 code: chunk.code.unwrap_or_default(),
265 message: chunk.message,
266 error_type: chunk.typ,
267 }));
268 }
269 }
270 else if !line.is_empty()
271 {
272 log::error!("Unhandled line: {}.", line);
273 }
274 }
275 }
276 Err(e) =>
277 {
278 results.push(Err(Error::HttpError(format!("{:?}", e))));
279 }
280 }
281 futures::stream::iter(results)
282 });
283
284 let flat_stream = stream.flatten().boxed();
286
287 Ok(pin_stream(flat_stream))
288}
289
290fn grammar_for(format: crate::Format) -> Option<String>
291{
292 match format
293 {
294 crate::Format::Json => Some(include_str!("../data/grammar/json.gbnf").to_string()),
295 crate::Format::Text => None,
296 }
297}
298
299impl<RT> LargeLanguageModel for SimpleApi<RT>
300where
301 RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
302{
303 fn chat_stream(
304 &self,
305 prompt: ChatPrompt,
306 ) -> Result<impl Future<Output = Result<StringStream>> + Send>
307 {
308 let url = format!("{}:{}/v1/chat/completions", self.baseuri, self.port);
309
310 let messages = prompt
311 .messages
312 .into_iter()
313 .map(|m| Message {
314 role: match m.role
315 {
316 Role::User => "user".to_string(),
317 Role::System => "system".to_string(),
318 Role::Assistant => "assistant".to_string(),
319 Role::Custom(custom) => custom,
320 },
321 content: m.content,
322 })
323 .collect();
324
325 let request_body = ChatRequestBody {
326 model: self.model.to_owned(),
327 messages,
328 temperature: 0.7,
329 max_tokens: 2560,
330 stream: true,
331 grammar: grammar_for(prompt.format),
332 };
333
334 let rt = self.runtime.clone();
335
336 let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;
337
338 Ok(async move {
339 let response = rt.wpost(yreq).await;
340
341 if !response.status().is_success()
342 {
343 return Err(Error::HttpError(format!(
344 "Error code {}",
345 response.status()
346 )));
347 }
348
349 response_to_stream::<ChatChunk>(response)
350 })
351 }
352 fn generate_stream(
353 &self,
354 prompt: GenerationPrompt,
355 ) -> Result<impl Future<Output = Result<StringStream>> + Send>
356 {
357 let rt = self.runtime.clone();
358 Ok(async move {
359 if prompt.system.is_none() && prompt.assistant.is_none()
360 {
361 let url = format!("{}:{}/v1/completions", self.baseuri, self.port);
362
363 let request_body = GenerationRequestBody {
364 model: self.model.to_owned(),
365 prompt: prompt.user,
366 temperature: 0.7,
367 max_tokens: 2560,
368 stream: true,
369 grammar: grammar_for(prompt.format),
370 };
371
372 let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;
373
374 let response = rt.wpost(yreq).await;
375
376 if !response.status().is_success()
377 {
378 return Err(Error::HttpError(format!(
379 "Error code {}",
380 response.status()
381 )));
382 }
383
384 response_to_stream::<StreamFrame>(response)
385 }
386 else
387 {
388 crate::generate_with_chat(self, prompt)?.await
389 }
390 })
391 }
392}