use std::path::PathBuf;
use std::sync::Mutex;
use anyhow::{anyhow, Context, Result};
use candle_core::quantized::gguf_file;
use candle_core::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::quantized_qwen2::ModelWeights;
use tokenizers::Tokenizer;
use super::{Backend, GenAnswer, GenRequest, Generator};
#[derive(Debug, Clone, Copy)]
struct Preset {
id: &'static str,
gguf_repo: &'static str,
gguf_file: &'static str,
tokenizer_repo: &'static str,
}
const PRESETS: &[Preset] = &[
Preset {
id: "qwen2-0.5b",
gguf_repo: "Qwen/Qwen2-0.5B-Instruct-GGUF",
gguf_file: "qwen2-0_5b-instruct-q4_0.gguf",
tokenizer_repo: "Qwen/Qwen2-0.5B-Instruct",
},
Preset {
id: "qwen2-1.5b",
gguf_repo: "Qwen/Qwen2-1.5B-Instruct-GGUF",
gguf_file: "qwen2-1_5b-instruct-q4_0.gguf",
tokenizer_repo: "Qwen/Qwen2-1.5B-Instruct",
},
];
const DEFAULT_PRESET: &Preset = &PRESETS[0];
fn resolve_preset(model: &str) -> Result<&'static Preset> {
if model.is_empty() {
return Ok(DEFAULT_PRESET);
}
PRESETS
.iter()
.find(|p| p.id == model)
.ok_or_else(|| {
let ids: Vec<&str> = PRESETS.iter().map(|p| p.id).collect();
anyhow!("candle: unknown model `{model}` — known presets: {}", ids.join(", "))
})
}
struct Loaded {
model: ModelWeights,
tokenizer: Tokenizer,
device: Device,
}
pub struct CandleGenerator {
id: String,
preset: &'static Preset,
loaded: Mutex<Option<Loaded>>,
}
impl CandleGenerator {
pub fn new(model: &str) -> Result<Self> {
let preset = resolve_preset(model)?;
Ok(Self {
id: format!("candle:{}", preset.id),
preset,
loaded: Mutex::new(None),
})
}
fn cached_gguf(&self) -> Option<PathBuf> {
cached_hub_file(self.preset.gguf_repo, self.preset.gguf_file)
}
fn cached_tokenizer(&self) -> Option<PathBuf> {
cached_hub_file(self.preset.tokenizer_repo, "tokenizer.json")
}
fn load(&self) -> Result<Loaded> {
let api = hf_hub::api::sync::Api::new().context("candle: hf-hub api init")?;
let gguf_path = api
.model(self.preset.gguf_repo.to_string())
.get(self.preset.gguf_file)
.context("candle: fetch GGUF weights")?;
let tok_path = api
.model(self.preset.tokenizer_repo.to_string())
.get("tokenizer.json")
.context("candle: fetch tokenizer.json")?;
let device = Device::Cpu;
let mut file = std::fs::File::open(&gguf_path)
.with_context(|| format!("candle: open {}", gguf_path.display()))?;
let content = gguf_file::Content::read(&mut file)
.map_err(|e| anyhow!("candle: read GGUF: {e}"))?;
let model = ModelWeights::from_gguf(content, &mut file, &device)
.map_err(|e| anyhow!("candle: build model from GGUF: {e}"))?;
let tokenizer =
Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("candle: load tokenizer: {e}"))?;
Ok(Loaded { model, tokenizer, device })
}
fn format_prompt(req: &GenRequest) -> String {
let mut s = String::new();
if let Some(sys) = &req.system {
s.push_str(&format!("<|im_start|>system\n{sys}<|im_end|>\n"));
}
s.push_str(&format!(
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
req.prompt
));
s
}
}
impl Backend for CandleGenerator {
fn id(&self) -> &str {
&self.id
}
fn available(&self) -> bool {
self.cached_gguf().is_some() && self.cached_tokenizer().is_some()
}
}
impl Generator for CandleGenerator {
fn complete(&self, req: &GenRequest) -> Result<GenAnswer> {
let started = std::time::Instant::now();
let mut guard = self.loaded.lock().expect("candle loaded mutex");
if guard.is_none() {
*guard = Some(self.load()?);
}
let Loaded { model, tokenizer, device } = guard.as_mut().expect("loaded set above");
let prompt = Self::format_prompt(req);
let encoding =
tokenizer.encode(prompt, true).map_err(|e| anyhow!("candle: encode prompt: {e}"))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let tokens_in = prompt_tokens.len() as i64;
let mut logits_processor = LogitsProcessor::new(
42,
if req.temperature > 0.0 { Some(req.temperature as f64) } else { None },
None,
);
let mut all_tokens: Vec<u32> = Vec::new();
let input = Tensor::new(prompt_tokens.as_slice(), device)
.map_err(|e| anyhow!("candle: prompt tensor: {e}"))?
.unsqueeze(0)
.map_err(|e| anyhow!("candle: unsqueeze: {e}"))?;
let mut logits = model
.forward(&input, 0)
.map_err(|e| anyhow!("candle: prefill forward: {e}"))?;
logits = logits
.squeeze(0)
.map_err(|e| anyhow!("candle: squeeze logits: {e}"))?;
let eos = tokenizer.token_to_id("<|im_end|>").unwrap_or(u32::MAX);
let mut next = logits_processor
.sample(&logits)
.map_err(|e| anyhow!("candle: sample: {e}"))?;
let mut index = prompt_tokens.len();
#[allow(clippy::explicit_counter_loop)]
for _ in 0..req.max_tokens {
if next == eos {
break;
}
all_tokens.push(next);
let input = Tensor::new(&[next], device)
.map_err(|e| anyhow!("candle: step tensor: {e}"))?
.unsqueeze(0)
.map_err(|e| anyhow!("candle: step unsqueeze: {e}"))?;
let l = model
.forward(&input, index)
.map_err(|e| anyhow!("candle: decode forward: {e}"))?
.squeeze(0)
.map_err(|e| anyhow!("candle: decode squeeze: {e}"))?;
next = logits_processor.sample(&l).map_err(|e| anyhow!("candle: sample: {e}"))?;
index += 1;
if !req.stop.is_empty() {
let so_far = tokenizer
.decode(&all_tokens, true)
.map_err(|e| anyhow!("candle: decode: {e}"))?;
if req.stop.iter().any(|s| so_far.contains(s)) {
break;
}
}
}
let text = tokenizer
.decode(&all_tokens, true)
.map_err(|e| anyhow!("candle: final decode: {e}"))?;
let latency_ms = started.elapsed().as_secs_f64() * 1000.0;
let tokens_out = all_tokens.len() as i64;
let tokens_per_s = if latency_ms > 0.0 {
tokens_out as f64 / (latency_ms / 1000.0)
} else {
0.0
};
Ok(GenAnswer { text, tokens_in, tokens_out, tokens_per_s, latency_ms })
}
}
fn cached_hub_file(repo: &str, file: &str) -> Option<PathBuf> {
let api = hf_hub::api::sync::Api::new().ok()?;
let cached = api.model(repo.to_string()).get(file);
match cached {
Ok(p) if p.exists() => Some(p),
_ => {
let cache = hf_hub::Cache::default();
cache.model(repo.to_string()).get(file).filter(|p| p.exists())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn unknown_preset_errors_with_known_list() {
let err = match CandleGenerator::new("no-such-model") {
Ok(_) => panic!("unknown preset must error"),
Err(e) => e.to_string(),
};
assert!(err.contains("unknown model"), "{err}");
assert!(err.contains("qwen2-0.5b"), "lists presets: {err}");
}
#[test]
fn default_preset_when_empty() {
let gen = CandleGenerator::new("").unwrap();
assert_eq!(gen.id(), "candle:qwen2-0.5b");
}
#[test]
fn constructs_and_reports_availability_without_loading() {
let gen = CandleGenerator::new("qwen2-0.5b").unwrap();
assert_eq!(gen.id(), "candle:qwen2-0.5b");
let _ = gen.available();
}
#[test]
#[ignore = "downloads a multi-GB GGUF model"]
fn real_generation_round_trips() {
let gen = CandleGenerator::new("qwen2-0.5b").unwrap();
let req = GenRequest::new("Reply with the single word: pong").with_max_tokens(8);
let ans = gen.complete(&req).unwrap();
assert!(ans.tokens_in > 0);
assert!(ans.tokens_out > 0);
assert!(!ans.text.is_empty());
}
}