use std::path::Path;
use candle_core::quantized::gguf_file;
use candle_core::{Device as CandleDevice, Tensor};
use candle_transformers::models::quantized_qwen3 as qwen;
use tokenizers::Tokenizer;
use super::moe::Qwen3MoeModel;
use crate::{Device, InferenceError};
enum QwenModel {
Standard(qwen::ModelWeights),
Moe(Qwen3MoeModel),
}
pub struct CandleBackend {
model: QwenModel,
pub tokenizer: Tokenizer,
pub device: CandleDevice,
}
impl CandleBackend {
pub fn load(model_dir: &Path, device: Device) -> Result<Self, InferenceError> {
let candle_device = to_candle_device(device)?;
let model_path = model_dir.join("model.gguf");
let mut file = std::fs::File::open(&model_path)
.map_err(|e| InferenceError::InferenceFailed(format!("open model: {e}")))?;
let gguf = gguf_file::Content::read(&mut file)
.map_err(|e| InferenceError::InferenceFailed(format!("read gguf: {e}")))?;
let arch = gguf
.metadata
.get("general.architecture")
.and_then(|v| v.to_string().ok())
.map(|s| s.to_string())
.unwrap_or_default();
let model = if arch == "qwen3moe" {
let moe = Qwen3MoeModel::from_gguf(gguf, &mut file, &candle_device)
.map_err(|e| InferenceError::InferenceFailed(format!("load moe weights: {e}")))?;
QwenModel::Moe(moe)
} else {
let std = qwen::ModelWeights::from_gguf(gguf, &mut file, &candle_device)
.map_err(|e| InferenceError::InferenceFailed(format!("load weights: {e}")))?;
QwenModel::Standard(std)
};
let tokenizer_path = model_dir.join("tokenizer.json");
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| InferenceError::TokenizationError(format!("load tokenizer: {e}")))?;
Ok(Self {
model,
tokenizer,
device: candle_device,
})
}
pub fn clear_kv_cache(&mut self) {
match &mut self.model {
QwenModel::Standard(m) => m.clear_kv_cache(),
QwenModel::Moe(m) => m.clear_kv_cache(),
}
}
pub fn forward(&mut self, tokens: &[u32], pos: usize) -> Result<Tensor, InferenceError> {
let input = Tensor::new(tokens, &self.device)
.map_err(|e| InferenceError::InferenceFailed(format!("tensor: {e}")))?
.unsqueeze(0)
.map_err(|e| InferenceError::InferenceFailed(format!("unsqueeze: {e}")))?;
match &mut self.model {
QwenModel::Standard(m) => m
.forward(&input, pos)
.map_err(|e| InferenceError::InferenceFailed(format!("forward: {e}"))),
QwenModel::Moe(m) => m
.forward(&input, pos)
.map_err(|e| InferenceError::InferenceFailed(format!("forward moe: {e}"))),
}
}
pub fn encode(&self, text: &str) -> Result<Vec<u32>, InferenceError> {
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
Ok(encoding.get_ids().to_vec())
}
pub fn decode(&self, tokens: &[u32]) -> Result<String, InferenceError> {
self.tokenizer
.decode(tokens, true)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))
}
pub fn tokenize_raw(&self, text: &str) -> Result<Vec<u32>, InferenceError> {
let encoding = self
.tokenizer
.encode(text, false)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
Ok(encoding.get_ids().to_vec())
}
pub fn detokenize_raw(&self, tokens: &[u32]) -> Result<String, InferenceError> {
self.tokenizer
.decode(tokens, false)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))
}
pub fn eos_token_id(&self) -> Option<u32> {
self.tokenizer
.token_to_id("<|endoftext|>")
.or_else(|| self.tokenizer.token_to_id("</s>"))
}
pub fn token_id(&self, token: &str) -> Option<u32> {
self.tokenizer.token_to_id(token)
}
pub fn context_length(&self) -> Option<usize> {
Some(32768)
}
}
pub fn to_candle_device_pub(device: Device) -> Result<CandleDevice, InferenceError> {
to_candle_device(device)
}
fn to_candle_device(device: Device) -> Result<CandleDevice, InferenceError> {
match device {
Device::Cpu => Ok(CandleDevice::Cpu),
Device::Metal => {
#[cfg(feature = "metal")]
{
Ok(CandleDevice::new_metal(0)
.map_err(|e| InferenceError::DeviceError(format!("metal: {e}")))?)
}
#[cfg(not(feature = "metal"))]
{
Err(InferenceError::DeviceError(
"metal feature not enabled".to_string(),
))
}
}
Device::Cuda(ordinal) => {
#[cfg(feature = "cuda")]
{
Ok(CandleDevice::new_cuda(ordinal)
.map_err(|e| InferenceError::DeviceError(format!("cuda({ordinal}): {e}")))?)
}
#[cfg(not(feature = "cuda"))]
{
let _ = ordinal;
Err(InferenceError::DeviceError(
"cuda feature not enabled".to_string(),
))
}
}
}
}