codegeex4_candle/
codegeex4.rs

1use candle_transformers::models::with_tracing::{linear_b as linear, Linear};
2use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
3use candle_core as candle;
4use candle_nn::VarBuilder;
5
6#[derive(Debug, Clone)]
7pub struct Config {
8    pub num_layers: usize,
9    pub padded_vocab_size: usize,
10    pub hidden_size: usize,
11    pub ffn_hidden_size: usize,
12    pub kv_channels: usize,
13    pub num_attention_heads: usize,
14    pub seq_length: usize,
15    pub layernorm_epsilon: f64,
16    pub rmsnorm: bool,
17    pub apply_residual_connection_post_layernorm: bool,
18    pub post_layer_norm: bool,
19    pub add_bias_linear: bool,
20    pub add_qkv_bias: bool,
21    pub bias_dropout_fusion: bool,
22    pub multi_query_attention: bool,
23    pub multi_query_group_num: usize,
24    pub apply_query_key_layer_scaling: bool,
25    pub attention_softmax_in_fp32: bool,
26    pub fp32_residual_connection: bool,
27}
28
29impl Config {
30    pub fn codegeex4() -> Self {
31        Self {
32            num_layers: 40,
33	    padded_vocab_size: 151552,
34            hidden_size: 4096,
35            ffn_hidden_size: 13696,
36            kv_channels: 128,
37            num_attention_heads: 32,
38            seq_length: 131072,
39            layernorm_epsilon: 1e-5,
40            rmsnorm: true,
41            apply_residual_connection_post_layernorm: false,
42            post_layer_norm: true,
43            add_bias_linear: false,
44            add_qkv_bias: true,
45            bias_dropout_fusion: true,
46            multi_query_attention: true,
47            multi_query_group_num: 2,
48            apply_query_key_layer_scaling: true,
49            attention_softmax_in_fp32: true,
50            fp32_residual_connection: false,
51        }
52    }
53}
54
55#[derive(Debug, Clone)]
56struct RotaryEmbedding {
57    cache: Tensor,
58}
59
60impl RotaryEmbedding {
61    fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> {
62        let rotary_dim = cfg.kv_channels;
63        let n_elem = rotary_dim / 2;
64        let inv_freq: Vec<_> = (0..n_elem)
65            .step_by(2)
66            .map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32)
67            .collect();
68        let inv_freq_len = inv_freq.len();
69        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
70        let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)?
71            .to_dtype(dtype).expect("unalbe to dytpe in Rotray Embedding new")
72            .reshape((cfg.seq_length, 1))?;
73        let freqs = t.matmul(&inv_freq)?;
74        let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
75        Ok(Self { cache })
76    }
77
78    fn apply(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
79        let (seqlen, _b, np, _hn) = xs.dims4()?;
80        let cache = self.cache.narrow(0, seqlen_offset, seqlen)?;
81        let rot_dim = cache.dim(D::Minus2)? * 2;
82        let (xs, xs_pass) = (
83            xs.narrow(D::Minus1, 0, rot_dim)?,
84            xs.narrow(D::Minus1, rot_dim, rot_dim)?,
85        );
86        let xshaped = xs.reshape((seqlen, (), np, rot_dim / 2, 2))?;
87        let cache = cache.reshape((seqlen, (), 1, rot_dim / 2, 2))?;
88        let (xshaped0, xshaped1) = (
89            xshaped.i((.., .., .., .., 0))?,
90            xshaped.i((.., .., .., .., 1))?,
91        );
92        let (cache0, cache1) = (cache.i((.., .., .., .., 0))?, cache.i((.., .., .., .., 1))?);
93        let xs_out = Tensor::stack(
94            &[
95                (xshaped0.broadcast_mul(&cache0)? - xshaped1.broadcast_mul(&cache1)?)?,
96                (xshaped1.broadcast_mul(&cache0)? + xshaped0.broadcast_mul(&cache1)?)?,
97            ],
98            D::Minus1,
99        )?;
100        let xs_out = xs_out.flatten_from(3)?;
101        Tensor::cat(&[xs_out, xs_pass], D::Minus1)
102    }
103}
104
105#[derive(Debug, Clone)]
106struct CoreAttention {
107    coeff: Option<f64>,
108    norm_factor: f64,
109}
110
111fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
112    let shape = mask.shape();
113    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
114    let m = mask.where_cond(&on_true, on_false)?;
115    Ok(m)
116}
117
118impl CoreAttention {
119    fn new(layer_number: usize, cfg: &Config) -> Result<Self> {
120        let norm_factor = (cfg.kv_channels as f64).sqrt();
121        let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
122            let coeff = f64::max(1.0, layer_number as f64);
123            (norm_factor * coeff, Some(coeff))
124        } else {
125            (norm_factor, None)
126        };
127        Ok(Self { coeff, norm_factor })
128    }
129
130    fn forward(
131        &self,
132        query_layer: &Tensor,
133        key_layer: &Tensor,
134        value_layer: &Tensor,
135        attention_mask: &Option<Tensor>,
136    ) -> Result<Tensor> {
137        let output_size = (
138            query_layer.dim(1)?, // b
139            query_layer.dim(2)?, // np
140            query_layer.dim(0)?, // sq
141            key_layer.dim(0)?,   // sk
142        );
143        let query_layer =
144            query_layer.reshape((output_size.2, output_size.0 * output_size.1, ()))?;
145        let key_layer = key_layer.reshape((output_size.3, output_size.0 * output_size.1, ()))?;
146        let matmul_result = Tensor::matmul(
147            &query_layer.transpose(0, 1)?,
148            &key_layer.transpose(0, 1)?.transpose(1, 2)?,
149        )?;
150        let matmul_result = (matmul_result / self.norm_factor)?.reshape(output_size)?;
151        let matmul_result = match self.coeff {
152            None => matmul_result,
153            Some(coeff) => (matmul_result * coeff)?,
154        };
155        let attention_scores = match attention_mask {
156            Some(mask) => masked_fill(
157                &matmul_result,
158                &mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
159                f32::NEG_INFINITY,
160            )?,
161            None => matmul_result,
162        };
163        let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
164
165        let output_size = (
166            value_layer.dim(1)?,
167            value_layer.dim(2)?,
168            query_layer.dim(0)?,
169            value_layer.dim(3)?,
170        );
171        let value_layer =
172            value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?;
173        let attention_probs =
174            attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?;
175        let context_layer = Tensor::matmul(&attention_probs, &value_layer.transpose(0, 1)?)?;
176        let context_layer = context_layer.reshape(output_size)?;
177        let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?;
178        context_layer.flatten_from(D::Minus2)
179    }
180}
181
182#[derive(Debug, Clone)]
183struct SelfAttention {
184    query_key_value: Linear,
185    core_attention: CoreAttention,
186    dense: Linear,
187    multi_query_attention: bool,
188    num_attention_heads_per_partition: usize,
189    num_multi_query_groups_per_partition: usize,
190    hidden_size_per_attention_head: usize,
191    kv_cache: Option<(Tensor, Tensor)>,
192}
193
194impl SelfAttention {
195    fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
196        let projection_size = cfg.kv_channels * cfg.num_attention_heads;
197        let hidden_size_per_attention_head = projection_size / cfg.num_attention_heads;
198        let qkv_hidden_size = if cfg.multi_query_attention {
199            projection_size + 2 * hidden_size_per_attention_head * cfg.multi_query_group_num
200        } else {
201            3 * projection_size
202        };
203        let query_key_value = linear(
204            cfg.hidden_size,
205            qkv_hidden_size,
206            cfg.add_bias_linear || cfg.add_qkv_bias,
207            vb.pp("query_key_value"),
208        )?;
209        let core_attention = CoreAttention::new(layer_number, cfg)?;
210        let dense = linear(
211            cfg.hidden_size,
212            cfg.hidden_size,
213            cfg.add_bias_linear,
214            vb.pp("dense"),
215        )?;
216        Ok(Self {
217            query_key_value,
218            core_attention,
219            dense,
220            multi_query_attention: cfg.multi_query_attention,
221            num_attention_heads_per_partition: cfg.num_attention_heads,
222            num_multi_query_groups_per_partition: cfg.multi_query_group_num,
223            hidden_size_per_attention_head: cfg.kv_channels,
224            kv_cache: None,
225        })
226    }
227
228    fn reset_kv_cache(&mut self) {
229        self.kv_cache = None
230    }
231
232    fn forward(
233        &mut self,
234        xs: &Tensor,
235        attention_mask: &Option<Tensor>,
236        rotary_emb: &RotaryEmbedding,
237    ) -> Result<Tensor> {
238        let mixed_x_layer = xs.apply(&self.query_key_value)?;
239        if !self.multi_query_attention {
240            candle::bail!("only multi_query_attention=true is supported")
241        }
242        let hpa = self.hidden_size_per_attention_head;
243        let query_layer =
244            mixed_x_layer.narrow(D::Minus1, 0, self.num_attention_heads_per_partition * hpa)?;
245        let key_layer = mixed_x_layer.narrow(
246            D::Minus1,
247            self.num_attention_heads_per_partition * hpa,
248            self.num_multi_query_groups_per_partition * hpa,
249        )?;
250        let value_layer = mixed_x_layer.narrow(
251            D::Minus1,
252            self.num_attention_heads_per_partition * hpa
253                + self.num_multi_query_groups_per_partition * hpa,
254            self.num_multi_query_groups_per_partition * hpa,
255        )?;
256        let query_layer = query_layer.reshape((
257            query_layer.dim(0)?,
258            query_layer.dim(1)?,
259            self.num_attention_heads_per_partition,
260            hpa,
261        ))?;
262        let key_layer = key_layer.reshape((
263            key_layer.dim(0)?,
264            key_layer.dim(1)?,
265            self.num_multi_query_groups_per_partition,
266            hpa,
267        ))?;
268        let value_layer = value_layer.reshape((
269            value_layer.dim(0)?,
270            value_layer.dim(1)?,
271            self.num_multi_query_groups_per_partition,
272            hpa,
273        ))?;
274
275        // Rotary embeddings.
276        let seqlen_offset = match &self.kv_cache {
277            None => 0,
278            Some((prev_k, _)) => prev_k.dim(0)?,
279        };
280        let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?;
281        let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?;
282
283        // KV cache.
284        let (key_layer, value_layer) = match &self.kv_cache {
285            None => (key_layer, value_layer),
286            Some((prev_k, prev_v)) => {
287                let k = Tensor::cat(&[prev_k, &key_layer], 0)?;
288                let v = Tensor::cat(&[prev_v, &value_layer], 0)?;
289                (k, v)
290            }
291        };
292        self.kv_cache = Some((key_layer.clone(), value_layer.clone()));
293
294        // Repeat KV.
295        let ratio =
296            self.num_attention_heads_per_partition / self.num_multi_query_groups_per_partition;
297        let key_layer = {
298            let (d0, d1, d2, d3) = key_layer.dims4()?;
299            key_layer
300                .unsqueeze(D::Minus2)?
301                .expand((d0, d1, d2, ratio, d3))?
302                .reshape((
303                    d0,
304                    d1,
305                    self.num_attention_heads_per_partition,
306                    self.hidden_size_per_attention_head,
307                ))?
308        };
309        let value_layer = {
310            let (d0, d1, d2, d3) = value_layer.dims4()?;
311            value_layer
312                .unsqueeze(D::Minus2)?
313                .expand((d0, d1, d2, ratio, d3))?
314                .reshape((
315                    d0,
316                    d1,
317                    self.num_attention_heads_per_partition,
318                    self.hidden_size_per_attention_head,
319                ))?
320        };
321
322        let context_layer =
323            self.core_attention
324                .forward(&query_layer, &key_layer, &value_layer, attention_mask)?;
325        let output = context_layer.apply(&self.dense)?;
326        Ok(output)
327    }
328}
329
330#[allow(clippy::upper_case_acronyms)]
331#[derive(Debug, Clone)]
332struct MLP {
333    dense_h_to_4h: Linear,
334    dense_4h_to_h: Linear,
335}
336
337impl MLP {
338    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
339        let dense_h_to_4h = linear(
340            cfg.hidden_size,
341            cfg.ffn_hidden_size * 2,
342            cfg.add_bias_linear,
343            vb.pp("dense_h_to_4h"),
344        )?;
345        let dense_4h_to_h = linear(
346            cfg.ffn_hidden_size,
347            cfg.hidden_size,
348            cfg.add_bias_linear,
349            vb.pp("dense_4h_to_h"),
350        )?;
351        Ok(Self {
352            dense_4h_to_h,
353            dense_h_to_4h,
354        })
355    }
356}
357
358impl Module for MLP {
359    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
360        xs.apply(&self.dense_h_to_4h)?
361            .apply(&candle_nn::Activation::Swiglu)?
362            .apply(&self.dense_4h_to_h)
363    }
364}
365
366#[derive(Debug, Clone)]
367struct Block {
368    input_layernorm: candle_nn::LayerNorm,
369    self_attention: SelfAttention,
370    post_attention_layernorm: candle_nn::LayerNorm,
371    mlp: MLP,
372    apply_residual_connection_post_layernorm: bool,
373}
374
375impl Block {
376    fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
377        let input_layernorm = if cfg.rmsnorm {
378            candle_nn::rms_norm(
379                cfg.hidden_size,
380                cfg.layernorm_epsilon,
381                vb.pp("input_layernorm"),
382            )?
383            .into_inner()
384        } else {
385            candle_nn::layer_norm(
386                cfg.hidden_size,
387                cfg.layernorm_epsilon,
388                vb.pp("input_layernorm"),
389            )?
390        };
391        let post_attention_layernorm = if cfg.rmsnorm {
392            candle_nn::rms_norm(
393                cfg.hidden_size,
394                cfg.layernorm_epsilon,
395                vb.pp("post_attention_layernorm"),
396            )?
397            .into_inner()
398        } else {
399            candle_nn::layer_norm(
400                cfg.hidden_size,
401                cfg.layernorm_epsilon,
402                vb.pp("post_attention_layernorm"),
403            )?
404        };
405        let self_attention = SelfAttention::new(layer_number, cfg, vb.pp("self_attention"))?;
406        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
407        Ok(Self {
408            input_layernorm,
409            self_attention,
410            post_attention_layernorm,
411            mlp,
412            apply_residual_connection_post_layernorm: cfg.apply_residual_connection_post_layernorm,
413        })
414    }
415
416    fn reset_kv_cache(&mut self) {
417        self.self_attention.reset_kv_cache()
418    }
419
420    fn forward(
421        &mut self,
422        xs: &Tensor,
423        attention_mask: &Option<Tensor>,
424        rotary_emb: &RotaryEmbedding,
425    ) -> Result<Tensor> {
426        let layernorm_output = xs.apply(&self.input_layernorm)?;
427        let attention_output =
428            self.self_attention
429                .forward(&layernorm_output, attention_mask, rotary_emb)?;
430        let residual = if self.apply_residual_connection_post_layernorm {
431            &layernorm_output
432        } else {
433            xs
434        };
435        let layernorm_input = (residual + attention_output)?;
436        let layernorm_output = layernorm_input.apply(&self.post_attention_layernorm)?;
437        let mlp_output = layernorm_output.apply(&self.mlp)?;
438        let residual = if self.apply_residual_connection_post_layernorm {
439            &layernorm_output
440        } else {
441            &layernorm_input
442        };
443        mlp_output + residual
444    }
445}
446
447#[derive(Debug, Clone)]
448struct Transformer {
449    layers: Vec<Block>,
450    final_layernorm: Option<candle_nn::LayerNorm>,
451    rotary_emb: RotaryEmbedding,
452}
453
454impl Transformer {
455    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
456        let vb_l = vb.pp("layers");
457        let mut layers = Vec::with_capacity(cfg.num_layers);
458	println!("transofrmer layers create");
459	let mut count = 0;
460        for layer_index in 0..cfg.num_layers {
461	    count += 1;
462	    println!("for layer index in {} total is {} ",count, cfg.num_layers);
463            let block = Block::new(layer_index + 1, cfg, vb_l.pp(layer_index))?;
464            layers.push(block)
465        }
466        let final_layernorm = if cfg.post_layer_norm {
467            let ln = if cfg.rmsnorm {
468                candle_nn::rms_norm(
469                    cfg.hidden_size,
470                    cfg.layernorm_epsilon,
471                    vb.pp("final_layernorm"),
472                )?
473                .into_inner()
474            } else {
475                candle_nn::layer_norm(
476                    cfg.hidden_size,
477                    cfg.layernorm_epsilon,
478                    vb.pp("final_layernorm"),
479                )?
480            };
481            Some(ln)
482        } else {
483            None
484        };
485        let rotary_emb = RotaryEmbedding::new(cfg, vb.dtype(), vb.device())?;
486        Ok(Self {
487            layers,
488            final_layernorm,
489            rotary_emb,
490        })
491    }
492
493    fn reset_kv_cache(&mut self) {
494        for block in self.layers.iter_mut() {
495            block.reset_kv_cache()
496        }
497    }
498
499    fn forward(&mut self, xs: &Tensor, attention_mask: &Option<Tensor>) -> Result<Tensor> {
500        let mut xs = xs.clone();
501        for block in self.layers.iter_mut() {
502            xs = block.forward(&xs, attention_mask, &self.rotary_emb)?
503        }
504        match self.final_layernorm.as_ref() {
505            None => Ok(xs),
506            Some(ln) => xs.apply(ln),
507        }
508    }
509}
510
511#[derive(Debug, Clone)]
512struct Embedding {
513    word_embeddings: candle_nn::Embedding,
514    fp32_residual_connection: bool,
515}
516
517impl Embedding {
518    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
519        let word_embeddings = candle_nn::embedding(
520            cfg.padded_vocab_size,
521            cfg.hidden_size,
522            vb.pp("word_embeddings"),
523        )?;
524        Ok(Self {
525            word_embeddings,
526            fp32_residual_connection: cfg.fp32_residual_connection,
527        })
528    }
529}
530
531impl Module for Embedding {
532    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
533        let xs = self.word_embeddings.forward(xs)?.transpose(0, 1)?; // b,s,h -> s,b,h
534        if self.fp32_residual_connection {
535            xs.to_dtype(candle::DType::F32)
536        } else {
537            xs.contiguous()
538        }
539    }
540}
541
542#[derive(Debug, Clone)]
543pub struct Model {
544    embedding: Embedding,
545    encoder: Transformer,
546    output_layer: Linear,
547}
548
549fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
550    let mask: Vec<_> = (0..size)
551        .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
552        .collect();
553    Tensor::from_slice(&mask, (size, size), device)
554}
555
556impl Model {
557    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
558        let vb = vb.pp("transformer");
559        let embedding = Embedding::new(cfg, vb.pp("embedding"))?;
560        let encoder = Transformer::new(cfg, vb.pp("encoder"))?;
561        let output_layer = linear(
562            cfg.hidden_size,
563            cfg.padded_vocab_size,
564            false,
565            vb.pp("output_layer"),
566        )?;
567	
568	
569        Ok(Self {
570            embedding,
571            encoder,
572            output_layer,
573        })
574    }
575
576    pub fn reset_kv_cache(&mut self) {
577        self.encoder.reset_kv_cache()
578    }
579
580    pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
581        let (_b_size, seq_len) = xs.dims2()?;
582        let input_embeds = xs.apply(&self.embedding)?;
583        let attention_mask = if seq_len <= 1 {
584            None
585        } else {
586            Some(get_mask(seq_len, xs.device())?)
587        };
588        let xs = self.encoder.forward(&input_embeds, &attention_mask)?;
589        let lm_logits = xs.i(seq_len - 1)?.apply(&self.output_layer)?;
590        Ok(lm_logits)
591    }
592}