post_cortex_embeddings/embeddings/backends/
bert.rs1use anyhow::Result;
17use async_trait::async_trait;
18use candle_core::{DType, Device, Tensor};
19use candle_nn::VarBuilder;
20use candle_transformers::models::bert::{BertModel, Config as BertConfig};
21use hf_hub::api::tokio::Api;
22use std::sync::Arc;
23use tokenizers::Tokenizer;
24use tracing::{debug, info};
25
26use crate::embeddings::backend::EmbeddingBackend;
27use crate::embeddings::config::EmbeddingModelType;
28
29const MAX_SEQ_LENGTH: usize = 512;
31
32pub struct BertBackend {
34 model: Arc<BertModel>,
35 tokenizer: Arc<Tokenizer>,
36 device: Device,
37 model_type: EmbeddingModelType,
38}
39
40impl BertBackend {
41 pub async fn load(model_type: EmbeddingModelType) -> Result<Self> {
43 info!("Loading BERT backend for model: {:?}", model_type);
44
45 let device = Device::Cpu;
46 let model_id = model_type.model_id();
47
48 let api = Api::new().map_err(|e| anyhow::anyhow!("Failed to create API: {}", e))?;
49 let repo = api.model(model_id.to_string());
50
51 let model_path = repo
52 .get("model.safetensors")
53 .await
54 .map_err(|e| anyhow::anyhow!("Failed to get model: {}", e))?;
55 let config_path = repo
56 .get("config.json")
57 .await
58 .map_err(|e| anyhow::anyhow!("Failed to get config: {}", e))?;
59 let tokenizer_path = repo
60 .get("tokenizer.json")
61 .await
62 .map_err(|e| anyhow::anyhow!("Failed to get tokenizer: {}", e))?;
63
64 let tokenizer = Tokenizer::from_file(tokenizer_path)
65 .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
66
67 let bert_config_text = std::fs::read_to_string(config_path)
68 .map_err(|e| anyhow::anyhow!("Failed to read BERT config: {}", e))?;
69 let bert_config: BertConfig = serde_json::from_str(&bert_config_text)
70 .map_err(|e| anyhow::anyhow!("Failed to parse BERT config: {}", e))?;
71
72 #[allow(unsafe_code)]
79 let vb =
80 unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)? };
81 let model = BertModel::load(vb, &bert_config)?;
82
83 info!(
84 "BERT model loaded successfully, embedding dimension: {}",
85 model_type.embedding_dimension()
86 );
87
88 Ok(Self {
89 model: Arc::new(model),
90 tokenizer: Arc::new(tokenizer),
91 device,
92 model_type,
93 })
94 }
95
96 fn l2_normalize_embeddings(embeddings: &Tensor) -> Result<Tensor> {
98 let squared = embeddings.sqr()?;
100 let sum_squared = squared.sum_keepdim(1)?;
101 let l2_norm = sum_squared.sqrt()?;
102
103 if tracing::enabled!(tracing::Level::DEBUG) {
104 let l2_norm_values = l2_norm.to_vec2::<f32>()?;
105 debug!(
106 "L2 normalization - batch size: {}, first norm: {:.6}",
107 l2_norm_values.len(),
108 l2_norm_values
109 .first()
110 .and_then(|v| v.first())
111 .unwrap_or(&0.0)
112 );
113 }
114
115 let epsilon = 1e-12_f32;
117 let l2_norm_safe = l2_norm.clamp(epsilon, f32::MAX)?;
118 let normalized = embeddings.broadcast_div(&l2_norm_safe)?;
119
120 debug!("L2 normalization completed successfully");
121 Ok(normalized)
122 }
123}
124
125#[async_trait]
126impl EmbeddingBackend for BertBackend {
127 fn embedding_dimension(&self) -> usize {
128 self.model_type.embedding_dimension()
129 }
130
131 fn is_bert_based(&self) -> bool {
132 true
133 }
134
135 async fn process_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
136 if texts.is_empty() {
137 return Ok(Vec::new());
138 }
139
140 let mut tokenized = Vec::with_capacity(texts.len());
142 for text in &texts {
143 let encoding = self
144 .tokenizer
145 .encode(text.clone(), true)
146 .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
147 tokenized.push(encoding);
148 }
149
150 let max_len = tokenized
151 .iter()
152 .map(|enc| enc.len())
153 .max()
154 .unwrap_or(0)
155 .min(MAX_SEQ_LENGTH);
156
157 let mut input_ids = Vec::new();
158 let mut attention_mask = Vec::new();
159
160 for encoding in tokenized {
161 let ids = encoding.get_ids();
162 let mask = encoding.get_attention_mask();
163
164 let truncate_len = ids.len().min(max_len);
165 input_ids.extend_from_slice(&ids[..truncate_len]);
166 attention_mask.extend_from_slice(&mask[..truncate_len]);
167
168 if truncate_len < max_len {
169 input_ids.extend(vec![0u32; max_len - truncate_len]);
170 attention_mask.extend(vec![0u32; max_len - truncate_len]);
171 }
172 }
173
174 let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
177 let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
178
179 let input_tensor = Tensor::from_vec(input_ids_i64, (texts.len(), max_len), &self.device)?;
180 let mask_tensor =
181 Tensor::from_vec(attention_mask_i64, (texts.len(), max_len), &self.device)?;
182
183 let outputs = self.model.forward(&input_tensor, &mask_tensor, None)?;
185
186 let mask_f32 = mask_tensor.to_dtype(DType::F32)?;
188 let mask_expanded = mask_f32.unsqueeze(2)?.broadcast_as(outputs.shape())?;
190
191 let masked_outputs = outputs.broadcast_mul(&mask_expanded)?;
192 let sum_embeddings = masked_outputs.sum(1)?;
193 let token_counts = mask_f32.sum(1)?.unsqueeze(1)?;
194 let token_counts_safe = token_counts.clamp(1e-9f64, f64::MAX)?;
195
196 let embeddings = sum_embeddings.broadcast_div(&token_counts_safe)?;
197 let embeddings_normalized = Self::l2_normalize_embeddings(&embeddings)?;
198 let embeddings_vec = embeddings_normalized.to_vec2::<f32>()?;
199
200 if tracing::enabled!(tracing::Level::DEBUG) {
201 for (i, emb) in embeddings_vec.iter().enumerate() {
202 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
203 debug!("Embedding {} norm after L2 normalization: {:.6}", i, norm);
204 }
205 }
206
207 Ok(embeddings_vec)
208 }
209}