kproc-llm 0.6.1

Knowledge Processing library, using LLMs.
Documentation
//! Module for using with simple API (like llama-cpp-server).

use std::future::Future;

use futures::StreamExt;
use serde::{Deserialize, Serialize};
use yaaral::prelude::*;

use crate::prelude::*;

/// Simple API.
#[derive(Debug)]
pub struct SimpleApi<RT>
where
  RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
{
  runtime: RT,
  baseuri: String,
  port: u16,
  model: Option<String>,
}

#[derive(Debug, Serialize, Deserialize)]
struct Message
{
  role: String,
  content: String,
}

#[derive(Debug, Serialize)]
struct ChatRequestBody
{
  #[serde(skip_serializing_if = "Option::is_none")]
  model: Option<String>,
  messages: Vec<Message>,
  temperature: f32,
  max_tokens: u32,
  stream: bool,
  grammar: Option<String>,
}

#[derive(Debug, Serialize)]
struct GenerationRequestBody
{
  #[serde(skip_serializing_if = "Option::is_none")]
  model: Option<String>,
  prompt: String,
  temperature: f32,
  max_tokens: u32,
  stream: bool,
  grammar: Option<String>,
}

#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct ChatChunk
{
  pub id: String,
  pub object: String,
  pub created: u64,
  pub model: String,
  pub choices: Vec<Choice>,
  #[serde(default)]
  pub timings: Option<Timings>, // optional, in case it's not always present
}

#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct Choice
{
  pub index: u32,
  #[serde(default)]
  pub finish_reason: Option<String>,
  #[serde(default)]
  pub delta: Delta,
}

#[derive(Debug, Deserialize, Default)]
#[allow(dead_code)]
struct Delta
{
  #[serde(default)]
  pub role: Option<String>,
  #[serde(default)]
  pub content: Option<String>,
}

#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct Timings
{
  pub prompt_n: Option<i32>,
  pub prompt_ms: Option<f64>,
  pub prompt_per_token_ms: Option<f64>,
  pub prompt_per_second: Option<f64>,
  pub predicted_n: Option<i32>,
  pub predicted_ms: Option<f64>,
  pub predicted_per_token_ms: Option<f64>,
  pub predicted_per_second: Option<f64>,
}

#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
#[allow(clippy::large_enum_variant)]
enum StreamFrame
{
  Delta(GenerationDelta),
  Final(Final),
}

#[derive(Debug, Deserialize, Serialize)]
struct GenerationDelta
{
  pub content: String,
  pub stop: bool,
  #[serde(skip_serializing_if = "Option::is_none")]
  pub oaicompat_msg_diffs: Option<Vec<OaiDelta>>,
}

#[derive(Debug, Deserialize, Serialize)]
struct OaiDelta
{
  pub content_delta: String,
}

#[derive(Debug, Deserialize, Serialize)]
struct Final
{
  pub content: String,
  pub generated_text: String,
  pub stop: bool,
  pub model: String,
  pub tokens_predicted: u64,
  pub tokens_evaluated: u64,
  pub generation_settings: serde_json::Value,
  pub prompt: String,
  pub truncated: bool,
  pub stopped_eos: bool,
  pub stopped_word: bool,
  pub stopped_limit: bool,
  pub tokens_cached: u64,
  pub timings: serde_json::Value,
}

#[derive(Debug, Deserialize, Serialize)]
pub(crate) struct ApiError
{
  pub code: Option<u32>,
  pub message: String,
  #[serde(rename = "type")]
  pub typ: String,
}

impl<RT> SimpleApi<RT>
where
  RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
{
  /// Create a new SimpleApi object that will query a LLM end-point at baseuri (e.g. http://localhost)
  /// om the given port (e.g. 8080), using the optional model.
  pub fn new(
    runtime: RT,
    baseuri: impl Into<String>,
    port: u16,
    model: Option<String>,
  ) -> Result<Self>
  {
    Ok(Self {
      baseuri: baseuri.into(),
      port,
      model,
      runtime,
    })
  }
}

trait Data
{
  fn content(&self) -> Option<&String>;
  fn is_finished(&self) -> bool;
}

impl Data for ChatChunk
{
  fn content(&self) -> Option<&String>
  {
    self.choices.first().and_then(|c| c.delta.content.as_ref())
  }
  fn is_finished(&self) -> bool
  {
    if let Some(reason) = self
      .choices
      .first()
      .and_then(|c| c.finish_reason.as_deref())
    {
      if reason == "stop"
      {
        return true;
      }
    }
    false
  }
}

