1use crate::error::{PopsamError, PopsamResult};
2use crate::model::{EmbeddedText, InputRecord};
3use candle_core::{DType, Device, Tensor};
4use candle_nn::VarBuilder;
5use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE as BERT_DTYPE};
6use hf_hub::api::sync::Api;
7use hf_hub::{Repo, RepoType};
8use reqwest::blocking::Client;
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
12
13pub trait EmbeddingProvider {
15 fn embed(&self, records: &[InputRecord]) -> PopsamResult<Vec<EmbeddedText>>;
17}
18
19#[derive(Debug, Clone)]
21pub struct OpenAiCompatibleEmbeddingProvider {
22 client: Client,
23 base_url: String,
24 api_key: String,
25 model: String,
26}
27
28impl OpenAiCompatibleEmbeddingProvider {
29 pub fn new(
31 base_url: impl Into<String>,
32 api_key: impl Into<String>,
33 model: impl Into<String>,
34 ) -> Self {
35 Self {
36 client: Client::new(),
37 base_url: base_url.into(),
38 api_key: api_key.into(),
39 model: model.into(),
40 }
41 }
42}
43
44impl EmbeddingProvider for OpenAiCompatibleEmbeddingProvider {
45 fn embed(&self, records: &[InputRecord]) -> PopsamResult<Vec<EmbeddedText>> {
46 let inputs: Vec<String> = records
47 .iter()
48 .map(|record| record.text.clone().unwrap_or_default())
49 .collect();
50
51 let response: EmbeddingResponse = self
52 .client
53 .post(format!("{}/embeddings", self.base_url.trim_end_matches('/')))
54 .bearer_auth(&self.api_key)
55 .json(&EmbeddingRequest {
56 model: self.model.clone(),
57 input: inputs,
58 })
59 .send()
60 .map_err(|err| PopsamError::Provider(err.to_string()))?
61 .error_for_status()
62 .map_err(|err| PopsamError::Provider(err.to_string()))?
63 .json()
64 .map_err(|err| PopsamError::Provider(err.to_string()))?;
65
66 let mut by_index = response.data;
67 by_index.sort_by_key(|item| item.index);
68
69 if by_index.len() != records.len() {
70 return Err(PopsamError::Provider(format!(
71 "embedding API returned {} vectors for {} inputs",
72 by_index.len(),
73 records.len()
74 )));
75 }
76
77 Ok(records
78 .iter()
79 .zip(by_index)
80 .map(|(record, item)| EmbeddedText {
81 id: record.id.clone(),
82 text: record.text.clone(),
83 embedding: item.embedding,
84 })
85 .collect())
86 }
87}
88
89pub struct CandleEmbeddingProvider {
91 tokenizer: Tokenizer,
92 model: BertModel,
93 device: Device,
94 max_length: usize,
95}
96
97#[derive(Debug, Clone)]
99pub struct CandleEmbeddingModelFiles {
100 pub config: PathBuf,
102 pub tokenizer: PathBuf,
104 pub weights: PathBuf,
106}
107
108#[derive(Debug, Clone)]
110pub struct CandleEmbeddingModelSpec {
111 pub model_id: String,
113 pub revision: String,
115 pub config_filename: String,
117 pub tokenizer_filename: String,
119 pub weights_filename: String,
121}
122
123impl CandleEmbeddingModelSpec {
124 pub fn multilingual_default() -> Self {
126 Self {
127 model_id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".to_string(),
128 revision: "main".to_string(),
129 config_filename: "config.json".to_string(),
130 tokenizer_filename: "tokenizer.json".to_string(),
131 weights_filename: "model.safetensors".to_string(),
132 }
133 }
134}
135
136impl CandleEmbeddingProvider {
137 pub fn from_local_files(
139 files: &CandleEmbeddingModelFiles,
140 device: Device,
141 max_length: usize,
142 ) -> PopsamResult<Self> {
143 let mut tokenizer = Tokenizer::from_file(&files.tokenizer)
144 .map_err(|err| PopsamError::ModelLoad(err.to_string()))?;
145 tokenizer.with_padding(Some(PaddingParams {
146 strategy: PaddingStrategy::BatchLongest,
147 ..Default::default()
148 }));
149 tokenizer
150 .with_truncation(Some(TruncationParams {
151 max_length,
152 ..Default::default()
153 }))
154 .map_err(|err| PopsamError::ModelLoad(err.to_string()))?;
155
156 let config_text = std::fs::read_to_string(&files.config)?;
157 let config: BertConfig =
158 serde_json::from_str(&config_text).map_err(|err| PopsamError::ModelLoad(err.to_string()))?;
159
160 let vb = unsafe {
161 VarBuilder::from_mmaped_safetensors(&[files.weights.clone()], BERT_DTYPE, &device)
162 .map_err(|err| PopsamError::ModelLoad(err.to_string()))?
163 };
164 let model = BertModel::load(vb, &config).map_err(|err| PopsamError::ModelLoad(err.to_string()))?;
165
166 Ok(Self {
167 tokenizer,
168 model,
169 device,
170 max_length,
171 })
172 }
173
174 pub fn from_hf_hub(spec: &CandleEmbeddingModelSpec, device: Device, max_length: usize) -> PopsamResult<Self> {
176 let api = Api::new().map_err(|err| PopsamError::ModelLoad(err.to_string()))?;
177 let repo = api.repo(Repo::with_revision(
178 spec.model_id.clone(),
179 RepoType::Model,
180 spec.revision.clone(),
181 ));
182 let files = CandleEmbeddingModelFiles {
183 config: repo
184 .get(&spec.config_filename)
185 .map_err(|err| PopsamError::ModelLoad(err.to_string()))?,
186 tokenizer: repo
187 .get(&spec.tokenizer_filename)
188 .map_err(|err| PopsamError::ModelLoad(err.to_string()))?,
189 weights: repo
190 .get(&spec.weights_filename)
191 .map_err(|err| PopsamError::ModelLoad(err.to_string()))?,
192 };
193 Self::from_local_files(&files, device, max_length)
194 }
195
196 pub fn cpu(multilingual_default: bool) -> PopsamResult<Self> {
200 let spec = if multilingual_default {
201 CandleEmbeddingModelSpec::multilingual_default()
202 } else {
203 CandleEmbeddingModelSpec::multilingual_default()
204 };
205 Self::from_hf_hub(&spec, Device::Cpu, 512)
206 }
207}
208
209impl EmbeddingProvider for CandleEmbeddingProvider {
210 fn embed(&self, records: &[InputRecord]) -> PopsamResult<Vec<EmbeddedText>> {
211 if records.is_empty() {
212 return Ok(Vec::new());
213 }
214
215 let texts = records
216 .iter()
217 .map(|record| record.text.clone().unwrap_or_default())
218 .collect::<Vec<_>>();
219 let encodings = self
220 .tokenizer
221 .encode_batch(texts, true)
222 .map_err(|err| PopsamError::Provider(err.to_string()))?;
223
224 let max_seq_len = encodings
225 .iter()
226 .map(|encoding| encoding.len())
227 .max()
228 .unwrap_or(0)
229 .min(self.max_length);
230
231 let mut input_ids = Vec::with_capacity(records.len() * max_seq_len);
232 let mut attention_mask = Vec::with_capacity(records.len() * max_seq_len);
233 let token_type_ids = vec![0_u32; records.len() * max_seq_len];
234
235 for encoding in &encodings {
236 let ids = encoding.get_ids();
237 let mask = encoding.get_attention_mask();
238 let pad_len = max_seq_len.saturating_sub(ids.len());
239
240 input_ids.extend_from_slice(ids);
241 input_ids.extend(std::iter::repeat_n(0_u32, pad_len));
242
243 attention_mask.extend_from_slice(mask);
244 attention_mask.extend(std::iter::repeat_n(0_u32, pad_len));
245 }
246
247 let input_ids = Tensor::new(input_ids.as_slice(), &self.device)
248 .map_err(|err| PopsamError::Provider(err.to_string()))?
249 .reshape((records.len(), max_seq_len))
250 .map_err(|err| PopsamError::Provider(err.to_string()))?;
251 let attention_mask = Tensor::new(attention_mask.as_slice(), &self.device)
252 .map_err(|err| PopsamError::Provider(err.to_string()))?
253 .reshape((records.len(), max_seq_len))
254 .map_err(|err| PopsamError::Provider(err.to_string()))?;
255 let token_type_ids = Tensor::new(token_type_ids.as_slice(), &self.device)
256 .map_err(|err| PopsamError::Provider(err.to_string()))?
257 .reshape((records.len(), max_seq_len))
258 .map_err(|err| PopsamError::Provider(err.to_string()))?;
259
260 let hidden = self
261 .model
262 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
263 .map_err(|err| PopsamError::Provider(err.to_string()))?;
264 let pooled = mean_pool(&hidden, &attention_mask)?;
265 let embeddings = pooled
266 .to_dtype(DType::F32)
267 .map_err(|err| PopsamError::Provider(err.to_string()))?
268 .to_vec2::<f32>()
269 .map_err(|err| PopsamError::Provider(err.to_string()))?;
270
271 Ok(records
272 .iter()
273 .zip(embeddings)
274 .map(|(record, embedding)| EmbeddedText {
275 id: record.id.clone(),
276 text: record.text.clone(),
277 embedding,
278 })
279 .collect())
280 }
281}
282
283fn mean_pool(hidden: &Tensor, attention_mask: &Tensor) -> PopsamResult<Tensor> {
284 let mask = attention_mask
285 .to_dtype(DType::F32)
286 .map_err(|err| PopsamError::Provider(err.to_string()))?
287 .unsqueeze(2)
288 .map_err(|err| PopsamError::Provider(err.to_string()))?;
289 let masked_hidden = hidden
290 .broadcast_mul(&mask)
291 .map_err(|err| PopsamError::Provider(err.to_string()))?;
292 let sum_hidden = masked_hidden
293 .sum(1)
294 .map_err(|err| PopsamError::Provider(err.to_string()))?;
295 let sum_mask = mask
296 .sum(1)
297 .map_err(|err| PopsamError::Provider(err.to_string()))?
298 .broadcast_maximum(&Tensor::new(&[1e-9_f32], hidden.device()).map_err(|err| PopsamError::Provider(err.to_string()))?)
299 .map_err(|err| PopsamError::Provider(err.to_string()))?;
300 sum_hidden
301 .broadcast_div(&sum_mask)
302 .map_err(|err| PopsamError::Provider(err.to_string()))
303}
304
305#[derive(Debug, Serialize)]
306struct EmbeddingRequest {
307 model: String,
308 input: Vec<String>,
309}
310
311#[derive(Debug, Deserialize)]
312struct EmbeddingResponse {
313 data: Vec<EmbeddingData>,
314}
315
316#[derive(Debug, Deserialize)]
317struct EmbeddingData {
318 index: usize,
319 embedding: Vec<f32>,
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn multilingual_default_points_to_sentence_transformers_model() {
328 let spec = CandleEmbeddingModelSpec::multilingual_default();
329 assert_eq!(
330 spec.model_id,
331 "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
332 );
333 assert_eq!(spec.weights_filename, "model.safetensors");
334 }
335}