Skip to main content

kproc_llm/
simple_api.rs

1//! Module for using with simple API (like llama-cpp-server).
2
3use std::future::Future;
4
5use ccutils::streams::pin_stream;
6use futures::StreamExt;
7use serde::{Deserialize, Serialize};
8use yaaral::prelude::*;
9
10use crate::prelude::*;
11
12/// Simple API.
13#[derive(Debug)]
14pub struct SimpleApi<RT>
15where
16  RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
17{
18  runtime: RT,
19  baseuri: String,
20  port: u16,
21  model: Option<String>,
22}
23
24#[derive(Debug, Serialize, Deserialize)]
25struct Message
26{
27  role: String,
28  content: String,
29}
30
31#[derive(Debug, Serialize)]
32struct ChatRequestBody
33{
34  #[serde(skip_serializing_if = "Option::is_none")]
35  model: Option<String>,
36  messages: Vec<Message>,
37  temperature: f32,
38  max_tokens: u32,
39  stream: bool,
40  grammar: Option<String>,
41}
42
43#[derive(Debug, Serialize)]
44struct GenerationRequestBody
45{
46  #[serde(skip_serializing_if = "Option::is_none")]
47  model: Option<String>,
48  prompt: String,
49  temperature: f32,
50  max_tokens: u32,
51  stream: bool,
52  grammar: Option<String>,
53}
54
55#[derive(Debug, Deserialize)]
56#[allow(dead_code)]
57struct ChatChunk
58{
59  pub id: String,
60  pub object: String,
61  pub created: u64,
62  pub model: String,
63  pub choices: Vec<Choice>,
64  #[serde(default)]
65  pub timings: Option<Timings>, // optional, in case it's not always present
66}
67
68#[derive(Debug, Deserialize)]
69#[allow(dead_code)]
70struct Choice
71{
72  pub index: u32,
73  #[serde(default)]
74  pub finish_reason: Option<String>,
75  #[serde(default)]
76  pub delta: Delta,
77}
78
79#[derive(Debug, Deserialize, Default)]
80#[allow(dead_code)]
81struct Delta
82{
83  #[serde(default)]
84  pub role: Option<String>,
85  #[serde(default)]
86  pub content: Option<String>,
87}
88
89#[derive(Debug, Deserialize)]
90#[allow(dead_code)]
91struct Timings
92{
93  pub prompt_n: Option<i32>,
94  pub prompt_ms: Option<f64>,
95  pub prompt_per_token_ms: Option<f64>,
96  pub prompt_per_second: Option<f64>,
97  pub predicted_n: Option<i32>,
98  pub predicted_ms: Option<f64>,
99  pub predicted_per_token_ms: Option<f64>,
100  pub predicted_per_second: Option<f64>,
101}
102
103#[derive(Debug, Deserialize, Serialize)]
104#[serde(untagged)]
105#[allow(clippy::large_enum_variant)]
106enum StreamFrame
107{
108  Delta(GenerationDelta),
109  Final(Final),
110}
111
112#[derive(Debug, Deserialize, Serialize)]
113struct GenerationDelta
114{
115  pub content: String,
116  pub stop: bool,
117  #[serde(skip_serializing_if = "Option::is_none")]
118  pub oaicompat_msg_diffs: Option<Vec<OaiDelta>>,
119}
120
121#[derive(Debug, Deserialize, Serialize)]
122struct OaiDelta
123{
124  pub content_delta: String,
125}
126
127#[derive(Debug, Deserialize, Serialize)]
128struct Final
129{
130  pub content: String,
131  pub generated_text: String,
132  pub stop: bool,
133  pub model: String,
134  pub tokens_predicted: u64,
135  pub tokens_evaluated: u64,
136  pub generation_settings: serde_json::Value,
137  pub prompt: String,
138  pub truncated: bool,
139  pub stopped_eos: bool,
140  pub stopped_word: bool,
141  pub stopped_limit: bool,
142  pub tokens_cached: u64,
143  pub timings: serde_json::Value,
144}
145
146#[derive(Debug, Deserialize, Serialize)]
147pub(crate) struct ApiError
148{
149  pub code: Option<u32>,
150  pub message: String,
151  #[serde(rename = "type")]
152  pub typ: String,
153}
154
155impl<RT> SimpleApi<RT>
156where
157  RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
158{
159  /// Create a new SimpleApi object that will query a LLM end-point at baseuri (e.g. http://localhost)
160  /// om the given port (e.g. 8080), using the optional model.
161  pub fn new(
162    runtime: RT,
163    baseuri: impl Into<String>,
164    port: u16,
165    model: Option<String>,
166  ) -> Result<Self>
167  {
168    Ok(Self {
169      baseuri: baseuri.into(),
170      port,
171      model,
172      runtime,
173    })
174  }
175}
176
177trait Data
178{
179  fn content(&self) -> Option<&String>;
180  fn is_finished(&self) -> bool;
181}
182
183impl Data for ChatChunk
184{
185  fn content(&self) -> Option<&String>
186  {
187    self.choices.first().and_then(|c| c.delta.content.as_ref())
188  }
189  fn is_finished(&self) -> bool
190  {
191    if let Some(reason) = self
192      .choices
193      .first()
194      .and_then(|c| c.finish_reason.as_deref())
195    {
196      if reason == "stop"
197      {
198        return true;
199      }
200    }
201    false
202  }
203}
204
205impl Data for StreamFrame
206{
207  fn content(&self) -> Option<&String>
208  {
209    match self
210    {
211      StreamFrame::Delta(delta) => Some(&delta.content),
212      StreamFrame::Final(_) => None,
213    }
214  }
215  fn is_finished(&self) -> bool
216  {
217    match self
218    {
219      Self::Delta(_) => false,
220      Self::Final(_) => true,
221    }
222  }
223}
224
225fn response_to_stream<D: Data + for<'de> Deserialize<'de>>(
226  response: impl yaaral::http::Response,
227) -> Result<StringStream>
228{
229  let stream = response.into_stream().map(|chunk_result| {
230    let mut results = vec![];
231    match chunk_result
232    {
233      Ok(chunk) =>
234      {
235        let chunk_str = String::from_utf8_lossy(&chunk);
236        for line in chunk_str.lines()
237        {
238          let line = line.trim();
239          if line.starts_with("data:")
240          {
241            let json_str = line.trim_start_matches("data:");
242            match serde_json::from_str::<D>(json_str)
243            {
244              Ok(chunk) =>
245              {
246                if let Some(content) = chunk.content()
247                {
248                  results.push(Ok(content.to_owned()));
249                }
250
251                if chunk.is_finished()
252                {
253                  break;
254                }
255              }
256              Err(e) => results.push(Err(e.into())),
257            }
258          }
259          else if line.starts_with("error:")
260          {
261            let json_str = line.trim_start_matches("error:");
262            if let Ok(chunk) = serde_json::from_str::<ApiError>(json_str)
263            {
264              results.push(Err(Error::SimpleApiError {
265                code: chunk.code.unwrap_or_default(),
266                message: chunk.message,
267                error_type: chunk.typ,
268              }));
269            }
270          }
271          else if !line.is_empty()
272          {
273            log::error!("Unhandled line: {}.", line);
274          }
275        }
276      }
277      Err(e) =>
278      {
279        results.push(Err(Error::HttpError(format!("{:?}", e))));
280      }
281    }
282    futures::stream::iter(results)
283  });
284
285  // Flatten nested streams and box it
286  let flat_stream = stream.flatten().boxed();
287
288  Ok(pin_stream(flat_stream))
289}
290
291fn grammar_for(format: crate::Format) -> Option<String>
292{
293  match format
294  {
295    crate::Format::Json => Some(include_str!("../data/grammar/json.gbnf").to_string()),
296    crate::Format::Text => None,
297  }
298}
299
300impl<RT> LargeLanguageModel for SimpleApi<RT>
301where
302  RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
303{
304  fn chat_stream(
305    &self,
306    prompt: ChatPrompt,
307  ) -> Result<impl Future<Output = Result<StringStream>> + Send>
308  {
309    let url = format!("{}:{}/v1/chat/completions", self.baseuri, self.port);
310
311    let messages = prompt
312      .messages
313      .into_iter()
314      .map(|m| Message {
315        role: match m.role
316        {
317          Role::User => "user".to_string(),
318          Role::System => "system".to_string(),
319          Role::Assistant => "assistant".to_string(),
320          Role::Custom(custom) => custom,
321        },
322        content: m.content,
323      })
324      .collect();
325
326    let request_body = ChatRequestBody {
327      model: self.model.to_owned(),
328      messages,
329      temperature: 0.7,
330      max_tokens: 2560,
331      stream: true,
332      grammar: grammar_for(prompt.options.format),
333    };
334
335    let rt = self.runtime.clone();
336
337    let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;
338
339    Ok(async move {
340      let response = rt.wpost(yreq).await;
341
342      if !response.status().is_success()
343      {
344        return Err(Error::HttpError(format!(
345          "Error code {}",
346          response.status()
347        )));
348      }
349
350      response_to_stream::<ChatChunk>(response)
351    })
352  }
353  fn generate_stream(
354    &self,
355    prompt: GenerationPrompt,
356  ) -> Result<impl Future<Output = Result<StringStream>> + Send>
357  {
358    let rt = self.runtime.clone();
359    Ok(async move {
360      if prompt.system.is_none() && prompt.assistant.is_none()
361      {
362        let url = format!("{}:{}/v1/completions", self.baseuri, self.port);
363
364        let request_body = GenerationRequestBody {
365          model: self.model.to_owned(),
366          prompt: prompt.user,
367          temperature: 0.7,
368          max_tokens: 2560,
369          stream: true,
370          grammar: grammar_for(prompt.options.format),
371        };
372
373        let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;
374
375        let response = rt.wpost(yreq).await;
376
377        if !response.status().is_success()
378        {
379          return Err(Error::HttpError(format!(
380            "Error code {}",
381            response.status()
382          )));
383        }
384
385        response_to_stream::<StreamFrame>(response)
386      }
387      else
388      {
389        crate::generate_with_chat(self, prompt)?.await
390      }
391    })
392  }
393}