ghostflow_nn/
bert.rs

1//! BERT (Bidirectional Encoder Representations from Transformers)
2//!
3//! Implements BERT as described in "BERT: Pre-training of Deep Bidirectional Transformers"
4//! - Token embeddings
5//! - Segment embeddings
6//! - Position embeddings
7//! - Transformer encoder layers
8//! - Masked language modeling head
9//! - Next sentence prediction head
10
11use ghostflow_core::Tensor;
12use crate::transformer::TransformerEncoder;
13use crate::linear::Linear;
14use crate::norm::LayerNorm;
15use crate::activation::GELU;
16use crate::Module;
17
18/// BERT configuration
19#[derive(Debug, Clone)]
20pub struct BertConfig {
21    /// Vocabulary size
22    pub vocab_size: usize,
23    /// Hidden size (embedding dimension)
24    pub hidden_size: usize,
25    /// Number of transformer layers
26    pub num_layers: usize,
27    /// Number of attention heads
28    pub num_heads: usize,
29    /// Intermediate size in feed-forward network
30    pub intermediate_size: usize,
31    /// Maximum sequence length
32    pub max_position_embeddings: usize,
33    /// Number of token types (segments)
34    pub type_vocab_size: usize,
35    /// Dropout probability
36    pub dropout: f32,
37    /// Layer norm epsilon
38    pub layer_norm_eps: f32,
39}
40
41impl Default for BertConfig {
42    fn default() -> Self {
43        BertConfig {
44            vocab_size: 30522,
45            hidden_size: 768,
46            num_layers: 12,
47            num_heads: 12,
48            intermediate_size: 3072,
49            max_position_embeddings: 512,
50            type_vocab_size: 2,
51            dropout: 0.1,
52            layer_norm_eps: 1e-12,
53        }
54    }
55}
56
57impl BertConfig {
58    /// BERT-Base configuration
59    pub fn bert_base() -> Self {
60        Self::default()
61    }
62    
63    /// BERT-Large configuration
64    pub fn bert_large() -> Self {
65        BertConfig {
66            hidden_size: 1024,
67            num_layers: 24,
68            num_heads: 16,
69            intermediate_size: 4096,
70            ..Default::default()
71        }
72    }
73    
74    /// BERT-Tiny configuration (for testing)
75    pub fn bert_tiny() -> Self {
76        BertConfig {
77            vocab_size: 1000,
78            hidden_size: 128,
79            num_layers: 2,
80            num_heads: 2,
81            intermediate_size: 512,
82            max_position_embeddings: 128,
83            ..Default::default()
84        }
85    }
86}
87
88/// BERT embeddings layer
89pub struct BertEmbeddings {
90    /// Token embeddings
91    token_embeddings: Tensor,
92    /// Position embeddings
93    position_embeddings: Tensor,
94    /// Token type (segment) embeddings
95    token_type_embeddings: Tensor,
96    /// Layer normalization
97    layer_norm: LayerNorm,
98    /// Configuration
99    config: BertConfig,
100}
101
102impl BertEmbeddings {
103    /// Create new BERT embeddings
104    pub fn new(config: BertConfig) -> Self {
105        // Initialize embeddings
106        let token_embeddings = Tensor::randn(&[config.vocab_size, config.hidden_size]);
107        let position_embeddings = Tensor::randn(&[config.max_position_embeddings, config.hidden_size]);
108        let token_type_embeddings = Tensor::randn(&[config.type_vocab_size, config.hidden_size]);
109        
110        let layer_norm = LayerNorm::new(&[config.hidden_size]);
111        
112        BertEmbeddings {
113            token_embeddings,
114            position_embeddings,
115            token_type_embeddings,
116            layer_norm,
117            config,
118        }
119    }
120    
121    /// Forward pass
122    pub fn forward(&self, input_ids: &Tensor, token_type_ids: Option<&Tensor>) -> Result<Tensor, String> {
123        let dims = input_ids.dims();
124        if dims.len() != 2 {
125            return Err(format!("Expected 2D input_ids, got {}D", dims.len()));
126        }
127        
128        let batch_size = dims[0];
129        let seq_length = dims[1];
130        
131        // Get token embeddings
132        let token_embeds = self.get_token_embeddings(input_ids)?;
133        
134        // Get position embeddings
135        let position_embeds = self.get_position_embeddings(seq_length)?;
136        
137        // Get token type embeddings
138        let token_type_embeds = if let Some(tt_ids) = token_type_ids {
139            self.get_token_type_embeddings(tt_ids)?
140        } else {
141            // Default to all zeros (first segment)
142            Tensor::zeros(&[batch_size, seq_length, self.config.hidden_size])
143        };
144        
145        // Sum all embeddings
146        let embeddings = self.sum_embeddings(&token_embeds, &position_embeds, &token_type_embeds)?;
147        
148        // Layer normalization
149        Ok(self.layer_norm.forward(&embeddings))
150    }
151    
152    /// Get token embeddings by lookup
153    fn get_token_embeddings(&self, input_ids: &Tensor) -> Result<Tensor, String> {
154        let ids_data = input_ids.data_f32();
155        let embed_data = self.token_embeddings.data_f32();
156        
157        let dims = input_ids.dims();
158        let batch_size = dims[0];
159        let seq_length = dims[1];
160        let hidden_size = self.config.hidden_size;
161        
162        let mut result = Vec::with_capacity(batch_size * seq_length * hidden_size);
163        
164        for &id in ids_data.iter() {
165            let idx = id as usize;
166            if idx >= self.config.vocab_size {
167                return Err(format!("Token ID {} out of vocabulary range", idx));
168            }
169            
170            let start = idx * hidden_size;
171            let end = start + hidden_size;
172            result.extend_from_slice(&embed_data[start..end]);
173        }
174        
175        Tensor::from_slice(&result, &[batch_size, seq_length, hidden_size])
176            .map_err(|e| format!("Failed to create token embeddings: {:?}", e))
177    }
178    
179    /// Get position embeddings
180    fn get_position_embeddings(&self, seq_length: usize) -> Result<Tensor, String> {
181        let embed_data = self.position_embeddings.data_f32();
182        let hidden_size = self.config.hidden_size;
183        
184        if seq_length > self.config.max_position_embeddings {
185            return Err(format!("Sequence length {} exceeds maximum {}", 
186                             seq_length, self.config.max_position_embeddings));
187        }
188        
189        let result = embed_data[..seq_length * hidden_size].to_vec();
190        
191        Tensor::from_slice(&result, &[seq_length, hidden_size])
192            .map_err(|e| format!("Failed to create position embeddings: {:?}", e))
193    }
194    
195    /// Get token type embeddings
196    fn get_token_type_embeddings(&self, token_type_ids: &Tensor) -> Result<Tensor, String> {
197        let ids_data = token_type_ids.data_f32();
198        let embed_data = self.token_type_embeddings.data_f32();
199        
200        let dims = token_type_ids.dims();
201        let batch_size = dims[0];
202        let seq_length = dims[1];
203        let hidden_size = self.config.hidden_size;
204        
205        let mut result = Vec::with_capacity(batch_size * seq_length * hidden_size);
206        
207        for &id in ids_data.iter() {
208            let idx = id as usize;
209            if idx >= self.config.type_vocab_size {
210                return Err(format!("Token type ID {} out of range", idx));
211            }
212            
213            let start = idx * hidden_size;
214            let end = start + hidden_size;
215            result.extend_from_slice(&embed_data[start..end]);
216        }
217        
218        Tensor::from_slice(&result, &[batch_size, seq_length, hidden_size])
219            .map_err(|e| format!("Failed to create token type embeddings: {:?}", e))
220    }
221    
222    /// Sum all embeddings
223    fn sum_embeddings(&self, token: &Tensor, position: &Tensor, token_type: &Tensor) -> Result<Tensor, String> {
224        let token_data = token.data_f32();
225        let pos_data = position.data_f32();
226        let tt_data = token_type.data_f32();
227        
228        let dims = token.dims();
229        let batch_size = dims[0];
230        let seq_length = dims[1];
231        let hidden_size = dims[2];
232        
233        let mut result = Vec::with_capacity(token_data.len());
234        
235        for b in 0..batch_size {
236            for s in 0..seq_length {
237                for h in 0..hidden_size {
238                    let token_idx = b * seq_length * hidden_size + s * hidden_size + h;
239                    let pos_idx = s * hidden_size + h;
240                    
241                    result.push(token_data[token_idx] + pos_data[pos_idx] + tt_data[token_idx]);
242                }
243            }
244        }
245        
246        Tensor::from_slice(&result, &[batch_size, seq_length, hidden_size])
247            .map_err(|e| format!("Failed to sum embeddings: {:?}", e))
248    }
249}
250
251/// BERT pooler (for classification tasks)
252pub struct BertPooler {
253    dense: Linear,
254    activation: std::marker::PhantomData<GELU>,
255}
256
257impl BertPooler {
258    /// Create new BERT pooler
259    pub fn new(hidden_size: usize) -> Self {
260        BertPooler {
261            dense: Linear::new(hidden_size, hidden_size),
262            activation: std::marker::PhantomData,
263        }
264    }
265    
266    /// Forward pass - pool first token (CLS)
267    pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor, String> {
268        // Extract first token: [batch, hidden_size]
269        let first_token = self.extract_first_token(hidden_states)?;
270        
271        // Dense layer
272        let pooled = self.dense.forward(&first_token);
273        
274        // Tanh activation (BERT uses tanh for pooler)
275        self.apply_tanh(&pooled)
276    }
277    
278    /// Extract first token from sequence
279    fn extract_first_token(&self, hidden_states: &Tensor) -> Result<Tensor, String> {
280        let data = hidden_states.data_f32();
281        let dims = hidden_states.dims();
282        
283        if dims.len() != 3 {
284            return Err(format!("Expected 3D hidden states, got {}D", dims.len()));
285        }
286        
287        let batch_size = dims[0];
288        let hidden_size = dims[2];
289        
290        let mut result = Vec::with_capacity(batch_size * hidden_size);
291        
292        for b in 0..batch_size {
293            let start = b * dims[1] * hidden_size;
294            let end = start + hidden_size;
295            result.extend_from_slice(&data[start..end]);
296        }
297        
298        Tensor::from_slice(&result, &[batch_size, hidden_size])
299            .map_err(|e| format!("Failed to extract first token: {:?}", e))
300    }
301    
302    /// Apply tanh activation
303    fn apply_tanh(&self, x: &Tensor) -> Result<Tensor, String> {
304        let data = x.data_f32();
305        let result: Vec<f32> = data.iter().map(|&v| v.tanh()).collect();
306        
307        Tensor::from_slice(&result, x.dims())
308            .map_err(|e| format!("Failed to apply tanh: {:?}", e))
309    }
310}
311
312/// BERT model
313pub struct BertModel {
314    /// Configuration
315    config: BertConfig,
316    /// Embeddings layer
317    embeddings: BertEmbeddings,
318    /// Transformer encoder
319    encoder: TransformerEncoder,
320    /// Pooler (optional, for classification)
321    pooler: Option<BertPooler>,
322}
323
324impl BertModel {
325    /// Create new BERT model
326    pub fn new(config: BertConfig, with_pooler: bool) -> Self {
327        let embeddings = BertEmbeddings::new(config.clone());
328        
329        let encoder = TransformerEncoder::new(
330            config.hidden_size,
331            config.num_heads,
332            config.intermediate_size,
333            config.num_layers,
334            config.dropout,
335        );
336        
337        let pooler = if with_pooler {
338            Some(BertPooler::new(config.hidden_size))
339        } else {
340            None
341        };
342        
343        BertModel {
344            config,
345            embeddings,
346            encoder,
347            pooler,
348        }
349    }
350    
351    /// Forward pass
352    pub fn forward(&self, input_ids: &Tensor, token_type_ids: Option<&Tensor>, 
353                   _attention_mask: Option<&Tensor>) -> Result<BertOutput, String> {
354        // Get embeddings
355        let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
356        
357        // Encoder
358        let sequence_output = self.encoder.forward(&embedding_output);
359        
360        // Pooler (if present)
361        let pooled_output = if let Some(ref pooler) = self.pooler {
362            Some(pooler.forward(&sequence_output)?)
363        } else {
364            None
365        };
366        
367        Ok(BertOutput {
368            last_hidden_state: sequence_output,
369            pooler_output: pooled_output,
370        })
371    }
372}
373
374/// BERT output
375pub struct BertOutput {
376    /// Last hidden state: [batch, seq_len, hidden_size]
377    pub last_hidden_state: Tensor,
378    /// Pooled output (CLS token): [batch, hidden_size]
379    pub pooler_output: Option<Tensor>,
380}
381
382/// BERT for Masked Language Modeling
383pub struct BertForMaskedLM {
384    bert: BertModel,
385    mlm_head: Linear,
386}
387
388impl BertForMaskedLM {
389    /// Create new BERT for MLM
390    pub fn new(config: BertConfig) -> Self {
391        let bert = BertModel::new(config.clone(), false);
392        let mlm_head = Linear::new(config.hidden_size, config.vocab_size);
393        
394        BertForMaskedLM {
395            bert,
396            mlm_head,
397        }
398    }
399    
400    /// Forward pass
401    pub fn forward(&self, input_ids: &Tensor, token_type_ids: Option<&Tensor>) -> Result<Tensor, String> {
402        let output = self.bert.forward(input_ids, token_type_ids, None)?;
403        Ok(self.mlm_head.forward(&output.last_hidden_state))
404    }
405}
406
407/// BERT for Sequence Classification
408pub struct BertForSequenceClassification {
409    bert: BertModel,
410    classifier: Linear,
411    num_labels: usize,
412}
413
414impl BertForSequenceClassification {
415    /// Create new BERT for classification
416    pub fn new(config: BertConfig, num_labels: usize) -> Self {
417        let bert = BertModel::new(config.clone(), true);
418        let classifier = Linear::new(config.hidden_size, num_labels);
419        
420        BertForSequenceClassification {
421            bert,
422            classifier,
423            num_labels,
424        }
425    }
426    
427    /// Forward pass
428    pub fn forward(&self, input_ids: &Tensor, token_type_ids: Option<&Tensor>) -> Result<Tensor, String> {
429        let output = self.bert.forward(input_ids, token_type_ids, None)?;
430        
431        let pooled = output.pooler_output
432            .ok_or_else(|| "Pooler output not available".to_string())?;
433        
434        Ok(self.classifier.forward(&pooled))
435    }
436}
437
438/// BERT for Token Classification (NER, POS tagging)
439pub struct BertForTokenClassification {
440    bert: BertModel,
441    classifier: Linear,
442    num_labels: usize,
443}
444
445impl BertForTokenClassification {
446    /// Create new BERT for token classification
447    pub fn new(config: BertConfig, num_labels: usize) -> Self {
448        let bert = BertModel::new(config.clone(), false);
449        let classifier = Linear::new(config.hidden_size, num_labels);
450        
451        BertForTokenClassification {
452            bert,
453            classifier,
454            num_labels,
455        }
456    }
457    
458    /// Forward pass
459    pub fn forward(&self, input_ids: &Tensor, token_type_ids: Option<&Tensor>) -> Result<Tensor, String> {
460        let output = self.bert.forward(input_ids, token_type_ids, None)?;
461        Ok(self.classifier.forward(&output.last_hidden_state))
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    
469    #[test]
470    fn test_bert_config() {
471        let config = BertConfig::bert_base();
472        assert_eq!(config.hidden_size, 768);
473        assert_eq!(config.num_layers, 12);
474        
475        let config = BertConfig::bert_large();
476        assert_eq!(config.hidden_size, 1024);
477        assert_eq!(config.num_layers, 24);
478    }
479    
480    #[test]
481    fn test_bert_embeddings() {
482        let config = BertConfig::bert_tiny();
483        let embeddings = BertEmbeddings::new(config);
484        
485        let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
486        let output = embeddings.forward(&input_ids, None).unwrap();
487        
488        assert_eq!(output.dims(), &[2, 2, 128]); // batch=2, seq=2, hidden=128
489    }
490    
491    #[test]
492    fn test_bert_model() {
493        let config = BertConfig::bert_tiny();
494        let bert = BertModel::new(config, true);
495        
496        let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
497        let output = bert.forward(&input_ids, None, None).unwrap();
498        
499        assert_eq!(output.last_hidden_state.dims(), &[2, 2, 128]);
500        assert!(output.pooler_output.is_some());
501        assert_eq!(output.pooler_output.unwrap().dims(), &[2, 128]);
502    }
503    
504    #[test]
505    fn test_bert_for_classification() {
506        let config = BertConfig::bert_tiny();
507        let bert = BertForSequenceClassification::new(config, 2);
508        
509        let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
510        let output = bert.forward(&input_ids, None).unwrap();
511        
512        assert_eq!(output.dims(), &[2, 2]); // batch=2, num_labels=2
513    }
514}