1use std::{future::Future, sync::Arc};
4
5use async_stream::try_stream;
6use candle_core::{
7 quantized::{ggml_file, gguf_file},
8 Device, Tensor,
9};
10use candle_transformers::{
11 generation::{LogitsProcessor, Sampling},
12 models::quantized_llama as model,
13};
14use model::ModelWeights;
15use smart_default::SmartDefault as Default;
16use tokenizers::Tokenizer;
17
18use crate::{generate_with_chat, prelude::*};
19
20pub mod factory;
21
22fn create_llama_template() -> template::Template
23{
24 template::Template::new(include_str!("../data/templates/llama")).unwrap()
25}
26
27#[derive(Debug, Default, Clone)]
28struct Params
29{
30 #[default(0.8)]
32 temperature: f64,
33
34 top_p: Option<f64>,
36
37 top_k: Option<usize>,
39
40 #[default(299792458)]
42 seed: u64,
43
44 #[default(1.1)]
46 repeat_penalty: f32,
47
48 #[default(64)]
50 repeat_last_n: usize,
51}
52
53#[derive(Debug, Default)]
55pub struct Builder
56{
57 model_path: Option<String>,
58 repo: Option<String>,
59 model: Option<String>,
60 #[default("main".into())]
61 revision: String,
62 tokenizer_path: Option<String>,
63 tokenizer_repo: String,
64
65 end_of_stream: String,
66
67 #[default(create_llama_template())]
68 template: template::Template,
69
70 params: Params,
71
72 #[default(true)]
74 cpu: bool,
75
76 #[default(1)]
78 gqa: usize,
79}
80
81fn format_size(size_in_bytes: usize) -> String
82{
83 if size_in_bytes < 1_000
84 {
85 format!("{size_in_bytes}B")
86 }
87 else if size_in_bytes < 1_000_000
88 {
89 format!("{:.2}KB", size_in_bytes as f64 / 1e3)
90 }
91 else if size_in_bytes < 1_000_000_000
92 {
93 format!("{:.2}MB", size_in_bytes as f64 / 1e6)
94 }
95 else
96 {
97 format!("{:.2}GB", size_in_bytes as f64 / 1e9)
98 }
99}
100
101fn device(cpu: bool) -> Result<Device>
102{
103 if cpu
104 {
105 Ok(Device::Cpu)
106 }
107 else if candle_core::utils::cuda_is_available()
108 {
109 Ok(Device::new_cuda(0)?)
110 }
111 else if candle_core::utils::metal_is_available()
112 {
113 Ok(Device::new_metal(0)?)
114 }
115 else
116 {
117 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
118 {
119 log::warn!(
120 "Running on CPU, to run on GPU(metal), build this example with `--features metal`"
121 );
122 }
123 Ok(Device::Cpu)
124 }
125}
126impl Builder
127{
128 pub fn model(mut self, repo: impl Into<String>, model: impl Into<String>) -> Self
130 {
131 self.repo = Some(repo.into());
132 self.model = Some(model.into());
133 self
134 }
135 pub fn revision(mut self, revision: impl Into<String>) -> Self
137 {
138 self.revision = revision.into();
139 self
140 }
141 pub fn tokenizer_repo(mut self, tokenizer_repo: impl Into<String>) -> Self
143 {
144 self.tokenizer_repo = tokenizer_repo.into();
145 self
146 }
147 pub fn end_of_stream(mut self, end_of_stream: impl Into<String>) -> Self
149 {
150 self.end_of_stream = end_of_stream.into();
151 self
152 }
153 pub fn template(mut self, template: impl Into<template::Template>) -> Self
155 {
156 self.template = template.into();
157 self
158 }
159 pub async fn build(self) -> Result<Candle>
161 {
162 let tokenizer_path = match self.tokenizer_path
163 {
164 Some(tokenizer_path) => std::path::PathBuf::from(tokenizer_path),
165 None =>
166 {
167 let api = hf_hub::api::tokio::Api::new()?;
168 let api = api.model(self.tokenizer_repo.clone());
169 api.get("tokenizer.json").await?
170 }
171 };
172 let tokenizer = Tokenizer::from_file(tokenizer_path)?;
173
174 let model_path = match self.model_path
175 {
176 Some(model_path) => std::path::PathBuf::from(model_path),
177 None => match (self.repo, self.model)
178 {
179 (Some(repo), Some(model)) =>
180 {
181 let api = hf_hub::api::tokio::Api::new()?;
182 api
183 .repo(hf_hub::Repo::with_revision(
184 repo.to_string(),
185 hf_hub::RepoType::Model,
186 self.revision,
187 ))
188 .get(&model)
189 .await?
190 }
191 _ => Err(Error::UndefinedModel)?,
192 },
193 };
194
195 let device = device(self.cpu)?;
196 let mut file = std::fs::File::open(&model_path)?;
197 let start = std::time::Instant::now();
198
199 let model_weights = match model_path.extension().and_then(|v| v.to_str())
200 {
201 Some("gguf") =>
202 {
203 let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
204 let mut total_size_in_bytes = 0;
205 for (_, tensor) in model.tensor_infos.iter()
206 {
207 let elem_count = tensor.shape.elem_count();
208 total_size_in_bytes +=
209 elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
210 }
211 log::info!(
212 "loaded {:?} tensors ({}) in {:.2}s",
213 model.tensor_infos.len(),
214 &format_size(total_size_in_bytes),
215 start.elapsed().as_secs_f32(),
216 );
217 ModelWeights::from_gguf(model, &mut file, &device)?
218 }
219 Some("ggml" | "bin") | Some(_) | None =>
220 {
221 let model =
222 ggml_file::Content::read(&mut file, &device).map_err(|e| e.with_path(model_path))?;
223 let mut total_size_in_bytes = 0;
224 for (_, tensor) in model.tensors.iter()
225 {
226 let elem_count = tensor.shape().elem_count();
227 total_size_in_bytes +=
228 elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
229 }
230 log::info!(
231 "loaded {:?} tensors ({}) in {:.2}s",
232 model.tensors.len(),
233 &format_size(total_size_in_bytes),
234 start.elapsed().as_secs_f32(),
235 );
236 log::info!("params: {:?}", model.hparams);
237 ModelWeights::from_ggml(model, self.gqa)?
238 }
239 };
240 let eos_token = *tokenizer
241 .get_vocab(true)
242 .get(&self.end_of_stream)
243 .ok_or_else(|| Error::UnknownEndOfStream(self.end_of_stream.to_string()))?;
244
245 Ok(Candle {
246 model_weights: model_weights.into(),
247 tokenizer: tokenizer.into(),
248 template: self.template,
249 params: self.params,
250 eos_token,
251 device,
252 })
253 }
254}
255
256pub struct Candle
258{
259 model_weights: ccutils::futures::ArcMutex<ModelWeights>,
260 tokenizer: Arc<tokenizers::Tokenizer>,
261 template: template::Template,
262
263 params: Params,
264
265 eos_token: u32,
266
267 device: Device,
268}
269
270impl Candle
271{
272 pub fn build() -> Builder
274 {
275 Builder::default()
276 }
277}
278
279impl LargeLanguageModel for Candle
280{
281 fn chat_stream(
282 &self,
283 prompt: ChatPrompt,
284 ) -> Result<impl Future<Output = Result<StringStream>> + Send>
285 {
286 let prompt_str = self.template.render(&prompt.messages)?;
287
288 let device = self.device.clone();
289 let model_weights = self.model_weights.clone();
290 let tokenizer = self.tokenizer.clone();
291 let params = self.params.clone();
292 let eos_token = self.eos_token;
293
294 Ok(Box::pin(async move {
295 let prompt_tokens_encoded = tokenizer.encode(prompt_str, true)?;
297 let prompt_tokens = prompt_tokens_encoded.get_ids().to_vec();
298 let mut all_tokens = prompt_tokens.clone().to_vec();
299
300 let mut logits_processor = {
302 let temperature = params.temperature;
303 let sampling = if temperature <= 0.0
304 {
305 Sampling::ArgMax
306 }
307 else
308 {
309 match (params.top_k, params.top_p)
310 {
311 (None, None) => Sampling::All { temperature },
312 (Some(k), None) => Sampling::TopK { k, temperature },
313 (None, Some(p)) => Sampling::TopP { p, temperature },
314 (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
315 }
316 };
317 LogitsProcessor::from_sampling(params.seed, sampling)
318 };
319
320 let prompt_len = prompt_tokens.len();
321 let device_cl = device.clone();
322 let model_cl = model_weights.clone();
323
324 let stream = try_stream! {
325 let mut tokenizer_output_stream = tokenizer.decode_stream(false);
326
327 let mut next_token = 0;
328 for (pos, token) in prompt_tokens.iter().enumerate() {
329 let input = Tensor::new(&[*token], &device_cl)?.unsqueeze(0)?;
330 let logits = model_cl.lock().await.forward(&input, pos)?;
331 let logits = logits.squeeze(0)?;
332
333 next_token = logits_processor.sample(&logits)?;
334 }
335
336 let mut index = 0;
337
338 loop {
339 if next_token == eos_token {
340 break;
341 }
342
343 all_tokens.push(next_token);
344
345 if let Some(fragment) = tokenizer_output_stream
347 .step(next_token)?
348 {
349 yield fragment;
350 }
351
352 let input = Tensor::new(&[next_token], &device_cl)?.unsqueeze(0)?;
353 let logits = model_cl
354 .lock().await
355 .forward(&input, prompt_len + index)?;
356 let logits = logits.squeeze(0)?;
357
358 if params.repeat_penalty != 1.0 {
359 let start_at = all_tokens.len()
360 .saturating_sub(params.repeat_last_n);
361
362 candle_transformers::utils::apply_repeat_penalty(
363 &logits,
364 params.repeat_penalty,
365 &all_tokens[start_at..],
366 )?;
367 }
368
369 next_token = logits_processor.sample(&logits)?;
370 index += 1;
371 }
372 };
373
374 Ok(Box::pin(stream) as StringStream)
375 }))
376 }
377 fn generate_stream(
378 &self,
379 prompt: GenerationPrompt,
380 ) -> Result<impl std::prelude::rust_2024::Future<Output = Result<StringStream>> + Send>
381 {
382 generate_with_chat(self, prompt)
383 }
384}