candle-transformers 0.10.2

Minimalist ML framework.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
//! Implementation of DistilBert, a distilled version of BERT.
//!
//! See:
//! - ["DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter"](https://arxiv.org/abs/1910.01108)
//!
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;

pub const DTYPE: DType = DType::F32;

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
    let shape = mask.shape();
    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
    let m = mask.where_cond(&on_true, on_false)?;
    Ok(m)
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum HiddenAct {
    Gelu,
    Relu,
}

struct HiddenActLayer {
    act: HiddenAct,
    span: tracing::Span,
}

impl HiddenActLayer {
    fn new(act: HiddenAct) -> Self {
        let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
        Self { act, span }
    }
}

impl Module for HiddenActLayer {
    fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
        let _enter = self.span.enter();
        match self.act {
            // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
            HiddenAct::Gelu => xs.gelu(),
            HiddenAct::Relu => xs.relu(),
        }
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PositionEmbeddingType {
    #[default]
    Absolute,
}

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
    pub vocab_size: usize,
    pub dim: usize,
    n_layers: usize,
    n_heads: usize,
    hidden_dim: usize,
    activation: HiddenAct,
    max_position_embeddings: usize,
    initializer_range: f64,
    pub pad_token_id: usize,
    #[serde(default)]
    position_embedding_type: PositionEmbeddingType,
    #[serde(default)]
    use_cache: bool,
    model_type: Option<String>,
}

impl Default for Config {
    fn default() -> Self {
        Self {
            vocab_size: 30522,
            dim: 768,
            n_layers: 12,
            n_heads: 12,
            hidden_dim: 3072,
            activation: HiddenAct::Gelu,
            max_position_embeddings: 512,
            initializer_range: 0.02,
            pad_token_id: 0,
            position_embedding_type: PositionEmbeddingType::Absolute,
            use_cache: true,
            model_type: Some("distilbert".to_string()),
        }
    }
}

struct Embeddings {
    word_embeddings: Embedding,
    position_embeddings: Embedding,
    layer_norm: LayerNorm,
    span: tracing::Span,
}

impl Embeddings {
    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
        let word_embeddings =
            candle_nn::embedding(config.vocab_size, config.dim, vb.pp("word_embeddings"))?;
        let position_embeddings = candle_nn::embedding(
            config.max_position_embeddings,
            config.dim,
            vb.pp("position_embeddings"),
        )?;
        let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("LayerNorm"))?;
        Ok(Self {
            word_embeddings,
            position_embeddings,
            layer_norm,
            span: tracing::span!(tracing::Level::TRACE, "embeddings"),
        })
    }

    fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        let (_bsize, seq_len) = input_ids.dims2()?;
        let input_embeddings = self.word_embeddings.forward(input_ids)?;
        let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
        let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
        let embeddings =
            input_embeddings.broadcast_add(&self.position_embeddings.forward(&position_ids)?)?;

        let embeddings = self.layer_norm.forward(&embeddings)?;
        Ok(embeddings)
    }
}

struct MultiHeadSelfAttention {
    q_lin: Linear,
    k_lin: Linear,
    v_lin: Linear,
    out_lin: Linear,
    n_heads: usize,
    attention_head_size: usize,
    span: tracing::Span,
}

impl MultiHeadSelfAttention {
    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
        let attention_head_size = config.dim / config.n_heads;
        let all_head_size = config.n_heads * attention_head_size;
        let dim = config.dim;
        let q_lin = linear(dim, all_head_size, vb.pp("q_lin"))?;
        let v_lin = linear(dim, all_head_size, vb.pp("v_lin"))?;
        let k_lin = linear(dim, all_head_size, vb.pp("k_lin"))?;
        let out_lin = linear(all_head_size, dim, vb.pp("out_lin"))?;
        Ok(Self {
            q_lin,
            k_lin,
            v_lin,
            out_lin,
            n_heads: config.n_heads,
            attention_head_size,
            span: tracing::span!(tracing::Level::TRACE, "attention"),
        })
    }
}

