1use std::future::Future;
4
5use ccutils::streams::pin_stream;
6use futures::StreamExt;
7use serde::{Deserialize, Serialize};
8use yaaral::prelude::*;
9
10use crate::prelude::*;
11
12#[derive(Debug)]
14pub struct SimpleApi<RT>
15where
16 RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
17{
18 runtime: RT,
19 baseuri: String,
20 port: u16,
21 model: Option<String>,
22}
23
24#[derive(Debug, Serialize, Deserialize)]
25struct Message
26{
27 role: String,
28 content: String,
29}
30
31#[derive(Debug, Serialize)]
32struct ChatRequestBody
33{
34 #[serde(skip_serializing_if = "Option::is_none")]
35 model: Option<String>,
36 messages: Vec<Message>,
37 temperature: f32,
38 max_tokens: u32,
39 stream: bool,
40 grammar: Option<String>,
41}
42
43#[derive(Debug, Serialize)]
44struct GenerationRequestBody
45{
46 #[serde(skip_serializing_if = "Option::is_none")]
47 model: Option<String>,
48 prompt: String,
49 temperature: f32,
50 max_tokens: u32,
51 stream: bool,
52 grammar: Option<String>,
53}
54
55#[derive(Debug, Deserialize)]
56#[allow(dead_code)]
57struct ChatChunk
58{
59 pub id: String,
60 pub object: String,
61 pub created: u64,
62 pub model: String,
63 pub choices: Vec<Choice>,
64 #[serde(default)]
65 pub timings: Option<Timings>, }
67
68#[derive(Debug, Deserialize)]
69#[allow(dead_code)]
70struct Choice
71{
72 pub index: u32,
73 #[serde(default)]
74 pub finish_reason: Option<String>,
75 #[serde(default)]
76 pub delta: Delta,
77}
78
79#[derive(Debug, Deserialize, Default)]
80#[allow(dead_code)]
81struct Delta
82{
83 #[serde(default)]
84 pub role: Option<String>,
85 #[serde(default)]
86 pub content: Option<String>,
87}
88
89#[derive(Debug, Deserialize)]
90#[allow(dead_code)]
91struct Timings
92{
93 pub prompt_n: Option<i32>,
94 pub prompt_ms: Option<f64>,
95 pub prompt_per_token_ms: Option<f64>,
96 pub prompt_per_second: Option<f64>,
97 pub predicted_n: Option<i32>,
98 pub predicted_ms: Option<f64>,
99 pub predicted_per_token_ms: Option<f64>,
100 pub predicted_per_second: Option<f64>,
101}
102
103#[derive(Debug, Deserialize, Serialize)]
104#[serde(untagged)]
105#[allow(clippy::large_enum_variant)]
106enum StreamFrame
107{
108 Delta(GenerationDelta),
109 Final(Final),
110}
111
112#[derive(Debug, Deserialize, Serialize)]
113struct GenerationDelta
114{
115 pub content: String,
116 pub stop: bool,
117 #[serde(skip_serializing_if = "Option::is_none")]
118 pub oaicompat_msg_diffs: Option<Vec<OaiDelta>>,
119}
120
121#[derive(Debug, Deserialize, Serialize)]
122struct OaiDelta
123{
124 pub content_delta: String,
125}
126
127#[derive(Debug, Deserialize, Serialize)]
128struct Final
129{
130 pub content: String,
131 pub generated_text: String,
132 pub stop: bool,
133 pub model: String,
134 pub tokens_predicted: u64,
135 pub tokens_evaluated: u64,
136 pub generation_settings: serde_json::Value,
137 pub prompt: String,
138 pub truncated: bool,
139 pub stopped_eos: bool,
140 pub stopped_word: bool,
141 pub stopped_limit: bool,
142 pub tokens_cached: u64,
143 pub timings: serde_json::Value,
144}
145
146#[derive(Debug, Deserialize, Serialize)]
147pub(crate) struct ApiError
148{
149 pub code: Option<u32>,
150 pub message: String,
151 #[serde(rename = "type")]
152 pub typ: String,
153}
154
155impl<RT> SimpleApi<RT>
156where
157 RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
158{
159 pub fn new(
162 runtime: RT,
163 baseuri: impl Into<String>,
164 port: u16,
165 model: Option<String>,
166 ) -> Result<Self>
167 {
168 Ok(Self {
169 baseuri: baseuri.into(),
170 port,
171 model,
172 runtime,
173 })
174 }
175}
176
177trait Data
178{
179 fn content(&self) -> Option<&String>;
180 fn is_finished(&self) -> bool;
181}
182
183impl Data for ChatChunk
184{
185 fn content(&self) -> Option<&String>
186 {
187 self.choices.first().and_then(|c| c.delta.content.as_ref())
188 }
189 fn is_finished(&self) -> bool
190 {
191 if let Some(reason) = self
192 .choices
193 .first()
194 .and_then(|c| c.finish_reason.as_deref())
195 {
196 if reason == "stop"
197 {
198 return true;
199 }
200 }
201 false
202 }
203}
204
205impl Data for StreamFrame
206{
207 fn content(&self) -> Option<&String>
208 {
209 match self
210 {
211 StreamFrame::Delta(delta) => Some(&delta.content),
212 StreamFrame::Final(_) => None,
213 }
214 }
215 fn is_finished(&self) -> bool
216 {
217 match self
218 {
219 Self::Delta(_) => false,
220 Self::Final(_) => true,
221 }
222 }
223}
224
225fn response_to_stream<D: Data + for<'de> Deserialize<'de>>(
226 response: impl yaaral::http::Response,
227) -> Result<StringStream>
228{
229 let stream = response.into_stream().map(|chunk_result| {
230 let mut results = vec![];
231 match chunk_result
232 {
233 Ok(chunk) =>
234 {
235 let chunk_str = String::from_utf8_lossy(&chunk);
236 for line in chunk_str.lines()
237 {
238 let line = line.trim();
239 if line.starts_with("data:")
240 {
241 let json_str = line.trim_start_matches("data:");
242 match serde_json::from_str::<D>(json_str)
243 {
244 Ok(chunk) =>
245 {
246 if let Some(content) = chunk.content()
247 {
248 results.push(Ok(content.to_owned()));
249 }
250
251 if chunk.is_finished()
252 {
253 break;
254 }
255 }
256 Err(e) => results.push(Err(e.into())),
257 }
258 }
259 else if line.starts_with("error:")
260 {
261 let json_str = line.trim_start_matches("error:");
262 if let Ok(chunk) = serde_json::from_str::<ApiError>(json_str)
263 {
264 results.push(Err(Error::SimpleApiError {
265 code: chunk.code.unwrap_or_default(),
266 message: chunk.message,
267 error_type: chunk.typ,
268 }));
269 }
270 }
271 else if !line.is_empty()
272 {
273 log::error!("Unhandled line: {}.", line);
274 }
275 }
276 }
277 Err(e) =>
278 {
279 results.push(Err(Error::HttpError(format!("{:?}", e))));
280 }
281 }
282 futures::stream::iter(results)
283 });
284
285 let flat_stream = stream.flatten().boxed();
287
288 Ok(pin_stream(flat_stream))
289}
290
291fn grammar_for(format: crate::Format) -> Option<String>
292{
293 match format
294 {
295 crate::Format::Json => Some(include_str!("../data/grammar/json.gbnf").to_string()),
296 crate::Format::Text => None,
297 }
298}
299
300impl<RT> LargeLanguageModel for SimpleApi<RT>
301where
302 RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
303{
304 fn chat_stream(
305 &self,
306 prompt: ChatPrompt,
307 ) -> Result<impl Future<Output = Result<StringStream>> + Send>
308 {
309 let url = format!("{}:{}/v1/chat/completions", self.baseuri, self.port);
310
311 let messages = prompt
312 .messages
313 .into_iter()
314 .map(|m| Message {
315 role: match m.role
316 {
317 Role::User => "user".to_string(),
318 Role::System => "system".to_string(),
319 Role::Assistant => "assistant".to_string(),
320 Role::Custom(custom) => custom,
321 },
322 content: m.content,
323 })
324 .collect();
325
326 let request_body = ChatRequestBody {
327 model: self.model.to_owned(),
328 messages,
329 temperature: 0.7,
330 max_tokens: 2560,
331 stream: true,
332 grammar: grammar_for(prompt.options.format),
333 };
334
335 let rt = self.runtime.clone();
336
337 let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;
338
339 Ok(async move {
340 let response = rt.wpost(yreq).await;
341
342 if !response.status().is_success()
343 {
344 return Err(Error::HttpError(format!(
345 "Error code {}",
346 response.status()
347 )));
348 }
349
350 response_to_stream::<ChatChunk>(response)
351 })
352 }
353 fn generate_stream(
354 &self,
355 prompt: GenerationPrompt,
356 ) -> Result<impl Future<Output = Result<StringStream>> + Send>
357 {
358 let rt = self.runtime.clone();
359 Ok(async move {
360 if prompt.system.is_none() && prompt.assistant.is_none()
361 {
362 let url = format!("{}:{}/v1/completions", self.baseuri, self.port);
363
364 let request_body = GenerationRequestBody {
365 model: self.model.to_owned(),
366 prompt: prompt.user,
367 temperature: 0.7,
368 max_tokens: 2560,
369 stream: true,
370 grammar: grammar_for(prompt.options.format),
371 };
372
373 let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;
374
375 let response = rt.wpost(yreq).await;
376
377 if !response.status().is_success()
378 {
379 return Err(Error::HttpError(format!(
380 "Error code {}",
381 response.status()
382 )));
383 }
384
385 response_to_stream::<StreamFrame>(response)
386 }
387 else
388 {
389 crate::generate_with_chat(self, prompt)?.await
390 }
391 })
392 }
393}