ferrum_models/architectures/
bert.rs1use candle_core::{DType, Device as CandleDevice, Tensor};
5use candle_nn::VarBuilder;
6use candle_transformers::models::bert::{BertModel, Config as BertConfig, HiddenAct};
7use ferrum_types::{FerrumError, Result};
8use parking_lot::Mutex;
9use tracing::{debug, info};
10
11pub struct BertModelWrapper {
13 model: Mutex<BertModel>,
14 config: BertConfig,
15 device: CandleDevice,
16 dtype: DType,
17}
18
19impl BertModelWrapper {
20 pub fn from_varbuilder(
22 vb: VarBuilder,
23 config: &crate::definition::ModelDefinition,
24 device: CandleDevice,
25 dtype: DType,
26 ) -> Result<Self> {
27 info!("🔨 Creating BERT model from weights...");
28
29 let bert_config = BertConfig {
31 vocab_size: config.vocab_size,
32 hidden_size: config.hidden_size,
33 num_hidden_layers: config.num_hidden_layers,
34 num_attention_heads: config.num_attention_heads,
35 intermediate_size: config.intermediate_size,
36 hidden_act: HiddenAct::Gelu,
37 hidden_dropout_prob: 0.1,
38 max_position_embeddings: config.max_position_embeddings,
39 type_vocab_size: 2,
40 initializer_range: 0.02,
41 layer_norm_eps: config.norm_eps,
42 pad_token_id: 0,
43 position_embedding_type:
44 candle_transformers::models::bert::PositionEmbeddingType::Absolute,
45 use_cache: true,
46 classifier_dropout: None,
47 model_type: Some("bert".to_string()),
48 };
49
50 debug!(
51 "BERT config: hidden={}, layers={}, heads={}",
52 bert_config.hidden_size, bert_config.num_hidden_layers, bert_config.num_attention_heads,
53 );
54
55 let model = BertModel::load(vb, &bert_config)
57 .map_err(|e| FerrumError::model(format!("Failed to create BERT model: {}", e)))?;
58
59 info!("✅ BERT model created successfully");
60
61 Ok(Self {
62 model: Mutex::new(model),
63 config: bert_config,
64 device,
65 dtype,
66 })
67 }
68
69 pub fn from_config_json(
71 vb: VarBuilder,
72 config_path: &std::path::Path,
73 device: CandleDevice,
74 dtype: DType,
75 ) -> Result<Self> {
76 info!("🔨 Loading BERT model from config: {:?}", config_path);
77
78 let config_content = std::fs::read_to_string(config_path)
79 .map_err(|e| FerrumError::model(format!("Failed to read config: {}", e)))?;
80
81 let bert_config: BertConfig = serde_json::from_str(&config_content)
82 .map_err(|e| FerrumError::model(format!("Failed to parse BERT config: {}", e)))?;
83
84 debug!(
85 "BERT config: hidden={}, layers={}, heads={}",
86 bert_config.hidden_size, bert_config.num_hidden_layers, bert_config.num_attention_heads,
87 );
88
89 let model = BertModel::load(vb, &bert_config)
90 .map_err(|e| FerrumError::model(format!("Failed to create BERT model: {}", e)))?;
91
92 info!("✅ BERT model created successfully");
93
94 Ok(Self {
95 model: Mutex::new(model),
96 config: bert_config,
97 device,
98 dtype,
99 })
100 }
101
102 pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
105 let model = self.model.lock();
106
107 let output = model
108 .forward(input_ids, token_type_ids, None)
109 .map_err(|e| FerrumError::model(format!("BERT forward failed: {}", e)))?;
110
111 Ok(output)
112 }
113
114 pub fn get_sentence_embedding(
116 &self,
117 input_ids: &Tensor,
118 token_type_ids: &Tensor,
119 attention_mask: Option<&Tensor>,
120 ) -> Result<Tensor> {
121 let hidden_states = self.forward(input_ids, token_type_ids)?;
122
123 let embedding = if let Some(mask) = attention_mask {
125 let mask = mask
127 .unsqueeze(2)
128 .map_err(|e| FerrumError::model(format!("unsqueeze failed: {}", e)))?
129 .broadcast_as(hidden_states.shape())
130 .map_err(|e| FerrumError::model(format!("broadcast_as failed: {}", e)))?
131 .to_dtype(hidden_states.dtype())
132 .map_err(|e| FerrumError::model(format!("to_dtype failed: {}", e)))?;
133
134 let masked = hidden_states
136 .broadcast_mul(&mask)
137 .map_err(|e| FerrumError::model(format!("broadcast_mul failed: {}", e)))?;
138 let sum = masked
139 .sum(1)
140 .map_err(|e| FerrumError::model(format!("sum failed: {}", e)))?;
141 let count = mask
142 .sum(1)
143 .map_err(|e| FerrumError::model(format!("mask sum failed: {}", e)))?
144 .clamp(1e-9, f64::MAX)
145 .map_err(|e| FerrumError::model(format!("clamp failed: {}", e)))?;
146 sum.broadcast_div(&count)
147 .map_err(|e| FerrumError::model(format!("broadcast_div failed: {}", e)))?
148 } else {
149 hidden_states
151 .mean(1)
152 .map_err(|e| FerrumError::model(format!("mean failed: {}", e)))?
153 };
154
155 Ok(embedding)
156 }
157
158 pub fn get_cls_embedding(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
160 let hidden_states = self.forward(input_ids, token_type_ids)?;
161
162 hidden_states
164 .narrow(1, 0, 1)
165 .map_err(|e| FerrumError::model(format!("Failed to narrow: {}", e)))?
166 .squeeze(1)
167 .map_err(|e| FerrumError::model(format!("Failed to squeeze: {}", e)))
168 }
169
170 pub fn config(&self) -> &BertConfig {
172 &self.config
173 }
174
175 pub fn device(&self) -> &CandleDevice {
177 &self.device
178 }
179
180 pub fn dtype(&self) -> DType {
182 self.dtype
183 }
184
185 pub fn hidden_size(&self) -> usize {
187 self.config.hidden_size
188 }
189}