impl Data for StreamFrame
{
  fn content(&self) -> Option<&String>
  {
    match self
    {
      StreamFrame::Delta(delta) => Some(&delta.content),
      StreamFrame::Final(_) => None,
    }
  }
  fn is_finished(&self) -> bool
  {
    match self
    {
      Self::Delta(_) => false,
      Self::Final(_) => true,
    }
  }
}

fn response_to_stream<D: Data + for<'de> Deserialize<'de>>(
  response: impl yaaral::http::Response,
) -> Result<StringStream>
{
  let stream = response.into_stream().map(|chunk_result| {
    let mut results = vec![];
    match chunk_result
    {
      Ok(chunk) =>
      {
        let chunk_str = String::from_utf8_lossy(&chunk);
        for line in chunk_str.lines()
        {
          let line = line.trim();
          if line.starts_with("data:")
          {
            let json_str = line.trim_start_matches("data:");
            match serde_json::from_str::<D>(json_str)
            {
              Ok(chunk) =>
              {
                if let Some(content) = chunk.content()
                {
                  results.push(Ok(content.to_owned()));
                }

                if chunk.is_finished()
                {
                  break;
                }
              }
              Err(e) => results.push(Err(e.into())),
            }
          }
          else if line.starts_with("error:")
          {
            let json_str = line.trim_start_matches("error:");
            if let Ok(chunk) = serde_json::from_str::<ApiError>(json_str)
            {
              results.push(Err(Error::SimpleApiError {
                code: chunk.code.unwrap_or_default(),
                message: chunk.message,
                error_type: chunk.typ,
              }));
            }
          }
          else if !line.is_empty()
          {
            log::error!("Unhandled line: {}.", line);
          }
        }
      }
      Err(e) =>
      {
        results.push(Err(Error::HttpError(format!("{:?}", e))));
      }
    }
    futures::stream::iter(results)
  });

  // Flatten nested streams and box it
  let flat_stream = stream.flatten().boxed();

  Ok(pin_stream(flat_stream))
}

fn grammar_for(format: crate::Format) -> Option<String>
{
  match format
  {
    crate::Format::Json => Some(include_str!("../data/grammar/json.gbnf").to_string()),
    crate::Format::Text => None,
  }
}

impl<RT> LargeLanguageModel for SimpleApi<RT>
where
  RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
{
  fn chat_stream(
    &self,
    prompt: ChatPrompt,
  ) -> Result<impl Future<Output = Result<StringStream>> + Send>
  {
    let url = format!("{}:{}/v1/chat/completions", self.baseuri, self.port);

    let messages = prompt
      .messages
      .into_iter()
      .map(|m| Message {
        role: match m.role
        {
          Role::User => "user".to_string(),
          Role::System => "system".to_string(),
          Role::Assistant => "assistant".to_string(),
          Role::Custom(custom) => custom,
        },
        content: m.content,
      })
      .collect();

    let request_body = ChatRequestBody {
      model: self.model.to_owned(),
      messages,
      temperature: 0.7,
      max_tokens: 2560,
      stream: true,
      grammar: grammar_for(prompt.format),
    };

    let rt = self.runtime.clone();

    let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;

    Ok(async move {
      let response = rt.wpost(yreq).await;

      if !response.status().is_success()
      {
        return Err(Error::HttpError(format!(
          "Error code {}",
          response.status()
        )));
      }

      response_to_stream::<ChatChunk>(response)
    })
  }
  fn generate_stream(
    &self,
    prompt: GenerationPrompt,
  ) -> Result<impl Future<Output = Result<StringStream>> + Send>
  {
    let rt = self.runtime.clone();
    Ok(async move {
      if prompt.system.is_none() && prompt.assistant.is_none()
      {
        let url = format!("{}:{}/v1/completions", self.baseuri, self.port);

        let request_body = GenerationRequestBody {
          model: self.model.to_owned(),
          prompt: prompt.user,
          temperature: 0.7,
          max_tokens: 2560,
          stream: true,
          grammar: grammar_for(prompt.format),
        };

        let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;

        let response = rt.wpost(yreq).await;

        if !response.status().is_success()
        {
          return Err(Error::HttpError(format!(
            "Error code {}",
            response.status()
          )));
        }

        response_to_stream::<StreamFrame>(response)
      }
      else
      {
        crate::generate_with_chat(self, prompt)?.await
      }
    })
  }
}