1use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
2use candle_nn::{
3 embedding, layer_norm_no_bias, linear, linear_no_bias, ops::softmax, Embedding, LayerNorm,
4 Linear, Module, VarBuilder,
5};
6use serde::Deserialize;
7
8use core::f32;
9use std::collections::HashMap;
10use std::sync::Arc;
11
12#[derive(Debug, Clone, PartialEq, Deserialize)]
15pub struct Config {
16 pub vocab_size: usize,
17 pub hidden_size: usize,
18 pub num_hidden_layers: usize,
19 pub num_attention_heads: usize,
20 pub intermediate_size: usize,
21 pub max_position_embeddings: usize,
22 pub layer_norm_eps: f64,
23 pub pad_token_id: u32,
24 pub global_attn_every_n_layers: usize,
25 pub global_rope_theta: f64,
26 pub local_attention: usize,
27 pub local_rope_theta: f64,
28 #[serde(default)]
29 #[serde(flatten)]
30 pub classifier_config: Option<ClassifierConfig>,
31}
32
33#[derive(Debug, Clone, Deserialize, PartialEq, Copy, Default)]
34#[serde(rename_all = "lowercase")]
35pub enum ClassifierPooling {
36 #[default]
37 CLS,
38 MEAN,
39}
40
41#[derive(Debug, Clone, PartialEq, Deserialize)]
42pub struct ClassifierConfig {
43 pub id2label: HashMap<String, String>,
44 pub label2id: HashMap<String, String>,
45 pub classifier_pooling: ClassifierPooling,
46}
47
48#[derive(Debug, Clone)]
49struct RotaryEmbedding {
50 sin: Tensor,
51 cos: Tensor,
52}
53
54impl RotaryEmbedding {
55 fn new(dtype: DType, config: &Config, rope_theta: f64, dev: &Device) -> Result<Self> {
56 let dim = config.hidden_size / config.num_attention_heads;
57 let inv_freq: Vec<_> = (0..dim)
58 .step_by(2)
59 .map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32)
60 .collect();
61 let inv_freq_len = inv_freq.len();
62 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
63 let max_seq_len = config.max_position_embeddings;
64 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
65 .to_dtype(dtype)?
66 .reshape((max_seq_len, 1))?;
67 let freqs = t.matmul(&inv_freq)?;
68 Ok(Self {
69 sin: freqs.sin()?,
70 cos: freqs.cos()?,
71 })
72 }
73
74 fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
75 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &self.cos, &self.sin)?;
76 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &self.cos, &self.sin)?;
77 Ok((q_embed, k_embed))
78 }
79}
80
81#[derive(Clone)]
82struct ModernBertAttention {
83 qkv: Linear,
84 proj: Linear,
85 num_attention_heads: usize,
86 attention_head_size: usize,
87 rotary_emb: Arc<RotaryEmbedding>,
88}
89
90impl ModernBertAttention {
91 fn load(vb: VarBuilder, config: &Config, rotary_emb: Arc<RotaryEmbedding>) -> Result<Self> {
92 let num_attention_heads = config.num_attention_heads;
93 let attention_head_size = config.hidden_size / config.num_attention_heads;
94
95 let qkv = linear_no_bias(config.hidden_size, config.hidden_size * 3, vb.pp("Wqkv"))?;
96 let proj = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("Wo"))?;
97
98 Ok(Self {
99 qkv,
100 proj,
101 num_attention_heads,
102 attention_head_size,
103 rotary_emb,
104 })
105 }
106
107 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
108 let xs = hidden_states.clone();
109 let (b, seq_len, d) = xs.dims3()?;
110 let qkv = xs
111 .apply(&self.qkv)?
112 .reshape((
113 b,
114 seq_len,
115 3,
116 self.num_attention_heads,
117 self.attention_head_size,
118 ))?
119 .permute((2, 0, 3, 1, 4))?;
120
121 let q = qkv.get(0)?;
122 let k = qkv.get(1)?;
123 let v = qkv.get(2)?;
124
125 let (q, k) = self.rotary_emb.apply_rotary_emb_qkv(&q, &k)?;
126
127 let scale = (self.attention_head_size as f64).powf(-0.5);
128 let q = (q * scale)?;
129
130 let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
131
132 let att = att.broadcast_add(attention_mask)?;
133 let att = softmax(&att, D::Minus1)?;
134
135 let xs = att.matmul(&v)?;
136
137 let xs = xs.transpose(1, 2)?.reshape((b, seq_len, d))?;
138 let xs = xs.apply(&self.proj)?;
139 let xs = xs.reshape((b, seq_len, d))?;
140
141 Ok(xs)
142 }
143}
144
145#[derive(Clone)]
146pub struct ModernBertMLP {
147 wi: Linear,
148 wo: Linear,
149}
150
151impl ModernBertMLP {
152 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
153 let wi = linear_no_bias(
154 config.hidden_size,
155 config.intermediate_size * 2,
156 vb.pp("Wi"),
157 )?;
158 let wo = linear_no_bias(config.intermediate_size, config.hidden_size, vb.pp("Wo"))?;
159 Ok(Self { wi, wo })
160 }
161}
162
163impl Module for ModernBertMLP {
164 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
165 let xs = xs.apply(&self.wi)?;
166 let xs = xs.chunk(2, D::Minus1)?;
167 let xs = (&xs[0].gelu_erf()? * &xs[1])?.apply(&self.wo)?; Ok(xs)
169 }
170}
171
172#[derive(Clone)]
173pub struct ModernBertLayer {
174 attn: ModernBertAttention,
175 mlp: ModernBertMLP,
176 attn_norm: Option<LayerNorm>,
177 mlp_norm: LayerNorm,
178 uses_local_attention: bool,
179}
180
181impl ModernBertLayer {
182 fn load(
183 vb: VarBuilder,
184 config: &Config,
185 rotary_emb: Arc<RotaryEmbedding>,
186 uses_local_attention: bool,
187 ) -> Result<Self> {
188 let attn = ModernBertAttention::load(vb.pp("attn"), config, rotary_emb)?;
189 let mlp = ModernBertMLP::load(vb.pp("mlp"), config)?;
190 let attn_norm = layer_norm_no_bias(
191 config.hidden_size,
192 config.layer_norm_eps,
193 vb.pp("attn_norm"),
194 )
195 .ok();
196 let mlp_norm =
197 layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("mlp_norm"))?;
198 Ok(Self {
199 attn,
200 mlp,
201 attn_norm,
202 mlp_norm,
203 uses_local_attention,
204 })
205 }
206
207 fn forward(
208 &self,
209 xs: &Tensor,
210 global_attention_mask: &Tensor,
211 local_attention_mask: &Tensor,
212 ) -> Result<Tensor> {
213 let residual = xs.clone();
214 let mut xs = xs.clone();
215 if let Some(norm) = &self.attn_norm {
216 xs = xs.apply(norm)?;
217 }
218
219 let attention_mask = if self.uses_local_attention {
220 &global_attention_mask.broadcast_add(local_attention_mask)?
221 } else {
222 global_attention_mask
223 };
224 let xs = self.attn.forward(&xs, attention_mask)?;
225 let xs = (xs + residual)?;
226 let mlp_out = xs.apply(&self.mlp_norm)?.apply(&self.mlp)?;
227 let xs = (xs + mlp_out)?;
228 Ok(xs)
229 }
230}
231
232#[derive(Clone)]
233pub struct ModernBertHead {
234 dense: Linear,
235 norm: LayerNorm,
236}
237
238impl ModernBertHead {
239 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
240 let dense = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
241 let norm = layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("norm"))?;
242 Ok(Self { dense, norm })
243 }
244}
245
246impl Module for ModernBertHead {
247 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
248 let xs = xs.apply(&self.dense)?.gelu_erf()?.apply(&self.norm)?;
249 Ok(xs)
250 }
251}
252
253#[derive(Clone)]
254pub struct ModernBertDecoder {
255 decoder: Linear,
256}
257
258impl ModernBertDecoder {
259 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
260 let decoder_weights = vb.get(
262 (config.vocab_size, config.hidden_size),
263 "embeddings.tok_embeddings.weight",
264 )?;
265 let decoder_bias = vb.get(config.vocab_size, "decoder.bias")?;
266 let decoder = Linear::new(decoder_weights, Some(decoder_bias));
267 Ok(Self { decoder })
268 }
269}
270
271impl Module for ModernBertDecoder {
272 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
273 let xs = xs.apply(&self.decoder)?;
274 Ok(xs)
275 }
276}
277
278fn prepare_4d_attention_mask(
280 mask: &Tensor,
281 dtype: DType,
282 tgt_len: Option<usize>,
283) -> Result<Tensor> {
284 let bsz = mask.dim(0)?;
285 let src_len = mask.dim(1)?;
286 let tgt_len = tgt_len.unwrap_or(src_len);
287
288 let expanded_mask = mask
289 .unsqueeze(1)?
290 .unsqueeze(2)?
291 .expand((bsz, 1, tgt_len, src_len))?
292 .to_dtype(dtype)?;
293
294 let inverted_mask = (1.0 - expanded_mask)?;
295
296 (inverted_mask * f32::MIN as f64)?.to_dtype(dtype)
297}
298
299fn get_local_attention_mask(
301 seq_len: usize,
302 max_distance: usize,
303 device: &Device,
304) -> Result<Tensor> {
305 let mask: Vec<_> = (0..seq_len)
306 .flat_map(|i| {
307 (0..seq_len).map(move |j| {
308 if (j as i32 - i as i32).abs() > max_distance as i32 {
309 f32::NEG_INFINITY
310 } else {
311 0.
312 }
313 })
314 })
315 .collect();
316 Tensor::from_slice(&mask, (seq_len, seq_len), device)
317}
318
319#[derive(Clone)]
321pub struct ModernBert {
322 word_embeddings: Embedding,
323 norm: LayerNorm,
324 layers: Vec<ModernBertLayer>,
325 final_norm: LayerNorm,
326 local_attention_size: usize,
327}
328
329impl ModernBert {
330 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
331 let word_embeddings = embedding(
332 config.vocab_size,
333 config.hidden_size,
334 vb.pp("embeddings.tok_embeddings"),
335 )?;
336 let norm = layer_norm_no_bias(
337 config.hidden_size,
338 config.layer_norm_eps,
339 vb.pp("embeddings.norm"),
340 )?;
341 let global_rotary_emb = Arc::new(RotaryEmbedding::new(
342 vb.dtype(),
343 config,
344 config.global_rope_theta,
345 vb.device(),
346 )?);
347 let local_rotary_emb = Arc::new(RotaryEmbedding::new(
348 vb.dtype(),
349 config,
350 config.local_rope_theta,
351 vb.device(),
352 )?);
353
354 let mut layers = Vec::with_capacity(config.num_hidden_layers);
355 for layer_id in 0..config.num_hidden_layers {
356 let layer_uses_local_attention = layer_id % config.global_attn_every_n_layers != 0;
357 layers.push(ModernBertLayer::load(
358 vb.pp(format!("layers.{layer_id}")),
359 config,
360 if layer_uses_local_attention {
361 local_rotary_emb.clone()
362 } else {
363 global_rotary_emb.clone()
364 },
365 layer_uses_local_attention,
366 )?);
367 }
368
369 let final_norm = layer_norm_no_bias(
370 config.hidden_size,
371 config.layer_norm_eps,
372 vb.pp("final_norm"),
373 )?;
374
375 Ok(Self {
376 word_embeddings,
377 norm,
378 layers,
379 final_norm,
380 local_attention_size: config.local_attention,
381 })
382 }
383
384 pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
385 let seq_len = xs.shape().dims()[1];
386 let global_attention_mask =
387 prepare_4d_attention_mask(mask, DType::F32, None)?.to_device(xs.device())?;
388 let local_attention_mask =
389 get_local_attention_mask(seq_len, self.local_attention_size / 2, xs.device())?;
390 let mut xs = xs.apply(&self.word_embeddings)?.apply(&self.norm)?;
391 for layer in self.layers.iter() {
392 xs = layer.forward(&xs, &global_attention_mask, &local_attention_mask)?;
393 }
394 let xs = xs.apply(&self.final_norm)?;
395 Ok(xs)
396 }
397}
398
399#[derive(Clone)]
401pub struct ModernBertForMaskedLM {
402 model: ModernBert,
403 decoder: ModernBertDecoder,
404 head: ModernBertHead,
405}
406
407impl ModernBertForMaskedLM {
408 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
409 let model = ModernBert::load(vb.clone(), config)?;
410 let decoder = ModernBertDecoder::load(vb.clone(), config)?;
411 let head = ModernBertHead::load(vb.pp("head"), config)?;
412 Ok(Self {
413 model,
414 decoder,
415 head,
416 })
417 }
418
419 pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
420 let xs = self
421 .model
422 .forward(xs, mask)?
423 .apply(&self.head)?
424 .apply(&self.decoder)?;
425 Ok(xs)
426 }
427}
428
429#[derive(Clone)]
430pub struct ModernBertClassifier {
431 classifier: Linear,
432}
433
434impl ModernBertClassifier {
435 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
436 let classifier = linear(
438 config.hidden_size,
439 config
440 .classifier_config
441 .as_ref()
442 .map(|cc| cc.id2label.len())
443 .unwrap_or_default(),
444 vb.pp("classifier"),
445 )?;
446 Ok(Self { classifier })
447 }
448}
449
450impl Module for ModernBertClassifier {
451 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
452 let xs = xs.apply(&self.classifier)?;
453 softmax(&xs, D::Minus1)
454 }
455}
456
457#[derive(Clone)]
458pub struct ModernBertForSequenceClassification {
459 model: ModernBert,
460 head: ModernBertHead,
461 classifier: ModernBertClassifier,
462 classifier_pooling: ClassifierPooling,
463}
464
465impl ModernBertForSequenceClassification {
466 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
467 let model = ModernBert::load(vb.clone(), config)?;
468 let classifier = ModernBertClassifier::load(vb.clone(), config)?;
469 let head = ModernBertHead::load(vb.pp("head"), config)?;
470 Ok(Self {
471 model,
472 head,
473 classifier,
474 classifier_pooling: config
475 .classifier_config
476 .as_ref()
477 .map(|cc| cc.classifier_pooling)
478 .unwrap_or_default(),
479 })
480 }
481
482 pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
483 let output = self.model.forward(xs, mask)?;
484 let last_hidden_state = match self.classifier_pooling {
485 ClassifierPooling::CLS => output.i((.., .., 0))?,
486 ClassifierPooling::MEAN => {
487 let unsqueezed_mask = &mask.unsqueeze(D::Minus1)?.to_dtype(DType::F32)?;
488 let sum_output = output.broadcast_mul(unsqueezed_mask)?.sum(1)?;
489 sum_output.broadcast_div(&mask.sum_keepdim(1)?.to_dtype(DType::F32)?)?
490 },
491 };
492 let xs = self
493 .head
494 .forward(&last_hidden_state)?
495 .apply(&self.classifier)?;
496 Ok(xs)
497 }
498}