Skip to main content

popsam_core/
embedding.rs

1use crate::error::{PopsamError, PopsamResult};
2use crate::model::{EmbeddedText, InputRecord};
3use candle_core::{DType, Device, Tensor};
4use candle_nn::VarBuilder;
5use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE as BERT_DTYPE};
6use hf_hub::api::sync::Api;
7use hf_hub::{Repo, RepoType};
8use reqwest::blocking::Client;
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
12
13/// Trait implemented by embedding backends that can turn raw texts into vectors.
14pub trait EmbeddingProvider {
15    /// Embeds a batch of records and returns the corresponding vectors.
16    fn embed(&self, records: &[InputRecord]) -> PopsamResult<Vec<EmbeddedText>>;
17}
18
19/// Embedding provider for OpenAI-compatible `/embeddings` HTTP APIs.
20#[derive(Debug, Clone)]
21pub struct OpenAiCompatibleEmbeddingProvider {
22    client: Client,
23    base_url: String,
24    api_key: String,
25    model: String,
26}
27
28impl OpenAiCompatibleEmbeddingProvider {
29    /// Creates a new OpenAI-compatible embedding provider.
30    pub fn new(
31        base_url: impl Into<String>,
32        api_key: impl Into<String>,
33        model: impl Into<String>,
34    ) -> Self {
35        Self {
36            client: Client::new(),
37            base_url: base_url.into(),
38            api_key: api_key.into(),
39            model: model.into(),
40        }
41    }
42}
43
44impl EmbeddingProvider for OpenAiCompatibleEmbeddingProvider {
45    fn embed(&self, records: &[InputRecord]) -> PopsamResult<Vec<EmbeddedText>> {
46        let inputs: Vec<String> = records
47            .iter()
48            .map(|record| record.text.clone().unwrap_or_default())
49            .collect();
50
51        let response: EmbeddingResponse = self
52            .client
53            .post(format!("{}/embeddings", self.base_url.trim_end_matches('/')))
54            .bearer_auth(&self.api_key)
55            .json(&EmbeddingRequest {
56                model: self.model.clone(),
57                input: inputs,
58            })
59            .send()
60            .map_err(|err| PopsamError::Provider(err.to_string()))?
61            .error_for_status()
62            .map_err(|err| PopsamError::Provider(err.to_string()))?
63            .json()
64            .map_err(|err| PopsamError::Provider(err.to_string()))?;
65
66        let mut by_index = response.data;
67        by_index.sort_by_key(|item| item.index);
68
69        if by_index.len() != records.len() {
70            return Err(PopsamError::Provider(format!(
71                "embedding API returned {} vectors for {} inputs",
72                by_index.len(),
73                records.len()
74            )));
75        }
76
77        Ok(records
78            .iter()
79            .zip(by_index)
80            .map(|(record, item)| EmbeddedText {
81                id: record.id.clone(),
82                text: record.text.clone(),
83                embedding: item.embedding,
84            })
85            .collect())
86    }
87}
88
89/// Local embedding provider backed by Candle and a BERT-style sentence model.
90pub struct CandleEmbeddingProvider {
91    tokenizer: Tokenizer,
92    model: BertModel,
93    device: Device,
94    max_length: usize,
95}
96
97/// Local file paths needed to load a Candle embedding model.
98#[derive(Debug, Clone)]
99pub struct CandleEmbeddingModelFiles {
100    /// Path to the Hugging Face `config.json`.
101    pub config: PathBuf,
102    /// Path to the tokenizer JSON file.
103    pub tokenizer: PathBuf,
104    /// Path to the model weights in `safetensors` format.
105    pub weights: PathBuf,
106}
107
108/// Model specification used to download a sentence embedding model from Hugging Face.
109#[derive(Debug, Clone)]
110pub struct CandleEmbeddingModelSpec {
111    /// Hugging Face model ID.
112    pub model_id: String,
113    /// Revision, branch, or tag to resolve.
114    pub revision: String,
115    /// Config filename inside the repository.
116    pub config_filename: String,
117    /// Tokenizer filename inside the repository.
118    pub tokenizer_filename: String,
119    /// Weights filename inside the repository.
120    pub weights_filename: String,
121}
122
123impl CandleEmbeddingModelSpec {
124    /// Returns the default multilingual sentence-transformer model specification.
125    pub fn multilingual_default() -> Self {
126        Self {
127            model_id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".to_string(),
128            revision: "main".to_string(),
129            config_filename: "config.json".to_string(),
130            tokenizer_filename: "tokenizer.json".to_string(),
131            weights_filename: "model.safetensors".to_string(),
132        }
133    }
134}
135
136impl CandleEmbeddingProvider {
137    /// Loads a local Candle embedding provider from explicit model files.
138    pub fn from_local_files(
139        files: &CandleEmbeddingModelFiles,
140        device: Device,
141        max_length: usize,
142    ) -> PopsamResult<Self> {
143        let mut tokenizer = Tokenizer::from_file(&files.tokenizer)
144            .map_err(|err| PopsamError::ModelLoad(err.to_string()))?;
145        tokenizer.with_padding(Some(PaddingParams {
146            strategy: PaddingStrategy::BatchLongest,
147            ..Default::default()
148        }));
149        tokenizer
150            .with_truncation(Some(TruncationParams {
151                max_length,
152                ..Default::default()
153            }))
154            .map_err(|err| PopsamError::ModelLoad(err.to_string()))?;
155
156        let config_text = std::fs::read_to_string(&files.config)?;
157        let config: BertConfig =
158            serde_json::from_str(&config_text).map_err(|err| PopsamError::ModelLoad(err.to_string()))?;
159
160        let vb = unsafe {
161            VarBuilder::from_mmaped_safetensors(&[files.weights.clone()], BERT_DTYPE, &device)
162                .map_err(|err| PopsamError::ModelLoad(err.to_string()))?
163        };
164        let model = BertModel::load(vb, &config).map_err(|err| PopsamError::ModelLoad(err.to_string()))?;
165
166        Ok(Self {
167            tokenizer,
168            model,
169            device,
170            max_length,
171        })
172    }
173
174    /// Downloads the configured model files from Hugging Face and loads them locally.
175    pub fn from_hf_hub(spec: &CandleEmbeddingModelSpec, device: Device, max_length: usize) -> PopsamResult<Self> {
176        let api = Api::new().map_err(|err| PopsamError::ModelLoad(err.to_string()))?;
177        let repo = api.repo(Repo::with_revision(
178            spec.model_id.clone(),
179            RepoType::Model,
180            spec.revision.clone(),
181        ));
182        let files = CandleEmbeddingModelFiles {
183            config: repo
184                .get(&spec.config_filename)
185                .map_err(|err| PopsamError::ModelLoad(err.to_string()))?,
186            tokenizer: repo
187                .get(&spec.tokenizer_filename)
188                .map_err(|err| PopsamError::ModelLoad(err.to_string()))?,
189            weights: repo
190                .get(&spec.weights_filename)
191                .map_err(|err| PopsamError::ModelLoad(err.to_string()))?,
192        };
193        Self::from_local_files(&files, device, max_length)
194    }
195
196    /// Convenience constructor for the default CPU-backed local model.
197    ///
198    /// The current implementation always resolves to the multilingual default model.
199    pub fn cpu(multilingual_default: bool) -> PopsamResult<Self> {
200        let spec = if multilingual_default {
201            CandleEmbeddingModelSpec::multilingual_default()
202        } else {
203            CandleEmbeddingModelSpec::multilingual_default()
204        };
205        Self::from_hf_hub(&spec, Device::Cpu, 512)
206    }
207}
208
209impl EmbeddingProvider for CandleEmbeddingProvider {
210    fn embed(&self, records: &[InputRecord]) -> PopsamResult<Vec<EmbeddedText>> {
211        if records.is_empty() {
212            return Ok(Vec::new());
213        }
214
215        let texts = records
216            .iter()
217            .map(|record| record.text.clone().unwrap_or_default())
218            .collect::<Vec<_>>();
219        let encodings = self
220            .tokenizer
221            .encode_batch(texts, true)
222            .map_err(|err| PopsamError::Provider(err.to_string()))?;
223
224        let max_seq_len = encodings
225            .iter()
226            .map(|encoding| encoding.len())
227            .max()
228            .unwrap_or(0)
229            .min(self.max_length);
230
231        let mut input_ids = Vec::with_capacity(records.len() * max_seq_len);
232        let mut attention_mask = Vec::with_capacity(records.len() * max_seq_len);
233        let token_type_ids = vec![0_u32; records.len() * max_seq_len];
234
235        for encoding in &encodings {
236            let ids = encoding.get_ids();
237            let mask = encoding.get_attention_mask();
238            let pad_len = max_seq_len.saturating_sub(ids.len());
239
240            input_ids.extend_from_slice(ids);
241            input_ids.extend(std::iter::repeat_n(0_u32, pad_len));
242
243            attention_mask.extend_from_slice(mask);
244            attention_mask.extend(std::iter::repeat_n(0_u32, pad_len));
245        }
246
247        let input_ids = Tensor::new(input_ids.as_slice(), &self.device)
248            .map_err(|err| PopsamError::Provider(err.to_string()))?
249            .reshape((records.len(), max_seq_len))
250            .map_err(|err| PopsamError::Provider(err.to_string()))?;
251        let attention_mask = Tensor::new(attention_mask.as_slice(), &self.device)
252            .map_err(|err| PopsamError::Provider(err.to_string()))?
253            .reshape((records.len(), max_seq_len))
254            .map_err(|err| PopsamError::Provider(err.to_string()))?;
255        let token_type_ids = Tensor::new(token_type_ids.as_slice(), &self.device)
256            .map_err(|err| PopsamError::Provider(err.to_string()))?
257            .reshape((records.len(), max_seq_len))
258            .map_err(|err| PopsamError::Provider(err.to_string()))?;
259
260        let hidden = self
261            .model
262            .forward(&input_ids, &token_type_ids, Some(&attention_mask))
263            .map_err(|err| PopsamError::Provider(err.to_string()))?;
264        let pooled = mean_pool(&hidden, &attention_mask)?;
265        let embeddings = pooled
266            .to_dtype(DType::F32)
267            .map_err(|err| PopsamError::Provider(err.to_string()))?
268            .to_vec2::<f32>()
269            .map_err(|err| PopsamError::Provider(err.to_string()))?;
270
271        Ok(records
272            .iter()
273            .zip(embeddings)
274            .map(|(record, embedding)| EmbeddedText {
275                id: record.id.clone(),
276                text: record.text.clone(),
277                embedding,
278            })
279            .collect())
280    }
281}
282
283fn mean_pool(hidden: &Tensor, attention_mask: &Tensor) -> PopsamResult<Tensor> {
284    let mask = attention_mask
285        .to_dtype(DType::F32)
286        .map_err(|err| PopsamError::Provider(err.to_string()))?
287        .unsqueeze(2)
288        .map_err(|err| PopsamError::Provider(err.to_string()))?;
289    let masked_hidden = hidden
290        .broadcast_mul(&mask)
291        .map_err(|err| PopsamError::Provider(err.to_string()))?;
292    let sum_hidden = masked_hidden
293        .sum(1)
294        .map_err(|err| PopsamError::Provider(err.to_string()))?;
295    let sum_mask = mask
296        .sum(1)
297        .map_err(|err| PopsamError::Provider(err.to_string()))?
298        .broadcast_maximum(&Tensor::new(&[1e-9_f32], hidden.device()).map_err(|err| PopsamError::Provider(err.to_string()))?)
299        .map_err(|err| PopsamError::Provider(err.to_string()))?;
300    sum_hidden
301        .broadcast_div(&sum_mask)
302        .map_err(|err| PopsamError::Provider(err.to_string()))
303}
304
305#[derive(Debug, Serialize)]
306struct EmbeddingRequest {
307    model: String,
308    input: Vec<String>,
309}
310
311#[derive(Debug, Deserialize)]
312struct EmbeddingResponse {
313    data: Vec<EmbeddingData>,
314}
315
316#[derive(Debug, Deserialize)]
317struct EmbeddingData {
318    index: usize,
319    embedding: Vec<f32>,
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn multilingual_default_points_to_sentence_transformers_model() {
328        let spec = CandleEmbeddingModelSpec::multilingual_default();
329        assert_eq!(
330            spec.model_id,
331            "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
332        );
333        assert_eq!(spec.weights_filename, "model.safetensors");
334    }
335}