kproc-llm 0.6.1

Knowledge Processing library, using LLMs.
Documentation
//! Interface with candle llm

use std::{future::Future, sync::Arc};

use async_stream::try_stream;
use candle_core::{
  quantized::{ggml_file, gguf_file},
  Device, Tensor,
};
use candle_transformers::{
  generation::{LogitsProcessor, Sampling},
  models::quantized_llama as model,
};
use model::ModelWeights;
use smart_default::SmartDefault as Default;
use tokenizers::Tokenizer;

use crate::{generate_with_chat, prelude::*};

pub mod factory;

fn create_llama_template() -> template::Template
{
  template::Template::new(include_str!("../data/templates/llama")).unwrap()
}

#[derive(Debug, Default, Clone)]
struct Params
{
  /// The temperature used to generate samples, use 0 for greedy sampling.
  #[default(0.8)]
  temperature: f64,

  /// Nucleus sampling probability cutoff.
  top_p: Option<f64>,

  /// Only sample among the top K samples.
  top_k: Option<usize>,

  /// The seed to use when generating random samples.
  #[default(299792458)]
  seed: u64,

  /// Penalty to be applied for repeating tokens, 1. means no penalty.
  #[default(1.1)]
  repeat_penalty: f32,

  /// The context size to consider for the repeat penalty.
  #[default(64)]
  repeat_last_n: usize,
}

/// Builder for configuring candle interface
#[derive(Debug, Default)]
pub struct Builder
{
  model_path: Option<String>,
  repo: Option<String>,
  model: Option<String>,
  #[default("main".into())]
  revision: String,
  tokenizer_path: Option<String>,
  tokenizer_repo: String,

  end_of_stream: String,

  #[default(create_llama_template())]
  template: template::Template,

  params: Params,

  /// Run on CPU rather than GPU even if a GPU is available.
  #[default(true)]
  cpu: bool,

  /// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
  #[default(1)]
  gqa: usize,
}

fn format_size(size_in_bytes: usize) -> String
{
  if size_in_bytes < 1_000
  {
    format!("{size_in_bytes}B")
  }
  else if size_in_bytes < 1_000_000
  {
    format!("{:.2}KB", size_in_bytes as f64 / 1e3)
  }
  else if size_in_bytes < 1_000_000_000
  {
    format!("{:.2}MB", size_in_bytes as f64 / 1e6)
  }
  else
  {
    format!("{:.2}GB", size_in_bytes as f64 / 1e9)
  }
}

fn device(cpu: bool) -> Result<Device>
{
  if cpu
  {
    Ok(Device::Cpu)
  }
  else if candle_core::utils::cuda_is_available()
  {
    Ok(Device::new_cuda(0)?)
  }
  else if candle_core::utils::metal_is_available()
  {
    Ok(Device::new_metal(0)?)
  }
  else
  {
    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
    {
      log::warn!(
        "Running on CPU, to run on GPU(metal), build this example with `--features metal`"
      );
    }
    Ok(Device::Cpu)
  }
}
impl Builder
{
  /// Set the model
  pub fn model(mut self, repo: impl Into<String>, model: impl Into<String>) -> Self
  {
    self.repo = Some(repo.into());
    self.model = Some(model.into());
    self
  }
  /// Set the revision used for the model
  pub fn revision(mut self, revision: impl Into<String>) -> Self
  {
    self.revision = revision.into();
    self
  }
  /// Set the tokenizer_repo
  pub fn tokenizer_repo(mut self, tokenizer_repo: impl Into<String>) -> Self
  {
    self.tokenizer_repo = tokenizer_repo.into();
    self
  }
  /// Set the token used for end of stream
  pub fn end_of_stream(mut self, end_of_stream: impl Into<String>) -> Self
  {
    self.end_of_stream = end_of_stream.into();
    self
  }
  /// Set the template
  pub fn template(mut self, template: impl Into<template::Template>) -> Self
  {
    self.template = template.into();
    self
  }
  /// Build the candle interface
  pub async fn build(self) -> Result<Candle>
  {
    let tokenizer_path = match self.tokenizer_path
    {
      Some(tokenizer_path) => std::path::PathBuf::from(tokenizer_path),
      None =>
      {
        let api = hf_hub::api::tokio::Api::new()?;
        let api = api.model(self.tokenizer_repo.clone());
        api.get("tokenizer.json").await?
      }
    };
    let tokenizer = Tokenizer::from_file(tokenizer_path)?;

    let model_path = match self.model_path
    {
      Some(model_path) => std::path::PathBuf::from(model_path),
      None => match (self.repo, self.model)
      {
        (Some(repo), Some(model)) =>
        {
          let api = hf_hub::api::tokio::Api::new()?;
          api
            .repo(hf_hub::Repo::with_revision(
              repo.to_string(),
              hf_hub::RepoType::Model,
              self.revision,
            ))
            .get(&model)
            .await?
        }
        _ => Err(Error::UndefinedModel)?,
      },
    };

    let device = device(self.cpu)?;
    let mut file = std::fs::File::open(&model_path)?;
    let start = std::time::Instant::now();

    let model_weights = match model_path.extension().and_then(|v| v.to_str())
    {
      Some("gguf") =>
      {
        let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
        let mut total_size_in_bytes = 0;
        for (_, tensor) in model.tensor_infos.iter()
        {
          let elem_count = tensor.shape.elem_count();
          total_size_in_bytes +=
            elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
        }
        log::info!(
          "loaded {:?} tensors ({}) in {:.2}s",
          model.tensor_infos.len(),
          &format_size(total_size_in_bytes),
          start.elapsed().as_secs_f32(),
        );
        ModelWeights::from_gguf(model, &mut file, &device)?
      }
      Some("ggml" | "bin") | Some(_) | None =>
      {
        let model =
          ggml_file::Content::read(&mut file, &device).map_err(|e| e.with_path(model_path))?;
        let mut total_size_in_bytes = 0;
        for (_, tensor) in model.tensors.iter()
        {
          let elem_count = tensor.shape().elem_count();
          total_size_in_bytes +=
            elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
        }
        log::info!(
          "loaded {:?} tensors ({}) in {:.2}s",
          model.tensors.len(),
          &format_size(total_size_in_bytes),
          start.elapsed().as_secs_f32(),
        );
        log::info!("params: {:?}", model.hparams);
        ModelWeights::from_ggml(model, self.gqa)?
      }
    };
    let eos_token = *tokenizer
      .get_vocab(true)
      .get(&self.end_of_stream)
      .ok_or_else(|| Error::UnknownEndOfStream(self.end_of_stream.to_string()))?;

    Ok(Candle {
      model_weights: model_weights.into(),
      tokenizer: tokenizer.into(),
      template: self.template,
      params: self.params,
      eos_token,
      device,
    })
  }
}

