Skip to main content

axonml_llm/
bert.rs

1//! BERT Model Implementation
2//!
3//! Bidirectional Encoder Representations from Transformers.
4
5use axonml_autograd::Variable;
6use axonml_nn::{Module, Linear, Dropout, Parameter};
7use axonml_tensor::Tensor;
8
9use crate::config::BertConfig;
10use crate::embedding::BertEmbedding;
11use crate::error::{LLMError, LLMResult};
12use crate::transformer::{TransformerEncoder, LayerNorm};
13
14/// BERT model (encoder-only transformer).
15#[derive(Debug)]
16pub struct Bert {
17    /// Configuration
18    pub config: BertConfig,
19    /// Embeddings
20    pub embeddings: BertEmbedding,
21    /// Transformer encoder
22    pub encoder: TransformerEncoder,
23    /// Pooler (optional CLS token transformation)
24    pub pooler: Option<BertPooler>,
25}
26
27/// BERT pooler for sequence classification.
28#[derive(Debug)]
29pub struct BertPooler {
30    /// Dense layer
31    pub dense: Linear,
32}
33
34impl BertPooler {
35    /// Creates a new BERT pooler.
36    pub fn new(hidden_size: usize) -> Self {
37        Self {
38            dense: Linear::new(hidden_size, hidden_size),
39        }
40    }
41}
42
43impl Module for BertPooler {
44    fn forward(&self, input: &Variable) -> Variable {
45        // Take the first token ([CLS]) representation
46        let input_data = input.data();
47        let shape = input_data.shape();
48        let batch_size = shape[0];
49        let hidden_size = shape[2];
50
51        // Extract [CLS] token: input[:, 0, :]
52        let cls_output = input.slice(&[0..batch_size, 0..1, 0..hidden_size]);
53        let cls_output = cls_output.reshape(&[batch_size, hidden_size]);
54
55        // Apply dense + tanh
56        let pooled = self.dense.forward(&cls_output);
57        pooled.tanh()
58    }
59
60    fn parameters(&self) -> Vec<Parameter> {
61        self.dense.parameters()
62    }
63}
64
65impl Bert {
66    /// Creates a new BERT model.
67    pub fn new(config: &BertConfig) -> Self {
68        Self::with_pooler(config, true)
69    }
70
71    /// Creates a new BERT model with optional pooler.
72    pub fn with_pooler(config: &BertConfig, add_pooler: bool) -> Self {
73        let embeddings = BertEmbedding::new(
74            config.vocab_size,
75            config.max_position_embeddings,
76            config.type_vocab_size,
77            config.hidden_size,
78            config.layer_norm_eps,
79            config.hidden_dropout_prob,
80        );
81
82        let encoder = TransformerEncoder::new(
83            config.num_hidden_layers,
84            config.hidden_size,
85            config.num_attention_heads,
86            config.intermediate_size,
87            config.hidden_dropout_prob,
88            config.layer_norm_eps,
89            &config.hidden_act,
90            false, // post-norm for BERT
91        );
92
93        let pooler = if add_pooler {
94            Some(BertPooler::new(config.hidden_size))
95        } else {
96            None
97        };
98
99        Self {
100            config: config.clone(),
101            embeddings,
102            encoder,
103            pooler,
104        }
105    }
106
107    /// Forward pass returning both sequence output and pooled output.
108    pub fn forward_with_pooling(
109        &self,
110        input_ids: &Tensor<u32>,
111        token_type_ids: Option<&Tensor<u32>>,
112        attention_mask: Option<&Tensor<f32>>,
113    ) -> (Variable, Option<Variable>) {
114        // Get embeddings
115        let hidden_states = self.embeddings.forward_with_ids(input_ids, token_type_ids, None);
116
117        // Encode
118        let sequence_output = self.encoder.forward_with_mask(&hidden_states, attention_mask);
119
120        // Pool if pooler exists
121        let pooled_output = self.pooler.as_ref().map(|p| p.forward(&sequence_output));
122
123        (sequence_output, pooled_output)
124    }
125
126    /// Forward pass returning sequence output.
127    pub fn forward_ids(&self, input_ids: &Tensor<u32>) -> Variable {
128        let (sequence_output, _) = self.forward_with_pooling(input_ids, None, None);
129        sequence_output
130    }
131}
132
133impl Module for Bert {
134    fn forward(&self, input: &Variable) -> Variable {
135        // Assume input is embeddings if using Module trait
136        self.encoder.forward(input)
137    }
138
139    fn parameters(&self) -> Vec<Parameter> {
140        let mut params = Vec::new();
141        params.extend(self.embeddings.parameters());
142        params.extend(self.encoder.parameters());
143        if let Some(ref pooler) = self.pooler {
144            params.extend(pooler.parameters());
145        }
146        params
147    }
148
149    fn train(&mut self) {
150        self.embeddings.train();
151        self.encoder.train();
152    }
153
154    fn eval(&mut self) {
155        self.embeddings.eval();
156        self.encoder.eval();
157    }
158}
159
160/// BERT for sequence classification.
161#[derive(Debug)]
162pub struct BertForSequenceClassification {
163    /// Base BERT model
164    pub bert: Bert,
165    /// Dropout
166    pub dropout: Dropout,
167    /// Classification head
168    pub classifier: Linear,
169    /// Number of labels
170    pub num_labels: usize,
171}
172
173impl BertForSequenceClassification {
174    /// Creates a new BERT for sequence classification.
175    pub fn new(config: &BertConfig, num_labels: usize) -> Self {
176        Self {
177            bert: Bert::new(config),
178            dropout: Dropout::new(config.hidden_dropout_prob),
179            classifier: Linear::new(config.hidden_size, num_labels),
180            num_labels,
181        }
182    }
183
184    /// Forward pass for classification.
185    ///
186    /// # Errors
187    /// Returns an error if the BERT model does not have a pooler configured.
188    pub fn forward_classification(&self, input_ids: &Tensor<u32>) -> LLMResult<Variable> {
189        let (_, pooled_output) = self.bert.forward_with_pooling(input_ids, None, None);
190
191        if let Some(pooled) = pooled_output {
192            let pooled = self.dropout.forward(&pooled);
193            Ok(self.classifier.forward(&pooled))
194        } else {
195            Err(LLMError::InvalidConfig(
196                "BERT model must have pooler for sequence classification".to_string()
197            ))
198        }
199    }
200}
201
202impl Module for BertForSequenceClassification {
203    fn forward(&self, input: &Variable) -> Variable {
204        let sequence_output = self.bert.forward(input);
205
206        // Get [CLS] token
207        let seq_data = sequence_output.data();
208        let shape = seq_data.shape();
209        let batch_size = shape[0];
210        let hidden_size = shape[2];
211
212        let cls_output = sequence_output.slice(&[0..batch_size, 0..1, 0..hidden_size]);
213        let cls_output = cls_output.reshape(&[batch_size, hidden_size]);
214
215        let cls_output = self.dropout.forward(&cls_output);
216        self.classifier.forward(&cls_output)
217    }
218
219    fn parameters(&self) -> Vec<Parameter> {
220        let mut params = self.bert.parameters();
221        params.extend(self.classifier.parameters());
222        params
223    }
224
225    fn train(&mut self) {
226        self.bert.train();
227        self.dropout.train();
228    }
229
230    fn eval(&mut self) {
231        self.bert.eval();
232        self.dropout.eval();
233    }
234}
235
236/// BERT for masked language modeling.
237#[derive(Debug)]
238pub struct BertForMaskedLM {
239    /// Base BERT model
240    pub bert: Bert,
241    /// MLM head
242    pub cls: BertLMPredictionHead,
243}
244
245/// BERT LM prediction head.
246#[derive(Debug)]
247pub struct BertLMPredictionHead {
248    /// Transform layer
249    pub transform: BertPredictionHeadTransform,
250    /// Output projection (tied to embeddings in full implementation)
251    pub decoder: Linear,
252}
253
254/// Transform layer for BERT prediction head.
255#[derive(Debug)]
256pub struct BertPredictionHeadTransform {
257    /// Dense layer
258    pub dense: Linear,
259    /// Layer norm
260    pub layer_norm: LayerNorm,
261    /// Activation
262    pub activation: String,
263}
264
265impl BertPredictionHeadTransform {
266    /// Creates a new prediction head transform.
267    pub fn new(hidden_size: usize, layer_norm_eps: f32, activation: &str) -> Self {
268        Self {
269            dense: Linear::new(hidden_size, hidden_size),
270            layer_norm: LayerNorm::new(hidden_size, layer_norm_eps),
271            activation: activation.to_string(),
272        }
273    }
274}
275
276impl Module for BertPredictionHeadTransform {
277    fn forward(&self, input: &Variable) -> Variable {
278        let x = self.dense.forward(input);
279        let x = match self.activation.as_str() {
280            "gelu" => x.gelu(),
281            "relu" => x.relu(),
282            _ => x.gelu(),
283        };
284        self.layer_norm.forward(&x)
285    }
286
287    fn parameters(&self) -> Vec<Parameter> {
288        let mut params = self.dense.parameters();
289        params.extend(self.layer_norm.parameters());
290        params
291    }
292}
293
294impl BertLMPredictionHead {
295    /// Creates a new LM prediction head.
296    pub fn new(hidden_size: usize, vocab_size: usize, layer_norm_eps: f32, activation: &str) -> Self {
297        Self {
298            transform: BertPredictionHeadTransform::new(hidden_size, layer_norm_eps, activation),
299            decoder: Linear::new(hidden_size, vocab_size),
300        }
301    }
302}
303
304impl Module for BertLMPredictionHead {
305    fn forward(&self, input: &Variable) -> Variable {
306        let x = self.transform.forward(input);
307        self.decoder.forward(&x)
308    }
309
310    fn parameters(&self) -> Vec<Parameter> {
311        let mut params = self.transform.parameters();
312        params.extend(self.decoder.parameters());
313        params
314    }
315}
316
317impl BertForMaskedLM {
318    /// Creates a new BERT for masked language modeling.
319    pub fn new(config: &BertConfig) -> Self {
320        let bert = Bert::with_pooler(config, false); // No pooler needed for MLM
321        let cls = BertLMPredictionHead::new(
322            config.hidden_size,
323            config.vocab_size,
324            config.layer_norm_eps,
325            &config.hidden_act,
326        );
327
328        Self { bert, cls }
329    }
330
331    /// Forward pass for MLM.
332    pub fn forward_mlm(&self, input_ids: &Tensor<u32>) -> Variable {
333        let sequence_output = self.bert.forward_ids(input_ids);
334        self.cls.forward(&sequence_output)
335    }
336}
337
338impl Module for BertForMaskedLM {
339    fn forward(&self, input: &Variable) -> Variable {
340        let sequence_output = self.bert.forward(input);
341        self.cls.forward(&sequence_output)
342    }
343
344    fn parameters(&self) -> Vec<Parameter> {
345        let mut params = self.bert.parameters();
346        params.extend(self.cls.parameters());
347        params
348    }
349
350    fn train(&mut self) {
351        self.bert.train();
352    }
353
354    fn eval(&mut self) {
355        self.bert.eval();
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_bert_tiny() {
365        let config = BertConfig::tiny();
366        let model = Bert::new(&config);
367
368        let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8], &[2, 4]).unwrap();
369        let output = model.forward_ids(&input_ids);
370
371        assert_eq!(output.data().shape(), &[2, 4, config.hidden_size]);
372    }
373
374    #[test]
375    fn test_bert_pooler() {
376        let config = BertConfig::tiny();
377        let model = Bert::new(&config);
378
379        let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8], &[2, 4]).unwrap();
380        let (seq_out, pooled_out) = model.forward_with_pooling(&input_ids, None, None);
381
382        assert_eq!(seq_out.data().shape(), &[2, 4, config.hidden_size]);
383        assert!(pooled_out.is_some());
384        assert_eq!(pooled_out.unwrap().data().shape(), &[2, config.hidden_size]);
385    }
386
387    #[test]
388    fn test_bert_for_classification() {
389        let config = BertConfig::tiny();
390        let model = BertForSequenceClassification::new(&config, 2);
391
392        let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8], &[2, 4]).unwrap();
393        let logits = model.forward_classification(&input_ids).unwrap();
394
395        assert_eq!(logits.data().shape(), &[2, 2]); // [batch, num_labels]
396    }
397
398    #[test]
399    fn test_bert_for_mlm() {
400        let config = BertConfig::tiny();
401        let model = BertForMaskedLM::new(&config);
402
403        let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8], &[2, 4]).unwrap();
404        let logits = model.forward_mlm(&input_ids);
405
406        assert_eq!(logits.data().shape(), &[2, 4, config.vocab_size]);
407    }
408
409    #[test]
410    fn test_bert_parameter_count() {
411        let config = BertConfig::tiny();
412        let model = Bert::new(&config);
413        let params = model.parameters();
414
415        // Should have many parameters
416        assert!(!params.is_empty());
417    }
418}