Skip to main content

candle_transformers/models/
quantized_phi3.rs

1//! Phi3 model implementation with quantization support.
2//!
3//! Phi3 is a language model intended for research purposes.
4//! This implementation provides quantization for reduced memory usage.
5//!
6//! Key characteristics:
7//! - Multi-head attention
8//! - RMSNorm for layer normalization
9//! - Rotary positional embeddings (RoPE)
10//! - Support for quantization
11//!
12//! References:
13//! - [Model Card](https://huggingface.co/microsoft/phi-3)
14//!
15
16use std::collections::HashMap;
17
18use candle::quantized::gguf_file;
19use candle::quantized::QTensor;
20use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
21use candle_nn::{kv_cache::KvCache, Embedding, RmsNorm};
22
23#[derive(Debug, Clone)]
24struct QLinear {
25    inner: candle::quantized::QMatMul,
26    span: tracing::Span,
27}
28
29impl QLinear {
30    fn new<R: std::io::Read + std::io::Seek>(
31        ct: &gguf_file::Content,
32        r: &mut R,
33        name: &str,
34        device: &Device,
35    ) -> Result<Self> {
36        let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
37        let w = ct.tensor(r, &format!("{name}.weight"), device)?;
38        let inner = candle::quantized::QMatMul::from_qtensor(w)?;
39        Ok(Self { inner, span })
40    }
41}
42
43impl Module for QLinear {
44    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
45        let _enter = self.span.enter();
46        self.inner.forward(xs)
47    }
48}
49
50#[derive(Debug, Clone)]
51struct Mlp {
52    ffn_up: QLinear,
53    ffn_down: QLinear,
54    i_size: usize,
55}
56
57impl Module for Mlp {
58    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
59        let up_states = xs.apply(&self.ffn_up)?;
60        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
61        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
62        let up_states = (up_states * gate.silu()?)?;
63        up_states.apply(&self.ffn_down)
64    }
65}
66
67fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
68    let w = w.dequantize(&w.device())?;
69    let rms = RmsNorm::new(w, eps);
70    Ok(rms)
71}
72
73#[derive(Debug, Clone)]
74struct LayerWeights {
75    attn_qkv: QLinear,
76    attn_output: QLinear,
77    attn_norm: RmsNorm,
78    ffn_norm: RmsNorm,
79    mlp: Mlp,
80    n_head: usize,
81    n_kv_head: usize,
82    head_dim: usize,
83    cos: Tensor,
84    sin: Tensor,
85    neg_inf: Tensor,
86    kv_cache: KvCache,
87    use_flash_attn: bool,
88    span_attn: tracing::Span,
89    span_rot: tracing::Span,
90}
91
92fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
93    let shape = mask.shape();
94    let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
95    Ok(m)
96}
97
98impl LayerWeights {
99    fn apply_rotary_emb(&self, xs: &Tensor, index_pos: usize) -> Result<Tensor> {
100        let _enter = self.span_rot.enter();
101        let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
102        let cos = self.cos.narrow(0, index_pos, seq_len)?;
103        let sin = self.sin.narrow(0, index_pos, seq_len)?;
104        candle_nn::rotary_emb::rope(&xs.contiguous()?, &cos, &sin)
105    }
106
107    fn forward_attn(
108        &mut self,
109        x: &Tensor,
110        mask: Option<&Tensor>,
111        index_pos: usize,
112    ) -> Result<Tensor> {
113        let _enter = self.span_attn.enter();
114        let (b_sz, seq_len, n_embd) = x.dims3()?;
115        let qkv = self.attn_qkv.forward(x)?;
116
117        let query_pos = self.n_head * self.head_dim;
118        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
119        let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
120        let v = qkv.narrow(
121            D::Minus1,
122            query_pos + self.n_kv_head * self.head_dim,
123            self.n_kv_head * self.head_dim,
124        )?;
125
126        let q = q
127            .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
128            .transpose(1, 2)?;
129        let k = k
130            .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
131            .transpose(1, 2)?;
132        let v = v
133            .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
134            .transpose(1, 2)?;
135
136        let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?;
137        let k = self.apply_rotary_emb(&k, index_pos)?;
138
139        if index_pos == 0 {
140            self.kv_cache.reset();
141        }
142        let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
143
144        let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
145        let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
146
147        let y = if self.use_flash_attn {
148            // flash-attn expects (b_sz, seq_len, nheads, head_dim)
149            let q = q.to_dtype(DType::BF16)?.transpose(1, 2)?;
150            let k = k.to_dtype(DType::BF16)?.transpose(1, 2)?;
151            let v = v.to_dtype(DType::BF16)?.transpose(1, 2)?;
152            let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
153            flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
154                .to_dtype(DType::F32)?
155                .transpose(1, 2)?
156        } else {
157            let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
158            let att = match mask {
159                None => att,
160                Some(mask) => {
161                    let mask = mask.broadcast_as(att.shape())?;
162                    masked_fill(&att, &mask, &self.neg_inf)?
163                }
164            };
165            let att = candle_nn::ops::softmax_last_dim(&att)?;
166            // Convert to contiguous as matmul doesn't support strided vs for now.
167            att.matmul(&v)?
168        };
169        let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
170        let y = self.attn_output.forward(&y)?;
171        Ok(y)
172    }
173}
174
175#[cfg(feature = "flash-attn")]
176fn flash_attn(
177    q: &Tensor,
178    k: &Tensor,
179    v: &Tensor,
180    softmax_scale: f32,
181    causal: bool,
182) -> Result<Tensor> {
183    candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
184}
185
186#[cfg(not(feature = "flash-attn"))]
187fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
188    unimplemented!("compile with '--features flash-attn'")
189}
190
191#[derive(Debug, Clone)]
192pub struct ModelWeights {
193    tok_embeddings: Embedding,
194    layers: Vec<LayerWeights>,
195    output_norm: RmsNorm,
196    output: QLinear,
197    masks: HashMap<(usize, usize), Tensor>,
198    span: tracing::Span,
199    span_output: tracing::Span,
200}
201
202fn precomput_freqs_cis(
203    head_dim: usize,
204    max_seq_len: usize,
205    freq_base: f32,
206    device: &Device,
207) -> Result<(Tensor, Tensor)> {
208    let theta: Vec<_> = (0..head_dim)
209        .step_by(2)
210        .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
211        .collect();
212    let theta = Tensor::new(theta.as_slice(), device)?;
213    let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
214        .to_dtype(DType::F32)?
215        .reshape((max_seq_len, 1))?
216        .matmul(&theta.reshape((1, theta.elem_count()))?)?;
217    let cos = idx_theta.cos()?;
218    let sin = idx_theta.sin()?;
219    Ok((cos, sin))
220}
221
222impl ModelWeights {
223    pub fn from_gguf<R: std::io::Seek + std::io::Read>(
224        use_flash_attn: bool,
225        ct: gguf_file::Content,
226        reader: &mut R,
227        device: &Device,
228    ) -> Result<Self> {
229        let md_get = |s: &str| match ct.metadata.get(s) {
230            None => candle::bail!("cannot find {s} in metadata"),
231            Some(v) => Ok(v),
232        };
233
234        // Parameter extraction from metadata.
235        let head_count = md_get("phi3.attention.head_count")?.to_u32()? as usize;
236        let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize;
237        let block_count = md_get("phi3.block_count")?.to_u32()? as usize;
238        let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize;
239        let max_seq_len = md_get("phi3.context_length")?.to_u32()? as usize;
240        let head_dim = embedding_length / head_count;
241        let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize;
242        let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize;
243        let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
244        let (cos, sin) = precomput_freqs_cis(rope_dim, max_seq_len, 10_000., device)?;
245        let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
246
247        let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
248        let tok_embeddings = tok_embeddings.dequantize(device)?;
249        let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?;
250        let output = QLinear::new(&ct, reader, "output", device)?;
251
252        let mut layers = Vec::with_capacity(block_count);
253        for layer_idx in 0..block_count {
254            let prefix = format!("blk.{layer_idx}");
255            let ffn_up = QLinear::new(&ct, reader, &format!("{prefix}.ffn_up"), device)?;
256            let ffn_down = QLinear::new(&ct, reader, &format!("{prefix}.ffn_down"), device)?;
257            let mlp = Mlp {
258                ffn_up,
259                ffn_down,
260                i_size,
261            };
262            let attn_norm = rms_norm(
263                ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?,
264                rms_eps,
265            )?;
266            let ffn_norm = rms_norm(
267                ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?,
268                rms_eps,
269            )?;
270            let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
271            let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
272            let kv_cache = KvCache::new(2, max_seq_len);
273            layers.push(LayerWeights {
274                attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
275                attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
276                attn_norm,
277                ffn_norm,
278                mlp,
279                n_head: head_count,
280                n_kv_head: head_count_kv,
281                head_dim,
282                cos: cos.clone(),
283                sin: sin.clone(),
284                neg_inf: neg_inf.clone(),
285                kv_cache,
286                use_flash_attn,
287                span_attn,
288                span_rot,
289            })
290        }
291        let span = tracing::span!(tracing::Level::TRACE, "model");
292        let span_output = tracing::span!(tracing::Level::TRACE, "output");
293        Ok(Self {
294            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
295            layers,
296            output_norm,
297            output,
298            masks: HashMap::new(),
299            span,
300            span_output,
301        })
302    }
303
304    fn mask(&mut self, seq_len: usize, index_pos: usize, device: &Device) -> Result<Tensor> {
305        let kv_len = index_pos + seq_len;
306        if let Some(mask) = self.masks.get(&(seq_len, kv_len)) {
307            Ok(mask.clone())
308        } else {
309            let mask = crate::utils::build_causal_mask(seq_len, index_pos, device)?;
310            self.masks.insert((seq_len, kv_len), mask.clone());
311            Ok(mask)
312        }
313    }
314
315    pub fn forward(&mut self, xs: &Tensor, index_pos: usize) -> Result<Tensor> {
316        let (_b_sz, seq_len) = xs.dims2()?;
317        let mask = if seq_len == 1 {
318            None
319        } else {
320            Some(self.mask(seq_len, index_pos, xs.device())?)
321        };
322        let _enter = self.span.enter();
323        let mut xs = self.tok_embeddings.forward(xs)?;
324        for layer in self.layers.iter_mut() {
325            let residual = &xs;
326            let ys = xs.apply(&layer.attn_norm)?;
327            let ys = layer.forward_attn(&ys, mask.as_ref(), index_pos)?;
328            let ys = (ys + residual)?;
329            let residual = &ys;
330            let ys = ys.apply(&layer.ffn_norm)?;
331            let ys = layer.mlp.forward(&ys)?;
332            xs = (ys + residual)?
333        }
334        let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?;
335        let _enter = self.span_output.enter();
336        self.output.forward(&xs)
337    }
338}