Skip to main content

kproc_llm/
candle.rs

1//! Interface with candle llm
2
3use std::{future::Future, sync::Arc};
4
5use async_stream::try_stream;
6
7#[cfg(feature = "candle-git")]
8use candle_git_core as candle_core;
9
10#[cfg(feature = "candle-git")]
11use candle_git_transformers as candle_transformers;
12
13use candle_core::{
14  quantized::{ggml_file, gguf_file},
15  Device, Tensor,
16};
17use candle_transformers::{
18  generation::{LogitsProcessor, Sampling},
19  models,
20};
21use smart_default::SmartDefault as Default;
22use tokenizers::Tokenizer;
23
24use crate::{generate_with_chat, prelude::*};
25
26pub mod factory;
27
28fn create_llama_template() -> template::Template
29{
30  template::Template::new(include_str!("../data/templates/llama")).unwrap()
31}
32
33/// Enum to select the underlying base model
34#[derive(Default, Debug)]
35pub enum BaseModel
36{
37  /// For Llama3 models
38  #[default]
39  QuantizedLlama,
40  /// For SmolLM3 models
41  #[cfg(feature = "candle-git")]
42  SmolLM3,
43}
44
45#[derive(Debug, Default, Clone)]
46struct Params
47{
48  /// The temperature used to generate samples, use 0 for greedy sampling.
49  #[default(0.8)]
50  temperature: f64,
51
52  /// Nucleus sampling probability cutoff.
53  top_p: Option<f64>,
54
55  /// Only sample among the top K samples.
56  top_k: Option<usize>,
57
58  /// The seed to use when generating random samples.
59  #[default(299792458)]
60  seed: u64,
61
62  /// Penalty to be applied for repeating tokens, 1. means no penalty.
63  #[default(1.1)]
64  repeat_penalty: f32,
65
66  /// The context size to consider for the repeat penalty.
67  #[default(64)]
68  repeat_last_n: usize,
69}
70
71/// Builder for configuring candle interface
72#[derive(Debug, Default)]
73pub struct Builder
74{
75  base_model: BaseModel,
76
77  model_path: Option<String>,
78  repo: Option<String>,
79  model: Option<String>,
80  #[default("main".into())]
81  revision: String,
82  tokenizer_path: Option<String>,
83  tokenizer_repo: String,
84
85  end_of_stream: String,
86
87  #[default(create_llama_template())]
88  template: template::Template,
89
90  params: Params,
91
92  /// Run on CPU rather than GPU even if a GPU is available.
93  #[default(true)]
94  cpu: bool,
95
96  /// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
97  #[default(1)]
98  gqa: usize,
99}
100
101fn format_size(size_in_bytes: usize) -> String
102{
103  if size_in_bytes < 1_000
104  {
105    format!("{size_in_bytes}B")
106  }
107  else if size_in_bytes < 1_000_000
108  {
109    format!("{:.2}KB", size_in_bytes as f64 / 1e3)
110  }
111  else if size_in_bytes < 1_000_000_000
112  {
113    format!("{:.2}MB", size_in_bytes as f64 / 1e6)
114  }
115  else
116  {
117    format!("{:.2}GB", size_in_bytes as f64 / 1e9)
118  }
119}
120
121fn device(cpu: bool) -> Result<Device>
122{
123  if cpu
124  {
125    Ok(Device::Cpu)
126  }
127  else if candle_core::utils::cuda_is_available()
128  {
129    Ok(Device::new_cuda(0)?)
130  }
131  else if candle_core::utils::metal_is_available()
132  {
133    Ok(Device::new_metal(0)?)
134  }
135  else
136  {
137    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
138    {
139      log::warn!(
140        "Running on CPU, to run on GPU(metal), build this example with `--features metal`"
141      );
142    }
143    Ok(Device::Cpu)
144  }
145}
146impl Builder
147{
148  /// Set the base model
149  pub fn base_model(mut self, base_model: BaseModel) -> Self
150  {
151    self.base_model = base_model;
152    self
153  }
154  /// Set the model
155  pub fn model(mut self, repo: impl Into<String>, model: impl Into<String>) -> Self
156  {
157    self.repo = Some(repo.into());
158    self.model = Some(model.into());
159    self
160  }
161  /// Set the revision used for the model
162  pub fn revision(mut self, revision: impl Into<String>) -> Self
163  {
164    self.revision = revision.into();
165    self
166  }
167  /// Set the tokenizer_repo
168  pub fn tokenizer_repo(mut self, tokenizer_repo: impl Into<String>) -> Self
169  {
170    self.tokenizer_repo = tokenizer_repo.into();
171    self
172  }
173  /// Set the token used for end of stream
174  pub fn end_of_stream(mut self, end_of_stream: impl Into<String>) -> Self
175  {
176    self.end_of_stream = end_of_stream.into();
177    self
178  }
179  /// Set the template
180  pub fn template(mut self, template: impl Into<template::Template>) -> Self
181  {
182    self.template = template.into();
183    self
184  }
185  /// Build the candle interface
186  pub async fn build(self) -> Result<Candle>
187  {
188    let tokenizer_path = match self.tokenizer_path
189    {
190      Some(tokenizer_path) => std::path::PathBuf::from(tokenizer_path),
191      None =>
192      {
193        let api = hf_hub::api::tokio::Api::new()?;
194        let api = api.model(self.tokenizer_repo.clone());
195        api.get("tokenizer.json").await?
196      }
197    };
198    let tokenizer = Tokenizer::from_file(tokenizer_path)?;
199
200    let model_path = match self.model_path
201    {
202      Some(model_path) => std::path::PathBuf::from(model_path),
203      None => match (self.repo, self.model)
204      {
205        (Some(repo), Some(model)) =>
206        {
207          let api = hf_hub::api::tokio::Api::new()?;
208          api
209            .repo(hf_hub::Repo::with_revision(
210              repo.to_string(),
211              hf_hub::RepoType::Model,
212              self.revision,
213            ))
214            .get(&model)
215            .await?
216        }
217        _ => Err(Error::UndefinedModel)?,
218      },
219    };
220
221    let device = device(self.cpu)?;
222    let mut file = std::fs::File::open(&model_path)?;
223    let start = std::time::Instant::now();
224
225    let model_weights = match model_path.extension().and_then(|v| v.to_str())
226    {
227      Some("gguf") => match self.base_model
228      {
229        BaseModel::QuantizedLlama =>
230        {
231          let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
232          let mut total_size_in_bytes = 0;
233          for (_, tensor) in model.tensor_infos.iter()
234          {
235            let elem_count = tensor.shape.elem_count();
236            total_size_in_bytes +=
237              elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
238          }
239          log::info!(
240            "loaded {:?} tensors ({}) in {:.2}s",
241            model.tensor_infos.len(),
242            &format_size(total_size_in_bytes),
243            start.elapsed().as_secs_f32(),
244          );
245
246          ModelWeights::QuantizedLlama(models::quantized_llama::ModelWeights::from_gguf(
247            model, &mut file, &device,
248          )?)
249        }
250        #[cfg(feature = "candle-git")]
251        BaseModel::SmolLM3 =>
252        {
253          use models::smol::quantized_smollm3::QuantizedModelForCausalLM;
254          ModelWeights::QuantizedSmolLM3(QuantizedModelForCausalLM::from_gguf(
255            &model_path,
256            &device,
257          )?)
258        }
259      },
260      Some("ggml" | "bin") | Some(_) | None =>
261      {
262        let model =
263          ggml_file::Content::read(&mut file, &device).map_err(|e| e.with_path(model_path))?;
264        let mut total_size_in_bytes = 0;
265        for (_, tensor) in model.tensors.iter()
266        {
267          let elem_count = tensor.shape().elem_count();
268          total_size_in_bytes +=
269            elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
270        }
271        log::info!(
272          "loaded {:?} tensors ({}) in {:.2}s",
273          model.tensors.len(),
274          &format_size(total_size_in_bytes),
275          start.elapsed().as_secs_f32(),
276        );
277        log::info!("params: {:?}", model.hparams);
278        match self.base_model
279        {
280          BaseModel::QuantizedLlama => ModelWeights::QuantizedLlama(
281            models::quantized_llama::ModelWeights::from_ggml(model, self.gqa)?,
282          ),
283          #[cfg(feature = "candle-git")]
284          BaseModel::SmolLM3 => Err(Error::UnsupportedFileFormat)?,
285        }
286      }
287    };
288    let eos_token = *tokenizer
289      .get_vocab(true)
290      .get(&self.end_of_stream)
291      .ok_or_else(|| Error::UnknownEndOfStream(self.end_of_stream.to_string()))?;
292
293    Ok(Candle {
294      model_weights: model_weights.into(),
295      tokenizer: tokenizer.into(),
296      template: self.template,
297      params: self.params,
298      eos_token,
299      device,
300    })
301  }
302}
303
304enum ModelWeights
305{
306  QuantizedLlama(models::quantized_llama::ModelWeights),
307  #[cfg(feature = "candle-git")]
308  QuantizedSmolLM3(models::smol::quantized_smollm3::QuantizedModelForCausalLM),
309}
310
311impl ModelWeights
312{
313  fn forward(&mut self, input: &Tensor, pos: usize) -> Result<Tensor>
314  {
315    match self
316    {
317      Self::QuantizedLlama(model) => Ok(model.forward(input, pos)?),
318      #[cfg(feature = "candle-git")]
319      Self::QuantizedSmolLM3(model) => Ok(model.forward(input, pos)?),
320    }
321  }
322}
323
324/// Interface to candle
325pub struct Candle
326{
327  model_weights: ccutils::futures::ArcMutex<ModelWeights>,
328  tokenizer: Arc<tokenizers::Tokenizer>,
329  template: template::Template,
330
331  params: Params,
332
333  eos_token: u32,
334
335  device: Device,
336}
337
338impl Candle
339{
340  /// Instantiate a `llama` model
341  pub fn build() -> Builder
342  {
343    Builder::default()
344  }
345}
346
347impl LargeLanguageModel for Candle
348{
349  fn chat_stream(
350    &self,
351    prompt: ChatPrompt,
352  ) -> Result<impl Future<Output = Result<StringStream>> + Send>
353  {
354    let prompt_str = self.template.render(
355      &prompt.messages,
356      prompt.options.thinking,
357      prompt.template_context,
358    )?;
359
360    let device = self.device.clone();
361    let model_weights = self.model_weights.clone();
362    let tokenizer = self.tokenizer.clone();
363    let params = self.params.clone();
364    let eos_token = self.eos_token;
365
366    Ok(Box::pin(async move {
367      // Encode prompt
368      let prompt_tokens_encoded = tokenizer.encode(prompt_str, true)?;
369      let prompt_tokens = prompt_tokens_encoded.get_ids().to_vec();
370      let mut all_tokens = prompt_tokens.clone().to_vec();
371
372      // Build logits processor
373      let mut logits_processor = {
374        let temperature = params.temperature;
375        let sampling = if temperature <= 0.0
376        {
377          Sampling::ArgMax
378        }
379        else
380        {
381          match (params.top_k, params.top_p)
382          {
383            (None, None) => Sampling::All { temperature },
384            (Some(k), None) => Sampling::TopK { k, temperature },
385            (None, Some(p)) => Sampling::TopP { p, temperature },
386            (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
387          }
388        };
389        LogitsProcessor::from_sampling(params.seed, sampling)
390      };
391
392      let prompt_len = prompt_tokens.len();
393      let device_cl = device.clone();
394      let model_cl = model_weights.clone();
395
396      let stream = try_stream! {
397          let mut tokenizer_output_stream = tokenizer.decode_stream(false);
398
399          let mut next_token = 0;
400          for (pos, token) in prompt_tokens.iter().enumerate() {
401              let input = Tensor::new(&[*token], &device_cl)?.unsqueeze(0)?;
402              let logits = model_cl.lock().await.forward(&input, pos)?;
403              let logits = logits.squeeze(0)?;
404              let logits = logits.squeeze(0)?;
405              next_token = logits_processor.sample(&logits)?;
406          }
407
408          let mut index = 0;
409
410          loop {
411              if next_token == eos_token {
412                  break;
413              }
414
415              all_tokens.push(next_token);
416
417              // Try to convert token to text and yield it
418              if let Some(fragment) = tokenizer_output_stream
419                  .step(next_token)?
420              {
421                  yield fragment;
422              }
423
424              let input = Tensor::new(&[next_token], &device_cl)?.unsqueeze(0)?;
425              let logits = model_cl
426                  .lock().await
427                  .forward(&input, prompt_len + index)?;
428              let logits = logits.squeeze(0)?;
429              let logits = logits.squeeze(0)?;
430
431              if params.repeat_penalty != 1.0 {
432                  let start_at = all_tokens.len()
433                      .saturating_sub(params.repeat_last_n);
434
435                  candle_transformers::utils::apply_repeat_penalty(
436                      &logits,
437                      params.repeat_penalty,
438                      &all_tokens[start_at..],
439                  )?;
440              }
441
442              next_token = logits_processor.sample(&logits)?;
443              index += 1;
444          }
445      };
446
447      Ok(Box::pin(stream) as StringStream)
448    }))
449  }
450  fn generate_stream(
451    &self,
452    prompt: GenerationPrompt,
453  ) -> Result<impl std::prelude::rust_2024::Future<Output = Result<StringStream>> + Send>
454  {
455    generate_with_chat(self, prompt)
456  }
457}