1use crate::models::with_tracing::{linear, linear_no_bias, Linear};
17use candle::{DType, Device, Module, Result, Tensor, D};
18use candle_nn::{Activation, LayerNorm, VarBuilder};
19use serde::Deserialize;
20use std::sync::Arc;
21
22#[derive(Debug, Clone, PartialEq, Deserialize)]
24pub struct Config {
25 pub(crate) vocab_size: usize,
26 pub(crate) intermediate_size: usize,
27 pub(crate) hidden_size: usize,
28 pub(crate) num_hidden_layers: usize,
29 pub(crate) num_attention_heads: usize,
30 pub(crate) num_key_value_heads: usize,
31 pub(crate) hidden_act: Activation,
32 pub(crate) partial_rotary_factor: f64,
33 pub(crate) rope_theta: f64,
34 pub(crate) max_position_embeddings: usize,
35 pub(crate) layer_norm_eps: f64,
36 pub(crate) use_cache: bool,
37 #[serde(default)]
38 pub(crate) use_qkv_bias: bool, #[serde(default)]
40 pub(crate) use_flash_attn: bool, }
42
43impl Config {
44 pub fn stablelm_3b_4e1t(use_flash_attn: bool) -> Self {
45 Self {
46 vocab_size: 50304,
47 intermediate_size: 6912,
48 hidden_size: 2560,
49 num_hidden_layers: 32,
50 num_attention_heads: 32,
51 num_key_value_heads: 32,
52 hidden_act: Activation::Silu,
53 partial_rotary_factor: 0.25,
54 rope_theta: 10_000.,
55 max_position_embeddings: 4096,
56 layer_norm_eps: 1e-5,
57 use_qkv_bias: false,
58 use_cache: true,
59 use_flash_attn,
60 }
61 }
62
63 pub fn head_dim(&self) -> usize {
64 self.hidden_size / self.num_attention_heads
65 }
66
67 pub fn rotary_ndims(&self) -> usize {
68 (self.head_dim() as f64 * self.partial_rotary_factor) as usize
69 }
70
71 pub fn num_kv_groups(&self) -> usize {
72 self.num_attention_heads / self.num_key_value_heads
73 }
74
75 pub fn set_use_flash_attn(&mut self, use_flash_attn: bool) {
76 self.use_flash_attn = use_flash_attn
77 }
78}
79
80#[derive(Debug)]
81pub(crate) struct RotaryEmbedding {
82 sin: Tensor,
83 cos: Tensor,
84}
85
86fn rotate_half(xs: &Tensor) -> Result<Tensor> {
87 let xs = xs.chunk(2, D::Minus1)?;
88 Tensor::cat(&[&xs[1].neg()?, &xs[0]], D::Minus1)
89}
90
91impl RotaryEmbedding {
92 pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
93 let dim = cfg.rotary_ndims();
94 let max_seq_len = cfg.max_position_embeddings;
95 let inv_freq: Vec<_> = (0..dim)
96 .step_by(2)
97 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
98 .collect();
99 let inv_freq_len = inv_freq.len();
100 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
101 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
102 .to_dtype(dtype)?
103 .reshape((max_seq_len, 1))?;
104 let freqs = t.matmul(&inv_freq)?;
105 let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
106 Ok(Self {
107 sin: freqs.sin()?,
108 cos: freqs.cos()?,
109 })
110 }
111
112 pub(crate) fn apply_rotary_emb_qkv(
113 &self,
114 q: &Tensor,
115 k: &Tensor,
116 seqlen_offset: usize,
117 ) -> Result<(Tensor, Tensor)> {
118 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
119 let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
120 let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
121 let cos = cos.unsqueeze(0)?.unsqueeze(0)?; let sin = sin.unsqueeze(0)?.unsqueeze(0)?; let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
124 let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
125 Ok((q_embed, k_embed))
126 }
127}
128
129#[derive(Debug)]
130#[allow(clippy::upper_case_acronyms)]
131struct MLP {
132 gate_proj: Linear,
133 up_proj: Linear,
134 down_proj: Linear,
135 act_fn: Activation,
136 span: tracing::Span,
137}
138
139impl MLP {
140 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
141 let hidden_sz = cfg.hidden_size;
142 let intermediate_sz = cfg.intermediate_size;
143 let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
144 let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
145 let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
146 Ok(Self {
147 gate_proj,
148 up_proj,
149 down_proj,
150 act_fn: cfg.hidden_act,
151 span: tracing::span!(tracing::Level::TRACE, "mlp"),
152 })
153 }
154}
155
156impl Module for MLP {
157 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
158 let _enter = self.span.enter();
159 let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
160 let rhs = xs.apply(&self.up_proj)?;
161 (lhs * rhs)?.apply(&self.down_proj)
162 }
163}
164
165#[cfg(feature = "flash-attn")]
166fn flash_attn(
167 q: &Tensor,
168 k: &Tensor,
169 v: &Tensor,
170 softmax_scale: f32,
171 causal: bool,
172) -> Result<Tensor> {
173 candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
174}
175
176#[cfg(not(feature = "flash-attn"))]
177fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
178 unimplemented!("compile with '--features flash-attn'")
179}
180
181#[derive(Debug)]
182struct Attention {
183 q_proj: Linear,
184 k_proj: Linear,
185 v_proj: Linear,
186 o_proj: Linear,
187 num_heads: usize,
188 num_kv_heads: usize,
189 num_kv_groups: usize,
190 head_dim: usize,
191 hidden_size: usize,
192 rotary_emb: Arc<RotaryEmbedding>,
193 kv_cache: Option<(Tensor, Tensor)>,
194 use_cache: bool,
195 rotary_ndims: usize,
196 use_flash_attn: bool,
197 span: tracing::Span,
198}
199
200impl Attention {
201 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
202 let hidden_sz = cfg.hidden_size;
203 let head_dim = cfg.head_dim();
204 let num_heads = cfg.num_attention_heads;
205 let num_kv_heads = cfg.num_key_value_heads;
206 let linear_layer = if cfg.use_qkv_bias {
207 linear
208 } else {
209 linear_no_bias
210 };
211
212 let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
213 let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
214 let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
215 let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
216 Ok(Self {
217 q_proj,
218 k_proj,
219 v_proj,
220 o_proj,
221 num_heads,
222 num_kv_heads,
223 num_kv_groups: cfg.num_kv_groups(),
224 head_dim,
225 hidden_size: hidden_sz,
226 rotary_emb,
227 kv_cache: None,
228 use_cache: cfg.use_cache,
229 rotary_ndims: cfg.rotary_ndims(),
230 use_flash_attn: cfg.use_flash_attn,
231 span: tracing::span!(tracing::Level::TRACE, "attn"),
232 })
233 }
234
235 fn forward(
236 &mut self,
237 xs: &Tensor,
238 attention_mask: Option<&Tensor>,
239 seqlen_offset: usize,
240 ) -> Result<Tensor> {
241 let _enter = self.span.enter();
242 let (b_sz, q_len, _) = xs.dims3()?;
243
244 let query_states = self.q_proj.forward(xs)?;
245 let key_states = self.k_proj.forward(xs)?;
246 let value_states = self.v_proj.forward(xs)?;
247
248 let query_states = query_states
249 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
250 .transpose(1, 2)?;
251 let key_states = key_states
252 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
253 .transpose(1, 2)?;
254 let value_states = value_states
255 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
256 .transpose(1, 2)?;
257
258 let (rot_ndims, pass_ndims) = (self.rotary_ndims, self.head_dim - self.rotary_ndims);
259 let query_rot = query_states.narrow(D::Minus1, 0, rot_ndims)?;
260 let query_pass = query_states.narrow(D::Minus1, rot_ndims, pass_ndims)?;
261 let key_rot = key_states.narrow(D::Minus1, 0, rot_ndims)?;
262 let key_pass = key_states.narrow(D::Minus1, rot_ndims, pass_ndims)?;
263 let (query_rot, key_rot) =
264 self.rotary_emb
265 .apply_rotary_emb_qkv(&query_rot, &key_rot, seqlen_offset)?;
266 let query_states = Tensor::cat(&[query_rot, query_pass], D::Minus1)?.contiguous()?;
267 let key_states = Tensor::cat(&[key_rot, key_pass], D::Minus1)?.contiguous()?;
268
269 let (key_states, value_states) = match &self.kv_cache {
270 None => (key_states, value_states),
271 Some((prev_k, prev_v)) => {
272 let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
273 let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
274 (key_states, value_states)
275 }
276 };
277 if self.use_cache {
278 self.kv_cache = Some((key_states.clone(), value_states.clone()));
279 }
280
281 let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
282 let value_states =
283 crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
284
285 let attn_output = if self.use_flash_attn {
286 let q = query_states.transpose(1, 2)?;
288 let k = key_states.transpose(1, 2)?;
289 let v = value_states.transpose(1, 2)?;
290 let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
291 flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)?
292 } else {
293 let scale = 1f64 / f64::sqrt(self.head_dim as f64);
294 let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
295
296 let attn_weights = match attention_mask {
297 None => attn_weights,
298 Some(mask) => attn_weights.broadcast_add(mask)?,
299 };
300 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
301 attn_weights.matmul(&value_states)?
302 };
303 attn_output
304 .transpose(1, 2)?
305 .reshape((b_sz, q_len, self.hidden_size))?
306 .apply(&self.o_proj)
307 }
308}
309
310#[derive(Debug)]
311struct DecoderLayer {
312 self_attn: Attention,
313 mlp: MLP,
314 input_layernorm: LayerNorm,
315 post_attention_layernorm: LayerNorm,
316 span: tracing::Span,
317}
318
319impl DecoderLayer {
320 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
321 let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
322 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
323 let input_layernorm = candle_nn::layer_norm(
324 cfg.hidden_size,
325 cfg.layer_norm_eps,
326 vb.pp("input_layernorm"),
327 )?;
328 let post_attention_layernorm = candle_nn::layer_norm(
329 cfg.hidden_size,
330 cfg.layer_norm_eps,
331 vb.pp("post_attention_layernorm"),
332 )?;
333 Ok(Self {
334 self_attn,
335 mlp,
336 input_layernorm,
337 post_attention_layernorm,
338 span: tracing::span!(tracing::Level::TRACE, "layer"),
339 })
340 }
341
342 fn forward(
343 &mut self,
344 xs: &Tensor,
345 attention_mask: Option<&Tensor>,
346 seqlen_offset: usize,
347 ) -> Result<Tensor> {
348 let _enter = self.span.enter();
349 let residual = xs;
350 let xs = self.input_layernorm.forward(xs)?;
351 let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
352 let xs = (xs + residual)?;
353 let residual = &xs;
354 let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
355 residual + xs
356 }
357}
358
359#[derive(Debug)]
360pub struct Model {
361 embed_tokens: candle_nn::Embedding,
362 layers: Vec<DecoderLayer>,
363 norm: LayerNorm,
364 lm_head: Linear,
365 device: Device,
366 dtype: DType,
367 span: tracing::Span,
368}
369
370impl Model {
371 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
372 let vb_m = vb.pp("model");
373 let embed_tokens =
374 candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
375 let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
376 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
377 let vb_l = vb_m.pp("layers");
378 for layer_idx in 0..cfg.num_hidden_layers {
379 let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
380 layers.push(layer)
381 }
382 let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_m.pp("norm"))?;
383 let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
384 Ok(Self {
385 embed_tokens,
386 layers,
387 norm,
388 lm_head,
389 device: vb.device().clone(),
390 dtype: vb.dtype(),
391 span: tracing::span!(tracing::Level::TRACE, "model"),
392 })
393 }
394
395 fn prepare_decoder_attention_mask(
396 &self,
397 b_size: usize,
398 tgt_len: usize,
399 seqlen_offset: usize,
400 ) -> Result<Tensor> {
401 let mask: Vec<_> = (0..tgt_len)
403 .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
404 .collect();
405 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
406 let mask = if seqlen_offset > 0 {
407 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
408 Tensor::cat(&[&mask0, &mask], D::Minus1)?
409 } else {
410 mask
411 };
412 mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
413 .to_dtype(self.dtype)
414 }
415
416 pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
417 let _enter = self.span.enter();
418 let (b_size, seq_len) = input_ids.dims2()?;
419 let attention_mask = if seq_len <= 1 {
420 None
421 } else {
422 let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
423 Some(mask)
424 };
425 let mut xs = self.embed_tokens.forward(input_ids)?;
426 for layer in self.layers.iter_mut() {
427 xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
428 }
429 xs.narrow(1, seq_len - 1, 1)?
430 .apply(&self.norm)?
431 .apply(&self.lm_head)
432 }
433}