kproc_llm/
simple_api.rs

1//! Module for using with simple API (like llama-cpp-server).
2
3use std::future::Future;
4
5use futures::StreamExt;
6use serde::{Deserialize, Serialize};
7use yaaral::prelude::*;
8
9use crate::prelude::*;
10
11/// Simple API.
12#[derive(Debug)]
13pub struct SimpleApi<RT>
14where
15  RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
16{
17  runtime: RT,
18  baseuri: String,
19  port: u16,
20  model: Option<String>,
21}
22
23#[derive(Debug, Serialize, Deserialize)]
24struct Message
25{
26  role: String,
27  content: String,
28}
29
30#[derive(Debug, Serialize)]
31struct ChatRequestBody
32{
33  #[serde(skip_serializing_if = "Option::is_none")]
34  model: Option<String>,
35  messages: Vec<Message>,
36  temperature: f32,
37  max_tokens: u32,
38  stream: bool,
39  grammar: Option<String>,
40}
41
42#[derive(Debug, Serialize)]
43struct GenerationRequestBody
44{
45  #[serde(skip_serializing_if = "Option::is_none")]
46  model: Option<String>,
47  prompt: String,
48  temperature: f32,
49  max_tokens: u32,
50  stream: bool,
51  grammar: Option<String>,
52}
53
54#[derive(Debug, Deserialize)]
55#[allow(dead_code)]
56struct ChatChunk
57{
58  pub id: String,
59  pub object: String,
60  pub created: u64,
61  pub model: String,
62  pub choices: Vec<Choice>,
63  #[serde(default)]
64  pub timings: Option<Timings>, // optional, in case it's not always present
65}
66
67#[derive(Debug, Deserialize)]
68#[allow(dead_code)]
69struct Choice
70{
71  pub index: u32,
72  #[serde(default)]
73  pub finish_reason: Option<String>,
74  #[serde(default)]
75  pub delta: Delta,
76}
77
78#[derive(Debug, Deserialize, Default)]
79#[allow(dead_code)]
80struct Delta
81{
82  #[serde(default)]
83  pub role: Option<String>,
84  #[serde(default)]
85  pub content: Option<String>,
86}
87
88#[derive(Debug, Deserialize)]
89#[allow(dead_code)]
90struct Timings
91{
92  pub prompt_n: Option<i32>,
93  pub prompt_ms: Option<f64>,
94  pub prompt_per_token_ms: Option<f64>,
95  pub prompt_per_second: Option<f64>,
96  pub predicted_n: Option<i32>,
97  pub predicted_ms: Option<f64>,
98  pub predicted_per_token_ms: Option<f64>,
99  pub predicted_per_second: Option<f64>,
100}
101
102#[derive(Debug, Deserialize, Serialize)]
103#[serde(untagged)]
104enum StreamFrame
105{
106  Delta(GenerationDelta),
107  Final(Final),
108}
109
110#[derive(Debug, Deserialize, Serialize)]
111struct GenerationDelta
112{
113  pub content: String,
114  pub stop: bool,
115  #[serde(skip_serializing_if = "Option::is_none")]
116  pub oaicompat_msg_diffs: Option<Vec<OaiDelta>>,
117}
118
119#[derive(Debug, Deserialize, Serialize)]
120struct OaiDelta
121{
122  pub content_delta: String,
123}
124
125#[derive(Debug, Deserialize, Serialize)]
126struct Final
127{
128  pub content: String,
129  pub generated_text: String,
130  pub stop: bool,
131  pub model: String,
132  pub tokens_predicted: u64,
133  pub tokens_evaluated: u64,
134  pub generation_settings: serde_json::Value,
135  pub prompt: String,
136  pub truncated: bool,
137  pub stopped_eos: bool,
138  pub stopped_word: bool,
139  pub stopped_limit: bool,
140  pub tokens_cached: u64,
141  pub timings: serde_json::Value,
142}
143
144#[derive(Debug, Deserialize, Serialize)]
145pub(crate) struct ApiError
146{
147  pub code: Option<u32>,
148  pub message: String,
149  #[serde(rename = "type")]
150  pub typ: String,
151}
152
153impl<RT> SimpleApi<RT>
154where
155  RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
156{
157  /// Create a new SimpleApi object that will query a LLM end-point at baseuri (e.g. http://localhost)
158  /// om the given port (e.g. 8080), using the optional model.
159  pub fn new(
160    runtime: RT,
161    baseuri: impl Into<String>,
162    port: u16,
163    model: Option<String>,
164  ) -> Result<Self>
165  {
166    Ok(Self {
167      baseuri: baseuri.into(),
168      port,
169      model,
170      runtime,
171    })
172  }
173}
174
175trait Data
176{
177  fn content(&self) -> Option<&String>;
178  fn is_finished(&self) -> bool;
179}
180
181impl Data for ChatChunk
182{
183  fn content(&self) -> Option<&String>
184  {
185    self.choices.get(0).and_then(|c| c.delta.content.as_ref())
186  }
187  fn is_finished(&self) -> bool
188  {
189    if let Some(reason) = self.choices.get(0).and_then(|c| c.finish_reason.as_deref())
190    {
191      if reason == "stop"
192      {
193        return true;
194      }
195    }
196    return false;
197  }
198}
199
200impl Data for StreamFrame
201{
202  fn content(&self) -> Option<&String>
203  {
204    match self
205    {
206      StreamFrame::Delta(delta) => Some(&delta.content),
207      StreamFrame::Final(_) => None,
208    }
209  }
210  fn is_finished(&self) -> bool
211  {
212    match self
213    {
214      Self::Delta(_) => false,
215      Self::Final(_) => true,
216    }
217  }
218}
219
220fn response_to_stream<D: Data + for<'de> Deserialize<'de>>(
221  response: impl yaaral::http::Response,
222) -> Result<StringStream>
223{
224  let stream = response.into_stream().map(|chunk_result| {
225    let mut results = vec![];
226    match chunk_result
227    {
228      Ok(chunk) =>
229      {
230        let chunk_str = String::from_utf8_lossy(&chunk);
231        for line in chunk_str.lines()
232        {
233          let line = line.trim();
234          if line.starts_with("data:")
235          {
236            let json_str = line.trim_start_matches("data:");
237            match serde_json::from_str::<D>(json_str)
238            {
239              Ok(chunk) =>
240              {
241                if let Some(content) = chunk.content()
242                {
243                  results.push(Ok(content.to_owned()));
244                }
245
246                if chunk.is_finished()
247                {
248                  break;
249                }
250              }
251              Err(e) => results.push(Err(e.into())),
252            }
253          }
254          else if line.starts_with("error:")
255          {
256            let json_str = line.trim_start_matches("error:");
257            if let Ok(chunk) = serde_json::from_str::<ApiError>(json_str)
258            {
259              results.push(Err(Error::SimpleApiError {
260                code: chunk.code.unwrap_or_default(),
261                message: chunk.message,
262                error_type: chunk.typ,
263              }));
264            }
265          }
266          else if !line.is_empty()
267          {
268            log::error!("Unhandled line: {}.", line);
269          }
270        }
271      }
272      Err(e) =>
273      {
274        results.push(Err(Error::HttpError(format!("{:?}", e))));
275      }
276    }
277    futures::stream::iter(results)
278  });
279
280  // Flatten nested streams and box it
281  let flat_stream = stream.flatten().boxed();
282
283  Ok(pin_stream(flat_stream))
284}
285
286fn grammar_for(format: crate::Format) -> Option<String>
287{
288  match format
289  {
290    crate::Format::Json => Some(include_str!("../data/grammar/json.gbnf").to_string()),
291    crate::Format::Text => None,
292  }
293}
294
295impl<RT> LargeLanguageModel for SimpleApi<RT>
296where
297  RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
298{
299  fn chat_stream(
300    &self,
301    prompt: ChatPrompt,
302  ) -> Result<impl Future<Output = Result<StringStream>> + Send>
303  {
304    let url = format!("{}:{}/v1/chat/completions", self.baseuri, self.port);
305
306    let messages = prompt
307      .messages
308      .into_iter()
309      .map(|m| Message {
310        role: match m.role
311        {
312          Role::User => "user".to_string(),
313          Role::System => "system".to_string(),
314          Role::Assistant => "assistant".to_string(),
315          Role::Custom(custom) => custom,
316        },
317        content: m.content,
318      })
319      .collect();
320
321    let request_body = ChatRequestBody {
322      model: self.model.to_owned(),
323      messages,
324      temperature: 0.7,
325      max_tokens: 2560,
326      stream: true,
327      grammar: grammar_for(prompt.format),
328    };
329
330    let rt = self.runtime.clone();
331
332    let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;
333
334    Ok(async move {
335      let response = rt.wpost(yreq).await;
336
337      if !response.status().is_success()
338      {
339        return Err(Error::HttpError(format!(
340          "Error code {}",
341          response.status()
342        )));
343      }
344
345      response_to_stream::<ChatChunk>(response)
346    })
347  }
348  fn generate_stream(
349    &self,
350    prompt: GenerationPrompt,
351  ) -> Result<impl Future<Output = Result<StringStream>> + Send>
352  {
353    let rt = self.runtime.clone();
354    Ok(async move {
355      if prompt.system.is_none() && prompt.assistant.is_none()
356      {
357        let url = format!("{}:{}/v1/completions", self.baseuri, self.port);
358
359        let request_body = GenerationRequestBody {
360          model: self.model.to_owned(),
361          prompt: prompt.user,
362          temperature: 0.7,
363          max_tokens: 2560,
364          stream: true,
365          grammar: grammar_for(prompt.format),
366        };
367
368        let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;
369
370        let response = rt.wpost(yreq).await;
371
372        if !response.status().is_success()
373        {
374          return Err(Error::HttpError(format!(
375            "Error code {}",
376            response.status()
377          )));
378        }
379
380        response_to_stream::<StreamFrame>(response)
381      }
382      else
383      {
384        crate::generate_with_chat(self, prompt)?.await
385      }
386    })
387  }
388}