kproc_llm/
simple_api.rs

1//! Module for using with simple API (like llama-cpp-server).
2
3use std::future::Future;
4
5use futures::StreamExt;
6use serde::{Deserialize, Serialize};
7use yaaral::prelude::*;
8
9use crate::prelude::*;
10
11/// Simple API.
12#[derive(Debug)]
13pub struct SimpleApi<RT>
14where
15  RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
16{
17  runtime: RT,
18  baseuri: String,
19  port: u16,
20  model: Option<String>,
21}
22
23#[derive(Debug, Serialize, Deserialize)]
24struct Message
25{
26  role: String,
27  content: String,
28}
29
30#[derive(Debug, Serialize)]
31struct ChatRequestBody
32{
33  #[serde(skip_serializing_if = "Option::is_none")]
34  model: Option<String>,
35  messages: Vec<Message>,
36  temperature: f32,
37  max_tokens: u32,
38  stream: bool,
39  grammar: Option<String>,
40}
41
42#[derive(Debug, Serialize)]
43struct GenerationRequestBody
44{
45  #[serde(skip_serializing_if = "Option::is_none")]
46  model: Option<String>,
47  prompt: String,
48  temperature: f32,
49  max_tokens: u32,
50  stream: bool,
51  grammar: Option<String>,
52}
53
54#[derive(Debug, Deserialize)]
55#[allow(dead_code)]
56struct ChatChunk
57{
58  pub id: String,
59  pub object: String,
60  pub created: u64,
61  pub model: String,
62  pub choices: Vec<Choice>,
63  #[serde(default)]
64  pub timings: Option<Timings>, // optional, in case it's not always present
65}
66
67#[derive(Debug, Deserialize)]
68#[allow(dead_code)]
69struct Choice
70{
71  pub index: u32,
72  #[serde(default)]
73  pub finish_reason: Option<String>,
74  #[serde(default)]
75  pub delta: Delta,
76}
77
78#[derive(Debug, Deserialize, Default)]
79#[allow(dead_code)]
80struct Delta
81{
82  #[serde(default)]
83  pub role: Option<String>,
84  #[serde(default)]
85  pub content: Option<String>,
86}
87
88#[derive(Debug, Deserialize)]
89#[allow(dead_code)]
90struct Timings
91{
92  pub prompt_n: Option<i32>,
93  pub prompt_ms: Option<f64>,
94  pub prompt_per_token_ms: Option<f64>,
95  pub prompt_per_second: Option<f64>,
96  pub predicted_n: Option<i32>,
97  pub predicted_ms: Option<f64>,
98  pub predicted_per_token_ms: Option<f64>,
99  pub predicted_per_second: Option<f64>,
100}
101
102#[derive(Debug, Deserialize, Serialize)]
103#[serde(untagged)]
104#[allow(clippy::large_enum_variant)]
105enum StreamFrame
106{
107  Delta(GenerationDelta),
108  Final(Final),
109}
110
111#[derive(Debug, Deserialize, Serialize)]
112struct GenerationDelta
113{
114  pub content: String,
115  pub stop: bool,
116  #[serde(skip_serializing_if = "Option::is_none")]
117  pub oaicompat_msg_diffs: Option<Vec<OaiDelta>>,
118}
119
120#[derive(Debug, Deserialize, Serialize)]
121struct OaiDelta
122{
123  pub content_delta: String,
124}
125
126#[derive(Debug, Deserialize, Serialize)]
127struct Final
128{
129  pub content: String,
130  pub generated_text: String,
131  pub stop: bool,
132  pub model: String,
133  pub tokens_predicted: u64,
134  pub tokens_evaluated: u64,
135  pub generation_settings: serde_json::Value,
136  pub prompt: String,
137  pub truncated: bool,
138  pub stopped_eos: bool,
139  pub stopped_word: bool,
140  pub stopped_limit: bool,
141  pub tokens_cached: u64,
142  pub timings: serde_json::Value,
143}
144
145#[derive(Debug, Deserialize, Serialize)]
146pub(crate) struct ApiError
147{
148  pub code: Option<u32>,
149  pub message: String,
150  #[serde(rename = "type")]
151  pub typ: String,
152}
153
154impl<RT> SimpleApi<RT>
155where
156  RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
157{
158  /// Create a new SimpleApi object that will query a LLM end-point at baseuri (e.g. http://localhost)
159  /// om the given port (e.g. 8080), using the optional model.
160  pub fn new(
161    runtime: RT,
162    baseuri: impl Into<String>,
163    port: u16,
164    model: Option<String>,
165  ) -> Result<Self>
166  {
167    Ok(Self {
168      baseuri: baseuri.into(),
169      port,
170      model,
171      runtime,
172    })
173  }
174}
175
176trait Data
177{
178  fn content(&self) -> Option<&String>;
179  fn is_finished(&self) -> bool;
180}
181
182impl Data for ChatChunk
183{
184  fn content(&self) -> Option<&String>
185  {
186    self.choices.first().and_then(|c| c.delta.content.as_ref())
187  }
188  fn is_finished(&self) -> bool
189  {
190    if let Some(reason) = self
191      .choices
192      .first()
193      .and_then(|c| c.finish_reason.as_deref())
194    {
195      if reason == "stop"
196      {
197        return true;
198      }
199    }
200    false
201  }
202}
203
204impl Data for StreamFrame
205{
206  fn content(&self) -> Option<&String>
207  {
208    match self
209    {
210      StreamFrame::Delta(delta) => Some(&delta.content),
211      StreamFrame::Final(_) => None,
212    }
213  }
214  fn is_finished(&self) -> bool
215  {
216    match self
217    {
218      Self::Delta(_) => false,
219      Self::Final(_) => true,
220    }
221  }
222}
223
224fn response_to_stream<D: Data + for<'de> Deserialize<'de>>(
225  response: impl yaaral::http::Response,
226) -> Result<StringStream>
227{
228  let stream = response.into_stream().map(|chunk_result| {
229    let mut results = vec![];
230    match chunk_result
231    {
232      Ok(chunk) =>
233      {
234        let chunk_str = String::from_utf8_lossy(&chunk);
235        for line in chunk_str.lines()
236        {
237          let line = line.trim();
238          if line.starts_with("data:")
239          {
240            let json_str = line.trim_start_matches("data:");
241            match serde_json::from_str::<D>(json_str)
242            {
243              Ok(chunk) =>
244              {
245                if let Some(content) = chunk.content()
246                {
247                  results.push(Ok(content.to_owned()));
248                }
249
250                if chunk.is_finished()
251                {
252                  break;
253                }
254              }
255              Err(e) => results.push(Err(e.into())),
256            }
257          }
258          else if line.starts_with("error:")
259          {
260            let json_str = line.trim_start_matches("error:");
261            if let Ok(chunk) = serde_json::from_str::<ApiError>(json_str)
262            {
263              results.push(Err(Error::SimpleApiError {
264                code: chunk.code.unwrap_or_default(),
265                message: chunk.message,
266                error_type: chunk.typ,
267              }));
268            }
269          }
270          else if !line.is_empty()
271          {
272            log::error!("Unhandled line: {}.", line);
273          }
274        }
275      }
276      Err(e) =>
277      {
278        results.push(Err(Error::HttpError(format!("{:?}", e))));
279      }
280    }
281    futures::stream::iter(results)
282  });
283
284  // Flatten nested streams and box it
285  let flat_stream = stream.flatten().boxed();
286
287  Ok(pin_stream(flat_stream))
288}
289
290fn grammar_for(format: crate::Format) -> Option<String>
291{
292  match format
293  {
294    crate::Format::Json => Some(include_str!("../data/grammar/json.gbnf").to_string()),
295    crate::Format::Text => None,
296  }
297}
298
299impl<RT> LargeLanguageModel for SimpleApi<RT>
300where
301  RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
302{
303  fn chat_stream(
304    &self,
305    prompt: ChatPrompt,
306  ) -> Result<impl Future<Output = Result<StringStream>> + Send>
307  {
308    let url = format!("{}:{}/v1/chat/completions", self.baseuri, self.port);
309
310    let messages = prompt
311      .messages
312      .into_iter()
313      .map(|m| Message {
314        role: match m.role
315        {
316          Role::User => "user".to_string(),
317          Role::System => "system".to_string(),
318          Role::Assistant => "assistant".to_string(),
319          Role::Custom(custom) => custom,
320        },
321        content: m.content,
322      })
323      .collect();
324
325    let request_body = ChatRequestBody {
326      model: self.model.to_owned(),
327      messages,
328      temperature: 0.7,
329      max_tokens: 2560,
330      stream: true,
331      grammar: grammar_for(prompt.format),
332    };
333
334    let rt = self.runtime.clone();
335
336    let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;
337
338    Ok(async move {
339      let response = rt.wpost(yreq).await;
340
341      if !response.status().is_success()
342      {
343        return Err(Error::HttpError(format!(
344          "Error code {}",
345          response.status()
346        )));
347      }
348
349      response_to_stream::<ChatChunk>(response)
350    })
351  }
352  fn generate_stream(
353    &self,
354    prompt: GenerationPrompt,
355  ) -> Result<impl Future<Output = Result<StringStream>> + Send>
356  {
357    let rt = self.runtime.clone();
358    Ok(async move {
359      if prompt.system.is_none() && prompt.assistant.is_none()
360      {
361        let url = format!("{}:{}/v1/completions", self.baseuri, self.port);
362
363        let request_body = GenerationRequestBody {
364          model: self.model.to_owned(),
365          prompt: prompt.user,
366          temperature: 0.7,
367          max_tokens: 2560,
368          stream: true,
369          grammar: grammar_for(prompt.format),
370        };
371
372        let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;
373
374        let response = rt.wpost(yreq).await;
375
376        if !response.status().is_success()
377        {
378          return Err(Error::HttpError(format!(
379            "Error code {}",
380            response.status()
381          )));
382        }
383
384        response_to_stream::<StreamFrame>(response)
385      }
386      else
387      {
388        crate::generate_with_chat(self, prompt)?.await
389      }
390    })
391  }
392}