use std::path::PathBuf;
use std::sync::Mutex;
use anyhow::{anyhow, Context, Result};
use candle_core::quantized::gguf_file;
use candle_core::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::quantized_qwen2::ModelWeights;
use tokenizers::Tokenizer;
use super::{Backend, GenAnswer, GenRequest, Generator};
#[derive(Debug, Clone, Copy)]
struct Preset {
id: &'static str,
gguf_repo: &'static str,
gguf_file: &'static str,
tokenizer_repo: &'static str,
}
const PRESETS: &[Preset] = &[
Preset {
id: "qwen2-0.5b",
gguf_repo: "Qwen/Qwen2-0.5B-Instruct-GGUF",
gguf_file: "qwen2-0_5b-instruct-q4_0.gguf",
tokenizer_repo: "Qwen/Qwen2-0.5B-Instruct",
},
Preset {
id: "qwen2-1.5b",
gguf_repo: "Qwen/Qwen2-1.5B-Instruct-GGUF",
gguf_file: "qwen2-1_5b-instruct-q4_0.gguf",
tokenizer_repo: "Qwen/Qwen2-1.5B-Instruct",
},
];
const DEFAULT_PRESET: &Preset = &PRESETS[0];
fn resolve_preset(model: &str) -> Result<&'static Preset> {
if model.is_empty() {
return Ok(DEFAULT_PRESET);
}
PRESETS
.iter()
.find(|p| p.id == model)
.ok_or_else(|| {
let ids: Vec<&str> = PRESETS.iter().map(|p| p.id).collect();
anyhow!("candle: unknown model `{model}` — known presets: {}", ids.join(", "))
})
}
struct Loaded {
model: ModelWeights,
tokenizer: Tokenizer,
device: Device,
}
pub struct CandleGenerator {
id: String,
preset: &'static Preset,
loaded: Mutex<Option<Loaded>>,
}
impl CandleGenerator {
pub fn new(model: &str) -> Result<Self> {
let preset = resolve_preset(model)?;
Ok(Self {
id: format!("candle:{}", preset.id),
preset,
loaded: Mutex::new(None),
})
}
fn cached_gguf(&self) -> Option<PathBuf> {
cached_hub_file(self.preset.gguf_repo, self.preset.gguf_file)
}
fn cached_tokenizer(&self) -> Option<PathBuf> {
cached_hub_file(self.preset.tokenizer_repo, "tokenizer.json")
}
fn load(&self) -> Result<Loaded> {
let (gguf_path, tok_path) = match Self::override_paths() {
Some(paths) => paths,
None => {
let api = hf_hub::api::sync::Api::new().context("candle: hf-hub api init")?;
let gguf_path = api
.model(self.preset.gguf_repo.to_string())
.get(self.preset.gguf_file)
.context("candle: fetch GGUF weights")?;
let tok_path = api
.model(self.preset.tokenizer_repo.to_string())
.get("tokenizer.json")
.context("candle: fetch tokenizer.json")?;
(gguf_path, tok_path)
}
};
Self::load_from(&gguf_path, &tok_path)
}
fn override_paths() -> Option<(PathBuf, PathBuf)> {
let gguf = std::env::var_os("NORNIR_CANDLE_GGUF")?;
let tok = std::env::var_os("NORNIR_CANDLE_TOKENIZER")?;
Some((PathBuf::from(gguf), PathBuf::from(tok)))
}
fn load_from(gguf_path: &std::path::Path, tok_path: &std::path::Path) -> Result<Loaded> {
let device = Device::Cpu;
let mut file = std::fs::File::open(gguf_path)
.with_context(|| format!("candle: open {}", gguf_path.display()))?;
let content = gguf_file::Content::read(&mut file)
.map_err(|e| anyhow!("candle: read GGUF: {e}"))?;
let model = ModelWeights::from_gguf(content, &mut file, &device)
.map_err(|e| anyhow!("candle: build model from GGUF: {e}"))?;
let tokenizer =
Tokenizer::from_file(tok_path).map_err(|e| anyhow!("candle: load tokenizer: {e}"))?;
Ok(Loaded { model, tokenizer, device })
}
fn format_prompt(req: &GenRequest) -> String {
let mut s = String::new();
if let Some(sys) = &req.system {
s.push_str(&format!("<|im_start|>system\n{sys}<|im_end|>\n"));
}
s.push_str(&format!(
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
req.prompt
));
s
}
}
impl Backend for CandleGenerator {
fn id(&self) -> &str {
&self.id
}
fn available(&self) -> bool {
self.cached_gguf().is_some() && self.cached_tokenizer().is_some()
}
}
impl Generator for CandleGenerator {
fn complete(&self, req: &GenRequest) -> Result<GenAnswer> {
let started = std::time::Instant::now();
let mut guard = self.loaded.lock().expect("candle loaded mutex");
if guard.is_none() {
*guard = Some(self.load()?);
}
let Loaded { model, tokenizer, device } = guard.as_mut().expect("loaded set above");
let prompt = Self::format_prompt(req);
let encoding =
tokenizer.encode(prompt, true).map_err(|e| anyhow!("candle: encode prompt: {e}"))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let tokens_in = prompt_tokens.len() as i64;
let mut logits_processor = LogitsProcessor::new(
42,
if req.temperature > 0.0 { Some(req.temperature as f64) } else { None },
None,
);
let mut all_tokens: Vec<u32> = Vec::new();
let input = Tensor::new(prompt_tokens.as_slice(), device)
.map_err(|e| anyhow!("candle: prompt tensor: {e}"))?
.unsqueeze(0)
.map_err(|e| anyhow!("candle: unsqueeze: {e}"))?;
let mut logits = model
.forward(&input, 0)
.map_err(|e| anyhow!("candle: prefill forward: {e}"))?;
logits = logits
.squeeze(0)
.map_err(|e| anyhow!("candle: squeeze logits: {e}"))?;
let eos = tokenizer.token_to_id("<|im_end|>").unwrap_or(u32::MAX);
let mut next = logits_processor
.sample(&logits)
.map_err(|e| anyhow!("candle: sample: {e}"))?;
let mut index = prompt_tokens.len();
#[allow(clippy::explicit_counter_loop)]
for _ in 0..req.max_tokens {
if next == eos {
break;
}
all_tokens.push(next);
let input = Tensor::new(&[next], device)
.map_err(|e| anyhow!("candle: step tensor: {e}"))?
.unsqueeze(0)
.map_err(|e| anyhow!("candle: step unsqueeze: {e}"))?;
let l = model
.forward(&input, index)
.map_err(|e| anyhow!("candle: decode forward: {e}"))?
.squeeze(0)
.map_err(|e| anyhow!("candle: decode squeeze: {e}"))?;
next = logits_processor.sample(&l).map_err(|e| anyhow!("candle: sample: {e}"))?;
index += 1;
if !req.stop.is_empty() {
let so_far = tokenizer
.decode(&all_tokens, true)
.map_err(|e| anyhow!("candle: decode: {e}"))?;
if req.stop.iter().any(|s| so_far.contains(s)) {
break;
}
}
}
let text = tokenizer
.decode(&all_tokens, true)
.map_err(|e| anyhow!("candle: final decode: {e}"))?;
let latency_ms = started.elapsed().as_secs_f64() * 1000.0;
let tokens_out = all_tokens.len() as i64;
let tokens_per_s = if latency_ms > 0.0 {
tokens_out as f64 / (latency_ms / 1000.0)
} else {
0.0
};
Ok(GenAnswer { text, tokens_in, tokens_out, tokens_per_s, latency_ms })
}
}
fn cached_hub_file(repo: &str, file: &str) -> Option<PathBuf> {
let api = hf_hub::api::sync::Api::new().ok()?;
let cached = api.model(repo.to_string()).get(file);
match cached {
Ok(p) if p.exists() => Some(p),
_ => {
let cache = hf_hub::Cache::default();
cache.model(repo.to_string()).get(file).filter(|p| p.exists())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn unknown_preset_errors_with_known_list() {
let err = match CandleGenerator::new("no-such-model") {
Ok(_) => panic!("unknown preset must error"),
Err(e) => e.to_string(),
};
assert!(err.contains("unknown model"), "{err}");
assert!(err.contains("qwen2-0.5b"), "lists presets: {err}");
}
#[test]
fn default_preset_when_empty() {
let gen = CandleGenerator::new("").unwrap();
assert_eq!(gen.id(), "candle:qwen2-0.5b");
}
#[test]
fn constructs_and_reports_availability_without_loading() {
let gen = CandleGenerator::new("qwen2-0.5b").unwrap();
assert_eq!(gen.id(), "candle:qwen2-0.5b");
let _ = gen.available();
}
#[test]
fn synthetic_gguf_round_trips_through_the_real_reader() {
use candle_core::quantized::{gguf_file::Value, GgmlDType, QTensor};
use std::io::{Seek, SeekFrom};
let device = Device::Cpu;
let src = Tensor::from_vec(
(0..32).map(|i| i as f32).collect::<Vec<f32>>(),
(1, 32),
&device,
)
.expect("build source tensor");
let qtensor = QTensor::quantize(&src, GgmlDType::Q4_0).expect("quantize tensor");
let arch = Value::String("qwen2-synthetic".to_string());
let ctx_len = Value::U32(2048);
let metadata: Vec<(&str, &Value)> = vec![
("general.architecture", &arch),
("qwen2.context_length", &ctx_len),
];
let tensors: Vec<(&str, &QTensor)> = vec![("token_embd.weight", &qtensor)];
let mut tmp = tempfile::NamedTempFile::new().expect("tempfile");
gguf_file::write(tmp.as_file_mut(), &metadata, &tensors).expect("write synthetic GGUF");
tmp.as_file_mut().seek(SeekFrom::Start(0)).expect("rewind");
let mut f = std::fs::File::open(tmp.path()).expect("reopen synthetic GGUF");
let content = gguf_file::Content::read(&mut f).expect("read synthetic GGUF (real path)");
match content.metadata.get("general.architecture") {
Some(Value::String(s)) => assert_eq!(s, "qwen2-synthetic", "architecture round-trips"),
other => panic!("architecture metadata missing/wrong: {other:?}"),
}
match content.metadata.get("qwen2.context_length") {
Some(Value::U32(n)) => assert_eq!(*n, 2048, "context_length round-trips"),
other => panic!("context_length metadata missing/wrong: {other:?}"),
}
assert!(
content.tensor_infos.contains_key("token_embd.weight"),
"tensor info round-trips; got {:?}",
content.tensor_infos.keys().collect::<Vec<_>>()
);
let read_back = content
.tensor(&mut f, "token_embd.weight", &device)
.expect("read tensor back");
assert_eq!(read_back.shape().dims(), &[1, 32], "tensor shape round-trips");
}
#[test]
fn override_paths_require_both_env_vars() {
let prev_g = std::env::var_os("NORNIR_CANDLE_GGUF");
let prev_t = std::env::var_os("NORNIR_CANDLE_TOKENIZER");
std::env::remove_var("NORNIR_CANDLE_GGUF");
std::env::remove_var("NORNIR_CANDLE_TOKENIZER");
assert!(CandleGenerator::override_paths().is_none(), "no override when unset");
std::env::set_var("NORNIR_CANDLE_GGUF", "/tmp/synthetic.gguf");
assert!(
CandleGenerator::override_paths().is_none(),
"half-set override must not engage (would hide intent)"
);
std::env::set_var("NORNIR_CANDLE_TOKENIZER", "/tmp/tok.json");
let resolved = CandleGenerator::override_paths().expect("both set → override");
assert_eq!(resolved.0, PathBuf::from("/tmp/synthetic.gguf"));
assert_eq!(resolved.1, PathBuf::from("/tmp/tok.json"));
match prev_g {
Some(v) => std::env::set_var("NORNIR_CANDLE_GGUF", v),
None => std::env::remove_var("NORNIR_CANDLE_GGUF"),
}
match prev_t {
Some(v) => std::env::set_var("NORNIR_CANDLE_TOKENIZER", v),
None => std::env::remove_var("NORNIR_CANDLE_TOKENIZER"),
}
}
#[test]
#[ignore = "downloads a multi-GB GGUF model (real-data arm)"]
fn real_generation_round_trips() {
let gen = CandleGenerator::new("qwen2-0.5b").unwrap();
let req = GenRequest::new("Reply with the single word: pong").with_max_tokens(8);
let ans = gen.complete(&req).unwrap();
assert!(ans.tokens_in > 0);
assert!(ans.tokens_out > 0);
assert!(!ans.text.is_empty());
}
}