Skip to main content

candle_transformers/models/
phi.rs

1//! Microsoft Phi model implementation
2//!
3//! The Phi series are decoder-only transformers designed for code and language tasks.
4//!
5//! Key characteristics:
6//! - Decoder-only transformer architecture
7//! - RoPE embeddings
8//! - Layer normalization
9//! - QK normalization
10//!
11//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/radames/Candle-phi1-phi2-wasm-demo)
12//! - 🤗 [HF Link](https://huggingface.co/microsoft/phi-2)
13//!
14
15use crate::models::with_tracing::{layer_norm, linear, Embedding, LayerNorm, Linear};
16/// Phi model.
17/// https://huggingface.co/microsoft/phi-2
18/// There is an alternative implementation of the phi model in mixformers.rs.
19/// This corresponds to the model update made with the following commit:
20/// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869
21use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
22use candle_nn::{Activation, VarBuilder};
23use serde::Deserialize;
24
25// https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py
26#[derive(Debug, Clone, PartialEq, Deserialize)]
27pub struct Config {
28    pub(crate) vocab_size: usize,
29    pub(crate) hidden_size: usize,
30    pub(crate) intermediate_size: usize,
31    pub(crate) num_hidden_layers: usize,
32    pub(crate) num_attention_heads: usize,
33    pub(crate) num_key_value_heads: Option<usize>,
34    pub(crate) hidden_act: Activation,
35    pub(crate) max_position_embeddings: usize,
36    pub(crate) layer_norm_eps: f64,
37    pub(crate) tie_word_embeddings: bool,
38    pub(crate) rope_theta: f32,
39    pub(crate) partial_rotary_factor: f64,
40    pub(crate) qk_layernorm: bool,
41}
42
43impl Config {
44    fn num_key_value_heads(&self) -> usize {
45        self.num_key_value_heads.unwrap_or(self.num_attention_heads)
46    }
47
48    fn head_dim(&self) -> usize {
49        self.hidden_size / self.num_attention_heads
50    }
51}
52
53#[derive(Debug, Clone)]
54struct RotaryEmbedding {
55    dim: usize,
56    sin: Tensor,
57    cos: Tensor,
58}
59
60impl RotaryEmbedding {
61    fn new(cfg: &Config, dev: &Device) -> Result<Self> {
62        let dim = (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize;
63        let inv_freq: Vec<_> = (0..dim)
64            .step_by(2)
65            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
66            .collect();
67        let inv_freq_len = inv_freq.len();
68        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
69        let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
70            .to_dtype(DType::F32)?
71            .reshape((cfg.max_position_embeddings, 1))?;
72        let freqs = t.matmul(&inv_freq)?;
73        Ok(Self {
74            dim,
75            sin: freqs.sin()?,
76            cos: freqs.cos()?,
77        })
78    }
79
80    fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
81        let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
82        let xs_rot = xs.i((.., .., .., ..self.dim))?.contiguous()?;
83        let xs_pass = xs.i((.., .., .., self.dim..))?;
84        let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
85        let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
86        let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, &c, &s)?;
87        Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
88    }
89}
90
91#[derive(Debug, Clone)]
92#[allow(clippy::upper_case_acronyms)]
93struct MLP {
94    fc1: Linear,
95    fc2: Linear,
96    act: Activation,
97}
98
99impl MLP {
100    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
101        let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?;
102        let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
103        Ok(Self {
104            fc1,
105            fc2,
106            // This does not match the mixformers implementation where Gelu is used rather than
107            // GeluNew.
108            act: cfg.hidden_act,
109        })
110    }
111}
112
113impl Module for MLP {
114    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
115        xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
116    }
117}
118
119#[derive(Clone)]
120struct Attention {
121    q_proj: Linear,
122    k_proj: Linear,
123    v_proj: Linear,
124    dense: Linear,
125    kv_cache: Option<(Tensor, Tensor)>,
126    q_layernorm: Option<LayerNorm>,
127    k_layernorm: Option<LayerNorm>,
128    rotary_emb: RotaryEmbedding,
129    softmax_scale: f64,
130    num_heads: usize,
131    num_kv_heads: usize,
132    head_dim: usize,
133    span: tracing::Span,
134}
135
136fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
137    let mask: Vec<_> = (0..size)
138        .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
139        .collect();
140    Tensor::from_slice(&mask, (size, size), device)
141}
142
143fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
144    let shape = mask.shape();
145    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
146    let m = mask.where_cond(&on_true, on_false)?;
147    Ok(m)
148}
149
150impl Attention {
151    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
152        let num_heads = cfg.num_attention_heads;
153        let num_kv_heads = cfg.num_key_value_heads();
154        let head_dim = cfg.head_dim();
155        let q_proj = linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?;
156        let k_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?;
157        let v_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?;
158        let dense = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("dense"))?;
159        // Alternative rope scalings are not supported.
160        let rotary_emb = RotaryEmbedding::new(cfg, vb.device())?;
161        let (q_layernorm, k_layernorm) = if cfg.qk_layernorm {
162            let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?;
163            let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?;
164            (Some(q_layernorm), Some(k_layernorm))
165        } else {
166            (None, None)
167        };
168        let softmax_scale = 1f64 / (head_dim as f64).sqrt();
169        Ok(Self {
170            q_proj,
171            k_proj,
172            v_proj,
173            dense,
174            kv_cache: None,
175            q_layernorm,
176            k_layernorm,
177            rotary_emb,
178            softmax_scale,
179            num_heads,
180            num_kv_heads,
181            head_dim,
182            span: tracing::span!(tracing::Level::TRACE, "attention"),
183        })
184    }
185
186    fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
187        crate::utils::repeat_kv(xs, self.num_heads / self.num_kv_heads)
188    }
189
190    fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
191        let _enter = self.span.enter();
192        let (b_size, seq_len, _n_embd) = xs.dims3()?;
193        let query_states = self.q_proj.forward(xs)?;
194        let key_states = self.k_proj.forward(xs)?;
195        let value_states = self.v_proj.forward(xs)?;
196
197        let query_states = match &self.q_layernorm {
198            None => query_states,
199            Some(ln) => query_states.apply(ln)?,
200        };
201        let key_states = match &self.k_layernorm {
202            None => key_states,
203            Some(ln) => key_states.apply(ln)?,
204        };
205
206        let query_states = query_states
207            .reshape((b_size, seq_len, self.num_heads, self.head_dim))?
208            .transpose(1, 2)?;
209        let key_states = key_states
210            .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
211            .transpose(1, 2)?;
212        let value_states = value_states
213            .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
214            .transpose(1, 2)?;
215
216        // Rotary embeddings.
217        let seqlen_offset = match &self.kv_cache {
218            None => 0,
219            Some((prev_k, _)) => prev_k.dim(2)?,
220        };
221        let query_states = self
222            .rotary_emb
223            .apply_rotary_emb(&query_states, seqlen_offset)?;
224        let key_states = self
225            .rotary_emb
226            .apply_rotary_emb(&key_states, seqlen_offset)?;
227
228        // KV cache.
229        let (key_states, value_states) = match &self.kv_cache {
230            None => (key_states, value_states),
231            Some((prev_k, prev_v)) => {
232                let k = Tensor::cat(&[prev_k, &key_states], 2)?;
233                let v = Tensor::cat(&[prev_v, &value_states], 2)?;
234                (k, v)
235            }
236        };
237        self.kv_cache = Some((key_states.clone(), value_states.clone()));
238
239        // Repeat kv.
240        let key_states = self.repeat_kv(key_states)?.contiguous()?;
241        let value_states = self.repeat_kv(value_states)?.contiguous()?;
242
243        let attn_weights = (query_states
244            .to_dtype(DType::F32)?
245            .contiguous()?
246            .matmul(&key_states.to_dtype(DType::F32)?.t()?)?
247            * self.softmax_scale)?;
248        let attn_weights = match mask {
249            None => attn_weights,
250            Some(mask) => masked_fill(
251                &attn_weights,
252                &mask.broadcast_left((b_size, self.num_heads))?,
253                f32::NEG_INFINITY,
254            )?,
255        };
256        let attn_weights =
257            candle_nn::ops::softmax_last_dim(&attn_weights)?.to_dtype(value_states.dtype())?;
258        let attn_output = attn_weights.matmul(&value_states)?;
259        let attn_output = attn_output
260            .transpose(1, 2)?
261            .reshape((b_size, seq_len, ()))?;
262        attn_output.apply(&self.dense)
263    }
264
265    fn clear_kv_cache(&mut self) {
266        self.kv_cache = None
267    }
268}
269
270#[derive(Clone)]
271struct DecoderLayer {
272    self_attn: Attention,
273    mlp: MLP,
274    input_layernorm: LayerNorm,
275    span: tracing::Span,
276}
277
278impl DecoderLayer {
279    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
280        let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
281        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
282        let input_layernorm = layer_norm(
283            cfg.hidden_size,
284            cfg.layer_norm_eps,
285            vb.pp("input_layernorm"),
286        )?;
287        Ok(Self {
288            self_attn,
289            mlp,
290            input_layernorm,
291            span: tracing::span!(tracing::Level::TRACE, "block"),
292        })
293    }
294
295    fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
296        let _enter = self.span.enter();
297        let residual = xs;
298        let xs = xs.apply(&self.input_layernorm)?;
299        let attn_outputs = self.self_attn.forward(&xs, mask)?;
300        let feed_forward_hidden_states = self.mlp.forward(&xs)?;
301        attn_outputs + feed_forward_hidden_states + residual
302    }
303
304    fn clear_kv_cache(&mut self) {
305        self.self_attn.clear_kv_cache()
306    }
307}
308
309#[derive(Clone)]
310pub struct Model {
311    embed_tokens: Embedding,
312    layers: Vec<DecoderLayer>,
313    final_layernorm: LayerNorm,
314    lm_head: Linear,
315    span: tracing::Span,
316}
317
318impl Model {
319    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
320        let vb_m = vb.pp("model");
321        let embed_tokens =
322            Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
323        let final_layernorm = layer_norm(
324            cfg.hidden_size,
325            cfg.layer_norm_eps,
326            vb_m.pp("final_layernorm"),
327        )?;
328        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
329        let vb_m = vb_m.pp("layers");
330        for layer_idx in 0..cfg.num_hidden_layers {
331            let layer = DecoderLayer::new(cfg, vb_m.pp(layer_idx))?;
332            layers.push(layer)
333        }
334        let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
335        Ok(Self {
336            embed_tokens,
337            layers,
338            final_layernorm,
339            lm_head,
340            span: tracing::span!(tracing::Level::TRACE, "model"),
341        })
342    }
343
344    pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
345        let _enter = self.span.enter();
346        let (_b_size, seq_len) = xs.dims2()?;
347        let mut xs = xs.apply(&self.embed_tokens)?;
348        let mask = if seq_len <= 1 {
349            None
350        } else {
351            Some(get_mask(seq_len, xs.device())?)
352        };
353        for layer in self.layers.iter_mut() {
354            xs = layer.forward(&xs, mask.as_ref())?;
355        }
356        xs.apply(&self.final_layernorm)?
357            .narrow(1, seq_len - 1, 1)?
358            .apply(&self.lm_head)?
359            .squeeze(1)
360    }
361
362    pub fn clear_kv_cache(&mut self) {
363        self.layers.iter_mut().for_each(|b| b.clear_kv_cache())
364    }
365}