1use candle::{DType, Device, Result, Tensor, D};
8use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
9use serde::Deserialize;
10
11const MAX_SEQ_LEN: usize = 5000;
12
13fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
14 let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
15 (Ok(weight), Ok(bias)) => (weight, bias),
16 (Err(err), _) | (_, Err(err)) => {
17 if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
18 (weight, bias)
19 } else {
20 return Err(err);
21 }
22 }
23 };
24 Ok(LayerNorm::new(weight, bias, eps))
25}
26
27#[derive(Clone, Debug, Deserialize)]
29pub struct Config {
30 pub vocab_size: usize,
31 pub hidden_size: usize,
32 pub num_hidden_layers: usize,
33 pub num_attention_heads: usize,
34 pub layer_norm_epsilon: f64,
35 pub initializer_range: f64,
36 pub use_cache: bool,
37 pub bos_token_id: u32,
38 pub eos_token_id: u32,
39 pub hidden_dropout: f64,
40 pub attention_dropout: f64,
41 pub n_head_kv: Option<usize>,
42 pub alibi: bool,
43 pub new_decoder_architecture: bool,
44 pub multi_query: bool,
45 pub parallel_attn: bool,
46 pub bias: bool,
47}
48
49impl Default for Config {
50 fn default() -> Self {
51 Self {
52 vocab_size: 65024,
53 hidden_size: 4544,
54 num_hidden_layers: 32,
55 num_attention_heads: 71,
56 layer_norm_epsilon: 1e-5,
57 initializer_range: 0.02,
58 use_cache: true,
59 bos_token_id: 11,
60 eos_token_id: 11,
61 hidden_dropout: 0.0,
62 attention_dropout: 0.0,
63 n_head_kv: None,
64 alibi: false,
65 new_decoder_architecture: false,
66 multi_query: true,
67 parallel_attn: true,
68 bias: false,
69 }
70 }
71}
72
73impl Config {
74 pub fn validate(&self) -> Result<()> {
75 if self.alibi {
76 candle::bail!("alibi is not supported");
77 }
78 if self.new_decoder_architecture {
79 candle::bail!("new_decoder_architecture is not supported");
80 }
81 if self.n_head_kv.is_some() {
82 candle::bail!("n_head_kv is not supported");
83 }
84 Ok(())
85 }
86
87 pub fn falcon7b() -> Self {
89 Self {
92 vocab_size: 65024,
93 hidden_size: 4544,
94 num_hidden_layers: 32,
95 num_attention_heads: 71,
96 layer_norm_epsilon: 1e-5,
97 initializer_range: 0.02,
98 use_cache: true,
99 bos_token_id: 11,
100 eos_token_id: 11,
101 hidden_dropout: 0.,
102 attention_dropout: 0.,
103 n_head_kv: None,
104 alibi: false,
105 new_decoder_architecture: false,
106 multi_query: true,
107 parallel_attn: true,
108 bias: false,
109 }
110 }
111
112 fn head_dim(&self) -> usize {
113 self.hidden_size / self.num_attention_heads
114 }
115
116 fn rotary(&self) -> bool {
117 !self.alibi
118 }
119}
120
121fn rotate_half(x: &Tensor) -> Result<Tensor> {
122 let l = x.dim(D::Minus1)?;
123 let x1 = x.narrow(D::Minus1, 0, l / 2)?;
124 let x2 = x.narrow(D::Minus1, l / 2, l - l / 2)?;
125 let x21 = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
126 Ok(x21)
127}
128
129#[derive(Debug, Clone)]
130struct FalconRotaryEmbedding {
131 inv_freq: Tensor,
132 cache: Option<(usize, Tensor, Tensor)>,
133}
134
135impl FalconRotaryEmbedding {
136 fn load(device: &Device, cfg: &Config) -> Result<Self> {
137 let head_dim = cfg.head_dim();
138 let inv_freq: Vec<_> = (0..head_dim)
139 .step_by(2)
140 .map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
141 .collect();
142 Ok(Self {
143 inv_freq: Tensor::new(inv_freq.as_slice(), device)?,
144 cache: None,
145 })
146 }
147
148 fn cos_sin(
149 &mut self,
150 seq_len: usize,
151 device: &Device,
152 dtype: DType,
153 ) -> Result<(Tensor, Tensor)> {
154 match &self.cache {
155 Some((s, cos, sin)) if *s == seq_len => {
156 return Ok((cos.clone(), sin.clone()));
157 }
158 _ => {}
159 }
160 let t = Tensor::arange(0, seq_len as u32, device)?.to_dtype(dtype)?;
161 let inv_freq = self.inv_freq.to_dtype(dtype)?;
162 let freqs = t.unsqueeze(1)?.matmul(&inv_freq.unsqueeze(0)?)?;
163 let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
164 let cos = emb.cos()?;
165 let sin = emb.sin()?;
166 self.cache = Some((seq_len, cos.clone(), sin.clone()));
167 Ok((cos, sin))
168 }
169
170 fn forward(
171 &mut self,
172 query: &Tensor,
173 key: &Tensor,
174 past_kv_len: usize,
175 ) -> Result<(Tensor, Tensor)> {
176 let (_batch, seq_len, _head_dim) = query.dims3()?;
177 let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
178 let cos = cos.narrow(0, past_kv_len, seq_len)?;
179 let sin = sin.narrow(0, past_kv_len, seq_len)?;
180 let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?;
181 let ks = (key.broadcast_mul(&cos)? + &rotate_half(key)?.broadcast_mul(&sin)?)?;
182 Ok((qs, ks))
183 }
184}
185
186fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
187 let shape = mask.shape();
188 let on_true = Tensor::new(on_true, on_false.device())?
189 .to_dtype(on_false.dtype())?
190 .broadcast_as(shape.dims())?;
191 let m = mask.where_cond(&on_true, on_false)?;
192 Ok(m)
193}
194
195#[derive(Debug, Clone)]
196struct FalconAttention {
197 query_key_value: Linear,
198 dense: Linear,
199 maybe_rotary: Option<FalconRotaryEmbedding>,
200 kv_cache: Option<(Tensor, Tensor)>,
201 inv_norm_factor: f64,
202 multi_query: bool,
203 use_cache: bool,
204 num_heads: usize,
205 head_dim: usize,
206 n_head_kv: usize,
207}
208
209impl FalconAttention {
210 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
211 let maybe_rotary = if cfg.rotary() {
212 let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?;
213 Some(rotary)
214 } else {
215 None
216 };
217 let head_dim = cfg.head_dim();
218 let hidden_size = cfg.hidden_size;
219 let qkv_out_dim = if cfg.multi_query {
220 hidden_size + 2 * head_dim
221 } else {
222 3 * hidden_size
223 };
224 let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?;
225 let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?;
226 Ok(Self {
227 query_key_value,
228 dense,
229 maybe_rotary,
230 kv_cache: None,
231 inv_norm_factor: 1. / (head_dim as f64).sqrt(),
232 multi_query: cfg.multi_query,
233 use_cache: cfg.use_cache,
234 num_heads: cfg.num_attention_heads,
235 n_head_kv: cfg.n_head_kv.unwrap_or(1),
236 head_dim,
237 })
238 }
239
240 fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
241 let (b_sz, seq_len, _) = fused_qkv.dims3()?;
242 if !self.multi_query {
243 let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?;
244 let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?;
245 let k = fused_qkv.narrow(D::Minus2, 1, 1)?.squeeze(D::Minus2)?;
246 let v = fused_qkv.narrow(D::Minus2, 2, 1)?.squeeze(D::Minus2)?;
247 Ok((q, k, v))
248 } else {
249 let fused_qkv =
250 fused_qkv.reshape((b_sz, seq_len, self.num_heads + 2, self.head_dim))?;
251 let d = fused_qkv.dim(D::Minus2)?;
252 let q = fused_qkv.narrow(D::Minus2, 0, d - 2)?;
253 let k = fused_qkv.narrow(D::Minus2, d - 2, 1)?;
254 let v = fused_qkv.narrow(D::Minus2, d - 1, 1)?;
255 Ok((q, k, v))
256 }
257 }
258
259 fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> {
260 let fused_qkv = self.query_key_value.forward(x)?;
261 let head_dim = self.head_dim;
262 let (query, key, value) = self.split_heads(&fused_qkv)?;
263 let (b_sz, seq_len, _, _) = query.dims4()?;
264 let query = query
265 .transpose(1, 2)?
266 .reshape((b_sz * self.num_heads, seq_len, head_dim))?;
267 let key = key
268 .transpose(1, 2)?
269 .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?;
270 let value = value
271 .transpose(1, 2)?
272 .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?;
273 let (query, key) = if let Some(r) = &mut self.maybe_rotary {
274 r.forward(&query, &key, past_kv_len)?
275 } else {
276 (query, key)
277 };
278 let (mut key, mut value) = (key, value);
279 if self.use_cache {
280 if let Some((cache_k, cache_v)) = &self.kv_cache {
281 key = Tensor::cat(&[cache_k, &key], 1)?.contiguous()?;
284 value = Tensor::cat(&[cache_v, &value], 1)?.contiguous()?;
285 }
286 self.kv_cache = Some((key.clone(), value.clone()))
287 }
288 let query = query.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
289 let all_len = past_kv_len + seq_len;
290 let key = key.reshape((b_sz * self.n_head_kv, all_len, head_dim))?;
291 let value = value.reshape((b_sz * self.n_head_kv, all_len, head_dim))?;
292
293 let (key, value) = if self.n_head_kv == 1 {
294 (
295 key.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?,
296 value.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?,
297 )
298 } else {
299 (key, value)
300 };
301
302 let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
304 let attention_scores = match mask {
305 None => attention_scores,
306 Some(mask) => {
307 let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?
308 .to_dtype(query.dtype())?;
309 attention_scores.broadcast_add(&mask.squeeze(1)?)?
310 }
311 };
312
313 let attention_scores =
314 candle_nn::ops::softmax(&attention_scores.to_dtype(DType::F32)?, D::Minus1)?
315 .to_dtype(x.dtype())?;
316 let attn_output = attention_scores
317 .matmul(&value)?
318 .reshape((b_sz, self.num_heads, seq_len, head_dim))?
319 .transpose(1, 2)?
320 .reshape((b_sz, seq_len, self.num_heads * head_dim))?;
321 let attn_output = self.dense.forward(&attn_output)?;
322 Ok(attn_output)
323 }
324
325 fn clear_kv_cache(&mut self) {
326 self.kv_cache = None
327 }
328}
329
330#[derive(Debug, Clone)]
331struct FalconMlp {
332 dense_h_to_4h: Linear,
333 dense_4h_to_h: Linear,
334}
335
336impl FalconMlp {
337 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
338 let h = cfg.hidden_size;
339 let b = cfg.bias;
340 let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?;
341 let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?;
342 Ok(Self {
343 dense_h_to_4h,
344 dense_4h_to_h,
345 })
346 }
347
348 fn forward(&self, x: &Tensor) -> Result<Tensor> {
349 let x = self.dense_h_to_4h.forward(x)?.gelu()?;
350 let x = self.dense_4h_to_h.forward(&x)?;
351 Ok(x)
352 }
353}
354
355#[derive(Debug, Clone)]
356struct FalconDecoderLayer {
357 inp_layernorm: LayerNorm,
358 self_attention: FalconAttention,
359 post_attention_layernorm: Option<LayerNorm>,
360 mlp: FalconMlp,
361 parallel_attn: bool,
362}
363
364impl FalconDecoderLayer {
365 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
366 let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?;
367 let inp_layernorm = layer_norm(
368 cfg.hidden_size,
369 cfg.layer_norm_epsilon,
370 vb.pp("input_layernorm"),
371 )?;
372 let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?;
373 let post_attention_layernorm = if cfg.parallel_attn {
374 None
375 } else {
376 let ln = layer_norm(
377 cfg.hidden_size,
378 cfg.layer_norm_epsilon,
379 vb.pp("post_attention_layernorm"),
380 )?;
381 Some(ln)
382 };
383 Ok(Self {
384 inp_layernorm,
385 self_attention,
386 post_attention_layernorm,
387 mlp,
388 parallel_attn: cfg.parallel_attn,
389 })
390 }
391
392 fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> {
393 let residual = x.clone();
394 let ln_attn = self.inp_layernorm.forward(x)?;
395 let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?;
396 let (residual, ln_mlp) = match &self.post_attention_layernorm {
397 None => (residual, ln_attn),
398 Some(pal) => {
399 let residual = (&attn_output + &residual)?;
401 let ln_mlp = pal.forward(&residual)?;
402 (residual, ln_mlp)
403 }
404 };
405 let mlp_output = self.mlp.forward(&ln_mlp)?;
406
407 let mlp_output = if self.parallel_attn {
408 (mlp_output + attn_output)?
409 } else {
410 mlp_output
411 };
412 let output = (mlp_output + residual)?;
413 Ok(output)
414 }
415
416 pub fn clear_kv_cache(&mut self) {
417 self.self_attention.clear_kv_cache()
418 }
419}
420
421#[derive(Debug, Clone)]
422pub struct Falcon {
423 word_embeddings: Embedding,
424 blocks: Vec<FalconDecoderLayer>,
425 ln_f: LayerNorm,
426 lm_head: Linear,
427 config: Config,
428}
429
430fn make_causal_mask(t: usize) -> Result<Tensor> {
431 let mask: Vec<_> = (0..t)
432 .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
433 .collect();
434 let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
435 Ok(mask)
436}
437
438fn prepare_attn_mask(b_sz: usize, seq_len: usize) -> Result<Tensor> {
439 let mask = make_causal_mask(seq_len)?;
441 let mask = mask.broadcast_as((b_sz, 1, seq_len, seq_len))?;
442 Ok(mask)
443}
444
445impl Falcon {
446 pub fn config(&self) -> &Config {
447 &self.config
448 }
449
450 pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
451 let word_embeddings = embedding(
452 cfg.vocab_size,
453 cfg.hidden_size,
454 vb.pp("transformer.word_embeddings"),
455 )?;
456 let blocks = (0..cfg.num_hidden_layers)
457 .map(|i| FalconDecoderLayer::load(vb.pp(format!("transformer.h.{i}")), &cfg))
458 .collect::<Result<Vec<_>>>()?;
459 let ln_f = layer_norm(
460 cfg.hidden_size,
461 cfg.layer_norm_epsilon,
462 vb.pp("transformer.ln_f"),
463 )?;
464 let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
465 Ok(Self {
466 word_embeddings,
467 blocks,
468 ln_f,
469 lm_head,
470 config: cfg,
471 })
472 }
473
474 pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
475 let (b_sz, seq_len) = input_ids.dims2()?;
476 let mut hidden_state = self.word_embeddings.forward(input_ids)?;
477 let past_kv_len = match &self.blocks[0].self_attention.kv_cache {
478 Some((k, _)) => k.dim(1)?,
479 None => 0,
480 };
481 let causal_mask = if seq_len <= 1 {
482 None
483 } else {
484 Some(prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?)
485 };
486 for block in self.blocks.iter_mut() {
487 hidden_state = block.forward(&hidden_state, causal_mask.as_ref(), past_kv_len)?;
488 }
489 let hidden_state = self.ln_f.forward(&hidden_state)?;
490 let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?;
491 let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?;
492 Ok(logits)
493 }
494
495 pub fn clear_kv_cache(&mut self) {
496 for block in self.blocks.iter_mut() {
497 block.clear_kv_cache()
498 }
499 }
500}