Skip to main content

ferrum_models/architectures/
bert.rs

1//! BERT architecture using Candle's built-in implementation
2//! BERT is an encoder model used for embeddings and classification tasks
3
4use candle_core::{DType, Device as CandleDevice, Tensor};
5use candle_nn::VarBuilder;
6use candle_transformers::models::bert::{BertModel, Config as BertConfig, HiddenAct};
7use ferrum_types::{FerrumError, Result};
8use parking_lot::Mutex;
9use tracing::{debug, info};
10
11/// BERT model wrapper for embeddings
12pub struct BertModelWrapper {
13    model: Mutex<BertModel>,
14    config: BertConfig,
15    device: CandleDevice,
16    dtype: DType,
17}
18
19impl BertModelWrapper {
20    /// Create from VarBuilder and config
21    pub fn from_varbuilder(
22        vb: VarBuilder,
23        config: &crate::definition::ModelDefinition,
24        device: CandleDevice,
25        dtype: DType,
26    ) -> Result<Self> {
27        info!("🔨 Creating BERT model from weights...");
28
29        // Build Candle's BERT config
30        let bert_config = BertConfig {
31            vocab_size: config.vocab_size,
32            hidden_size: config.hidden_size,
33            num_hidden_layers: config.num_hidden_layers,
34            num_attention_heads: config.num_attention_heads,
35            intermediate_size: config.intermediate_size,
36            hidden_act: HiddenAct::Gelu,
37            hidden_dropout_prob: 0.1,
38            max_position_embeddings: config.max_position_embeddings,
39            type_vocab_size: 2,
40            initializer_range: 0.02,
41            layer_norm_eps: config.norm_eps,
42            pad_token_id: 0,
43            position_embedding_type:
44                candle_transformers::models::bert::PositionEmbeddingType::Absolute,
45            use_cache: true,
46            classifier_dropout: None,
47            model_type: Some("bert".to_string()),
48        };
49
50        debug!(
51            "BERT config: hidden={}, layers={}, heads={}",
52            bert_config.hidden_size, bert_config.num_hidden_layers, bert_config.num_attention_heads,
53        );
54
55        // Load model
56        let model = BertModel::load(vb, &bert_config)
57            .map_err(|e| FerrumError::model(format!("Failed to create BERT model: {}", e)))?;
58
59        info!("✅ BERT model created successfully");
60
61        Ok(Self {
62            model: Mutex::new(model),
63            config: bert_config,
64            device,
65            dtype,
66        })
67    }
68
69    /// Load from config.json path
70    pub fn from_config_json(
71        vb: VarBuilder,
72        config_path: &std::path::Path,
73        device: CandleDevice,
74        dtype: DType,
75    ) -> Result<Self> {
76        info!("🔨 Loading BERT model from config: {:?}", config_path);
77
78        let config_content = std::fs::read_to_string(config_path)
79            .map_err(|e| FerrumError::model(format!("Failed to read config: {}", e)))?;
80
81        let bert_config: BertConfig = serde_json::from_str(&config_content)
82            .map_err(|e| FerrumError::model(format!("Failed to parse BERT config: {}", e)))?;
83
84        debug!(
85            "BERT config: hidden={}, layers={}, heads={}",
86            bert_config.hidden_size, bert_config.num_hidden_layers, bert_config.num_attention_heads,
87        );
88
89        let model = BertModel::load(vb, &bert_config)
90            .map_err(|e| FerrumError::model(format!("Failed to create BERT model: {}", e)))?;
91
92        info!("✅ BERT model created successfully");
93
94        Ok(Self {
95            model: Mutex::new(model),
96            config: bert_config,
97            device,
98            dtype,
99        })
100    }
101
102    /// Forward pass to get embeddings
103    /// Returns the pooled output (CLS token representation) for sentence embeddings
104    pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
105        let model = self.model.lock();
106
107        let output = model
108            .forward(input_ids, token_type_ids, None)
109            .map_err(|e| FerrumError::model(format!("BERT forward failed: {}", e)))?;
110
111        Ok(output)
112    }
113
114    /// Get sentence embedding (mean pooling over sequence)
115    pub fn get_sentence_embedding(
116        &self,
117        input_ids: &Tensor,
118        token_type_ids: &Tensor,
119        attention_mask: Option<&Tensor>,
120    ) -> Result<Tensor> {
121        let hidden_states = self.forward(input_ids, token_type_ids)?;
122
123        // Mean pooling over sequence dimension (dim 1)
124        let embedding = if let Some(mask) = attention_mask {
125            // Expand mask to hidden size
126            let mask = mask
127                .unsqueeze(2)
128                .map_err(|e| FerrumError::model(format!("unsqueeze failed: {}", e)))?
129                .broadcast_as(hidden_states.shape())
130                .map_err(|e| FerrumError::model(format!("broadcast_as failed: {}", e)))?
131                .to_dtype(hidden_states.dtype())
132                .map_err(|e| FerrumError::model(format!("to_dtype failed: {}", e)))?;
133
134            // Masked mean
135            let masked = hidden_states
136                .broadcast_mul(&mask)
137                .map_err(|e| FerrumError::model(format!("broadcast_mul failed: {}", e)))?;
138            let sum = masked
139                .sum(1)
140                .map_err(|e| FerrumError::model(format!("sum failed: {}", e)))?;
141            let count = mask
142                .sum(1)
143                .map_err(|e| FerrumError::model(format!("mask sum failed: {}", e)))?
144                .clamp(1e-9, f64::MAX)
145                .map_err(|e| FerrumError::model(format!("clamp failed: {}", e)))?;
146            sum.broadcast_div(&count)
147                .map_err(|e| FerrumError::model(format!("broadcast_div failed: {}", e)))?
148        } else {
149            // Simple mean over sequence dimension
150            hidden_states
151                .mean(1)
152                .map_err(|e| FerrumError::model(format!("mean failed: {}", e)))?
153        };
154
155        Ok(embedding)
156    }
157
158    /// Get CLS token embedding
159    pub fn get_cls_embedding(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
160        let hidden_states = self.forward(input_ids, token_type_ids)?;
161
162        // Get first token (CLS) - shape [batch, seq, hidden] -> [batch, hidden]
163        hidden_states
164            .narrow(1, 0, 1)
165            .map_err(|e| FerrumError::model(format!("Failed to narrow: {}", e)))?
166            .squeeze(1)
167            .map_err(|e| FerrumError::model(format!("Failed to squeeze: {}", e)))
168    }
169
170    /// Get config reference
171    pub fn config(&self) -> &BertConfig {
172        &self.config
173    }
174
175    /// Get device
176    pub fn device(&self) -> &CandleDevice {
177        &self.device
178    }
179
180    /// Get dtype
181    pub fn dtype(&self) -> DType {
182        self.dtype
183    }
184
185    /// Get hidden size
186    pub fn hidden_size(&self) -> usize {
187        self.config.hidden_size
188    }
189}