1use std::{future::Future, sync::Arc};
4
5use async_stream::try_stream;
6
7#[cfg(feature = "candle-git")]
8use candle_git_core as candle_core;
9
10#[cfg(feature = "candle-git")]
11use candle_git_transformers as candle_transformers;
12
13use candle_core::{
14 quantized::{ggml_file, gguf_file},
15 Device, Tensor,
16};
17use candle_transformers::{
18 generation::{LogitsProcessor, Sampling},
19 models,
20};
21use smart_default::SmartDefault as Default;
22use tokenizers::Tokenizer;
23
24use crate::{generate_with_chat, prelude::*};
25
26pub mod factory;
27
28fn create_llama_template() -> template::Template
29{
30 template::Template::new(include_str!("../data/templates/llama")).unwrap()
31}
32
33#[derive(Default, Debug)]
35pub enum BaseModel
36{
37 #[default]
39 QuantizedLlama,
40 #[cfg(feature = "candle-git")]
42 SmolLM3,
43}
44
45#[derive(Debug, Default, Clone)]
46struct Params
47{
48 #[default(0.8)]
50 temperature: f64,
51
52 top_p: Option<f64>,
54
55 top_k: Option<usize>,
57
58 #[default(299792458)]
60 seed: u64,
61
62 #[default(1.1)]
64 repeat_penalty: f32,
65
66 #[default(64)]
68 repeat_last_n: usize,
69}
70
71#[derive(Debug, Default)]
73pub struct Builder
74{
75 base_model: BaseModel,
76
77 model_path: Option<String>,
78 repo: Option<String>,
79 model: Option<String>,
80 #[default("main".into())]
81 revision: String,
82 tokenizer_path: Option<String>,
83 tokenizer_repo: String,
84
85 end_of_stream: String,
86
87 #[default(create_llama_template())]
88 template: template::Template,
89
90 params: Params,
91
92 #[default(true)]
94 cpu: bool,
95
96 #[default(1)]
98 gqa: usize,
99}
100
101fn format_size(size_in_bytes: usize) -> String
102{
103 if size_in_bytes < 1_000
104 {
105 format!("{size_in_bytes}B")
106 }
107 else if size_in_bytes < 1_000_000
108 {
109 format!("{:.2}KB", size_in_bytes as f64 / 1e3)
110 }
111 else if size_in_bytes < 1_000_000_000
112 {
113 format!("{:.2}MB", size_in_bytes as f64 / 1e6)
114 }
115 else
116 {
117 format!("{:.2}GB", size_in_bytes as f64 / 1e9)
118 }
119}
120
121fn device(cpu: bool) -> Result<Device>
122{
123 if cpu
124 {
125 Ok(Device::Cpu)
126 }
127 else if candle_core::utils::cuda_is_available()
128 {
129 Ok(Device::new_cuda(0)?)
130 }
131 else if candle_core::utils::metal_is_available()
132 {
133 Ok(Device::new_metal(0)?)
134 }
135 else
136 {
137 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
138 {
139 log::warn!(
140 "Running on CPU, to run on GPU(metal), build this example with `--features metal`"
141 );
142 }
143 Ok(Device::Cpu)
144 }
145}
146impl Builder
147{
148 pub fn base_model(mut self, base_model: BaseModel) -> Self
150 {
151 self.base_model = base_model;
152 self
153 }
154 pub fn model(mut self, repo: impl Into<String>, model: impl Into<String>) -> Self
156 {
157 self.repo = Some(repo.into());
158 self.model = Some(model.into());
159 self
160 }
161 pub fn revision(mut self, revision: impl Into<String>) -> Self
163 {
164 self.revision = revision.into();
165 self
166 }
167 pub fn tokenizer_repo(mut self, tokenizer_repo: impl Into<String>) -> Self
169 {
170 self.tokenizer_repo = tokenizer_repo.into();
171 self
172 }
173 pub fn end_of_stream(mut self, end_of_stream: impl Into<String>) -> Self
175 {
176 self.end_of_stream = end_of_stream.into();
177 self
178 }
179 pub fn template(mut self, template: impl Into<template::Template>) -> Self
181 {
182 self.template = template.into();
183 self
184 }
185 pub async fn build(self) -> Result<Candle>
187 {
188 let tokenizer_path = match self.tokenizer_path
189 {
190 Some(tokenizer_path) => std::path::PathBuf::from(tokenizer_path),
191 None =>
192 {
193 let api = hf_hub::api::tokio::Api::new()?;
194 let api = api.model(self.tokenizer_repo.clone());
195 api.get("tokenizer.json").await?
196 }
197 };
198 let tokenizer = Tokenizer::from_file(tokenizer_path)?;
199
200 let model_path = match self.model_path
201 {
202 Some(model_path) => std::path::PathBuf::from(model_path),
203 None => match (self.repo, self.model)
204 {
205 (Some(repo), Some(model)) =>
206 {
207 let api = hf_hub::api::tokio::Api::new()?;
208 api
209 .repo(hf_hub::Repo::with_revision(
210 repo.to_string(),
211 hf_hub::RepoType::Model,
212 self.revision,
213 ))
214 .get(&model)
215 .await?
216 }
217 _ => Err(Error::UndefinedModel)?,
218 },
219 };
220
221 let device = device(self.cpu)?;
222 let mut file = std::fs::File::open(&model_path)?;
223 let start = std::time::Instant::now();
224
225 let model_weights = match model_path.extension().and_then(|v| v.to_str())
226 {
227 Some("gguf") => match self.base_model
228 {
229 BaseModel::QuantizedLlama =>
230 {
231 let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
232 let mut total_size_in_bytes = 0;
233 for (_, tensor) in model.tensor_infos.iter()
234 {
235 let elem_count = tensor.shape.elem_count();
236 total_size_in_bytes +=
237 elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
238 }
239 log::info!(
240 "loaded {:?} tensors ({}) in {:.2}s",
241 model.tensor_infos.len(),
242 &format_size(total_size_in_bytes),
243 start.elapsed().as_secs_f32(),
244 );
245
246 ModelWeights::QuantizedLlama(models::quantized_llama::ModelWeights::from_gguf(
247 model, &mut file, &device,
248 )?)
249 }
250 #[cfg(feature = "candle-git")]
251 BaseModel::SmolLM3 =>
252 {
253 use models::smol::quantized_smollm3::QuantizedModelForCausalLM;
254 ModelWeights::QuantizedSmolLM3(QuantizedModelForCausalLM::from_gguf(
255 &model_path,
256 &device,
257 )?)
258 }
259 },
260 Some("ggml" | "bin") | Some(_) | None =>
261 {
262 let model =
263 ggml_file::Content::read(&mut file, &device).map_err(|e| e.with_path(model_path))?;
264 let mut total_size_in_bytes = 0;
265 for (_, tensor) in model.tensors.iter()
266 {
267 let elem_count = tensor.shape().elem_count();
268 total_size_in_bytes +=
269 elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
270 }
271 log::info!(
272 "loaded {:?} tensors ({}) in {:.2}s",
273 model.tensors.len(),
274 &format_size(total_size_in_bytes),
275 start.elapsed().as_secs_f32(),
276 );
277 log::info!("params: {:?}", model.hparams);
278 match self.base_model
279 {
280 BaseModel::QuantizedLlama => ModelWeights::QuantizedLlama(
281 models::quantized_llama::ModelWeights::from_ggml(model, self.gqa)?,
282 ),
283 #[cfg(feature = "candle-git")]
284 BaseModel::SmolLM3 => Err(Error::UnsupportedFileFormat)?,
285 }
286 }
287 };
288 let eos_token = *tokenizer
289 .get_vocab(true)
290 .get(&self.end_of_stream)
291 .ok_or_else(|| Error::UnknownEndOfStream(self.end_of_stream.to_string()))?;
292
293 Ok(Candle {
294 model_weights: model_weights.into(),
295 tokenizer: tokenizer.into(),
296 template: self.template,
297 params: self.params,
298 eos_token,
299 device,
300 })
301 }
302}
303
304enum ModelWeights
305{
306 QuantizedLlama(models::quantized_llama::ModelWeights),
307 #[cfg(feature = "candle-git")]
308 QuantizedSmolLM3(models::smol::quantized_smollm3::QuantizedModelForCausalLM),
309}
310
311impl ModelWeights
312{
313 fn forward(&mut self, input: &Tensor, pos: usize) -> Result<Tensor>
314 {
315 match self
316 {
317 Self::QuantizedLlama(model) => Ok(model.forward(input, pos)?),
318 #[cfg(feature = "candle-git")]
319 Self::QuantizedSmolLM3(model) => Ok(model.forward(input, pos)?),
320 }
321 }
322}
323
324pub struct Candle
326{
327 model_weights: ccutils::futures::ArcMutex<ModelWeights>,
328 tokenizer: Arc<tokenizers::Tokenizer>,
329 template: template::Template,
330
331 params: Params,
332
333 eos_token: u32,
334
335 device: Device,
336}
337
338impl Candle
339{
340 pub fn build() -> Builder
342 {
343 Builder::default()
344 }
345}
346
347impl LargeLanguageModel for Candle
348{
349 fn chat_stream(
350 &self,
351 prompt: ChatPrompt,
352 ) -> Result<impl Future<Output = Result<StringStream>> + Send>
353 {
354 let prompt_str = self.template.render(
355 &prompt.messages,
356 prompt.options.thinking,
357 prompt.template_context,
358 )?;
359
360 let device = self.device.clone();
361 let model_weights = self.model_weights.clone();
362 let tokenizer = self.tokenizer.clone();
363 let params = self.params.clone();
364 let eos_token = self.eos_token;
365
366 Ok(Box::pin(async move {
367 let prompt_tokens_encoded = tokenizer.encode(prompt_str, true)?;
369 let prompt_tokens = prompt_tokens_encoded.get_ids().to_vec();
370 let mut all_tokens = prompt_tokens.clone().to_vec();
371
372 let mut logits_processor = {
374 let temperature = params.temperature;
375 let sampling = if temperature <= 0.0
376 {
377 Sampling::ArgMax
378 }
379 else
380 {
381 match (params.top_k, params.top_p)
382 {
383 (None, None) => Sampling::All { temperature },
384 (Some(k), None) => Sampling::TopK { k, temperature },
385 (None, Some(p)) => Sampling::TopP { p, temperature },
386 (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
387 }
388 };
389 LogitsProcessor::from_sampling(params.seed, sampling)
390 };
391
392 let prompt_len = prompt_tokens.len();
393 let device_cl = device.clone();
394 let model_cl = model_weights.clone();
395
396 let stream = try_stream! {
397 let mut tokenizer_output_stream = tokenizer.decode_stream(false);
398
399 let mut next_token = 0;
400 for (pos, token) in prompt_tokens.iter().enumerate() {
401 let input = Tensor::new(&[*token], &device_cl)?.unsqueeze(0)?;
402 let logits = model_cl.lock().await.forward(&input, pos)?;
403 let logits = logits.squeeze(0)?;
404 let logits = logits.squeeze(0)?;
405 next_token = logits_processor.sample(&logits)?;
406 }
407
408 let mut index = 0;
409
410 loop {
411 if next_token == eos_token {
412 break;
413 }
414
415 all_tokens.push(next_token);
416
417 if let Some(fragment) = tokenizer_output_stream
419 .step(next_token)?
420 {
421 yield fragment;
422 }
423
424 let input = Tensor::new(&[next_token], &device_cl)?.unsqueeze(0)?;
425 let logits = model_cl
426 .lock().await
427 .forward(&input, prompt_len + index)?;
428 let logits = logits.squeeze(0)?;
429 let logits = logits.squeeze(0)?;
430
431 if params.repeat_penalty != 1.0 {
432 let start_at = all_tokens.len()
433 .saturating_sub(params.repeat_last_n);
434
435 candle_transformers::utils::apply_repeat_penalty(
436 &logits,
437 params.repeat_penalty,
438 &all_tokens[start_at..],
439 )?;
440 }
441
442 next_token = logits_processor.sample(&logits)?;
443 index += 1;
444 }
445 };
446
447 Ok(Box::pin(stream) as StringStream)
448 }))
449 }
450 fn generate_stream(
451 &self,
452 prompt: GenerationPrompt,
453 ) -> Result<impl std::prelude::rust_2024::Future<Output = Result<StringStream>> + Send>
454 {
455 generate_with_chat(self, prompt)
456 }
457}