impl MultiHeadSelfAttention {
    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        let (bs, q_length, _dim) = hidden_states.dims3()?;

        let dim_per_head = self.attention_head_size;
        let q = self.q_lin.forward(hidden_states)?;
        let k = self.k_lin.forward(hidden_states)?;
        let v = self.v_lin.forward(hidden_states)?;

        let q = q
            .reshape((bs, q_length, self.n_heads, dim_per_head))?
            .transpose(1, 2)?;
        let k = k
            .reshape((bs, q_length, self.n_heads, dim_per_head))?
            .transpose(1, 2)?;
        let v = v
            .reshape((bs, q_length, self.n_heads, dim_per_head))?
            .transpose(1, 2)?;

        let q: Tensor = (q / (dim_per_head as f64).sqrt())?;
        let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
        let mask = attention_mask.broadcast_as(scores.shape())?;

        let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?;
        let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?;

        let context = weights.matmul(&v.contiguous()?)?;
        let context = context
            .transpose(1, 2)?
            .reshape((bs, q_length, self.n_heads * dim_per_head))?
            .contiguous()?;
        let context = self.out_lin.forward(&context)?;

        Ok(context)
    }
}

#[allow(clippy::upper_case_acronyms)]
struct FFN {
    lin1: Linear,
    lin2: Linear,
    activation: HiddenActLayer,
    span: tracing::Span,
}

impl FFN {
    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
        let lin1 = linear(config.dim, config.hidden_dim, vb.pp("lin1"))?;
        let lin2 = linear(config.hidden_dim, config.dim, vb.pp("lin2"))?;
        Ok(Self {
            lin1,
            lin2,
            activation: HiddenActLayer::new(config.activation),
            span: tracing::span!(tracing::Level::TRACE, "ffn"),
        })
    }
}

impl Module for FFN {
    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        hidden_states
            .apply(&self.lin1)?
            .apply(&self.activation)?
            .apply(&self.lin2)
    }
}

struct TransformerBlock {
    attention: MultiHeadSelfAttention,
    sa_layer_norm: LayerNorm,
    ffn: FFN,
    output_layer_norm: LayerNorm,
    span: tracing::Span,
}

impl TransformerBlock {
    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
        let attention = MultiHeadSelfAttention::load(vb.pp("attention"), config)?;
        let sa_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("sa_layer_norm"))?;
        let ffn = FFN::load(vb.pp("ffn"), config)?;
        let output_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("output_layer_norm"))?;
        Ok(Self {
            attention,
            sa_layer_norm,
            ffn,
            output_layer_norm,
            span: tracing::span!(tracing::Level::TRACE, "layer"),
        })
    }
}

impl TransformerBlock {
    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        let sa_output = self.attention.forward(hidden_states, attention_mask)?;
        // TODO: Support cross-attention?
        // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
        // TODO: Support something similar to `apply_chunking_to_forward`?
        let sa_output = sa_output.broadcast_add(hidden_states)?;
        let sa_output = self.sa_layer_norm.forward(&sa_output)?;

        let ffn_output = self.ffn.forward(&sa_output)?;
        let ffn_output = (&ffn_output + sa_output)?;
        let output = self.output_layer_norm.forward(&ffn_output)?;
        Ok(output)
    }
}

// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
struct Transformer {
    layers: Vec<TransformerBlock>,
    span: tracing::Span,
}

impl Transformer {
    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
        let layers = (0..config.n_layers)
            .map(|index| TransformerBlock::load(vb.pp(format!("layer.{index}")), config))
            .collect::<Result<Vec<_>>>()?;
        let span = tracing::span!(tracing::Level::TRACE, "encoder");
        Ok(Transformer { layers, span })
    }
}

impl Transformer {
    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        let mut hidden_states = hidden_states.clone();
        // Use a loop rather than a fold as it's easier to modify when adding debug/...
        for layer in self.layers.iter() {
            hidden_states = layer.forward(&hidden_states, attention_mask)?;
        }
        Ok(hidden_states)
    }
}