/// Interface to candle
pub struct Candle
{
  model_weights: ccutils::futures::ArcMutex<ModelWeights>,
  tokenizer: Arc<tokenizers::Tokenizer>,
  template: template::Template,

  params: Params,

  eos_token: u32,

  device: Device,
}

impl Candle
{
  /// Instantiate a `llama` model
  pub fn build() -> Builder
  {
    Builder::default()
  }
}

impl LargeLanguageModel for Candle
{
  fn chat_stream(
    &self,
    prompt: ChatPrompt,
  ) -> Result<impl Future<Output = Result<StringStream>> + Send>
  {
    let prompt_str = self.template.render(&prompt.messages)?;

    let device = self.device.clone();
    let model_weights = self.model_weights.clone();
    let tokenizer = self.tokenizer.clone();
    let params = self.params.clone();
    let eos_token = self.eos_token;

    Ok(Box::pin(async move {
      // Encode prompt
      let prompt_tokens_encoded = tokenizer.encode(prompt_str, true)?;
      let prompt_tokens = prompt_tokens_encoded.get_ids().to_vec();
      let mut all_tokens = prompt_tokens.clone().to_vec();

      // Build logits processor
      let mut logits_processor = {
        let temperature = params.temperature;
        let sampling = if temperature <= 0.0
        {
          Sampling::ArgMax
        }
        else
        {
          match (params.top_k, params.top_p)
          {
            (None, None) => Sampling::All { temperature },
            (Some(k), None) => Sampling::TopK { k, temperature },
            (None, Some(p)) => Sampling::TopP { p, temperature },
            (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
          }
        };
        LogitsProcessor::from_sampling(params.seed, sampling)
      };

      let prompt_len = prompt_tokens.len();
      let device_cl = device.clone();
      let model_cl = model_weights.clone();

      let stream = try_stream! {
          let mut tokenizer_output_stream = tokenizer.decode_stream(false);

          let mut next_token = 0;
          for (pos, token) in prompt_tokens.iter().enumerate() {
              let input = Tensor::new(&[*token], &device_cl)?.unsqueeze(0)?;
              let logits = model_cl.lock().await.forward(&input, pos)?;
              let logits = logits.squeeze(0)?;

              next_token = logits_processor.sample(&logits)?;
          }

          let mut index = 0;

          loop {
              if next_token == eos_token {
                  break;
              }

              all_tokens.push(next_token);

              // Try to convert token to text and yield it
              if let Some(fragment) = tokenizer_output_stream
                  .step(next_token)?
              {
                  yield fragment;
              }

              let input = Tensor::new(&[next_token], &device_cl)?.unsqueeze(0)?;
              let logits = model_cl
                  .lock().await
                  .forward(&input, prompt_len + index)?;
              let logits = logits.squeeze(0)?;

              if params.repeat_penalty != 1.0 {
                  let start_at = all_tokens.len()
                      .saturating_sub(params.repeat_last_n);

                  candle_transformers::utils::apply_repeat_penalty(
                      &logits,
                      params.repeat_penalty,
                      &all_tokens[start_at..],
                  )?;
              }

              next_token = logits_processor.sample(&logits)?;
              index += 1;
          }
      };

      Ok(Box::pin(stream) as StringStream)
    }))
  }
  fn generate_stream(
    &self,
    prompt: GenerationPrompt,
  ) -> Result<impl std::prelude::rust_2024::Future<Output = Result<StringStream>> + Send>
  {
    generate_with_chat(self, prompt)
  }
}