use std::{future::Future, sync::Arc};
use async_stream::try_stream;
use candle_core::{
quantized::{ggml_file, gguf_file},
Device, Tensor,
};
use candle_transformers::{
generation::{LogitsProcessor, Sampling},
models::quantized_llama as model,
};
use model::ModelWeights;
use smart_default::SmartDefault as Default;
use tokenizers::Tokenizer;
use crate::{generate_with_chat, prelude::*};
pub mod factory;
fn create_llama_template() -> template::Template
{
template::Template::new(include_str!("../data/templates/llama")).unwrap()
}
#[derive(Debug, Default, Clone)]
struct Params
{
#[default(0.8)]
temperature: f64,
top_p: Option<f64>,
top_k: Option<usize>,
#[default(299792458)]
seed: u64,
#[default(1.1)]
repeat_penalty: f32,
#[default(64)]
repeat_last_n: usize,
}
#[derive(Debug, Default)]
pub struct Builder
{
model_path: Option<String>,
repo: Option<String>,
model: Option<String>,
#[default("main".into())]
revision: String,
tokenizer_path: Option<String>,
tokenizer_repo: String,
end_of_stream: String,
#[default(create_llama_template())]
template: template::Template,
params: Params,
#[default(true)]
cpu: bool,
#[default(1)]
gqa: usize,
}
fn format_size(size_in_bytes: usize) -> String
{
if size_in_bytes < 1_000
{
format!("{size_in_bytes}B")
}
else if size_in_bytes < 1_000_000
{
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
}
else if size_in_bytes < 1_000_000_000
{
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
}
else
{
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}
fn device(cpu: bool) -> Result<Device>
{
if cpu
{
Ok(Device::Cpu)
}
else if candle_core::utils::cuda_is_available()
{
Ok(Device::new_cuda(0)?)
}
else if candle_core::utils::metal_is_available()
{
Ok(Device::new_metal(0)?)
}
else
{
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
log::warn!(
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
);
}
Ok(Device::Cpu)
}
}
impl Builder
{
pub fn model(mut self, repo: impl Into<String>, model: impl Into<String>) -> Self
{
self.repo = Some(repo.into());
self.model = Some(model.into());
self
}
pub fn revision(mut self, revision: impl Into<String>) -> Self
{
self.revision = revision.into();
self
}
pub fn tokenizer_repo(mut self, tokenizer_repo: impl Into<String>) -> Self
{
self.tokenizer_repo = tokenizer_repo.into();
self
}
pub fn end_of_stream(mut self, end_of_stream: impl Into<String>) -> Self
{
self.end_of_stream = end_of_stream.into();
self
}
pub fn template(mut self, template: impl Into<template::Template>) -> Self
{
self.template = template.into();
self
}
pub async fn build(self) -> Result<Candle>
{
let tokenizer_path = match self.tokenizer_path
{
Some(tokenizer_path) => std::path::PathBuf::from(tokenizer_path),
None =>
{
let api = hf_hub::api::tokio::Api::new()?;
let api = api.model(self.tokenizer_repo.clone());
api.get("tokenizer.json").await?
}
};
let tokenizer = Tokenizer::from_file(tokenizer_path)?;
let model_path = match self.model_path
{
Some(model_path) => std::path::PathBuf::from(model_path),
None => match (self.repo, self.model)
{
(Some(repo), Some(model)) =>
{
let api = hf_hub::api::tokio::Api::new()?;
api
.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
self.revision,
))
.get(&model)
.await?
}
_ => Err(Error::UndefinedModel)?,
},
};
let device = device(self.cpu)?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
let model_weights = match model_path.extension().and_then(|v| v.to_str())
{
Some("gguf") =>
{
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter()
{
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
log::info!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file, &device)?
}
Some("ggml" | "bin") | Some(_) | None =>
{
let model =
ggml_file::Content::read(&mut file, &device).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter()
{
let elem_count = tensor.shape().elem_count();
total_size_in_bytes +=
elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
}
log::info!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensors.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
log::info!("params: {:?}", model.hparams);
ModelWeights::from_ggml(model, self.gqa)?
}
};
let eos_token = *tokenizer
.get_vocab(true)
.get(&self.end_of_stream)
.ok_or_else(|| Error::UnknownEndOfStream(self.end_of_stream.to_string()))?;
Ok(Candle {
model_weights: model_weights.into(),
tokenizer: tokenizer.into(),
template: self.template,
params: self.params,
eos_token,
device,
})
}
}
pub struct Candle
{
model_weights: ccutils::futures::ArcMutex<ModelWeights>,
tokenizer: Arc<tokenizers::Tokenizer>,
template: template::Template,
params: Params,
eos_token: u32,
device: Device,
}
impl Candle
{
pub fn build() -> Builder
{
Builder::default()
}
}
impl LargeLanguageModel for Candle
{
fn chat_stream(
&self,
prompt: ChatPrompt,
) -> Result<impl Future<Output = Result<StringStream>> + Send>
{
let prompt_str = self.template.render(&prompt.messages)?;
let device = self.device.clone();
let model_weights = self.model_weights.clone();
let tokenizer = self.tokenizer.clone();
let params = self.params.clone();
let eos_token = self.eos_token;
Ok(Box::pin(async move {
let prompt_tokens_encoded = tokenizer.encode(prompt_str, true)?;
let prompt_tokens = prompt_tokens_encoded.get_ids().to_vec();
let mut all_tokens = prompt_tokens.clone().to_vec();
let mut logits_processor = {
let temperature = params.temperature;
let sampling = if temperature <= 0.0
{
Sampling::ArgMax
}
else
{
match (params.top_k, params.top_p)
{
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(params.seed, sampling)
};
let prompt_len = prompt_tokens.len();
let device_cl = device.clone();
let model_cl = model_weights.clone();
let stream = try_stream! {
let mut tokenizer_output_stream = tokenizer.decode_stream(false);
let mut next_token = 0;
for (pos, token) in prompt_tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device_cl)?.unsqueeze(0)?;
let logits = model_cl.lock().await.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?;
}
let mut index = 0;
loop {
if next_token == eos_token {
break;
}
all_tokens.push(next_token);
if let Some(fragment) = tokenizer_output_stream
.step(next_token)?
{
yield fragment;
}
let input = Tensor::new(&[next_token], &device_cl)?.unsqueeze(0)?;
let logits = model_cl
.lock().await
.forward(&input, prompt_len + index)?;
let logits = logits.squeeze(0)?;
if params.repeat_penalty != 1.0 {
let start_at = all_tokens.len()
.saturating_sub(params.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
params.repeat_penalty,
&all_tokens[start_at..],
)?;
}
next_token = logits_processor.sample(&logits)?;
index += 1;
}
};
Ok(Box::pin(stream) as StringStream)
}))
}
fn generate_stream(
&self,
prompt: GenerationPrompt,
) -> Result<impl std::prelude::rust_2024::Future<Output = Result<StringStream>> + Send>
{
generate_with_chat(self, prompt)
}
}