pylate_rs/
modernbert.rs

1use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
2use candle_nn::{
3    embedding, layer_norm_no_bias, linear, linear_no_bias, ops::softmax, Embedding, LayerNorm,
4    Linear, Module, VarBuilder,
5};
6use serde::Deserialize;
7
8use core::f32;
9use std::collections::HashMap;
10use std::sync::Arc;
11
12// This module has been adapted from the `candle` library in order to properly fit the PyLate format.
13
14#[derive(Debug, Clone, PartialEq, Deserialize)]
15pub struct Config {
16    pub vocab_size: usize,
17    pub hidden_size: usize,
18    pub num_hidden_layers: usize,
19    pub num_attention_heads: usize,
20    pub intermediate_size: usize,
21    pub max_position_embeddings: usize,
22    pub layer_norm_eps: f64,
23    pub pad_token_id: u32,
24    pub global_attn_every_n_layers: usize,
25    pub global_rope_theta: f64,
26    pub local_attention: usize,
27    pub local_rope_theta: f64,
28    #[serde(default)]
29    #[serde(flatten)]
30    pub classifier_config: Option<ClassifierConfig>,
31}
32
33#[derive(Debug, Clone, Deserialize, PartialEq, Copy, Default)]
34#[serde(rename_all = "lowercase")]
35pub enum ClassifierPooling {
36    #[default]
37    CLS,
38    MEAN,
39}
40
41#[derive(Debug, Clone, PartialEq, Deserialize)]
42pub struct ClassifierConfig {
43    pub id2label: HashMap<String, String>,
44    pub label2id: HashMap<String, String>,
45    pub classifier_pooling: ClassifierPooling,
46}
47
48#[derive(Debug, Clone)]
49struct RotaryEmbedding {
50    sin: Tensor,
51    cos: Tensor,
52}
53
54impl RotaryEmbedding {
55    fn new(dtype: DType, config: &Config, rope_theta: f64, dev: &Device) -> Result<Self> {
56        let dim = config.hidden_size / config.num_attention_heads;
57        let inv_freq: Vec<_> = (0..dim)
58            .step_by(2)
59            .map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32)
60            .collect();
61        let inv_freq_len = inv_freq.len();
62        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
63        let max_seq_len = config.max_position_embeddings;
64        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
65            .to_dtype(dtype)?
66            .reshape((max_seq_len, 1))?;
67        let freqs = t.matmul(&inv_freq)?;
68        Ok(Self {
69            sin: freqs.sin()?,
70            cos: freqs.cos()?,
71        })
72    }
73
74    fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
75        let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &self.cos, &self.sin)?;
76        let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &self.cos, &self.sin)?;
77        Ok((q_embed, k_embed))
78    }
79}
80
81#[derive(Clone)]
82struct ModernBertAttention {
83    qkv: Linear,
84    proj: Linear,
85    num_attention_heads: usize,
86    attention_head_size: usize,
87    rotary_emb: Arc<RotaryEmbedding>,
88}
89
90impl ModernBertAttention {
91    fn load(vb: VarBuilder, config: &Config, rotary_emb: Arc<RotaryEmbedding>) -> Result<Self> {
92        let num_attention_heads = config.num_attention_heads;
93        let attention_head_size = config.hidden_size / config.num_attention_heads;
94
95        let qkv = linear_no_bias(config.hidden_size, config.hidden_size * 3, vb.pp("Wqkv"))?;
96        let proj = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("Wo"))?;
97
98        Ok(Self {
99            qkv,
100            proj,
101            num_attention_heads,
102            attention_head_size,
103            rotary_emb,
104        })
105    }
106
107    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
108        let xs = hidden_states.clone();
109        let (b, seq_len, d) = xs.dims3()?;
110        let qkv = xs
111            .apply(&self.qkv)?
112            .reshape((
113                b,
114                seq_len,
115                3,
116                self.num_attention_heads,
117                self.attention_head_size,
118            ))?
119            .permute((2, 0, 3, 1, 4))?;
120
121        let q = qkv.get(0)?;
122        let k = qkv.get(1)?;
123        let v = qkv.get(2)?;
124
125        let (q, k) = self.rotary_emb.apply_rotary_emb_qkv(&q, &k)?;
126
127        let scale = (self.attention_head_size as f64).powf(-0.5);
128        let q = (q * scale)?;
129
130        let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
131
132        let att = att.broadcast_add(attention_mask)?;
133        let att = softmax(&att, D::Minus1)?;
134
135        let xs = att.matmul(&v)?;
136
137        let xs = xs.transpose(1, 2)?.reshape((b, seq_len, d))?;
138        let xs = xs.apply(&self.proj)?;
139        let xs = xs.reshape((b, seq_len, d))?;
140
141        Ok(xs)
142    }
143}
144
145#[derive(Clone)]
146pub struct ModernBertMLP {
147    wi: Linear,
148    wo: Linear,
149}
150
151impl ModernBertMLP {
152    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
153        let wi = linear_no_bias(
154            config.hidden_size,
155            config.intermediate_size * 2,
156            vb.pp("Wi"),
157        )?;
158        let wo = linear_no_bias(config.intermediate_size, config.hidden_size, vb.pp("Wo"))?;
159        Ok(Self { wi, wo })
160    }
161}
162
163impl Module for ModernBertMLP {
164    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
165        let xs = xs.apply(&self.wi)?;
166        let xs = xs.chunk(2, D::Minus1)?;
167        let xs = (&xs[0].gelu_erf()? * &xs[1])?.apply(&self.wo)?; // GeGLU
168        Ok(xs)
169    }
170}
171
172#[derive(Clone)]
173pub struct ModernBertLayer {
174    attn: ModernBertAttention,
175    mlp: ModernBertMLP,
176    attn_norm: Option<LayerNorm>,
177    mlp_norm: LayerNorm,
178    uses_local_attention: bool,
179}
180
181impl ModernBertLayer {
182    fn load(
183        vb: VarBuilder,
184        config: &Config,
185        rotary_emb: Arc<RotaryEmbedding>,
186        uses_local_attention: bool,
187    ) -> Result<Self> {
188        let attn = ModernBertAttention::load(vb.pp("attn"), config, rotary_emb)?;
189        let mlp = ModernBertMLP::load(vb.pp("mlp"), config)?;
190        let attn_norm = layer_norm_no_bias(
191            config.hidden_size,
192            config.layer_norm_eps,
193            vb.pp("attn_norm"),
194        )
195        .ok();
196        let mlp_norm =
197            layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("mlp_norm"))?;
198        Ok(Self {
199            attn,
200            mlp,
201            attn_norm,
202            mlp_norm,
203            uses_local_attention,
204        })
205    }
206
207    fn forward(
208        &self,
209        xs: &Tensor,
210        global_attention_mask: &Tensor,
211        local_attention_mask: &Tensor,
212    ) -> Result<Tensor> {
213        let residual = xs.clone();
214        let mut xs = xs.clone();
215        if let Some(norm) = &self.attn_norm {
216            xs = xs.apply(norm)?;
217        }
218
219        let attention_mask = if self.uses_local_attention {
220            &global_attention_mask.broadcast_add(local_attention_mask)?
221        } else {
222            global_attention_mask
223        };
224        let xs = self.attn.forward(&xs, attention_mask)?;
225        let xs = (xs + residual)?;
226        let mlp_out = xs.apply(&self.mlp_norm)?.apply(&self.mlp)?;
227        let xs = (xs + mlp_out)?;
228        Ok(xs)
229    }
230}
231
232#[derive(Clone)]
233pub struct ModernBertHead {
234    dense: Linear,
235    norm: LayerNorm,
236}
237
238impl ModernBertHead {
239    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
240        let dense = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
241        let norm = layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("norm"))?;
242        Ok(Self { dense, norm })
243    }
244}
245
246impl Module for ModernBertHead {
247    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
248        let xs = xs.apply(&self.dense)?.gelu_erf()?.apply(&self.norm)?;
249        Ok(xs)
250    }
251}
252
253#[derive(Clone)]
254pub struct ModernBertDecoder {
255    decoder: Linear,
256}
257
258impl ModernBertDecoder {
259    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
260        // The decoder weights are tied with the embeddings layer weights
261        let decoder_weights = vb.get(
262            (config.vocab_size, config.hidden_size),
263            "embeddings.tok_embeddings.weight",
264        )?;
265        let decoder_bias = vb.get(config.vocab_size, "decoder.bias")?;
266        let decoder = Linear::new(decoder_weights, Some(decoder_bias));
267        Ok(Self { decoder })
268    }
269}
270
271impl Module for ModernBertDecoder {
272    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
273        let xs = xs.apply(&self.decoder)?;
274        Ok(xs)
275    }
276}
277
278// Global attention mask calculated from padded token inputs
279fn prepare_4d_attention_mask(
280    mask: &Tensor,
281    dtype: DType,
282    tgt_len: Option<usize>,
283) -> Result<Tensor> {
284    let bsz = mask.dim(0)?;
285    let src_len = mask.dim(1)?;
286    let tgt_len = tgt_len.unwrap_or(src_len);
287
288    let expanded_mask = mask
289        .unsqueeze(1)?
290        .unsqueeze(2)?
291        .expand((bsz, 1, tgt_len, src_len))?
292        .to_dtype(dtype)?;
293
294    let inverted_mask = (1.0 - expanded_mask)?;
295
296    (inverted_mask * f32::MIN as f64)?.to_dtype(dtype)
297}
298
299// Attention mask caused by the sliding window
300fn get_local_attention_mask(
301    seq_len: usize,
302    max_distance: usize,
303    device: &Device,
304) -> Result<Tensor> {
305    let mask: Vec<_> = (0..seq_len)
306        .flat_map(|i| {
307            (0..seq_len).map(move |j| {
308                if (j as i32 - i as i32).abs() > max_distance as i32 {
309                    f32::NEG_INFINITY
310                } else {
311                    0.
312                }
313            })
314        })
315        .collect();
316    Tensor::from_slice(&mask, (seq_len, seq_len), device)
317}
318
319// ModernBERT backbone
320#[derive(Clone)]
321pub struct ModernBert {
322    word_embeddings: Embedding,
323    norm: LayerNorm,
324    layers: Vec<ModernBertLayer>,
325    final_norm: LayerNorm,
326    local_attention_size: usize,
327}
328
329impl ModernBert {
330    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
331        let word_embeddings = embedding(
332            config.vocab_size,
333            config.hidden_size,
334            vb.pp("embeddings.tok_embeddings"),
335        )?;
336        let norm = layer_norm_no_bias(
337            config.hidden_size,
338            config.layer_norm_eps,
339            vb.pp("embeddings.norm"),
340        )?;
341        let global_rotary_emb = Arc::new(RotaryEmbedding::new(
342            vb.dtype(),
343            config,
344            config.global_rope_theta,
345            vb.device(),
346        )?);
347        let local_rotary_emb = Arc::new(RotaryEmbedding::new(
348            vb.dtype(),
349            config,
350            config.local_rope_theta,
351            vb.device(),
352        )?);
353
354        let mut layers = Vec::with_capacity(config.num_hidden_layers);
355        for layer_id in 0..config.num_hidden_layers {
356            let layer_uses_local_attention = layer_id % config.global_attn_every_n_layers != 0;
357            layers.push(ModernBertLayer::load(
358                vb.pp(format!("layers.{layer_id}")),
359                config,
360                if layer_uses_local_attention {
361                    local_rotary_emb.clone()
362                } else {
363                    global_rotary_emb.clone()
364                },
365                layer_uses_local_attention,
366            )?);
367        }
368
369        let final_norm = layer_norm_no_bias(
370            config.hidden_size,
371            config.layer_norm_eps,
372            vb.pp("final_norm"),
373        )?;
374
375        Ok(Self {
376            word_embeddings,
377            norm,
378            layers,
379            final_norm,
380            local_attention_size: config.local_attention,
381        })
382    }
383
384    pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
385        let seq_len = xs.shape().dims()[1];
386        let global_attention_mask =
387            prepare_4d_attention_mask(mask, DType::F32, None)?.to_device(xs.device())?;
388        let local_attention_mask =
389            get_local_attention_mask(seq_len, self.local_attention_size / 2, xs.device())?;
390        let mut xs = xs.apply(&self.word_embeddings)?.apply(&self.norm)?;
391        for layer in self.layers.iter() {
392            xs = layer.forward(&xs, &global_attention_mask, &local_attention_mask)?;
393        }
394        let xs = xs.apply(&self.final_norm)?;
395        Ok(xs)
396    }
397}
398
399// ModernBERT for the fill-mask task
400#[derive(Clone)]
401pub struct ModernBertForMaskedLM {
402    model: ModernBert,
403    decoder: ModernBertDecoder,
404    head: ModernBertHead,
405}
406
407impl ModernBertForMaskedLM {
408    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
409        let model = ModernBert::load(vb.clone(), config)?;
410        let decoder = ModernBertDecoder::load(vb.clone(), config)?;
411        let head = ModernBertHead::load(vb.pp("head"), config)?;
412        Ok(Self {
413            model,
414            decoder,
415            head,
416        })
417    }
418
419    pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
420        let xs = self
421            .model
422            .forward(xs, mask)?
423            .apply(&self.head)?
424            .apply(&self.decoder)?;
425        Ok(xs)
426    }
427}
428
429#[derive(Clone)]
430pub struct ModernBertClassifier {
431    classifier: Linear,
432}
433
434impl ModernBertClassifier {
435    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
436        // The decoder weights are tied with the embeddings layer weights
437        let classifier = linear(
438            config.hidden_size,
439            config
440                .classifier_config
441                .as_ref()
442                .map(|cc| cc.id2label.len())
443                .unwrap_or_default(),
444            vb.pp("classifier"),
445        )?;
446        Ok(Self { classifier })
447    }
448}
449
450impl Module for ModernBertClassifier {
451    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
452        let xs = xs.apply(&self.classifier)?;
453        softmax(&xs, D::Minus1)
454    }
455}
456
457#[derive(Clone)]
458pub struct ModernBertForSequenceClassification {
459    model: ModernBert,
460    head: ModernBertHead,
461    classifier: ModernBertClassifier,
462    classifier_pooling: ClassifierPooling,
463}
464
465impl ModernBertForSequenceClassification {
466    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
467        let model = ModernBert::load(vb.clone(), config)?;
468        let classifier = ModernBertClassifier::load(vb.clone(), config)?;
469        let head = ModernBertHead::load(vb.pp("head"), config)?;
470        Ok(Self {
471            model,
472            head,
473            classifier,
474            classifier_pooling: config
475                .classifier_config
476                .as_ref()
477                .map(|cc| cc.classifier_pooling)
478                .unwrap_or_default(),
479        })
480    }
481
482    pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
483        let output = self.model.forward(xs, mask)?;
484        let last_hidden_state = match self.classifier_pooling {
485            ClassifierPooling::CLS => output.i((.., .., 0))?,
486            ClassifierPooling::MEAN => {
487                let unsqueezed_mask = &mask.unsqueeze(D::Minus1)?.to_dtype(DType::F32)?;
488                let sum_output = output.broadcast_mul(unsqueezed_mask)?.sum(1)?;
489                sum_output.broadcast_div(&mask.sum_keepdim(1)?.to_dtype(DType::F32)?)?
490            },
491        };
492        let xs = self
493            .head
494            .forward(&last_hidden_state)?
495            .apply(&self.classifier)?;
496        Ok(xs)
497    }
498}