1use ghostflow_core::Tensor;
12use crate::transformer::TransformerEncoder;
13use crate::linear::Linear;
14use crate::norm::LayerNorm;
15use crate::activation::GELU;
16use crate::Module;
17
18#[derive(Debug, Clone)]
20pub struct BertConfig {
21 pub vocab_size: usize,
23 pub hidden_size: usize,
25 pub num_layers: usize,
27 pub num_heads: usize,
29 pub intermediate_size: usize,
31 pub max_position_embeddings: usize,
33 pub type_vocab_size: usize,
35 pub dropout: f32,
37 pub layer_norm_eps: f32,
39}
40
41impl Default for BertConfig {
42 fn default() -> Self {
43 BertConfig {
44 vocab_size: 30522,
45 hidden_size: 768,
46 num_layers: 12,
47 num_heads: 12,
48 intermediate_size: 3072,
49 max_position_embeddings: 512,
50 type_vocab_size: 2,
51 dropout: 0.1,
52 layer_norm_eps: 1e-12,
53 }
54 }
55}
56
57impl BertConfig {
58 pub fn bert_base() -> Self {
60 Self::default()
61 }
62
63 pub fn bert_large() -> Self {
65 BertConfig {
66 hidden_size: 1024,
67 num_layers: 24,
68 num_heads: 16,
69 intermediate_size: 4096,
70 ..Default::default()
71 }
72 }
73
74 pub fn bert_tiny() -> Self {
76 BertConfig {
77 vocab_size: 1000,
78 hidden_size: 128,
79 num_layers: 2,
80 num_heads: 2,
81 intermediate_size: 512,
82 max_position_embeddings: 128,
83 ..Default::default()
84 }
85 }
86}
87
88pub struct BertEmbeddings {
90 token_embeddings: Tensor,
92 position_embeddings: Tensor,
94 token_type_embeddings: Tensor,
96 layer_norm: LayerNorm,
98 config: BertConfig,
100}
101
102impl BertEmbeddings {
103 pub fn new(config: BertConfig) -> Self {
105 let token_embeddings = Tensor::randn(&[config.vocab_size, config.hidden_size]);
107 let position_embeddings = Tensor::randn(&[config.max_position_embeddings, config.hidden_size]);
108 let token_type_embeddings = Tensor::randn(&[config.type_vocab_size, config.hidden_size]);
109
110 let layer_norm = LayerNorm::new(&[config.hidden_size]);
111
112 BertEmbeddings {
113 token_embeddings,
114 position_embeddings,
115 token_type_embeddings,
116 layer_norm,
117 config,
118 }
119 }
120
121 pub fn forward(&self, input_ids: &Tensor, token_type_ids: Option<&Tensor>) -> Result<Tensor, String> {
123 let dims = input_ids.dims();
124 if dims.len() != 2 {
125 return Err(format!("Expected 2D input_ids, got {}D", dims.len()));
126 }
127
128 let batch_size = dims[0];
129 let seq_length = dims[1];
130
131 let token_embeds = self.get_token_embeddings(input_ids)?;
133
134 let position_embeds = self.get_position_embeddings(seq_length)?;
136
137 let token_type_embeds = if let Some(tt_ids) = token_type_ids {
139 self.get_token_type_embeddings(tt_ids)?
140 } else {
141 Tensor::zeros(&[batch_size, seq_length, self.config.hidden_size])
143 };
144
145 let embeddings = self.sum_embeddings(&token_embeds, &position_embeds, &token_type_embeds)?;
147
148 Ok(self.layer_norm.forward(&embeddings))
150 }
151
152 fn get_token_embeddings(&self, input_ids: &Tensor) -> Result<Tensor, String> {
154 let ids_data = input_ids.data_f32();
155 let embed_data = self.token_embeddings.data_f32();
156
157 let dims = input_ids.dims();
158 let batch_size = dims[0];
159 let seq_length = dims[1];
160 let hidden_size = self.config.hidden_size;
161
162 let mut result = Vec::with_capacity(batch_size * seq_length * hidden_size);
163
164 for &id in ids_data.iter() {
165 let idx = id as usize;
166 if idx >= self.config.vocab_size {
167 return Err(format!("Token ID {} out of vocabulary range", idx));
168 }
169
170 let start = idx * hidden_size;
171 let end = start + hidden_size;
172 result.extend_from_slice(&embed_data[start..end]);
173 }
174
175 Tensor::from_slice(&result, &[batch_size, seq_length, hidden_size])
176 .map_err(|e| format!("Failed to create token embeddings: {:?}", e))
177 }
178
179 fn get_position_embeddings(&self, seq_length: usize) -> Result<Tensor, String> {
181 let embed_data = self.position_embeddings.data_f32();
182 let hidden_size = self.config.hidden_size;
183
184 if seq_length > self.config.max_position_embeddings {
185 return Err(format!("Sequence length {} exceeds maximum {}",
186 seq_length, self.config.max_position_embeddings));
187 }
188
189 let result = embed_data[..seq_length * hidden_size].to_vec();
190
191 Tensor::from_slice(&result, &[seq_length, hidden_size])
192 .map_err(|e| format!("Failed to create position embeddings: {:?}", e))
193 }
194
195 fn get_token_type_embeddings(&self, token_type_ids: &Tensor) -> Result<Tensor, String> {
197 let ids_data = token_type_ids.data_f32();
198 let embed_data = self.token_type_embeddings.data_f32();
199
200 let dims = token_type_ids.dims();
201 let batch_size = dims[0];
202 let seq_length = dims[1];
203 let hidden_size = self.config.hidden_size;
204
205 let mut result = Vec::with_capacity(batch_size * seq_length * hidden_size);
206
207 for &id in ids_data.iter() {
208 let idx = id as usize;
209 if idx >= self.config.type_vocab_size {
210 return Err(format!("Token type ID {} out of range", idx));
211 }
212
213 let start = idx * hidden_size;
214 let end = start + hidden_size;
215 result.extend_from_slice(&embed_data[start..end]);
216 }
217
218 Tensor::from_slice(&result, &[batch_size, seq_length, hidden_size])
219 .map_err(|e| format!("Failed to create token type embeddings: {:?}", e))
220 }
221
222 fn sum_embeddings(&self, token: &Tensor, position: &Tensor, token_type: &Tensor) -> Result<Tensor, String> {
224 let token_data = token.data_f32();
225 let pos_data = position.data_f32();
226 let tt_data = token_type.data_f32();
227
228 let dims = token.dims();
229 let batch_size = dims[0];
230 let seq_length = dims[1];
231 let hidden_size = dims[2];
232
233 let mut result = Vec::with_capacity(token_data.len());
234
235 for b in 0..batch_size {
236 for s in 0..seq_length {
237 for h in 0..hidden_size {
238 let token_idx = b * seq_length * hidden_size + s * hidden_size + h;
239 let pos_idx = s * hidden_size + h;
240
241 result.push(token_data[token_idx] + pos_data[pos_idx] + tt_data[token_idx]);
242 }
243 }
244 }
245
246 Tensor::from_slice(&result, &[batch_size, seq_length, hidden_size])
247 .map_err(|e| format!("Failed to sum embeddings: {:?}", e))
248 }
249}
250
251pub struct BertPooler {
253 dense: Linear,
254 activation: std::marker::PhantomData<GELU>,
255}
256
257impl BertPooler {
258 pub fn new(hidden_size: usize) -> Self {
260 BertPooler {
261 dense: Linear::new(hidden_size, hidden_size),
262 activation: std::marker::PhantomData,
263 }
264 }
265
266 pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor, String> {
268 let first_token = self.extract_first_token(hidden_states)?;
270
271 let pooled = self.dense.forward(&first_token);
273
274 self.apply_tanh(&pooled)
276 }
277
278 fn extract_first_token(&self, hidden_states: &Tensor) -> Result<Tensor, String> {
280 let data = hidden_states.data_f32();
281 let dims = hidden_states.dims();
282
283 if dims.len() != 3 {
284 return Err(format!("Expected 3D hidden states, got {}D", dims.len()));
285 }
286
287 let batch_size = dims[0];
288 let hidden_size = dims[2];
289
290 let mut result = Vec::with_capacity(batch_size * hidden_size);
291
292 for b in 0..batch_size {
293 let start = b * dims[1] * hidden_size;
294 let end = start + hidden_size;
295 result.extend_from_slice(&data[start..end]);
296 }
297
298 Tensor::from_slice(&result, &[batch_size, hidden_size])
299 .map_err(|e| format!("Failed to extract first token: {:?}", e))
300 }
301
302 fn apply_tanh(&self, x: &Tensor) -> Result<Tensor, String> {
304 let data = x.data_f32();
305 let result: Vec<f32> = data.iter().map(|&v| v.tanh()).collect();
306
307 Tensor::from_slice(&result, x.dims())
308 .map_err(|e| format!("Failed to apply tanh: {:?}", e))
309 }
310}
311
312pub struct BertModel {
314 config: BertConfig,
316 embeddings: BertEmbeddings,
318 encoder: TransformerEncoder,
320 pooler: Option<BertPooler>,
322}
323
324impl BertModel {
325 pub fn new(config: BertConfig, with_pooler: bool) -> Self {
327 let embeddings = BertEmbeddings::new(config.clone());
328
329 let encoder = TransformerEncoder::new(
330 config.hidden_size,
331 config.num_heads,
332 config.intermediate_size,
333 config.num_layers,
334 config.dropout,
335 );
336
337 let pooler = if with_pooler {
338 Some(BertPooler::new(config.hidden_size))
339 } else {
340 None
341 };
342
343 BertModel {
344 config,
345 embeddings,
346 encoder,
347 pooler,
348 }
349 }
350
351 pub fn forward(&self, input_ids: &Tensor, token_type_ids: Option<&Tensor>,
353 _attention_mask: Option<&Tensor>) -> Result<BertOutput, String> {
354 let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
356
357 let sequence_output = self.encoder.forward(&embedding_output);
359
360 let pooled_output = if let Some(ref pooler) = self.pooler {
362 Some(pooler.forward(&sequence_output)?)
363 } else {
364 None
365 };
366
367 Ok(BertOutput {
368 last_hidden_state: sequence_output,
369 pooler_output: pooled_output,
370 })
371 }
372}
373
374pub struct BertOutput {
376 pub last_hidden_state: Tensor,
378 pub pooler_output: Option<Tensor>,
380}
381
382pub struct BertForMaskedLM {
384 bert: BertModel,
385 mlm_head: Linear,
386}
387
388impl BertForMaskedLM {
389 pub fn new(config: BertConfig) -> Self {
391 let bert = BertModel::new(config.clone(), false);
392 let mlm_head = Linear::new(config.hidden_size, config.vocab_size);
393
394 BertForMaskedLM {
395 bert,
396 mlm_head,
397 }
398 }
399
400 pub fn forward(&self, input_ids: &Tensor, token_type_ids: Option<&Tensor>) -> Result<Tensor, String> {
402 let output = self.bert.forward(input_ids, token_type_ids, None)?;
403 Ok(self.mlm_head.forward(&output.last_hidden_state))
404 }
405}
406
407pub struct BertForSequenceClassification {
409 bert: BertModel,
410 classifier: Linear,
411 num_labels: usize,
412}
413
414impl BertForSequenceClassification {
415 pub fn new(config: BertConfig, num_labels: usize) -> Self {
417 let bert = BertModel::new(config.clone(), true);
418 let classifier = Linear::new(config.hidden_size, num_labels);
419
420 BertForSequenceClassification {
421 bert,
422 classifier,
423 num_labels,
424 }
425 }
426
427 pub fn forward(&self, input_ids: &Tensor, token_type_ids: Option<&Tensor>) -> Result<Tensor, String> {
429 let output = self.bert.forward(input_ids, token_type_ids, None)?;
430
431 let pooled = output.pooler_output
432 .ok_or_else(|| "Pooler output not available".to_string())?;
433
434 Ok(self.classifier.forward(&pooled))
435 }
436}
437
438pub struct BertForTokenClassification {
440 bert: BertModel,
441 classifier: Linear,
442 num_labels: usize,
443}
444
445impl BertForTokenClassification {
446 pub fn new(config: BertConfig, num_labels: usize) -> Self {
448 let bert = BertModel::new(config.clone(), false);
449 let classifier = Linear::new(config.hidden_size, num_labels);
450
451 BertForTokenClassification {
452 bert,
453 classifier,
454 num_labels,
455 }
456 }
457
458 pub fn forward(&self, input_ids: &Tensor, token_type_ids: Option<&Tensor>) -> Result<Tensor, String> {
460 let output = self.bert.forward(input_ids, token_type_ids, None)?;
461 Ok(self.classifier.forward(&output.last_hidden_state))
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
470 fn test_bert_config() {
471 let config = BertConfig::bert_base();
472 assert_eq!(config.hidden_size, 768);
473 assert_eq!(config.num_layers, 12);
474
475 let config = BertConfig::bert_large();
476 assert_eq!(config.hidden_size, 1024);
477 assert_eq!(config.num_layers, 24);
478 }
479
480 #[test]
481 fn test_bert_embeddings() {
482 let config = BertConfig::bert_tiny();
483 let embeddings = BertEmbeddings::new(config);
484
485 let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
486 let output = embeddings.forward(&input_ids, None).unwrap();
487
488 assert_eq!(output.dims(), &[2, 2, 128]); }
490
491 #[test]
492 fn test_bert_model() {
493 let config = BertConfig::bert_tiny();
494 let bert = BertModel::new(config, true);
495
496 let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
497 let output = bert.forward(&input_ids, None, None).unwrap();
498
499 assert_eq!(output.last_hidden_state.dims(), &[2, 2, 128]);
500 assert!(output.pooler_output.is_some());
501 assert_eq!(output.pooler_output.unwrap().dims(), &[2, 128]);
502 }
503
504 #[test]
505 fn test_bert_for_classification() {
506 let config = BertConfig::bert_tiny();
507 let bert = BertForSequenceClassification::new(config, 2);
508
509 let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
510 let output = bert.forward(&input_ids, None).unwrap();
511
512 assert_eq!(output.dims(), &[2, 2]); }
514}