use std::io::{Read, Seek};
use std::path::Path;
use std::sync::Arc;
use candle_core::quantized::gguf_file;
use candle_core::{DType, Device as CandleDevice, Tensor};
type Result<T> = candle_core::Result<T>;
use candle_nn::{kv_cache::ConcatKvCache, Activation, Embedding, Module};
use candle_transformers::models::quantized_qwen3::{Gguf, RotaryEmbedding};
use crate::backend::candle::to_candle_device_pub;
use crate::{Device, InferenceError};
use candle_transformers::models::with_tracing::QMatMul;
use candle_transformers::quantized_nn::RmsNorm;
use candle_transformers::utils::repeat_kv;
#[derive(Debug, Clone)]
struct MlpWeights {
gate_proj: QMatMul,
up_proj: QMatMul,
down_proj: QMatMul,
act_fn: Activation,
}
impl MlpWeights {
fn new<R: Read + Seek>(gg: &mut Gguf<R>, prefix: &str) -> Result<Self> {
Ok(Self {
gate_proj: gg.qmatmul(&format!("{prefix}.ffn_gate.weight"))?,
up_proj: gg.qmatmul(&format!("{prefix}.ffn_up.weight"))?,
down_proj: gg.qmatmul(&format!("{prefix}.ffn_down.weight"))?,
act_fn: Activation::Silu,
})
}
}
impl Module for MlpWeights {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let gate = self.gate_proj.forward(x)?.apply(&self.act_fn)?;
let up = self.up_proj.forward(x)?;
self.down_proj.forward(&(gate * up)?)
}
}
#[derive(Debug, Clone)]
struct AttentionWeights {
q_proj: QMatMul,
k_proj: QMatMul,
v_proj: QMatMul,
o_proj: QMatMul,
q_norm: RmsNorm,
k_norm: RmsNorm,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: ConcatKvCache,
}
impl AttentionWeights {
fn new<R: Read + Seek>(
gg: &mut Gguf<R>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rms_norm_eps: f64,
rotary_emb: Arc<RotaryEmbedding>,
prefix: &str,
) -> Result<Self> {
Ok(Self {
q_proj: gg.qmatmul(&format!("{prefix}.attn_q.weight"))?,
k_proj: gg.qmatmul(&format!("{prefix}.attn_k.weight"))?,
v_proj: gg.qmatmul(&format!("{prefix}.attn_v.weight"))?,
o_proj: gg.qmatmul(&format!("{prefix}.attn_output.weight"))?,
q_norm: gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?,
k_norm: gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?,
num_heads,
num_kv_heads,
num_kv_groups: num_heads / num_kv_heads,
head_dim,
rotary_emb,
kv_cache: ConcatKvCache::new(2),
})
}
fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
let (b, l, _) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let q = q
.reshape((b, l, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let q = self.q_norm.forward(&q.flatten(0, 2)?)?.reshape((
b,
self.num_heads,
l,
self.head_dim,
))?;
let k = self.k_norm.forward(&k.flatten(0, 2)?)?.reshape((
b,
self.num_kv_heads,
l,
self.head_dim,
))?;
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
let (k, v) = self.kv_cache.append(&k, &v)?;
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
let scale = 1.0 / (self.head_dim as f64).sqrt();
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(m) = attn_mask {
let mask = if m.dtype() != scores.dtype() {
m.to_dtype(scores.dtype())?
} else {
m.clone()
};
scores = scores.broadcast_add(&mask)?;
}
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
let ctx = probs.matmul(&v)?;
let out = ctx
.transpose(1, 2)?
.reshape((b, l, self.num_heads * self.head_dim))?;
self.o_proj.forward(&out)
}
fn clear_kv_cache(&mut self) {
self.kv_cache.reset();
}
}
#[derive(Debug, Clone)]
struct LayerWeights {
self_attn: AttentionWeights,
mlp: MlpWeights,
ln1: RmsNorm,
ln2: RmsNorm,
}
impl LayerWeights {
fn new<R: Read + Seek>(
gg: &mut Gguf<R>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rms_norm_eps: f64,
rotary: Arc<RotaryEmbedding>,
layer_idx: usize,
) -> Result<Self> {
let prefix = format!("blk.{layer_idx}");
Ok(Self {
ln1: gg.rms_norm(&format!("{prefix}.attn_norm.weight"), rms_norm_eps)?,
ln2: gg.rms_norm(&format!("{prefix}.ffn_norm.weight"), rms_norm_eps)?,
self_attn: AttentionWeights::new(
gg,
num_heads,
num_kv_heads,
head_dim,
rms_norm_eps,
rotary,
&prefix,
)?,
mlp: MlpWeights::new(gg, &prefix)?,
})
}
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
let h = self
.self_attn
.forward(&self.ln1.forward(x)?, mask, offset)?;
let x = (x + h)?;
let h2 = self.ln2.forward(&x)?;
let h2 = h2.apply(&self.mlp)?;
x + h2
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache();
}
}
#[derive(Debug, Clone)]
pub struct EmbeddingModelWeights {
embed_tokens: Embedding,
layers: Vec<LayerWeights>,
norm: RmsNorm,
device: CandleDevice,
dtype: DType,
hidden_size: usize,
}
impl EmbeddingModelWeights {
pub fn from_gguf<R: Read + Seek>(
ct: gguf_file::Content,
reader: &mut R,
device: &CandleDevice,
) -> Result<Self> {
let mut gg = Gguf::new(ct, reader, device.clone());
let md_get = |s: &str| match gg.metadata().get(s) {
None => candle_core::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
let num_attention_heads = md_get("qwen3.attention.head_count")?.to_u32()? as usize;
let num_kv_heads = md_get("qwen3.attention.head_count_kv")?.to_u32()? as usize;
let head_dim = md_get("qwen3.attention.key_length")?.to_u32()? as usize;
let num_layers = md_get("qwen3.block_count")?.to_u32()? as usize;
let hidden_size = md_get("qwen3.embedding_length")?.to_u32()? as usize;
let max_position_embeddings = md_get("qwen3.context_length")?.to_u32()? as usize;
let rms_norm_eps = md_get("qwen3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let rope_freq_base = md_get("qwen3.rope.freq_base")?.to_f32()? as f64;
let dtype = match gg.metadata().get("general.dtype") {
Some(v) => match v.to_u32() {
Ok(0) => DType::F32,
Ok(1) => DType::F16,
_ => DType::F16,
},
None => DType::F16,
};
let embed_tensor = gg.tensor("token_embd.weight")?;
let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size);
let rotary = Arc::new(RotaryEmbedding::new(
dtype,
head_dim,
max_position_embeddings,
rope_freq_base,
device,
)?);
let mut layers = Vec::with_capacity(num_layers);
for i in 0..num_layers {
layers.push(LayerWeights::new(
&mut gg,
num_attention_heads,
num_kv_heads,
head_dim,
rms_norm_eps,
rotary.clone(),
i,
)?);
}
let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?;
Ok(Self {
embed_tokens,
layers,
norm,
device: device.clone(),
dtype,
hidden_size,
})
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
let (b, l) = input.dims2()?;
assert!(
b == 1,
"EmbeddingModelWeights only supports batch_size=1, got {b}"
);
let mut h = self.embed_tokens.forward(input)?;
let causal_mask = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal_mask.as_ref(), offset)?;
}
let h = self.norm.forward(&h)?;
h.narrow(1, l - 1, 1)?.squeeze(1)?.squeeze(0)
}
pub fn clear_kv_cache(&mut self) {
for layer in &mut self.layers {
layer.clear_kv_cache();
}
}
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
}
pub struct EmbeddingBackend {
pub model: EmbeddingModelWeights,
pub tokenizer: tokenizers::Tokenizer,
pub device: CandleDevice,
}
impl EmbeddingBackend {
pub fn load(model_dir: &Path, device: Device) -> std::result::Result<Self, InferenceError> {
let candle_device = to_candle_device_pub(device)?;
let model_path = model_dir.join("model.gguf");
let mut file = std::fs::File::open(&model_path)
.map_err(|e| InferenceError::InferenceFailed(format!("open embedding model: {e}")))?;
let gguf = gguf_file::Content::read(&mut file)
.map_err(|e| InferenceError::InferenceFailed(format!("read gguf: {e}")))?;
let model = EmbeddingModelWeights::from_gguf(gguf, &mut file, &candle_device)
.map_err(|e| InferenceError::InferenceFailed(format!("load embedding weights: {e}")))?;
let tokenizer_path = model_dir.join("tokenizer.json");
let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
.map_err(|e| InferenceError::TokenizationError(format!("load tokenizer: {e}")))?;
Ok(Self {
model,
tokenizer,
device: candle_device,
})
}
pub fn encode(&self, text: &str) -> std::result::Result<Vec<u32>, InferenceError> {
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
Ok(encoding.get_ids().to_vec())
}
pub fn embed_one(&mut self, text: &str) -> std::result::Result<Vec<f32>, InferenceError> {
self.model.clear_kv_cache();
let tokens = self.encode(text)?;
if tokens.is_empty() {
return Ok(vec![0.0; self.model.hidden_size()]);
}
let input = Tensor::new(&tokens[..], &self.device)
.map_err(|e| InferenceError::InferenceFailed(format!("tensor: {e}")))?
.unsqueeze(0)
.map_err(|e| InferenceError::InferenceFailed(format!("unsqueeze: {e}")))?;
let hidden = self
.model
.forward(&input, 0)
.map_err(|e| InferenceError::InferenceFailed(format!("forward: {e}")))?;
let embedding: Vec<f32> = hidden
.to_dtype(candle_core::DType::F32)
.map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
.to_vec1()
.map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
Ok(car_ir::linalg::l2_normalize(&embedding))
}
pub fn embed_batch(
&mut self,
texts: &[String],
) -> std::result::Result<Vec<Vec<f32>>, InferenceError> {
texts.iter().map(|t| self.embed_one(t)).collect()
}
pub fn embed_query(
&mut self,
text: &str,
instruction: &str,
) -> std::result::Result<Vec<f32>, InferenceError> {
let formatted = format!("Instruct: {instruction}\nQuery: {text}");
self.embed_one(&formatted)
}
}