post_cortex_embeddings/embeddings/backends/
bert.rs1use anyhow::Result;
17use async_trait::async_trait;
18use candle_core::{DType, Device, Tensor};
19use candle_nn::VarBuilder;
20use candle_transformers::models::bert::{BertModel, Config as BertConfig};
21use hf_hub::api::tokio::Api;
22use std::sync::Arc;
23use tokenizers::Tokenizer;
24use tracing::{debug, info};
25
26use crate::embeddings::backend::EmbeddingBackend;
27use crate::embeddings::config::EmbeddingModelType;
28
29const MAX_SEQ_LENGTH: usize = 512;
31
32pub struct BertBackend {
34 model: Arc<BertModel>,
35 tokenizer: Arc<Tokenizer>,
36 device: Device,
37 model_type: EmbeddingModelType,
38}
39
40impl BertBackend {
41 pub async fn load(model_type: EmbeddingModelType) -> Result<Self> {
43 info!("Loading BERT backend for model: {:?}", model_type);
44
45 let device = Device::Cpu;
46 let model_id = model_type.model_id();
47
48 let api = Api::new().map_err(|e| anyhow::anyhow!("Failed to create API: {}", e))?;
49 let repo = api.model(model_id.to_string());
50
51 let model_path = repo
52 .get("model.safetensors")
53 .await
54 .map_err(|e| anyhow::anyhow!("Failed to get model: {}", e))?;
55 let config_path = repo
56 .get("config.json")
57 .await
58 .map_err(|e| anyhow::anyhow!("Failed to get config: {}", e))?;
59 let tokenizer_path = repo
60 .get("tokenizer.json")
61 .await
62 .map_err(|e| anyhow::anyhow!("Failed to get tokenizer: {}", e))?;
63
64 let tokenizer = Tokenizer::from_file(tokenizer_path)
65 .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
66
67 let bert_config_text = std::fs::read_to_string(config_path)
68 .map_err(|e| anyhow::anyhow!("Failed to read BERT config: {}", e))?;
69 let bert_config: BertConfig = serde_json::from_str(&bert_config_text)
70 .map_err(|e| anyhow::anyhow!("Failed to parse BERT config: {}", e))?;
71
72 #[allow(unsafe_code)]
79 let vb = unsafe {
80 VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)?
81 };
82 let model = BertModel::load(vb, &bert_config)?;
83
84 info!(
85 "BERT model loaded successfully, embedding dimension: {}",
86 model_type.embedding_dimension()
87 );
88
89 Ok(Self {
90 model: Arc::new(model),
91 tokenizer: Arc::new(tokenizer),
92 device,
93 model_type,
94 })
95 }
96
97 fn l2_normalize_embeddings(embeddings: &Tensor) -> Result<Tensor> {
99 let squared = embeddings.sqr()?;
101 let sum_squared = squared.sum_keepdim(1)?;
102 let l2_norm = sum_squared.sqrt()?;
103
104 if tracing::enabled!(tracing::Level::DEBUG) {
105 let l2_norm_values = l2_norm.to_vec2::<f32>()?;
106 debug!(
107 "L2 normalization - batch size: {}, first norm: {:.6}",
108 l2_norm_values.len(),
109 l2_norm_values
110 .first()
111 .and_then(|v| v.first())
112 .unwrap_or(&0.0)
113 );
114 }
115
116 let epsilon = 1e-12_f32;
118 let l2_norm_safe = l2_norm.clamp(epsilon, f32::MAX)?;
119 let normalized = embeddings.broadcast_div(&l2_norm_safe)?;
120
121 debug!("L2 normalization completed successfully");
122 Ok(normalized)
123 }
124}
125
126#[async_trait]
127impl EmbeddingBackend for BertBackend {
128 fn embedding_dimension(&self) -> usize {
129 self.model_type.embedding_dimension()
130 }
131
132 fn is_bert_based(&self) -> bool {
133 true
134 }
135
136 async fn process_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
137 if texts.is_empty() {
138 return Ok(Vec::new());
139 }
140
141 let mut tokenized = Vec::with_capacity(texts.len());
143 for text in &texts {
144 let encoding = self
145 .tokenizer
146 .encode(text.clone(), true)
147 .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
148 tokenized.push(encoding);
149 }
150
151 let max_len = tokenized
152 .iter()
153 .map(|enc| enc.len())
154 .max()
155 .unwrap_or(0)
156 .min(MAX_SEQ_LENGTH);
157
158 let mut input_ids = Vec::new();
159 let mut attention_mask = Vec::new();
160
161 for encoding in tokenized {
162 let ids = encoding.get_ids();
163 let mask = encoding.get_attention_mask();
164
165 let truncate_len = ids.len().min(max_len);
166 input_ids.extend_from_slice(&ids[..truncate_len]);
167 attention_mask.extend_from_slice(&mask[..truncate_len]);
168
169 if truncate_len < max_len {
170 input_ids.extend(vec![0u32; max_len - truncate_len]);
171 attention_mask.extend(vec![0u32; max_len - truncate_len]);
172 }
173 }
174
175 let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
178 let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
179
180 let input_tensor = Tensor::from_vec(input_ids_i64, (texts.len(), max_len), &self.device)?;
181 let mask_tensor =
182 Tensor::from_vec(attention_mask_i64, (texts.len(), max_len), &self.device)?;
183
184 let outputs = self.model.forward(&input_tensor, &mask_tensor, None)?;
186
187 let mask_f32 = mask_tensor.to_dtype(DType::F32)?;
189 let mask_expanded = mask_f32.unsqueeze(2)?.broadcast_as(outputs.shape())?;
191
192 let masked_outputs = outputs.broadcast_mul(&mask_expanded)?;
193 let sum_embeddings = masked_outputs.sum(1)?;
194 let token_counts = mask_f32.sum(1)?.unsqueeze(1)?;
195 let token_counts_safe = token_counts.clamp(1e-9f64, f64::MAX)?;
196
197 let embeddings = sum_embeddings.broadcast_div(&token_counts_safe)?;
198 let embeddings_normalized = Self::l2_normalize_embeddings(&embeddings)?;
199 let embeddings_vec = embeddings_normalized.to_vec2::<f32>()?;
200
201 if tracing::enabled!(tracing::Level::DEBUG) {
202 for (i, emb) in embeddings_vec.iter().enumerate() {
203 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
204 debug!("Embedding {} norm after L2 normalization: {:.6}", i, norm);
205 }
206 }
207
208 Ok(embeddings_vec)
209 }
210}