Skip to main content

post_cortex_embeddings/embeddings/backends/
bert.rs

1// Copyright (c) 2025 Julius ML
2//
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in all
11// copies or substantial portions of the Software.
12
13//! BERT backend: HuggingFace Hub model load, tokenize, forward pass,
14//! masked mean pooling, and L2 normalization.
15
16use anyhow::Result;
17use async_trait::async_trait;
18use candle_core::{DType, Device, Tensor};
19use candle_nn::VarBuilder;
20use candle_transformers::models::bert::{BertModel, Config as BertConfig};
21use hf_hub::api::tokio::Api;
22use std::sync::Arc;
23use tokenizers::Tokenizer;
24use tracing::{debug, info};
25
26use crate::embeddings::backend::EmbeddingBackend;
27use crate::embeddings::config::EmbeddingModelType;
28
29/// Maximum sequence length for BERT models (most variants cap at 512).
30const MAX_SEQ_LENGTH: usize = 512;
31
32/// BERT-based embedding backend backed by `candle_transformers::BertModel`.
33pub struct BertBackend {
34    model: Arc<BertModel>,
35    tokenizer: Arc<Tokenizer>,
36    device: Device,
37    model_type: EmbeddingModelType,
38}
39
40impl BertBackend {
41    /// Download (if needed) and load a BERT model + tokenizer from HuggingFace Hub.
42    pub async fn load(model_type: EmbeddingModelType) -> Result<Self> {
43        info!("Loading BERT backend for model: {:?}", model_type);
44
45        let device = Device::Cpu;
46        let model_id = model_type.model_id();
47
48        let api = Api::new().map_err(|e| anyhow::anyhow!("Failed to create API: {}", e))?;
49        let repo = api.model(model_id.to_string());
50
51        let model_path = repo
52            .get("model.safetensors")
53            .await
54            .map_err(|e| anyhow::anyhow!("Failed to get model: {}", e))?;
55        let config_path = repo
56            .get("config.json")
57            .await
58            .map_err(|e| anyhow::anyhow!("Failed to get config: {}", e))?;
59        let tokenizer_path = repo
60            .get("tokenizer.json")
61            .await
62            .map_err(|e| anyhow::anyhow!("Failed to get tokenizer: {}", e))?;
63
64        let tokenizer = Tokenizer::from_file(tokenizer_path)
65            .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
66
67        let bert_config_text = std::fs::read_to_string(config_path)
68            .map_err(|e| anyhow::anyhow!("Failed to read BERT config: {}", e))?;
69        let bert_config: BertConfig = serde_json::from_str(&bert_config_text)
70            .map_err(|e| anyhow::anyhow!("Failed to parse BERT config: {}", e))?;
71
72        // SAFETY: `from_mmaped_safetensors` is `unsafe fn` because mapping
73        // an external file means the kernel can change the bytes under us
74        // (the safetensors file could be truncated or replaced mid-read).
75        // In our pipeline the file is downloaded by `hf-hub`, lives in
76        // the user-local model cache, and is never modified after load —
77        // standard candle convention.
78        #[allow(unsafe_code)]
79        let vb =
80            unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)? };
81        let model = BertModel::load(vb, &bert_config)?;
82
83        info!(
84            "BERT model loaded successfully, embedding dimension: {}",
85            model_type.embedding_dimension()
86        );
87
88        Ok(Self {
89            model: Arc::new(model),
90            tokenizer: Arc::new(tokenizer),
91            device,
92            model_type,
93        })
94    }
95
96    /// L2-normalize a batch of embeddings (critical for cosine similarity).
97    fn l2_normalize_embeddings(embeddings: &Tensor) -> Result<Tensor> {
98        // embeddings shape: [batch_size, embedding_dim]
99        let squared = embeddings.sqr()?;
100        let sum_squared = squared.sum_keepdim(1)?;
101        let l2_norm = sum_squared.sqrt()?;
102
103        if tracing::enabled!(tracing::Level::DEBUG) {
104            let l2_norm_values = l2_norm.to_vec2::<f32>()?;
105            debug!(
106                "L2 normalization - batch size: {}, first norm: {:.6}",
107                l2_norm_values.len(),
108                l2_norm_values
109                    .first()
110                    .and_then(|v| v.first())
111                    .unwrap_or(&0.0)
112            );
113        }
114
115        // Clamp norm to a small epsilon to avoid division-by-zero
116        let epsilon = 1e-12_f32;
117        let l2_norm_safe = l2_norm.clamp(epsilon, f32::MAX)?;
118        let normalized = embeddings.broadcast_div(&l2_norm_safe)?;
119
120        debug!("L2 normalization completed successfully");
121        Ok(normalized)
122    }
123}
124
125#[async_trait]
126impl EmbeddingBackend for BertBackend {
127    fn embedding_dimension(&self) -> usize {
128        self.model_type.embedding_dimension()
129    }
130
131    fn is_bert_based(&self) -> bool {
132        true
133    }
134
135    async fn process_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
136        if texts.is_empty() {
137            return Ok(Vec::new());
138        }
139
140        // Tokenize
141        let mut tokenized = Vec::with_capacity(texts.len());
142        for text in &texts {
143            let encoding = self
144                .tokenizer
145                .encode(text.clone(), true)
146                .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
147            tokenized.push(encoding);
148        }
149
150        let max_len = tokenized
151            .iter()
152            .map(|enc| enc.len())
153            .max()
154            .unwrap_or(0)
155            .min(MAX_SEQ_LENGTH);
156
157        let mut input_ids = Vec::new();
158        let mut attention_mask = Vec::new();
159
160        for encoding in tokenized {
161            let ids = encoding.get_ids();
162            let mask = encoding.get_attention_mask();
163
164            let truncate_len = ids.len().min(max_len);
165            input_ids.extend_from_slice(&ids[..truncate_len]);
166            attention_mask.extend_from_slice(&mask[..truncate_len]);
167
168            if truncate_len < max_len {
169                input_ids.extend(vec![0u32; max_len - truncate_len]);
170                attention_mask.extend(vec![0u32; max_len - truncate_len]);
171            }
172        }
173
174        // Convert u32 → i64 for BERT compatibility. The O(n) conversion is
175        // negligible compared to BERT's O(n²) attention (~0.1ms / 512 tokens).
176        let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
177        let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
178
179        let input_tensor = Tensor::from_vec(input_ids_i64, (texts.len(), max_len), &self.device)?;
180        let mask_tensor =
181            Tensor::from_vec(attention_mask_i64, (texts.len(), max_len), &self.device)?;
182
183        // Pass None for token_type_ids — XLM-R/MiniLM ignore them.
184        let outputs = self.model.forward(&input_tensor, &mask_tensor, None)?;
185
186        // Masked mean pooling — standard for Sentence Transformers.
187        let mask_f32 = mask_tensor.to_dtype(DType::F32)?;
188        // mask_f32 is [batch, seq_len] → expand to [batch, seq_len, hidden_size]
189        let mask_expanded = mask_f32.unsqueeze(2)?.broadcast_as(outputs.shape())?;
190
191        let masked_outputs = outputs.broadcast_mul(&mask_expanded)?;
192        let sum_embeddings = masked_outputs.sum(1)?;
193        let token_counts = mask_f32.sum(1)?.unsqueeze(1)?;
194        let token_counts_safe = token_counts.clamp(1e-9f64, f64::MAX)?;
195
196        let embeddings = sum_embeddings.broadcast_div(&token_counts_safe)?;
197        let embeddings_normalized = Self::l2_normalize_embeddings(&embeddings)?;
198        let embeddings_vec = embeddings_normalized.to_vec2::<f32>()?;
199
200        if tracing::enabled!(tracing::Level::DEBUG) {
201            for (i, emb) in embeddings_vec.iter().enumerate() {
202                let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
203                debug!("Embedding {} norm after L2 normalization: {:.6}", i, norm);
204            }
205        }
206
207        Ok(embeddings_vec)
208    }
209}