candle_transformers/models/
falcon.rs

1//! Falcon language model inference implementation
2//!
3//! See ["Falcon: a new approach to large language models"](https://huggingface.co/blog/falcon)
4//!
5//! Based on implementation from [Huggingface Transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon)
6
7use candle::{DType, Device, Result, Tensor, D};
8use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
9use serde::Deserialize;
10
11const MAX_SEQ_LEN: usize = 5000;
12
13fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
14    let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
15        (Ok(weight), Ok(bias)) => (weight, bias),
16        (Err(err), _) | (_, Err(err)) => {
17            if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
18                (weight, bias)
19            } else {
20                return Err(err);
21            }
22        }
23    };
24    Ok(LayerNorm::new(weight, bias, eps))
25}
26
27// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
28#[derive(Clone, Debug, Deserialize)]
29pub struct Config {
30    pub vocab_size: usize,
31    pub hidden_size: usize,
32    pub num_hidden_layers: usize,
33    pub num_attention_heads: usize,
34    pub layer_norm_epsilon: f64,
35    pub initializer_range: f64,
36    pub use_cache: bool,
37    pub bos_token_id: u32,
38    pub eos_token_id: u32,
39    pub hidden_dropout: f64,
40    pub attention_dropout: f64,
41    pub n_head_kv: Option<usize>,
42    pub alibi: bool,
43    pub new_decoder_architecture: bool,
44    pub multi_query: bool,
45    pub parallel_attn: bool,
46    pub bias: bool,
47}
48
49impl Default for Config {
50    fn default() -> Self {
51        Self {
52            vocab_size: 65024,
53            hidden_size: 4544,
54            num_hidden_layers: 32,
55            num_attention_heads: 71,
56            layer_norm_epsilon: 1e-5,
57            initializer_range: 0.02,
58            use_cache: true,
59            bos_token_id: 11,
60            eos_token_id: 11,
61            hidden_dropout: 0.0,
62            attention_dropout: 0.0,
63            n_head_kv: None,
64            alibi: false,
65            new_decoder_architecture: false,
66            multi_query: true,
67            parallel_attn: true,
68            bias: false,
69        }
70    }
71}
72
73impl Config {
74    pub fn validate(&self) -> Result<()> {
75        if self.alibi {
76            candle::bail!("alibi is not supported");
77        }
78        if self.new_decoder_architecture {
79            candle::bail!("new_decoder_architecture is not supported");
80        }
81        if self.n_head_kv.is_some() {
82            candle::bail!("n_head_kv is not supported");
83        }
84        Ok(())
85    }
86
87    // https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
88    pub fn falcon7b() -> Self {
89        // This is currently on par with the defaults, the defaults come from the Python default
90        // arguments for the config initialization whereas the following come from the json config.
91        Self {
92            vocab_size: 65024,
93            hidden_size: 4544,
94            num_hidden_layers: 32,
95            num_attention_heads: 71,
96            layer_norm_epsilon: 1e-5,
97            initializer_range: 0.02,
98            use_cache: true,
99            bos_token_id: 11,
100            eos_token_id: 11,
101            hidden_dropout: 0.,
102            attention_dropout: 0.,
103            n_head_kv: None,
104            alibi: false,
105            new_decoder_architecture: false,
106            multi_query: true,
107            parallel_attn: true,
108            bias: false,
109        }
110    }
111
112    fn head_dim(&self) -> usize {
113        self.hidden_size / self.num_attention_heads
114    }
115
116    fn rotary(&self) -> bool {
117        !self.alibi
118    }
119}
120
121fn rotate_half(x: &Tensor) -> Result<Tensor> {
122    let l = x.dim(D::Minus1)?;
123    let x1 = x.narrow(D::Minus1, 0, l / 2)?;
124    let x2 = x.narrow(D::Minus1, l / 2, l - l / 2)?;
125    let x21 = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
126    Ok(x21)
127}
128
129#[derive(Debug, Clone)]
130struct FalconRotaryEmbedding {
131    inv_freq: Tensor,
132    cache: Option<(usize, Tensor, Tensor)>,
133}
134
135impl FalconRotaryEmbedding {
136    fn load(device: &Device, cfg: &Config) -> Result<Self> {
137        let head_dim = cfg.head_dim();
138        let inv_freq: Vec<_> = (0..head_dim)
139            .step_by(2)
140            .map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
141            .collect();
142        Ok(Self {
143            inv_freq: Tensor::new(inv_freq.as_slice(), device)?,
144            cache: None,
145        })
146    }
147
148    fn cos_sin(
149        &mut self,
150        seq_len: usize,
151        device: &Device,
152        dtype: DType,
153    ) -> Result<(Tensor, Tensor)> {
154        match &self.cache {
155            Some((s, cos, sin)) if *s == seq_len => {
156                return Ok((cos.clone(), sin.clone()));
157            }
158            _ => {}
159        }
160        let t = Tensor::arange(0, seq_len as u32, device)?.to_dtype(dtype)?;
161        let inv_freq = self.inv_freq.to_dtype(dtype)?;
162        let freqs = t.unsqueeze(1)?.matmul(&inv_freq.unsqueeze(0)?)?;
163        let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
164        let cos = emb.cos()?;
165        let sin = emb.sin()?;
166        self.cache = Some((seq_len, cos.clone(), sin.clone()));
167        Ok((cos, sin))
168    }
169
170    fn forward(
171        &mut self,
172        query: &Tensor,
173        key: &Tensor,
174        past_kv_len: usize,
175    ) -> Result<(Tensor, Tensor)> {
176        let (_batch, seq_len, _head_dim) = query.dims3()?;
177        let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
178        let cos = cos.narrow(0, past_kv_len, seq_len)?;
179        let sin = sin.narrow(0, past_kv_len, seq_len)?;
180        let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?;
181        let ks = (key.broadcast_mul(&cos)? + &rotate_half(key)?.broadcast_mul(&sin)?)?;
182        Ok((qs, ks))
183    }
184}
185
186fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
187    let shape = mask.shape();
188    let on_true = Tensor::new(on_true, on_false.device())?
189        .to_dtype(on_false.dtype())?
190        .broadcast_as(shape.dims())?;
191    let m = mask.where_cond(&on_true, on_false)?;
192    Ok(m)
193}
194
195#[derive(Debug, Clone)]
196struct FalconAttention {
197    query_key_value: Linear,
198    dense: Linear,
199    maybe_rotary: Option<FalconRotaryEmbedding>,
200    kv_cache: Option<(Tensor, Tensor)>,
201    inv_norm_factor: f64,
202    multi_query: bool,
203    use_cache: bool,
204    num_heads: usize,
205    head_dim: usize,
206    n_head_kv: usize,
207}
208
209impl FalconAttention {
210    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
211        let maybe_rotary = if cfg.rotary() {
212            let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?;
213            Some(rotary)
214        } else {
215            None
216        };
217        let head_dim = cfg.head_dim();
218        let hidden_size = cfg.hidden_size;
219        let qkv_out_dim = if cfg.multi_query {
220            hidden_size + 2 * head_dim
221        } else {
222            3 * hidden_size
223        };
224        let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?;
225        let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?;
226        Ok(Self {
227            query_key_value,
228            dense,
229            maybe_rotary,
230            kv_cache: None,
231            inv_norm_factor: 1. / (head_dim as f64).sqrt(),
232            multi_query: cfg.multi_query,
233            use_cache: cfg.use_cache,
234            num_heads: cfg.num_attention_heads,
235            n_head_kv: cfg.n_head_kv.unwrap_or(1),
236            head_dim,
237        })
238    }
239
240    fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
241        let (b_sz, seq_len, _) = fused_qkv.dims3()?;
242        if !self.multi_query {
243            let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?;
244            let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?;
245            let k = fused_qkv.narrow(D::Minus2, 1, 1)?.squeeze(D::Minus2)?;
246            let v = fused_qkv.narrow(D::Minus2, 2, 1)?.squeeze(D::Minus2)?;
247            Ok((q, k, v))
248        } else {
249            let fused_qkv =
250                fused_qkv.reshape((b_sz, seq_len, self.num_heads + 2, self.head_dim))?;
251            let d = fused_qkv.dim(D::Minus2)?;
252            let q = fused_qkv.narrow(D::Minus2, 0, d - 2)?;
253            let k = fused_qkv.narrow(D::Minus2, d - 2, 1)?;
254            let v = fused_qkv.narrow(D::Minus2, d - 1, 1)?;
255            Ok((q, k, v))
256        }
257    }
258
259    fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> {
260        let fused_qkv = self.query_key_value.forward(x)?;
261        let head_dim = self.head_dim;
262        let (query, key, value) = self.split_heads(&fused_qkv)?;
263        let (b_sz, seq_len, _, _) = query.dims4()?;
264        let query = query
265            .transpose(1, 2)?
266            .reshape((b_sz * self.num_heads, seq_len, head_dim))?;
267        let key = key
268            .transpose(1, 2)?
269            .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?;
270        let value = value
271            .transpose(1, 2)?
272            .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?;
273        let (query, key) = if let Some(r) = &mut self.maybe_rotary {
274            r.forward(&query, &key, past_kv_len)?
275        } else {
276            (query, key)
277        };
278        let (mut key, mut value) = (key, value);
279        if self.use_cache {
280            if let Some((cache_k, cache_v)) = &self.kv_cache {
281                // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for
282                // arbitrarily large sizes.
283                key = Tensor::cat(&[cache_k, &key], 1)?.contiguous()?;
284                value = Tensor::cat(&[cache_v, &value], 1)?.contiguous()?;
285            }
286            self.kv_cache = Some((key.clone(), value.clone()))
287        }
288        let query = query.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
289        let all_len = past_kv_len + seq_len;
290        let key = key.reshape((b_sz * self.n_head_kv, all_len, head_dim))?;
291        let value = value.reshape((b_sz * self.n_head_kv, all_len, head_dim))?;
292
293        let (key, value) = if self.n_head_kv == 1 {
294            (
295                key.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?,
296                value.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?,
297            )
298        } else {
299            (key, value)
300        };
301
302        // Only handle the case where alibi is None here, and non-flash attention.
303        let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
304        let attention_scores = match mask {
305            None => attention_scores,
306            Some(mask) => {
307                let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?
308                    .to_dtype(query.dtype())?;
309                attention_scores.broadcast_add(&mask.squeeze(1)?)?
310            }
311        };
312
313        let attention_scores =
314            candle_nn::ops::softmax(&attention_scores.to_dtype(DType::F32)?, D::Minus1)?
315                .to_dtype(x.dtype())?;
316        let attn_output = attention_scores
317            .matmul(&value)?
318            .reshape((b_sz, self.num_heads, seq_len, head_dim))?
319            .transpose(1, 2)?
320            .reshape((b_sz, seq_len, self.num_heads * head_dim))?;
321        let attn_output = self.dense.forward(&attn_output)?;
322        Ok(attn_output)
323    }
324
325    fn clear_kv_cache(&mut self) {
326        self.kv_cache = None
327    }
328}
329
330#[derive(Debug, Clone)]
331struct FalconMlp {
332    dense_h_to_4h: Linear,
333    dense_4h_to_h: Linear,
334}
335
336impl FalconMlp {
337    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
338        let h = cfg.hidden_size;
339        let b = cfg.bias;
340        let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?;
341        let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?;
342        Ok(Self {
343            dense_h_to_4h,
344            dense_4h_to_h,
345        })
346    }
347
348    fn forward(&self, x: &Tensor) -> Result<Tensor> {
349        let x = self.dense_h_to_4h.forward(x)?.gelu()?;
350        let x = self.dense_4h_to_h.forward(&x)?;
351        Ok(x)
352    }
353}
354
355#[derive(Debug, Clone)]
356struct FalconDecoderLayer {
357    inp_layernorm: LayerNorm,
358    self_attention: FalconAttention,
359    post_attention_layernorm: Option<LayerNorm>,
360    mlp: FalconMlp,
361    parallel_attn: bool,
362}
363
364impl FalconDecoderLayer {
365    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
366        let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?;
367        let inp_layernorm = layer_norm(
368            cfg.hidden_size,
369            cfg.layer_norm_epsilon,
370            vb.pp("input_layernorm"),
371        )?;
372        let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?;
373        let post_attention_layernorm = if cfg.parallel_attn {
374            None
375        } else {
376            let ln = layer_norm(
377                cfg.hidden_size,
378                cfg.layer_norm_epsilon,
379                vb.pp("post_attention_layernorm"),
380            )?;
381            Some(ln)
382        };
383        Ok(Self {
384            inp_layernorm,
385            self_attention,
386            post_attention_layernorm,
387            mlp,
388            parallel_attn: cfg.parallel_attn,
389        })
390    }
391
392    fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> {
393        let residual = x.clone();
394        let ln_attn = self.inp_layernorm.forward(x)?;
395        let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?;
396        let (residual, ln_mlp) = match &self.post_attention_layernorm {
397            None => (residual, ln_attn),
398            Some(pal) => {
399                // This should include some dropout.
400                let residual = (&attn_output + &residual)?;
401                let ln_mlp = pal.forward(&residual)?;
402                (residual, ln_mlp)
403            }
404        };
405        let mlp_output = self.mlp.forward(&ln_mlp)?;
406
407        let mlp_output = if self.parallel_attn {
408            (mlp_output + attn_output)?
409        } else {
410            mlp_output
411        };
412        let output = (mlp_output + residual)?;
413        Ok(output)
414    }
415
416    pub fn clear_kv_cache(&mut self) {
417        self.self_attention.clear_kv_cache()
418    }
419}
420
421#[derive(Debug, Clone)]
422pub struct Falcon {
423    word_embeddings: Embedding,
424    blocks: Vec<FalconDecoderLayer>,
425    ln_f: LayerNorm,
426    lm_head: Linear,
427    config: Config,
428}
429
430fn make_causal_mask(t: usize) -> Result<Tensor> {
431    let mask: Vec<_> = (0..t)
432        .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
433        .collect();
434    let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
435    Ok(mask)
436}
437
438fn prepare_attn_mask(b_sz: usize, seq_len: usize) -> Result<Tensor> {
439    // let mask = Tensor::ones((b_sz, seq_len), DType::U32, &Device::Cpu)?;
440    let mask = make_causal_mask(seq_len)?;
441    let mask = mask.broadcast_as((b_sz, 1, seq_len, seq_len))?;
442    Ok(mask)
443}
444
445impl Falcon {
446    pub fn config(&self) -> &Config {
447        &self.config
448    }
449
450    pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
451        let word_embeddings = embedding(
452            cfg.vocab_size,
453            cfg.hidden_size,
454            vb.pp("transformer.word_embeddings"),
455        )?;
456        let blocks = (0..cfg.num_hidden_layers)
457            .map(|i| FalconDecoderLayer::load(vb.pp(format!("transformer.h.{i}")), &cfg))
458            .collect::<Result<Vec<_>>>()?;
459        let ln_f = layer_norm(
460            cfg.hidden_size,
461            cfg.layer_norm_epsilon,
462            vb.pp("transformer.ln_f"),
463        )?;
464        let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
465        Ok(Self {
466            word_embeddings,
467            blocks,
468            ln_f,
469            lm_head,
470            config: cfg,
471        })
472    }
473
474    pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
475        let (b_sz, seq_len) = input_ids.dims2()?;
476        let mut hidden_state = self.word_embeddings.forward(input_ids)?;
477        let past_kv_len = match &self.blocks[0].self_attention.kv_cache {
478            Some((k, _)) => k.dim(1)?,
479            None => 0,
480        };
481        let causal_mask = if seq_len <= 1 {
482            None
483        } else {
484            Some(prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?)
485        };
486        for block in self.blocks.iter_mut() {
487            hidden_state = block.forward(&hidden_state, causal_mask.as_ref(), past_kv_len)?;
488        }
489        let hidden_state = self.ln_f.forward(&hidden_state)?;
490        let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?;
491        let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?;
492        Ok(logits)
493    }
494
495    pub fn clear_kv_cache(&mut self) {
496        for block in self.blocks.iter_mut() {
497            block.clear_kv_cache()
498        }
499    }
500}