candle_transformers/models/
bigcode.rs

1//! BigCode implementation in Rust based on the GPT-BigCode model.
2//!
3//! [StarCoder/BigCode](https://huggingface.co/bigcode/starcoderbase-1b) is a LLM
4//! model specialized to code generation. The initial model was trained on 80
5//! programming languages. See "StarCoder: A State-of-the-Art LLM for Code", Mukherjee et al. 2023
6//! - [Arxiv](https://arxiv.org/abs/2305.06161)
7//! - [Github](https://github.com/bigcode-project/starcoder)
8//!
9//! ## Running some example
10//!
11//! ```bash
12//! cargo run --example bigcode --release -- --prompt "fn fact(n: u64) -> u64"
13//!
14//! > fn fact(n: u64) -> u64  {
15//! >     if n == 0 {
16//! >         1
17//! >     } else {
18//! >         n * fact(n - 1)
19//! >     }
20//! > }
21//! ```
22//!
23
24use candle::{DType, Device, IndexOp, Result, Tensor, D};
25use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
26
27fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
28    let weight = vb.get(size, "weight")?;
29    let bias = vb.get(size, "bias")?;
30    Ok(LayerNorm::new(weight, bias, eps))
31}
32
33fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
34    let mask: Vec<_> = (0..t)
35        .flat_map(|i| (0..t).map(move |j| u8::from(j <= i)))
36        .collect();
37    let mask = Tensor::from_slice(&mask, (t, t), device)?;
38    Ok(mask)
39}
40
41#[derive(Debug)]
42pub struct Config {
43    pub vocab_size: usize,
44    // max_position_embeddings aka n_positions
45    pub max_position_embeddings: usize,
46    // num_hidden_layers aka n_layer
47    pub num_hidden_layers: usize,
48    // hidden_size aka n_embd
49    pub hidden_size: usize,
50    pub layer_norm_epsilon: f64,
51    pub n_inner: Option<usize>,
52    // num_attention_heads aka n_head
53    pub num_attention_heads: usize,
54    pub multi_query: bool,
55    pub use_cache: bool,
56}
57
58impl Config {
59    #[allow(dead_code)]
60    pub fn starcoder_1b() -> Self {
61        Self {
62            vocab_size: 49152,
63            max_position_embeddings: 8192,
64            num_hidden_layers: 24,
65            hidden_size: 2048,
66            layer_norm_epsilon: 1e-5,
67            n_inner: Some(8192),
68            num_attention_heads: 16,
69            multi_query: true,
70            use_cache: true,
71        }
72    }
73
74    #[allow(dead_code)]
75    pub fn starcoder_3b() -> Self {
76        Self {
77            vocab_size: 49152,
78            max_position_embeddings: 8192,
79            num_hidden_layers: 36,
80            hidden_size: 2816,
81            layer_norm_epsilon: 1e-5,
82            n_inner: Some(11264),
83            num_attention_heads: 22,
84            multi_query: true,
85            use_cache: true,
86        }
87    }
88
89    #[allow(dead_code)]
90    pub fn starcoder_7b() -> Self {
91        Self {
92            vocab_size: 49152,
93            max_position_embeddings: 8192,
94            num_hidden_layers: 42,
95            hidden_size: 4096,
96            layer_norm_epsilon: 1e-5,
97            n_inner: Some(16384),
98            num_attention_heads: 32,
99            multi_query: true,
100            use_cache: true,
101        }
102    }
103
104    #[allow(dead_code)]
105    pub fn starcoder() -> Self {
106        Self {
107            vocab_size: 49152,
108            max_position_embeddings: 8192,
109            num_hidden_layers: 40,
110            hidden_size: 6144,
111            layer_norm_epsilon: 1e-5,
112            n_inner: Some(24576),
113            num_attention_heads: 48,
114            multi_query: true,
115            use_cache: true,
116        }
117    }
118}
119
120struct Attention {
121    c_attn: Linear,
122    c_proj: Linear,
123    kv_cache: Option<Tensor>,
124    use_cache: bool,
125    embed_dim: usize,
126    kv_dim: usize,
127    num_heads: usize,
128    head_dim: usize,
129    multi_query: bool,
130}
131
132impl Attention {
133    pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
134        let hidden_size = cfg.hidden_size;
135        let head_dim = hidden_size / cfg.num_attention_heads;
136        let kv_heads = if cfg.multi_query {
137            1
138        } else {
139            cfg.num_attention_heads
140        };
141        let kv_dim = kv_heads * head_dim;
142        let c_attn = linear(hidden_size, hidden_size + 2 * kv_dim, true, vb.pp("c_attn"))?;
143        let c_proj = linear(hidden_size, hidden_size, true, vb.pp("c_proj"))?;
144        Ok(Self {
145            c_proj,
146            c_attn,
147            embed_dim: hidden_size,
148            kv_cache: None,
149            use_cache: cfg.use_cache,
150            kv_dim,
151            head_dim,
152            num_heads: cfg.num_attention_heads,
153            multi_query: cfg.multi_query,
154        })
155    }
156
157    fn attn(
158        &self,
159        query: &Tensor,
160        key: &Tensor,
161        value: &Tensor,
162        attention_mask: &Tensor,
163    ) -> Result<Tensor> {
164        if query.dtype() != DType::F32 {
165            // If we start supporting f16 models, we may need the upcasting scaling bits.
166            // https://github.com/huggingface/transformers/blob/a0042379269bea9182c1f87e6b2eee4ba4c8cce8/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L133
167            candle::bail!("upcasting is not supported {:?}", query.dtype())
168        }
169        let scale_factor = 1f64 / (self.head_dim as f64).sqrt();
170        let initial_query_shape = query.shape();
171        let key_len = key.dim(D::Minus1)?;
172        let (query, key, attn_shape, attn_view) = if self.multi_query {
173            let (b_sz, query_len, _) = query.dims3()?;
174            let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?;
175            let attn_shape = (b_sz, query_len, self.num_heads, key_len);
176            let attn_view = (b_sz, query_len * self.num_heads, key_len);
177            (query, key.clone(), attn_shape, attn_view)
178        } else {
179            let (b_sz, _num_heads, query_len, _head_dim) = query.dims4()?;
180            let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?;
181            let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?;
182            let attn_shape = (b_sz, self.num_heads, query_len, key_len);
183            let attn_view = (b_sz * self.num_heads, query_len, key_len);
184            (query, key, attn_shape, attn_view)
185        };
186
187        let attn_weights =
188            (query.matmul(&key.contiguous()?)? * scale_factor)?.reshape(attn_shape)?;
189        let attention_mask = attention_mask.broadcast_as(attn_shape)?;
190        let mask_value =
191            Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
192        let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
193        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
194        let value = value.contiguous()?;
195        let attn_output = if self.multi_query {
196            attn_weights
197                .reshape(attn_view)?
198                .matmul(&value)?
199                .reshape(initial_query_shape)?
200        } else {
201            attn_weights.matmul(&value)?
202        };
203        Ok(attn_output)
204    }
205
206    fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
207        let qkv = self.c_attn.forward(hidden_states)?;
208        let (query, key_value) = if self.multi_query {
209            let query = qkv.i((.., .., ..self.embed_dim))?;
210            let key_value = qkv.i((.., .., self.embed_dim..self.embed_dim + 2 * self.kv_dim))?;
211            (query, key_value)
212        } else {
213            let mut dims = qkv.dims().to_vec();
214            dims.pop();
215            dims.push(self.embed_dim);
216            dims.push(self.head_dim * 3);
217            let qkv = qkv.reshape(dims)?.transpose(1, 2)?;
218            let query = qkv.i((.., .., .., ..self.head_dim))?;
219            let key_value = qkv.i((.., .., .., self.head_dim..3 * self.head_dim))?;
220            (query, key_value)
221        };
222        let mut key_value = key_value;
223        if self.use_cache {
224            if let Some(kv_cache) = &self.kv_cache {
225                // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for
226                // arbitrarily large sizes.
227                key_value = Tensor::cat(&[kv_cache, &key_value], D::Minus2)?.contiguous()?;
228            }
229            self.kv_cache = Some(key_value.clone())
230        }
231
232        let key = key_value.narrow(D::Minus1, 0, self.head_dim)?;
233        let value = key_value.narrow(D::Minus1, self.head_dim, self.head_dim)?;
234        let attn_output = self.attn(&query, &key.t()?, &value, attention_mask)?;
235        let attn_output = if self.multi_query {
236            attn_output
237        } else {
238            attn_output
239                .transpose(1, 2)?
240                .reshape(hidden_states.shape())?
241        };
242        let attn_output = self.c_proj.forward(&attn_output)?;
243        Ok(attn_output)
244    }
245}
246
247struct Mlp {
248    c_fc: Linear,
249    c_proj: Linear,
250}
251
252impl Mlp {
253    fn load(inner_dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
254        let c_fc = linear(cfg.hidden_size, inner_dim, true, vb.pp("c_fc"))?;
255        let c_proj = linear(inner_dim, cfg.hidden_size, true, vb.pp("c_proj"))?;
256        Ok(Self { c_fc, c_proj })
257    }
258
259    fn forward(&mut self, hidden_states: &Tensor) -> Result<Tensor> {
260        let hidden_states = self.c_fc.forward(hidden_states)?.gelu()?;
261        let hidden_states = self.c_proj.forward(&hidden_states)?;
262        Ok(hidden_states)
263    }
264}
265
266// TODO: Add cross-attention?
267struct Block {
268    ln_1: LayerNorm,
269    attn: Attention,
270    ln_2: LayerNorm,
271    mlp: Mlp,
272}
273
274impl Block {
275    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
276        let hidden_size = cfg.hidden_size;
277        let inner_dim = cfg.n_inner.unwrap_or(4 * hidden_size);
278        let ln_1 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_1"))?;
279        let attn = Attention::load(vb.pp("attn"), cfg)?;
280        let ln_2 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_2"))?;
281        let mlp = Mlp::load(inner_dim, vb.pp("mlp"), cfg)?;
282        Ok(Self {
283            ln_1,
284            attn,
285            ln_2,
286            mlp,
287        })
288    }
289
290    fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
291        let residual = hidden_states;
292        let hidden_states = self.ln_1.forward(hidden_states)?;
293        let attn_outputs = self.attn.forward(&hidden_states, attention_mask)?;
294        let hidden_states = (&attn_outputs + residual)?;
295        let residual = &hidden_states;
296        let hidden_states = self.ln_2.forward(&hidden_states)?;
297        let hidden_states = self.mlp.forward(&hidden_states)?;
298        let hidden_states = (&hidden_states + residual)?;
299        Ok(hidden_states)
300    }
301}
302
303pub struct GPTBigCode {
304    wte: Embedding,
305    wpe: Embedding,
306    blocks: Vec<Block>,
307    ln_f: LayerNorm,
308    lm_head: Linear,
309    bias: Tensor,
310    config: Config,
311}
312
313impl GPTBigCode {
314    pub fn config(&self) -> &Config {
315        &self.config
316    }
317
318    pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
319        let hidden_size = cfg.hidden_size;
320        let vb_t = vb.pp("transformer");
321        let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?;
322        let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?;
323        let blocks = (0..cfg.num_hidden_layers)
324            .map(|i| Block::load(vb_t.pp(format!("h.{i}")), &cfg))
325            .collect::<Result<Vec<_>>>()?;
326        let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
327        let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?;
328        let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?;
329        Ok(Self {
330            wte,
331            wpe,
332            blocks,
333            lm_head,
334            ln_f,
335            bias,
336            config: cfg,
337        })
338    }
339
340    pub fn forward(&mut self, input_ids: &Tensor, past_len: usize) -> Result<Tensor> {
341        let dev = input_ids.device();
342        let (b_sz, seq_len) = input_ids.dims2()?;
343
344        let key_len = past_len + seq_len;
345        let attention_mask = self.bias.i((past_len..key_len, ..key_len))?.unsqueeze(0)?;
346        // MQA models: (batch_size, query_length, n_heads, key_length)
347        // MHA models: (batch_size, n_heads, query_length, key_length)
348        let seq_len_dim = if self.config.multi_query { 2 } else { 1 };
349        let attention_mask = attention_mask.unsqueeze(seq_len_dim)?;
350
351        let position_ids = Tensor::arange(past_len as u32, (past_len + seq_len) as u32, dev)?;
352        let position_ids = position_ids.unsqueeze(0)?.broadcast_as((b_sz, seq_len))?;
353        let input_embeds = self.wte.forward(input_ids)?;
354        let position_embeds = self.wpe.forward(&position_ids)?;
355
356        let mut hidden_states = (&input_embeds + &position_embeds)?;
357        for block in self.blocks.iter_mut() {
358            hidden_states = block.forward(&hidden_states, &attention_mask)?;
359        }
360        let hidden_states = self.ln_f.forward(&hidden_states)?;
361        let hidden_states = hidden_states
362            .reshape((b_sz, seq_len, self.config.hidden_size))?
363            .narrow(1, seq_len - 1, 1)?;
364        let logits = self.lm_head.forward(&hidden_states)?.squeeze(1)?;
365        Ok(logits)
366    }
367}