1use std::collections::HashMap;
19
20use candle::quantized::gguf_file;
21use candle::quantized::QTensor;
22use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
23use candle_nn::{Embedding, LayerNorm};
24
25pub const MAX_SEQ_LEN: usize = 4096;
26
27#[derive(Debug, Clone)]
28struct QLinear {
29 inner: candle::quantized::QMatMul,
30 bias: Tensor,
31 span: tracing::Span,
32}
33
34impl QLinear {
35 fn new<R: std::io::Read + std::io::Seek>(
36 ct: &gguf_file::Content,
37 r: &mut R,
38 name: &str,
39 device: &Device,
40 ) -> Result<Self> {
41 let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
42 let w = ct.tensor(r, &format!("{name}.weight"), device)?;
43 let b = ct.tensor(r, &format!("{name}.bias"), device)?;
44 let inner = candle::quantized::QMatMul::from_qtensor(w)?;
45 let bias = b.dequantize(device)?;
46 Ok(Self { inner, bias, span })
47 }
48}
49
50impl Module for QLinear {
51 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
52 let _enter = self.span.enter();
53 self.inner.forward(xs)?.broadcast_add(&self.bias)
54 }
55}
56
57#[derive(Debug, Clone)]
58struct Mlp {
59 ffn_up: QLinear,
60 ffn_down: QLinear,
61}
62
63impl Module for Mlp {
64 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
65 xs.apply(&self.ffn_up)?.gelu()?.apply(&self.ffn_down)
66 }
67}
68
69#[derive(Debug, Clone)]
70struct LayerWeights {
71 attn_qkv: QLinear,
72 attn_output: QLinear,
73 attn_norm: LayerNorm,
74 mlp: Mlp,
75 n_head: usize,
76 n_kv_head: usize,
77 head_dim: usize,
78 cos: Tensor,
79 sin: Tensor,
80 rope_dim: usize,
81 neg_inf: Tensor,
82 kv_cache: Option<(Tensor, Tensor)>,
83 span_attn: tracing::Span,
84 span_rot: tracing::Span,
85}
86
87fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
88 let shape = mask.shape();
89 let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
90 Ok(m)
91}
92
93impl LayerWeights {
94 fn apply_rotary_emb(&self, xs: &Tensor, index_pos: usize) -> Result<Tensor> {
95 let _enter = self.span_rot.enter();
96 let (_b_sz, _n_head, seq_len, _n_embd) = xs.dims4()?;
97 let xs_rot = xs.i((.., .., .., ..self.rope_dim))?;
98 let xs_pass = xs.i((.., .., .., self.rope_dim..))?;
99 let cos = self.cos.narrow(0, index_pos, seq_len)?;
100 let sin = self.sin.narrow(0, index_pos, seq_len)?;
101 let xs_rot = candle_nn::rotary_emb::rope(&xs_rot.contiguous()?, &cos, &sin)?;
102 Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
103 }
104
105 fn forward_attn(
106 &mut self,
107 x: &Tensor,
108 mask: Option<&Tensor>,
109 index_pos: usize,
110 ) -> Result<Tensor> {
111 let _enter = self.span_attn.enter();
112 let (b_sz, seq_len, n_embd) = x.dims3()?;
113 let qkv =
114 self.attn_qkv
115 .forward(x)?
116 .reshape((b_sz, seq_len, 3, self.n_head, self.head_dim))?;
117
118 let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
119 let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
120 let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
121 let v = v.contiguous()?;
125
126 let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?;
127 let k = self.apply_rotary_emb(&k, index_pos)?;
128
129 let (k, v) = match &self.kv_cache {
130 None => (k.contiguous()?, v.contiguous()?),
131 Some((k_cache, v_cache)) => {
132 if index_pos == 0 {
133 (k.contiguous()?, v.contiguous()?)
134 } else {
135 let k = Tensor::cat(&[k_cache, &k], 2)?;
136 let v = Tensor::cat(&[v_cache, &v], 2)?;
137 (k.contiguous()?, v.contiguous()?)
138 }
139 }
140 };
141 self.kv_cache = Some((k.clone(), v.clone()));
142
143 let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
144 let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
145
146 let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
147 let att = match mask {
148 None => att,
149 Some(mask) => {
150 let mask = mask.broadcast_as(att.shape())?;
151 masked_fill(&att, &mask, &self.neg_inf)?
152 }
153 };
154 let att = candle_nn::ops::softmax_last_dim(&att)?;
155 let y = att.matmul(&v.contiguous()?)?;
157 let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
158 let y = self.attn_output.forward(&y)?;
159 Ok(y)
160 }
161}
162
163#[derive(Debug, Clone)]
164pub struct ModelWeights {
165 tok_embeddings: Embedding,
166 layers: Vec<LayerWeights>,
167 output_norm: LayerNorm,
168 output: QLinear,
169 masks: HashMap<(usize, usize), Tensor>,
170 span: tracing::Span,
171 span_output: tracing::Span,
172}
173
174fn precomput_freqs_cis(
175 head_dim: usize,
176 freq_base: f32,
177 device: &Device,
178) -> Result<(Tensor, Tensor)> {
179 let theta: Vec<_> = (0..head_dim)
180 .step_by(2)
181 .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
182 .collect();
183 let theta = Tensor::new(theta.as_slice(), device)?;
184 let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
185 .to_dtype(DType::F32)?
186 .reshape((MAX_SEQ_LEN, 1))?
187 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
188 let cos = idx_theta.cos()?;
189 let sin = idx_theta.sin()?;
190 Ok((cos, sin))
191}
192
193fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<LayerNorm> {
194 let w = w.dequantize(&w.device())?;
195 let b = b.dequantize(&b.device())?;
196 let ln = LayerNorm::new(w, b, eps);
197 Ok(ln)
198}
199
200impl ModelWeights {
201 pub fn from_gguf<R: std::io::Seek + std::io::Read>(
202 ct: gguf_file::Content,
203 reader: &mut R,
204 device: &Device,
205 ) -> Result<Self> {
206 let md_get = |s: &str| match ct.metadata.get(s) {
207 None => candle::bail!("cannot find {s} in metadata"),
208 Some(v) => Ok(v),
209 };
210
211 let head_count = md_get("phi2.attention.head_count")?.to_u32()? as usize;
213 let head_count_kv = md_get("phi2.attention.head_count_kv")?.to_u32()? as usize;
214 let block_count = md_get("phi2.block_count")?.to_u32()? as usize;
215 let embedding_length = md_get("phi2.embedding_length")?.to_u32()? as usize;
216 let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize;
217 let ln_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64;
218 let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?;
219 let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
220
221 let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
222 let tok_embeddings = tok_embeddings.dequantize(device)?;
223 let output_norm = layer_norm(
224 ct.tensor(reader, "output_norm.weight", device)?,
225 ct.tensor(reader, "output_norm.bias", device)?,
226 ln_eps,
227 )?;
228 let output = QLinear::new(&ct, reader, "output", device)?;
229 let mut layers = Vec::with_capacity(block_count);
230 for layer_idx in 0..block_count {
231 let prefix = format!("blk.{layer_idx}");
232 let ffn_up = QLinear::new(&ct, reader, &format!("{prefix}.ffn_up"), device)?;
233 let ffn_down = QLinear::new(&ct, reader, &format!("{prefix}.ffn_down"), device)?;
234 let mlp = Mlp { ffn_up, ffn_down };
235 let attn_norm = layer_norm(
236 ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?,
237 ct.tensor(reader, &format!("{prefix}.attn_norm.bias"), device)?,
238 ln_eps,
239 )?;
240 let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
241 let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
242 layers.push(LayerWeights {
243 attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
244 attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
245 attn_norm,
246 mlp,
247 n_head: head_count,
248 n_kv_head: head_count_kv,
249 head_dim: embedding_length / head_count,
250 cos: cos.clone(),
251 sin: sin.clone(),
252 rope_dim,
253 neg_inf: neg_inf.clone(),
254 kv_cache: None,
255 span_attn,
256 span_rot,
257 })
258 }
259 let span = tracing::span!(tracing::Level::TRACE, "model");
260 let span_output = tracing::span!(tracing::Level::TRACE, "output");
261 Ok(Self {
262 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
263 layers,
264 output_norm,
265 output,
266 masks: HashMap::new(),
267 span,
268 span_output,
269 })
270 }
271
272 fn mask(&mut self, seq_len: usize, index_pos: usize, device: &Device) -> Result<Tensor> {
273 let kv_len = index_pos + seq_len;
274 if let Some(mask) = self.masks.get(&(seq_len, kv_len)) {
275 Ok(mask.clone())
276 } else {
277 let mask = crate::utils::build_causal_mask(seq_len, index_pos, device)?;
278 self.masks.insert((seq_len, kv_len), mask.clone());
279 Ok(mask)
280 }
281 }
282
283 pub fn forward(&mut self, xs: &Tensor, index_pos: usize) -> Result<Tensor> {
284 let (_b_sz, seq_len) = xs.dims2()?;
285 let mask = if seq_len == 1 {
286 None
287 } else {
288 Some(self.mask(seq_len, index_pos, xs.device())?)
289 };
290 let _enter = self.span.enter();
291 let mut xs = self.tok_embeddings.forward(xs)?;
292 for layer in self.layers.iter_mut() {
293 let residual = &xs;
294 let xs_norm = xs.apply(&layer.attn_norm)?;
295 let attn_outputs = layer.forward_attn(&xs_norm, mask.as_ref(), index_pos)?;
296 let feed_forward_hidden_states = layer.mlp.forward(&xs_norm)?;
297 xs = (attn_outputs + feed_forward_hidden_states + residual)?
298 }
299 let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?;
300 let _enter = self.span_output.enter();
301 self.output.forward(&xs)
302 }
303}