pub struct DistilBertModel {
    embeddings: Embeddings,
    transformer: Transformer,
    pub device: Device,
    span: tracing::Span,
}

impl DistilBertModel {
    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
        let (embeddings, transformer) = match (
            Embeddings::load(vb.pp("embeddings"), config),
            Transformer::load(vb.pp("transformer"), config),
        ) {
            (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
            (Err(err), _) | (_, Err(err)) => {
                if let Some(model_type) = &config.model_type {
                    if let (Ok(embeddings), Ok(encoder)) = (
                        Embeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
                        Transformer::load(vb.pp(format!("{model_type}.transformer")), config),
                    ) {
                        (embeddings, encoder)
                    } else {
                        return Err(err);
                    }
                } else {
                    return Err(err);
                }
            }
        };
        Ok(Self {
            embeddings,
            transformer,
            device: vb.device().clone(),
            span: tracing::span!(tracing::Level::TRACE, "model"),
        })
    }

    pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        let embedding_output = self.embeddings.forward(input_ids)?;
        let sequence_output = self
            .transformer
            .forward(&embedding_output, attention_mask)?;
        Ok(sequence_output)
    }
}

struct DistilBertPredictionHeadTransform {
    dense: Linear,
    activation: HiddenActLayer,
    layer_norm: LayerNorm,
}

impl DistilBertPredictionHeadTransform {
    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
        let dense = linear(config.dim, config.dim, vb.pp("vocab_transform"))?;
        let activation = HiddenActLayer::new(config.activation);
        let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("vocab_layer_norm"))?;
        Ok(Self {
            dense,
            activation,
            layer_norm,
        })
    }
}

impl Module for DistilBertPredictionHeadTransform {
    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
        let hidden_states = self
            .activation
            .forward(&self.dense.forward(hidden_states)?)?;
        self.layer_norm.forward(&hidden_states)
    }
}

// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
pub struct DistilBertLMPredictionHead {
    transform: DistilBertPredictionHeadTransform,
    decoder: Linear,
}

impl DistilBertLMPredictionHead {
    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
        let transform = DistilBertPredictionHeadTransform::load(vb.clone(), config)?;

        // distil_bert_uncased uses the word embeddings for the vocab projector weight, but has a separate vocab_projector bias
        let vocab_projector_weight_vb = vb.pp("distilbert.embeddings.word_embeddings");
        let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
        let ws = vocab_projector_weight_vb.get_with_hints(
            (config.vocab_size, config.dim),
            "weight",
            init_ws,
        )?;
        let bound = 1. / (config.dim as f64).sqrt();
        let init_bs = candle_nn::Init::Uniform {
            lo: -bound,
            up: bound,
        };

        let vocab_projector_bias_vb = vb.pp("vocab_projector");
        let bs = vocab_projector_bias_vb.get_with_hints(config.vocab_size, "bias", init_bs)?;

        let decoder = Linear::from_weights(ws, Some(bs));

        Ok(Self { transform, decoder })
    }
}

impl Module for DistilBertLMPredictionHead {
    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
        self.decoder
            .forward(&self.transform.forward(hidden_states)?)
    }
}

// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
pub struct DistilBertOnlyMLMHead {
    predictions: DistilBertLMPredictionHead,
}

impl DistilBertOnlyMLMHead {
    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
        let predictions = DistilBertLMPredictionHead::load(vb.clone(), config)?;
        Ok(Self { predictions })
    }
}

impl Module for DistilBertOnlyMLMHead {
    fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
        self.predictions.forward(sequence_output)
    }
}

pub struct DistilBertForMaskedLM {
    pub bert: DistilBertModel,
    cls: DistilBertOnlyMLMHead,
}

impl DistilBertForMaskedLM {
    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
        let bert = DistilBertModel::load(vb.pp("distilbert"), config)?;
        let cls = DistilBertOnlyMLMHead::load(vb.clone(), config)?;
        Ok(Self { bert, cls })
    }

    pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
        let sequence_output = self.bert.forward(input_ids, attention_mask)?;
        self.cls.forward(&sequence_output)
    }
}