Skip to main content

post_cortex_embeddings/embeddings/backends/
bert.rs

1// Copyright (c) 2025 Julius ML
2//
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in all
11// copies or substantial portions of the Software.
12
13//! BERT backend: HuggingFace Hub model load, tokenize, forward pass,
14//! masked mean pooling, and L2 normalization.
15
16use anyhow::Result;
17use async_trait::async_trait;
18use candle_core::{DType, Device, Tensor};
19use candle_nn::VarBuilder;
20use candle_transformers::models::bert::{BertModel, Config as BertConfig};
21use hf_hub::api::tokio::Api;
22use std::sync::Arc;
23use tokenizers::Tokenizer;
24use tracing::{debug, info};
25
26use crate::embeddings::backend::EmbeddingBackend;
27use crate::embeddings::config::EmbeddingModelType;
28
29/// Maximum sequence length for BERT models (most variants cap at 512).
30const MAX_SEQ_LENGTH: usize = 512;
31
32/// BERT-based embedding backend backed by `candle_transformers::BertModel`.
33pub struct BertBackend {
34    model: Arc<BertModel>,
35    tokenizer: Arc<Tokenizer>,
36    device: Device,
37    model_type: EmbeddingModelType,
38}
39
40impl BertBackend {
41    /// Download (if needed) and load a BERT model + tokenizer from HuggingFace Hub.
42    pub async fn load(model_type: EmbeddingModelType) -> Result<Self> {
43        info!("Loading BERT backend for model: {:?}", model_type);
44
45        let device = Device::Cpu;
46        let model_id = model_type.model_id();
47
48        let api = Api::new().map_err(|e| anyhow::anyhow!("Failed to create API: {}", e))?;
49        let repo = api.model(model_id.to_string());
50
51        let model_path = repo
52            .get("model.safetensors")
53            .await
54            .map_err(|e| anyhow::anyhow!("Failed to get model: {}", e))?;
55        let config_path = repo
56            .get("config.json")
57            .await
58            .map_err(|e| anyhow::anyhow!("Failed to get config: {}", e))?;
59        let tokenizer_path = repo
60            .get("tokenizer.json")
61            .await
62            .map_err(|e| anyhow::anyhow!("Failed to get tokenizer: {}", e))?;
63
64        let tokenizer = Tokenizer::from_file(tokenizer_path)
65            .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
66
67        let bert_config_text = std::fs::read_to_string(config_path)
68            .map_err(|e| anyhow::anyhow!("Failed to read BERT config: {}", e))?;
69        let bert_config: BertConfig = serde_json::from_str(&bert_config_text)
70            .map_err(|e| anyhow::anyhow!("Failed to parse BERT config: {}", e))?;
71
72        // SAFETY: `from_mmaped_safetensors` is `unsafe fn` because mapping
73        // an external file means the kernel can change the bytes under us
74        // (the safetensors file could be truncated or replaced mid-read).
75        // In our pipeline the file is downloaded by `hf-hub`, lives in
76        // the user-local model cache, and is never modified after load —
77        // standard candle convention.
78        #[allow(unsafe_code)]
79        let vb = unsafe {
80            VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)?
81        };
82        let model = BertModel::load(vb, &bert_config)?;
83
84        info!(
85            "BERT model loaded successfully, embedding dimension: {}",
86            model_type.embedding_dimension()
87        );
88
89        Ok(Self {
90            model: Arc::new(model),
91            tokenizer: Arc::new(tokenizer),
92            device,
93            model_type,
94        })
95    }
96
97    /// L2-normalize a batch of embeddings (critical for cosine similarity).
98    fn l2_normalize_embeddings(embeddings: &Tensor) -> Result<Tensor> {
99        // embeddings shape: [batch_size, embedding_dim]
100        let squared = embeddings.sqr()?;
101        let sum_squared = squared.sum_keepdim(1)?;
102        let l2_norm = sum_squared.sqrt()?;
103
104        if tracing::enabled!(tracing::Level::DEBUG) {
105            let l2_norm_values = l2_norm.to_vec2::<f32>()?;
106            debug!(
107                "L2 normalization - batch size: {}, first norm: {:.6}",
108                l2_norm_values.len(),
109                l2_norm_values
110                    .first()
111                    .and_then(|v| v.first())
112                    .unwrap_or(&0.0)
113            );
114        }
115
116        // Clamp norm to a small epsilon to avoid division-by-zero
117        let epsilon = 1e-12_f32;
118        let l2_norm_safe = l2_norm.clamp(epsilon, f32::MAX)?;
119        let normalized = embeddings.broadcast_div(&l2_norm_safe)?;
120
121        debug!("L2 normalization completed successfully");
122        Ok(normalized)
123    }
124}
125
126#[async_trait]
127impl EmbeddingBackend for BertBackend {
128    fn embedding_dimension(&self) -> usize {
129        self.model_type.embedding_dimension()
130    }
131
132    fn is_bert_based(&self) -> bool {
133        true
134    }
135
136    async fn process_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
137        if texts.is_empty() {
138            return Ok(Vec::new());
139        }
140
141        // Tokenize
142        let mut tokenized = Vec::with_capacity(texts.len());
143        for text in &texts {
144            let encoding = self
145                .tokenizer
146                .encode(text.clone(), true)
147                .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
148            tokenized.push(encoding);
149        }
150
151        let max_len = tokenized
152            .iter()
153            .map(|enc| enc.len())
154            .max()
155            .unwrap_or(0)
156            .min(MAX_SEQ_LENGTH);
157
158        let mut input_ids = Vec::new();
159        let mut attention_mask = Vec::new();
160
161        for encoding in tokenized {
162            let ids = encoding.get_ids();
163            let mask = encoding.get_attention_mask();
164
165            let truncate_len = ids.len().min(max_len);
166            input_ids.extend_from_slice(&ids[..truncate_len]);
167            attention_mask.extend_from_slice(&mask[..truncate_len]);
168
169            if truncate_len < max_len {
170                input_ids.extend(vec![0u32; max_len - truncate_len]);
171                attention_mask.extend(vec![0u32; max_len - truncate_len]);
172            }
173        }
174
175        // Convert u32 → i64 for BERT compatibility. The O(n) conversion is
176        // negligible compared to BERT's O(n²) attention (~0.1ms / 512 tokens).
177        let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
178        let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
179
180        let input_tensor = Tensor::from_vec(input_ids_i64, (texts.len(), max_len), &self.device)?;
181        let mask_tensor =
182            Tensor::from_vec(attention_mask_i64, (texts.len(), max_len), &self.device)?;
183
184        // Pass None for token_type_ids — XLM-R/MiniLM ignore them.
185        let outputs = self.model.forward(&input_tensor, &mask_tensor, None)?;
186
187        // Masked mean pooling — standard for Sentence Transformers.
188        let mask_f32 = mask_tensor.to_dtype(DType::F32)?;
189        // mask_f32 is [batch, seq_len] → expand to [batch, seq_len, hidden_size]
190        let mask_expanded = mask_f32.unsqueeze(2)?.broadcast_as(outputs.shape())?;
191
192        let masked_outputs = outputs.broadcast_mul(&mask_expanded)?;
193        let sum_embeddings = masked_outputs.sum(1)?;
194        let token_counts = mask_f32.sum(1)?.unsqueeze(1)?;
195        let token_counts_safe = token_counts.clamp(1e-9f64, f64::MAX)?;
196
197        let embeddings = sum_embeddings.broadcast_div(&token_counts_safe)?;
198        let embeddings_normalized = Self::l2_normalize_embeddings(&embeddings)?;
199        let embeddings_vec = embeddings_normalized.to_vec2::<f32>()?;
200
201        if tracing::enabled!(tracing::Level::DEBUG) {
202            for (i, emb) in embeddings_vec.iter().enumerate() {
203                let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
204                debug!("Embedding {} norm after L2 normalization: {:.6}", i, norm);
205            }
206        }
207
208        Ok(embeddings_vec)
209    }
210}