kproc-llm 0.2.1

Knowledge Processing library, using LLMs.
Documentation
//! Module for using `llama.cpp`.

use futures::{channel::mpsc, SinkExt, StreamExt};
use llama_cpp_2::{
  context::params::LlamaContextParams, llama_backend::LlamaBackend, llama_batch::LlamaBatch,
  model::LlamaModel, sampling::LlamaSampler,
};
use std::{future::Future, num::NonZeroU32};

use crate::{prelude::*, template};

/// Interface to run model using `llama.cpp`.
#[derive(Debug)]
pub struct LlamaCpp
{
  server_request_tx: mpsc::Sender<(Prompt, mpsc::Sender<Result<String>>)>,
}

/// Parameters for the `llama.cpp` model
pub struct LlamaCppParameters
{
  #[cfg_attr(not(feature = "gpu"), allow(dead_code))]
  disable_gpu: bool,
  #[cfg_attr(not(feature = "gpu"), allow(dead_code))]
  n_gpu_layers: u32,
  context_size: u32,
  seed: u32,
  model_uri: String,
  template: String,
  threads: Option<i32>,
  threads_batch: Option<i32>,
}

impl LlamaCppParameters
{
  /// Set the number of layers offloaded to the GPU.
  pub fn n_gpu_layers(mut self, n_gpu_layers: u32) -> LlamaCppParameters
  {
    self.n_gpu_layers = n_gpu_layers;
    self
  }
}

impl Default for LlamaCppParameters
{
  fn default() -> Self
  {
    Self {
      disable_gpu: false,
      n_gpu_layers: 20,
      context_size: 2048,
      seed: 1234,
      model_uri:
        "hf://ggml-org/Meta-Llama-3.1-8B-Instruct-Q4_0-GGUF/meta-llama-3.1-8b-instruct-q4_0.gguf"
          .into(),
      template: include_str!("../data/templates/llama").into(),
      threads: None,
      threads_batch: None,
    }
  }
}

impl LlamaCpp
{
  /// Create a new `llama.cpp` model from the given uri, using default parameters
  ///
  /// Uri starting with `hf://` are downloaded from hugging face in a local cache.
  pub fn from_model(model_uri: impl Into<String>) -> Result<Self>
  {
    let model_uri = model_uri.into();
    Self::new(LlamaCppParameters {
      model_uri,
      ..Default::default()
    })
  }
  /// Create a new `llama.cpp` model from the given parameters.
  pub fn new(params: LlamaCppParameters) -> Result<Self>
  {
    let backend = LlamaBackend::init()?;

    let model_params = {
      #[cfg(feature = "gpu")]
      if !params.disable_gpu
      {
        llama_cpp_2::model::params::LlamaModelParams::default()
          .with_n_gpu_layers(params.n_gpu_layers)
      }
      else
      {
        llama_cpp_2::model::params::LlamaModelParams::default()
      }
      #[cfg(not(feature = "gpu"))]
      llama_cpp_2::model::params::LlamaModelParams::default()
    };
    let mut ctx_params =
      LlamaContextParams::default().with_n_ctx(NonZeroU32::new(params.context_size));

    if let Some(threads) = params.threads
    {
      ctx_params = ctx_params.with_n_threads(threads);
    }
    if let Some(threads_batch) = params.threads_batch.or(params.threads)
    {
      ctx_params = ctx_params.with_n_threads_batch(threads_batch);
    }

    // Get the path to the model
    let model_path = if params.model_uri.starts_with("hf://")
    {
      let last_slash = params.model_uri[5..]
        .rfind("/")
        .ok_or(Error::HfInvalidUri)?
        + 5;
      let model_repo = &params.model_uri[5..last_slash];
      let model_name = &params.model_uri[(last_slash + 1)..];
      hf_hub::api::sync::ApiBuilder::new()
        .with_progress(true)
        .build()?
        .model(model_repo.to_string())
        .get(&model_name)?
    }
    else
    {
      params.model_uri.to_owned().into()
    };

    let model = LlamaModel::load_from_file(&backend, model_path, &model_params)?;
    let template = template::Template::new(params.template)?;

    let (server_request_tx, mut server_request_rx) =
      mpsc::channel::<(Prompt, mpsc::Sender<Result<String>>)>(10);

    std::thread::spawn(move || {
      futures::executor::block_on(async {
        while let Some((prompt, mut answer_tx)) = server_request_rx.next().await
        {
          let r = async {
            let mut ctx = model.new_context(&backend, ctx_params.clone())?;
            let prompt_str = template.render(
              Some(&prompt.prompt),
              prompt.assistant.as_ref().map(|x| x.as_str()),
              prompt.system.as_ref().map(|x| x.as_str()),
            )?;
            let tokens_list =
              model.str_to_token(&prompt_str, llama_cpp_2::model::AddBos::Always)?;

            // create a llama_batch
            // we use this object to submit token data for decoding
            let mut batch = LlamaBatch::new(ctx_params.n_batch() as usize, 1);

            let last_index: i32 = (tokens_list.len() - 1) as i32;
            for (i, token) in (0_i32..).zip(tokens_list.into_iter())
            {
              // llama_decode will output logits only for the last token of the prompt
              let is_last = i == last_index;
              batch.add(token, i, &[0], is_last)?;
            }

            ctx.decode(&mut batch)?;

            // main loop

            let mut n_cur = batch.n_tokens();

            // The `Decoder`
            let mut decoder = encoding_rs::UTF_8.new_decoder();

            let mut sampler = match prompt.format
            {
              Format::Text => LlamaSampler::chain_simple([
                LlamaSampler::dist(params.seed),
                LlamaSampler::greedy(),
              ]),
              Format::Json => LlamaSampler::chain_simple([
                LlamaSampler::grammar(&model, include_str!("../data/grammar/json.gbnf"), "root"),
                LlamaSampler::min_p(0.05, 1),
                LlamaSampler::temp(0.8),
                LlamaSampler::dist(params.seed),
              ]),
            };

            loop
            {
              // sample the next token

              let token = sampler.sample(&ctx, batch.n_tokens() - 1);

              sampler.accept(token);

              // is it an end of stream?
              if model.is_eog_token(token)
              {
                return Ok(());
              }

              let output_bytes =
                model.token_to_bytes(token, llama_cpp_2::model::Special::Tokenize)?;
              // use `Decoder.decode_to_string()` to avoid the intermediate buffer
              let mut output_string = String::with_capacity(32);
              let _decode_result =
                decoder.decode_to_string(&output_bytes, &mut output_string, false);

              answer_tx.send(Ok(output_string)).await?;

              batch.clear();
              batch.add(token, n_cur, &[0], true)?;

              n_cur += 1;

              ctx.decode(&mut batch)?;
            }
          }
          .await;
          if let Err(e) = r
          {
            let tx_r = answer_tx.send(Err(e)).await;
            if let Err(e) = tx_r
            {
              log::error!("Failed to send error, with error {:?}", e);
            }
          }
        }
      });
    });

    Ok(Self { server_request_tx })
  }
}

impl Default for LlamaCpp
{
  fn default() -> Self
  {
    Self::new(Default::default()).unwrap()
  }
}

impl LargeLanguageModel for LlamaCpp
{
  fn infer_stream(
    &self,
    prompt: Prompt,
  ) -> Result<impl Future<Output = Result<StringStream>> + Send>
  {
    Ok(async {
      let (tx, rx) = mpsc::channel(20);
      self.server_request_tx.clone().send((prompt, tx)).await?;
      Ok(pin_stream(rx))
    })
  }
}