1use candle::{DType, Device, IndexOp, Result, Tensor, D};
25use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
26
27fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
28 let weight = vb.get(size, "weight")?;
29 let bias = vb.get(size, "bias")?;
30 Ok(LayerNorm::new(weight, bias, eps))
31}
32
33fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
34 let mask: Vec<_> = (0..t)
35 .flat_map(|i| (0..t).map(move |j| u8::from(j <= i)))
36 .collect();
37 let mask = Tensor::from_slice(&mask, (t, t), device)?;
38 Ok(mask)
39}
40
41#[derive(Debug)]
42pub struct Config {
43 pub vocab_size: usize,
44 pub max_position_embeddings: usize,
46 pub num_hidden_layers: usize,
48 pub hidden_size: usize,
50 pub layer_norm_epsilon: f64,
51 pub n_inner: Option<usize>,
52 pub num_attention_heads: usize,
54 pub multi_query: bool,
55 pub use_cache: bool,
56}
57
58impl Config {
59 #[allow(dead_code)]
60 pub fn starcoder_1b() -> Self {
61 Self {
62 vocab_size: 49152,
63 max_position_embeddings: 8192,
64 num_hidden_layers: 24,
65 hidden_size: 2048,
66 layer_norm_epsilon: 1e-5,
67 n_inner: Some(8192),
68 num_attention_heads: 16,
69 multi_query: true,
70 use_cache: true,
71 }
72 }
73
74 #[allow(dead_code)]
75 pub fn starcoder_3b() -> Self {
76 Self {
77 vocab_size: 49152,
78 max_position_embeddings: 8192,
79 num_hidden_layers: 36,
80 hidden_size: 2816,
81 layer_norm_epsilon: 1e-5,
82 n_inner: Some(11264),
83 num_attention_heads: 22,
84 multi_query: true,
85 use_cache: true,
86 }
87 }
88
89 #[allow(dead_code)]
90 pub fn starcoder_7b() -> Self {
91 Self {
92 vocab_size: 49152,
93 max_position_embeddings: 8192,
94 num_hidden_layers: 42,
95 hidden_size: 4096,
96 layer_norm_epsilon: 1e-5,
97 n_inner: Some(16384),
98 num_attention_heads: 32,
99 multi_query: true,
100 use_cache: true,
101 }
102 }
103
104 #[allow(dead_code)]
105 pub fn starcoder() -> Self {
106 Self {
107 vocab_size: 49152,
108 max_position_embeddings: 8192,
109 num_hidden_layers: 40,
110 hidden_size: 6144,
111 layer_norm_epsilon: 1e-5,
112 n_inner: Some(24576),
113 num_attention_heads: 48,
114 multi_query: true,
115 use_cache: true,
116 }
117 }
118}
119
120struct Attention {
121 c_attn: Linear,
122 c_proj: Linear,
123 kv_cache: Option<Tensor>,
124 use_cache: bool,
125 embed_dim: usize,
126 kv_dim: usize,
127 num_heads: usize,
128 head_dim: usize,
129 multi_query: bool,
130}
131
132impl Attention {
133 pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
134 let hidden_size = cfg.hidden_size;
135 let head_dim = hidden_size / cfg.num_attention_heads;
136 let kv_heads = if cfg.multi_query {
137 1
138 } else {
139 cfg.num_attention_heads
140 };
141 let kv_dim = kv_heads * head_dim;
142 let c_attn = linear(hidden_size, hidden_size + 2 * kv_dim, true, vb.pp("c_attn"))?;
143 let c_proj = linear(hidden_size, hidden_size, true, vb.pp("c_proj"))?;
144 Ok(Self {
145 c_proj,
146 c_attn,
147 embed_dim: hidden_size,
148 kv_cache: None,
149 use_cache: cfg.use_cache,
150 kv_dim,
151 head_dim,
152 num_heads: cfg.num_attention_heads,
153 multi_query: cfg.multi_query,
154 })
155 }
156
157 fn attn(
158 &self,
159 query: &Tensor,
160 key: &Tensor,
161 value: &Tensor,
162 attention_mask: &Tensor,
163 ) -> Result<Tensor> {
164 if query.dtype() != DType::F32 {
165 candle::bail!("upcasting is not supported {:?}", query.dtype())
168 }
169 let scale_factor = 1f64 / (self.head_dim as f64).sqrt();
170 let initial_query_shape = query.shape();
171 let key_len = key.dim(D::Minus1)?;
172 let (query, key, attn_shape, attn_view) = if self.multi_query {
173 let (b_sz, query_len, _) = query.dims3()?;
174 let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?;
175 let attn_shape = (b_sz, query_len, self.num_heads, key_len);
176 let attn_view = (b_sz, query_len * self.num_heads, key_len);
177 (query, key.clone(), attn_shape, attn_view)
178 } else {
179 let (b_sz, _num_heads, query_len, _head_dim) = query.dims4()?;
180 let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?;
181 let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?;
182 let attn_shape = (b_sz, self.num_heads, query_len, key_len);
183 let attn_view = (b_sz * self.num_heads, query_len, key_len);
184 (query, key, attn_shape, attn_view)
185 };
186
187 let attn_weights =
188 (query.matmul(&key.contiguous()?)? * scale_factor)?.reshape(attn_shape)?;
189 let attention_mask = attention_mask.broadcast_as(attn_shape)?;
190 let mask_value =
191 Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
192 let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
193 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
194 let value = value.contiguous()?;
195 let attn_output = if self.multi_query {
196 attn_weights
197 .reshape(attn_view)?
198 .matmul(&value)?
199 .reshape(initial_query_shape)?
200 } else {
201 attn_weights.matmul(&value)?
202 };
203 Ok(attn_output)
204 }
205
206 fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
207 let qkv = self.c_attn.forward(hidden_states)?;
208 let (query, key_value) = if self.multi_query {
209 let query = qkv.i((.., .., ..self.embed_dim))?;
210 let key_value = qkv.i((.., .., self.embed_dim..self.embed_dim + 2 * self.kv_dim))?;
211 (query, key_value)
212 } else {
213 let mut dims = qkv.dims().to_vec();
214 dims.pop();
215 dims.push(self.embed_dim);
216 dims.push(self.head_dim * 3);
217 let qkv = qkv.reshape(dims)?.transpose(1, 2)?;
218 let query = qkv.i((.., .., .., ..self.head_dim))?;
219 let key_value = qkv.i((.., .., .., self.head_dim..3 * self.head_dim))?;
220 (query, key_value)
221 };
222 let mut key_value = key_value;
223 if self.use_cache {
224 if let Some(kv_cache) = &self.kv_cache {
225 key_value = Tensor::cat(&[kv_cache, &key_value], D::Minus2)?.contiguous()?;
228 }
229 self.kv_cache = Some(key_value.clone())
230 }
231
232 let key = key_value.narrow(D::Minus1, 0, self.head_dim)?;
233 let value = key_value.narrow(D::Minus1, self.head_dim, self.head_dim)?;
234 let attn_output = self.attn(&query, &key.t()?, &value, attention_mask)?;
235 let attn_output = if self.multi_query {
236 attn_output
237 } else {
238 attn_output
239 .transpose(1, 2)?
240 .reshape(hidden_states.shape())?
241 };
242 let attn_output = self.c_proj.forward(&attn_output)?;
243 Ok(attn_output)
244 }
245}
246
247struct Mlp {
248 c_fc: Linear,
249 c_proj: Linear,
250}
251
252impl Mlp {
253 fn load(inner_dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
254 let c_fc = linear(cfg.hidden_size, inner_dim, true, vb.pp("c_fc"))?;
255 let c_proj = linear(inner_dim, cfg.hidden_size, true, vb.pp("c_proj"))?;
256 Ok(Self { c_fc, c_proj })
257 }
258
259 fn forward(&mut self, hidden_states: &Tensor) -> Result<Tensor> {
260 let hidden_states = self.c_fc.forward(hidden_states)?.gelu()?;
261 let hidden_states = self.c_proj.forward(&hidden_states)?;
262 Ok(hidden_states)
263 }
264}
265
266struct Block {
268 ln_1: LayerNorm,
269 attn: Attention,
270 ln_2: LayerNorm,
271 mlp: Mlp,
272}
273
274impl Block {
275 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
276 let hidden_size = cfg.hidden_size;
277 let inner_dim = cfg.n_inner.unwrap_or(4 * hidden_size);
278 let ln_1 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_1"))?;
279 let attn = Attention::load(vb.pp("attn"), cfg)?;
280 let ln_2 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_2"))?;
281 let mlp = Mlp::load(inner_dim, vb.pp("mlp"), cfg)?;
282 Ok(Self {
283 ln_1,
284 attn,
285 ln_2,
286 mlp,
287 })
288 }
289
290 fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
291 let residual = hidden_states;
292 let hidden_states = self.ln_1.forward(hidden_states)?;
293 let attn_outputs = self.attn.forward(&hidden_states, attention_mask)?;
294 let hidden_states = (&attn_outputs + residual)?;
295 let residual = &hidden_states;
296 let hidden_states = self.ln_2.forward(&hidden_states)?;
297 let hidden_states = self.mlp.forward(&hidden_states)?;
298 let hidden_states = (&hidden_states + residual)?;
299 Ok(hidden_states)
300 }
301}
302
303pub struct GPTBigCode {
304 wte: Embedding,
305 wpe: Embedding,
306 blocks: Vec<Block>,
307 ln_f: LayerNorm,
308 lm_head: Linear,
309 bias: Tensor,
310 config: Config,
311}
312
313impl GPTBigCode {
314 pub fn config(&self) -> &Config {
315 &self.config
316 }
317
318 pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
319 let hidden_size = cfg.hidden_size;
320 let vb_t = vb.pp("transformer");
321 let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?;
322 let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?;
323 let blocks = (0..cfg.num_hidden_layers)
324 .map(|i| Block::load(vb_t.pp(format!("h.{i}")), &cfg))
325 .collect::<Result<Vec<_>>>()?;
326 let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
327 let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?;
328 let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?;
329 Ok(Self {
330 wte,
331 wpe,
332 blocks,
333 lm_head,
334 ln_f,
335 bias,
336 config: cfg,
337 })
338 }
339
340 pub fn forward(&mut self, input_ids: &Tensor, past_len: usize) -> Result<Tensor> {
341 let dev = input_ids.device();
342 let (b_sz, seq_len) = input_ids.dims2()?;
343
344 let key_len = past_len + seq_len;
345 let attention_mask = self.bias.i((past_len..key_len, ..key_len))?.unsqueeze(0)?;
346 let seq_len_dim = if self.config.multi_query { 2 } else { 1 };
349 let attention_mask = attention_mask.unsqueeze(seq_len_dim)?;
350
351 let position_ids = Tensor::arange(past_len as u32, (past_len + seq_len) as u32, dev)?;
352 let position_ids = position_ids.unsqueeze(0)?.broadcast_as((b_sz, seq_len))?;
353 let input_embeds = self.wte.forward(input_ids)?;
354 let position_embeds = self.wpe.forward(&position_ids)?;
355
356 let mut hidden_states = (&input_embeds + &position_embeds)?;
357 for block in self.blocks.iter_mut() {
358 hidden_states = block.forward(&hidden_states, &attention_mask)?;
359 }
360 let hidden_states = self.ln_f.forward(&hidden_states)?;
361 let hidden_states = hidden_states
362 .reshape((b_sz, seq_len, self.config.hidden_size))?
363 .narrow(1, seq_len - 1, 1)?;
364 let logits = self.lm_head.forward(&hidden_states)?.squeeze(1)?;
365 Ok(logits)
366 }
367}