candle_transformers/models/
starcoder2.rs

1//! StarCoder model implementation with quantization support.
2//!
3//! StarCoder is a large language model optimized for code generation.
4//! This implementation provides quantization for reduced memory and compute.
5//!
6//! Key characteristics:
7//! - Causal self-attention mechanism
8//! - Multi-query attention (MQA)
9//! - LayerNorm for normalization
10//! - Absolute positional embeddings
11//! - Support for 8-bit quantization
12//!
13//! References:
14//! - 📝 [StarCoder Paper](https://arxiv.org/abs/2305.06161)
15//! - 🤗 [Model Card](https://huggingface.co/bigcode/starcoder)
16//!
17
18use candle::{DType, Device, Module, Result, Tensor, D};
19use candle_nn::{layer_norm, linear_b, LayerNorm, Linear, VarBuilder};
20use std::sync::Arc;
21
22#[derive(Debug, Clone, serde::Deserialize)]
23pub struct Config {
24    vocab_size: usize,
25    hidden_size: usize,
26    intermediate_size: usize,
27    num_hidden_layers: usize,
28    num_attention_heads: usize,
29    num_key_value_heads: usize,
30    hidden_act: candle_nn::Activation,
31    max_position_embeddings: usize,
32    norm_epsilon: f64,
33    rope_theta: f64,
34    use_bias: bool,
35    sliding_window: Option<usize>,
36}
37
38#[derive(Debug, Clone)]
39struct RotaryEmbedding {
40    sin: Tensor,
41    cos: Tensor,
42}
43
44fn rotate_half(xs: &Tensor) -> Result<Tensor> {
45    let last_dim = xs.dim(D::Minus1)?;
46    let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
47    let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
48    Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
49}
50
51impl RotaryEmbedding {
52    fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
53        let dim = cfg.hidden_size / cfg.num_attention_heads;
54        let max_seq_len = cfg.max_position_embeddings;
55        let inv_freq: Vec<_> = (0..dim)
56            .step_by(2)
57            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
58            .collect();
59        let inv_freq_len = inv_freq.len();
60        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
61        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
62            .to_dtype(dtype)?
63            .reshape((max_seq_len, 1))?;
64        let freqs = t.matmul(&inv_freq)?;
65        let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
66        Ok(Self {
67            sin: freqs.sin()?,
68            cos: freqs.cos()?,
69        })
70    }
71
72    fn apply_rotary_emb_qkv(
73        &self,
74        q: &Tensor,
75        k: &Tensor,
76        seqlen_offset: usize,
77    ) -> Result<(Tensor, Tensor)> {
78        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
79        let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
80        let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
81        let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
82        let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
83        let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
84        let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
85        Ok((q_embed, k_embed))
86    }
87}
88
89#[derive(Debug, Clone)]
90#[allow(clippy::upper_case_acronyms)]
91struct MLP {
92    c_fc: Linear,
93    c_proj: Linear,
94    act: candle_nn::Activation,
95}
96
97impl MLP {
98    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
99        let (h_size, i_size) = (cfg.hidden_size, cfg.intermediate_size);
100        let c_fc = linear_b(h_size, i_size, cfg.use_bias, vb.pp("c_fc"))?;
101        let c_proj = linear_b(i_size, h_size, cfg.use_bias, vb.pp("c_proj"))?;
102        Ok(Self {
103            c_fc,
104            c_proj,
105            act: cfg.hidden_act,
106        })
107    }
108}
109
110impl Module for MLP {
111    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
112        xs.apply(&self.c_fc)?.apply(&self.act)?.apply(&self.c_proj)
113    }
114}
115
116#[derive(Debug, Clone)]
117struct Attention {
118    q_proj: Linear,
119    k_proj: Linear,
120    v_proj: Linear,
121    o_proj: Linear,
122    num_heads: usize,
123    num_kv_heads: usize,
124    num_kv_groups: usize,
125    head_dim: usize,
126    hidden_size: usize,
127    rotary_emb: Arc<RotaryEmbedding>,
128    kv_cache: Option<(Tensor, Tensor)>,
129}
130
131impl Attention {
132    fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
133        let hidden_sz = cfg.hidden_size;
134        let num_heads = cfg.num_attention_heads;
135        let num_kv_heads = cfg.num_key_value_heads;
136        let num_kv_groups = num_heads / num_kv_heads;
137        let head_dim = hidden_sz / num_heads;
138        let b = cfg.use_bias;
139        let q_proj = linear_b(hidden_sz, num_heads * head_dim, b, vb.pp("q_proj"))?;
140        let k_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("k_proj"))?;
141        let v_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("v_proj"))?;
142        let o_proj = linear_b(num_heads * head_dim, hidden_sz, b, vb.pp("o_proj"))?;
143        Ok(Self {
144            q_proj,
145            k_proj,
146            v_proj,
147            o_proj,
148            num_heads,
149            num_kv_heads,
150            num_kv_groups,
151            head_dim,
152            hidden_size: hidden_sz,
153            rotary_emb,
154            kv_cache: None,
155        })
156    }
157
158    fn forward(
159        &mut self,
160        xs: &Tensor,
161        attention_mask: Option<&Tensor>,
162        seqlen_offset: usize,
163    ) -> Result<Tensor> {
164        let (b_sz, q_len, _) = xs.dims3()?;
165
166        let query_states = self.q_proj.forward(xs)?;
167        let key_states = self.k_proj.forward(xs)?;
168        let value_states = self.v_proj.forward(xs)?;
169
170        let query_states = query_states
171            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
172            .transpose(1, 2)?;
173        let key_states = key_states
174            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
175            .transpose(1, 2)?;
176        let value_states = value_states
177            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
178            .transpose(1, 2)?;
179
180        let (query_states, key_states) =
181            self.rotary_emb
182                .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
183
184        let (key_states, value_states) = match &self.kv_cache {
185            None => (key_states, value_states),
186            Some((prev_k, prev_v)) => {
187                let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
188                let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
189                (key_states, value_states)
190            }
191        };
192        self.kv_cache = Some((key_states.clone(), value_states.clone()));
193
194        let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
195        let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
196
197        let scale = 1f64 / f64::sqrt(self.head_dim as f64);
198        let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
199
200        let attn_weights = match attention_mask {
201            None => attn_weights,
202            Some(mask) => attn_weights.broadcast_add(mask)?,
203        };
204        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
205        let attn_output = attn_weights.matmul(&value_states)?;
206        attn_output
207            .transpose(1, 2)?
208            .reshape((b_sz, q_len, self.hidden_size))?
209            .apply(&self.o_proj)
210    }
211
212    fn clear_kv_cache(&mut self) {
213        self.kv_cache = None
214    }
215}
216
217#[derive(Debug, Clone)]
218struct DecoderLayer {
219    self_attn: Attention,
220    mlp: MLP,
221    input_layernorm: LayerNorm,
222    post_attention_layernorm: LayerNorm,
223}
224
225impl DecoderLayer {
226    fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
227        let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
228        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
229        let input_layernorm =
230            layer_norm(cfg.hidden_size, cfg.norm_epsilon, vb.pp("input_layernorm"))?;
231        let post_attention_layernorm = layer_norm(
232            cfg.hidden_size,
233            cfg.norm_epsilon,
234            vb.pp("post_attention_layernorm"),
235        )?;
236        Ok(Self {
237            self_attn,
238            mlp,
239            input_layernorm,
240            post_attention_layernorm,
241        })
242    }
243
244    fn forward(
245        &mut self,
246        xs: &Tensor,
247        attention_mask: Option<&Tensor>,
248        seqlen_offset: usize,
249    ) -> Result<Tensor> {
250        let residual = xs;
251        let xs = self.input_layernorm.forward(xs)?;
252        let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
253        let xs = (xs + residual)?;
254        let residual = &xs;
255        let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
256        residual + xs
257    }
258
259    fn clear_kv_cache(&mut self) {
260        self.self_attn.clear_kv_cache()
261    }
262}
263
264#[derive(Debug, Clone)]
265pub struct Model {
266    embed_tokens: candle_nn::Embedding,
267    layers: Vec<DecoderLayer>,
268    norm: LayerNorm,
269    lm_head: Linear,
270    sliding_window: Option<usize>,
271    device: Device,
272    dtype: DType,
273}
274
275impl Model {
276    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
277        let vb_m = vb.pp("model");
278        let embed_tokens =
279            candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
280        let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
281        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
282        let vb_l = vb_m.pp("layers");
283        for layer_idx in 0..cfg.num_hidden_layers {
284            let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
285            layers.push(layer)
286        }
287        let norm = layer_norm(cfg.hidden_size, cfg.norm_epsilon, vb_m.pp("norm"))?;
288        let lm_head = candle_nn::Linear::new(embed_tokens.embeddings().clone(), None);
289        Ok(Self {
290            embed_tokens,
291            layers,
292            norm,
293            lm_head,
294            sliding_window: cfg.sliding_window,
295            device: vb.device().clone(),
296            dtype: vb.dtype(),
297        })
298    }
299
300    fn prepare_decoder_attention_mask(
301        &self,
302        b_size: usize,
303        tgt_len: usize,
304        seqlen_offset: usize,
305    ) -> Result<Tensor> {
306        let sliding_window = self.sliding_window.unwrap_or(tgt_len + 42);
307        let mask: Vec<_> = (0..tgt_len)
308            .flat_map(|i| {
309                (0..tgt_len).map(move |j| {
310                    if i < j || j + sliding_window < i {
311                        f32::NEG_INFINITY
312                    } else {
313                        0.
314                    }
315                })
316            })
317            .collect();
318        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
319        let mask = if seqlen_offset > 0 {
320            let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
321            Tensor::cat(&[&mask0, &mask], D::Minus1)?
322        } else {
323            mask
324        };
325        mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
326            .to_dtype(self.dtype)
327    }
328
329    pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
330        let (b_size, seq_len) = input_ids.dims2()?;
331        let attention_mask = if seq_len <= 1 {
332            None
333        } else {
334            let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
335            Some(mask)
336        };
337        let mut xs = self.embed_tokens.forward(input_ids)?;
338        for layer in self.layers.iter_mut() {
339            xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
340        }
341        xs.narrow(1, seq_len - 1, 1)?
342            .apply(&self.norm)?
343            .apply(&self.lm_head)
344    }
345
346    pub fn clear_kv_cache(&mut self) {
347        for layer in self.layers.iter_mut() {
348            layer.clear_kv_cache()
349        }
350    }
351}