kproc_llm/
candle.rs

1//! Interface with candle llm
2
3use std::{future::Future, sync::Arc};
4
5use async_stream::try_stream;
6use candle_core::{
7  quantized::{ggml_file, gguf_file},
8  Device, Tensor,
9};
10use candle_transformers::{
11  generation::{LogitsProcessor, Sampling},
12  models::quantized_llama as model,
13};
14use model::ModelWeights;
15use smart_default::SmartDefault as Default;
16use tokenizers::Tokenizer;
17
18use crate::{generate_with_chat, prelude::*};
19
20pub mod factory;
21
22fn create_llama_template() -> template::Template
23{
24  template::Template::new(include_str!("../data/templates/llama")).unwrap()
25}
26
27#[derive(Debug, Default, Clone)]
28struct Params
29{
30  /// The temperature used to generate samples, use 0 for greedy sampling.
31  #[default(0.8)]
32  temperature: f64,
33
34  /// Nucleus sampling probability cutoff.
35  top_p: Option<f64>,
36
37  /// Only sample among the top K samples.
38  top_k: Option<usize>,
39
40  /// The seed to use when generating random samples.
41  #[default(299792458)]
42  seed: u64,
43
44  /// Penalty to be applied for repeating tokens, 1. means no penalty.
45  #[default(1.1)]
46  repeat_penalty: f32,
47
48  /// The context size to consider for the repeat penalty.
49  #[default(64)]
50  repeat_last_n: usize,
51}
52
53/// Builder for configuring candle interface
54#[derive(Debug, Default)]
55pub struct Builder
56{
57  model_path: Option<String>,
58  repo: Option<String>,
59  model: Option<String>,
60  #[default("main".into())]
61  revision: String,
62  tokenizer_path: Option<String>,
63  tokenizer_repo: String,
64
65  end_of_stream: String,
66
67  #[default(create_llama_template())]
68  template: template::Template,
69
70  params: Params,
71
72  /// Run on CPU rather than GPU even if a GPU is available.
73  #[default(true)]
74  cpu: bool,
75
76  /// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
77  #[default(1)]
78  gqa: usize,
79}
80
81fn format_size(size_in_bytes: usize) -> String
82{
83  if size_in_bytes < 1_000
84  {
85    format!("{size_in_bytes}B")
86  }
87  else if size_in_bytes < 1_000_000
88  {
89    format!("{:.2}KB", size_in_bytes as f64 / 1e3)
90  }
91  else if size_in_bytes < 1_000_000_000
92  {
93    format!("{:.2}MB", size_in_bytes as f64 / 1e6)
94  }
95  else
96  {
97    format!("{:.2}GB", size_in_bytes as f64 / 1e9)
98  }
99}
100
101fn device(cpu: bool) -> Result<Device>
102{
103  if cpu
104  {
105    Ok(Device::Cpu)
106  }
107  else if candle_core::utils::cuda_is_available()
108  {
109    Ok(Device::new_cuda(0)?)
110  }
111  else if candle_core::utils::metal_is_available()
112  {
113    Ok(Device::new_metal(0)?)
114  }
115  else
116  {
117    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
118    {
119      log::warn!(
120        "Running on CPU, to run on GPU(metal), build this example with `--features metal`"
121      );
122    }
123    Ok(Device::Cpu)
124  }
125}
126impl Builder
127{
128  /// Set the model
129  pub fn model(mut self, repo: impl Into<String>, model: impl Into<String>) -> Self
130  {
131    self.repo = Some(repo.into());
132    self.model = Some(model.into());
133    self
134  }
135  /// Set the revision used for the model
136  pub fn revision(mut self, revision: impl Into<String>) -> Self
137  {
138    self.revision = revision.into();
139    self
140  }
141  /// Set the tokenizer_repo
142  pub fn tokenizer_repo(mut self, tokenizer_repo: impl Into<String>) -> Self
143  {
144    self.tokenizer_repo = tokenizer_repo.into();
145    self
146  }
147  /// Set the token used for end of stream
148  pub fn end_of_stream(mut self, end_of_stream: impl Into<String>) -> Self
149  {
150    self.end_of_stream = end_of_stream.into();
151    self
152  }
153  /// Set the template
154  pub fn template(mut self, template: impl Into<template::Template>) -> Self
155  {
156    self.template = template.into();
157    self
158  }
159  /// Build the candle interface
160  pub async fn build(self) -> Result<Candle>
161  {
162    let tokenizer_path = match self.tokenizer_path
163    {
164      Some(tokenizer_path) => std::path::PathBuf::from(tokenizer_path),
165      None =>
166      {
167        let api = hf_hub::api::tokio::Api::new()?;
168        let api = api.model(self.tokenizer_repo.clone());
169        api.get("tokenizer.json").await?
170      }
171    };
172    let tokenizer = Tokenizer::from_file(tokenizer_path)?;
173
174    let model_path = match self.model_path
175    {
176      Some(model_path) => std::path::PathBuf::from(model_path),
177      None => match (self.repo, self.model)
178      {
179        (Some(repo), Some(model)) =>
180        {
181          let api = hf_hub::api::tokio::Api::new()?;
182          api
183            .repo(hf_hub::Repo::with_revision(
184              repo.to_string(),
185              hf_hub::RepoType::Model,
186              self.revision,
187            ))
188            .get(&model)
189            .await?
190        }
191        _ => Err(Error::UndefinedModel)?,
192      },
193    };
194
195    let device = device(self.cpu)?;
196    let mut file = std::fs::File::open(&model_path)?;
197    let start = std::time::Instant::now();
198
199    let model_weights = match model_path.extension().and_then(|v| v.to_str())
200    {
201      Some("gguf") =>
202      {
203        let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
204        let mut total_size_in_bytes = 0;
205        for (_, tensor) in model.tensor_infos.iter()
206        {
207          let elem_count = tensor.shape.elem_count();
208          total_size_in_bytes +=
209            elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
210        }
211        log::info!(
212          "loaded {:?} tensors ({}) in {:.2}s",
213          model.tensor_infos.len(),
214          &format_size(total_size_in_bytes),
215          start.elapsed().as_secs_f32(),
216        );
217        ModelWeights::from_gguf(model, &mut file, &device)?
218      }
219      Some("ggml" | "bin") | Some(_) | None =>
220      {
221        let model =
222          ggml_file::Content::read(&mut file, &device).map_err(|e| e.with_path(model_path))?;
223        let mut total_size_in_bytes = 0;
224        for (_, tensor) in model.tensors.iter()
225        {
226          let elem_count = tensor.shape().elem_count();
227          total_size_in_bytes +=
228            elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
229        }
230        log::info!(
231          "loaded {:?} tensors ({}) in {:.2}s",
232          model.tensors.len(),
233          &format_size(total_size_in_bytes),
234          start.elapsed().as_secs_f32(),
235        );
236        log::info!("params: {:?}", model.hparams);
237        ModelWeights::from_ggml(model, self.gqa)?
238      }
239    };
240    let eos_token = *tokenizer
241      .get_vocab(true)
242      .get(&self.end_of_stream)
243      .ok_or_else(|| Error::UnknownEndOfStream(self.end_of_stream.to_string()))?;
244
245    Ok(Candle {
246      model_weights: model_weights.into(),
247      tokenizer: tokenizer.into(),
248      template: self.template,
249      params: self.params,
250      eos_token,
251      device,
252    })
253  }
254}
255
256/// Interface to candle
257pub struct Candle
258{
259  model_weights: ccutils::futures::ArcMutex<ModelWeights>,
260  tokenizer: Arc<tokenizers::Tokenizer>,
261  template: template::Template,
262
263  params: Params,
264
265  eos_token: u32,
266
267  device: Device,
268}
269
270impl Candle
271{
272  /// Instantiate a `llama` model
273  pub fn build() -> Builder
274  {
275    Builder::default()
276  }
277}
278
279impl LargeLanguageModel for Candle
280{
281  fn chat_stream(
282    &self,
283    prompt: ChatPrompt,
284  ) -> Result<impl Future<Output = Result<StringStream>> + Send>
285  {
286    let prompt_str = self.template.render(&prompt.messages)?;
287
288    let device = self.device.clone();
289    let model_weights = self.model_weights.clone();
290    let tokenizer = self.tokenizer.clone();
291    let params = self.params.clone();
292    let eos_token = self.eos_token;
293
294    Ok(Box::pin(async move {
295      // Encode prompt
296      let prompt_tokens_encoded = tokenizer.encode(prompt_str, true)?;
297      let prompt_tokens = prompt_tokens_encoded.get_ids().to_vec();
298      let mut all_tokens = prompt_tokens.clone().to_vec();
299
300      // Build logits processor
301      let mut logits_processor = {
302        let temperature = params.temperature;
303        let sampling = if temperature <= 0.0
304        {
305          Sampling::ArgMax
306        }
307        else
308        {
309          match (params.top_k, params.top_p)
310          {
311            (None, None) => Sampling::All { temperature },
312            (Some(k), None) => Sampling::TopK { k, temperature },
313            (None, Some(p)) => Sampling::TopP { p, temperature },
314            (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
315          }
316        };
317        LogitsProcessor::from_sampling(params.seed, sampling)
318      };
319
320      let prompt_len = prompt_tokens.len();
321      let device_cl = device.clone();
322      let model_cl = model_weights.clone();
323
324      let stream = try_stream! {
325          let mut tokenizer_output_stream = tokenizer.decode_stream(false);
326
327          let mut next_token = 0;
328          for (pos, token) in prompt_tokens.iter().enumerate() {
329              let input = Tensor::new(&[*token], &device_cl)?.unsqueeze(0)?;
330              let logits = model_cl.lock().await.forward(&input, pos)?;
331              let logits = logits.squeeze(0)?;
332
333              next_token = logits_processor.sample(&logits)?;
334          }
335
336          let mut index = 0;
337
338          loop {
339              if next_token == eos_token {
340                  break;
341              }
342
343              all_tokens.push(next_token);
344
345              // Try to convert token to text and yield it
346              if let Some(fragment) = tokenizer_output_stream
347                  .step(next_token)?
348              {
349                  yield fragment;
350              }
351
352              let input = Tensor::new(&[next_token], &device_cl)?.unsqueeze(0)?;
353              let logits = model_cl
354                  .lock().await
355                  .forward(&input, prompt_len + index)?;
356              let logits = logits.squeeze(0)?;
357
358              if params.repeat_penalty != 1.0 {
359                  let start_at = all_tokens.len()
360                      .saturating_sub(params.repeat_last_n);
361
362                  candle_transformers::utils::apply_repeat_penalty(
363                      &logits,
364                      params.repeat_penalty,
365                      &all_tokens[start_at..],
366                  )?;
367              }
368
369              next_token = logits_processor.sample(&logits)?;
370              index += 1;
371          }
372      };
373
374      Ok(Box::pin(stream) as StringStream)
375    }))
376  }
377  fn generate_stream(
378    &self,
379    prompt: GenerationPrompt,
380  ) -> Result<impl std::prelude::rust_2024::Future<Output = Result<StringStream>> + Send>
381  {
382    generate_with_chat(self, prompt)
383  }
384}