1use axonml_autograd::Variable;
6use axonml_nn::{Module, Linear, Dropout, Parameter};
7use axonml_tensor::Tensor;
8
9use crate::config::BertConfig;
10use crate::embedding::BertEmbedding;
11use crate::error::{LLMError, LLMResult};
12use crate::transformer::{TransformerEncoder, LayerNorm};
13
14#[derive(Debug)]
16pub struct Bert {
17 pub config: BertConfig,
19 pub embeddings: BertEmbedding,
21 pub encoder: TransformerEncoder,
23 pub pooler: Option<BertPooler>,
25}
26
27#[derive(Debug)]
29pub struct BertPooler {
30 pub dense: Linear,
32}
33
34impl BertPooler {
35 pub fn new(hidden_size: usize) -> Self {
37 Self {
38 dense: Linear::new(hidden_size, hidden_size),
39 }
40 }
41}
42
43impl Module for BertPooler {
44 fn forward(&self, input: &Variable) -> Variable {
45 let input_data = input.data();
47 let shape = input_data.shape();
48 let batch_size = shape[0];
49 let hidden_size = shape[2];
50
51 let cls_output = input.slice(&[0..batch_size, 0..1, 0..hidden_size]);
53 let cls_output = cls_output.reshape(&[batch_size, hidden_size]);
54
55 let pooled = self.dense.forward(&cls_output);
57 pooled.tanh()
58 }
59
60 fn parameters(&self) -> Vec<Parameter> {
61 self.dense.parameters()
62 }
63}
64
65impl Bert {
66 pub fn new(config: &BertConfig) -> Self {
68 Self::with_pooler(config, true)
69 }
70
71 pub fn with_pooler(config: &BertConfig, add_pooler: bool) -> Self {
73 let embeddings = BertEmbedding::new(
74 config.vocab_size,
75 config.max_position_embeddings,
76 config.type_vocab_size,
77 config.hidden_size,
78 config.layer_norm_eps,
79 config.hidden_dropout_prob,
80 );
81
82 let encoder = TransformerEncoder::new(
83 config.num_hidden_layers,
84 config.hidden_size,
85 config.num_attention_heads,
86 config.intermediate_size,
87 config.hidden_dropout_prob,
88 config.layer_norm_eps,
89 &config.hidden_act,
90 false, );
92
93 let pooler = if add_pooler {
94 Some(BertPooler::new(config.hidden_size))
95 } else {
96 None
97 };
98
99 Self {
100 config: config.clone(),
101 embeddings,
102 encoder,
103 pooler,
104 }
105 }
106
107 pub fn forward_with_pooling(
109 &self,
110 input_ids: &Tensor<u32>,
111 token_type_ids: Option<&Tensor<u32>>,
112 attention_mask: Option<&Tensor<f32>>,
113 ) -> (Variable, Option<Variable>) {
114 let hidden_states = self.embeddings.forward_with_ids(input_ids, token_type_ids, None);
116
117 let sequence_output = self.encoder.forward_with_mask(&hidden_states, attention_mask);
119
120 let pooled_output = self.pooler.as_ref().map(|p| p.forward(&sequence_output));
122
123 (sequence_output, pooled_output)
124 }
125
126 pub fn forward_ids(&self, input_ids: &Tensor<u32>) -> Variable {
128 let (sequence_output, _) = self.forward_with_pooling(input_ids, None, None);
129 sequence_output
130 }
131}
132
133impl Module for Bert {
134 fn forward(&self, input: &Variable) -> Variable {
135 self.encoder.forward(input)
137 }
138
139 fn parameters(&self) -> Vec<Parameter> {
140 let mut params = Vec::new();
141 params.extend(self.embeddings.parameters());
142 params.extend(self.encoder.parameters());
143 if let Some(ref pooler) = self.pooler {
144 params.extend(pooler.parameters());
145 }
146 params
147 }
148
149 fn train(&mut self) {
150 self.embeddings.train();
151 self.encoder.train();
152 }
153
154 fn eval(&mut self) {
155 self.embeddings.eval();
156 self.encoder.eval();
157 }
158}
159
160#[derive(Debug)]
162pub struct BertForSequenceClassification {
163 pub bert: Bert,
165 pub dropout: Dropout,
167 pub classifier: Linear,
169 pub num_labels: usize,
171}
172
173impl BertForSequenceClassification {
174 pub fn new(config: &BertConfig, num_labels: usize) -> Self {
176 Self {
177 bert: Bert::new(config),
178 dropout: Dropout::new(config.hidden_dropout_prob),
179 classifier: Linear::new(config.hidden_size, num_labels),
180 num_labels,
181 }
182 }
183
184 pub fn forward_classification(&self, input_ids: &Tensor<u32>) -> LLMResult<Variable> {
189 let (_, pooled_output) = self.bert.forward_with_pooling(input_ids, None, None);
190
191 if let Some(pooled) = pooled_output {
192 let pooled = self.dropout.forward(&pooled);
193 Ok(self.classifier.forward(&pooled))
194 } else {
195 Err(LLMError::InvalidConfig(
196 "BERT model must have pooler for sequence classification".to_string()
197 ))
198 }
199 }
200}
201
202impl Module for BertForSequenceClassification {
203 fn forward(&self, input: &Variable) -> Variable {
204 let sequence_output = self.bert.forward(input);
205
206 let seq_data = sequence_output.data();
208 let shape = seq_data.shape();
209 let batch_size = shape[0];
210 let hidden_size = shape[2];
211
212 let cls_output = sequence_output.slice(&[0..batch_size, 0..1, 0..hidden_size]);
213 let cls_output = cls_output.reshape(&[batch_size, hidden_size]);
214
215 let cls_output = self.dropout.forward(&cls_output);
216 self.classifier.forward(&cls_output)
217 }
218
219 fn parameters(&self) -> Vec<Parameter> {
220 let mut params = self.bert.parameters();
221 params.extend(self.classifier.parameters());
222 params
223 }
224
225 fn train(&mut self) {
226 self.bert.train();
227 self.dropout.train();
228 }
229
230 fn eval(&mut self) {
231 self.bert.eval();
232 self.dropout.eval();
233 }
234}
235
236#[derive(Debug)]
238pub struct BertForMaskedLM {
239 pub bert: Bert,
241 pub cls: BertLMPredictionHead,
243}
244
245#[derive(Debug)]
247pub struct BertLMPredictionHead {
248 pub transform: BertPredictionHeadTransform,
250 pub decoder: Linear,
252}
253
254#[derive(Debug)]
256pub struct BertPredictionHeadTransform {
257 pub dense: Linear,
259 pub layer_norm: LayerNorm,
261 pub activation: String,
263}
264
265impl BertPredictionHeadTransform {
266 pub fn new(hidden_size: usize, layer_norm_eps: f32, activation: &str) -> Self {
268 Self {
269 dense: Linear::new(hidden_size, hidden_size),
270 layer_norm: LayerNorm::new(hidden_size, layer_norm_eps),
271 activation: activation.to_string(),
272 }
273 }
274}
275
276impl Module for BertPredictionHeadTransform {
277 fn forward(&self, input: &Variable) -> Variable {
278 let x = self.dense.forward(input);
279 let x = match self.activation.as_str() {
280 "gelu" => x.gelu(),
281 "relu" => x.relu(),
282 _ => x.gelu(),
283 };
284 self.layer_norm.forward(&x)
285 }
286
287 fn parameters(&self) -> Vec<Parameter> {
288 let mut params = self.dense.parameters();
289 params.extend(self.layer_norm.parameters());
290 params
291 }
292}
293
294impl BertLMPredictionHead {
295 pub fn new(hidden_size: usize, vocab_size: usize, layer_norm_eps: f32, activation: &str) -> Self {
297 Self {
298 transform: BertPredictionHeadTransform::new(hidden_size, layer_norm_eps, activation),
299 decoder: Linear::new(hidden_size, vocab_size),
300 }
301 }
302}
303
304impl Module for BertLMPredictionHead {
305 fn forward(&self, input: &Variable) -> Variable {
306 let x = self.transform.forward(input);
307 self.decoder.forward(&x)
308 }
309
310 fn parameters(&self) -> Vec<Parameter> {
311 let mut params = self.transform.parameters();
312 params.extend(self.decoder.parameters());
313 params
314 }
315}
316
317impl BertForMaskedLM {
318 pub fn new(config: &BertConfig) -> Self {
320 let bert = Bert::with_pooler(config, false); let cls = BertLMPredictionHead::new(
322 config.hidden_size,
323 config.vocab_size,
324 config.layer_norm_eps,
325 &config.hidden_act,
326 );
327
328 Self { bert, cls }
329 }
330
331 pub fn forward_mlm(&self, input_ids: &Tensor<u32>) -> Variable {
333 let sequence_output = self.bert.forward_ids(input_ids);
334 self.cls.forward(&sequence_output)
335 }
336}
337
338impl Module for BertForMaskedLM {
339 fn forward(&self, input: &Variable) -> Variable {
340 let sequence_output = self.bert.forward(input);
341 self.cls.forward(&sequence_output)
342 }
343
344 fn parameters(&self) -> Vec<Parameter> {
345 let mut params = self.bert.parameters();
346 params.extend(self.cls.parameters());
347 params
348 }
349
350 fn train(&mut self) {
351 self.bert.train();
352 }
353
354 fn eval(&mut self) {
355 self.bert.eval();
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_bert_tiny() {
365 let config = BertConfig::tiny();
366 let model = Bert::new(&config);
367
368 let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8], &[2, 4]).unwrap();
369 let output = model.forward_ids(&input_ids);
370
371 assert_eq!(output.data().shape(), &[2, 4, config.hidden_size]);
372 }
373
374 #[test]
375 fn test_bert_pooler() {
376 let config = BertConfig::tiny();
377 let model = Bert::new(&config);
378
379 let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8], &[2, 4]).unwrap();
380 let (seq_out, pooled_out) = model.forward_with_pooling(&input_ids, None, None);
381
382 assert_eq!(seq_out.data().shape(), &[2, 4, config.hidden_size]);
383 assert!(pooled_out.is_some());
384 assert_eq!(pooled_out.unwrap().data().shape(), &[2, config.hidden_size]);
385 }
386
387 #[test]
388 fn test_bert_for_classification() {
389 let config = BertConfig::tiny();
390 let model = BertForSequenceClassification::new(&config, 2);
391
392 let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8], &[2, 4]).unwrap();
393 let logits = model.forward_classification(&input_ids).unwrap();
394
395 assert_eq!(logits.data().shape(), &[2, 2]); }
397
398 #[test]
399 fn test_bert_for_mlm() {
400 let config = BertConfig::tiny();
401 let model = BertForMaskedLM::new(&config);
402
403 let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8], &[2, 4]).unwrap();
404 let logits = model.forward_mlm(&input_ids);
405
406 assert_eq!(logits.data().shape(), &[2, 4, config.vocab_size]);
407 }
408
409 #[test]
410 fn test_bert_parameter_count() {
411 let config = BertConfig::tiny();
412 let model = Bert::new(&config);
413 let params = model.parameters();
414
415 assert!(!params.is_empty());
417 }
418}