aurora_semantic/embeddings/
providers.rs1use std::path::Path;
37use std::sync::Arc;
38
39use ndarray::IxDyn;
40use ort::session::Session;
41use ort::value::Tensor;
42use parking_lot::Mutex;
43use serde::{Deserialize, Serialize};
44use tokenizers::Tokenizer;
45
46use crate::embeddings::Embedder;
47use crate::error::{Error, Result};
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ExecutionProviderInfo {
52 pub name: String,
54 pub is_gpu: bool,
56 pub device_id: Option<u32>,
58 pub details: Option<String>,
60}
61
62impl ExecutionProviderInfo {
63 pub fn cpu() -> Self {
65 Self {
66 name: "CPU".to_string(),
67 is_gpu: false,
68 device_id: None,
69 details: Some("Default CPU execution".to_string()),
70 }
71 }
72
73 #[allow(dead_code)]
75 pub fn cuda(device_id: u32) -> Self {
76 Self {
77 name: "CUDA".to_string(),
78 is_gpu: true,
79 device_id: Some(device_id),
80 details: Some(format!("NVIDIA CUDA GPU (device {})", device_id)),
81 }
82 }
83
84 #[allow(dead_code)]
86 pub fn tensorrt(device_id: u32) -> Self {
87 Self {
88 name: "TensorRT".to_string(),
89 is_gpu: true,
90 device_id: Some(device_id),
91 details: Some(format!("NVIDIA TensorRT GPU (device {})", device_id)),
92 }
93 }
94
95 #[allow(dead_code)]
97 pub fn directml(device_id: u32) -> Self {
98 Self {
99 name: "DirectML".to_string(),
100 is_gpu: true,
101 device_id: Some(device_id),
102 details: Some(format!("DirectML GPU (device {})", device_id)),
103 }
104 }
105
106 #[allow(dead_code)]
108 pub fn coreml() -> Self {
109 Self {
110 name: "CoreML".to_string(),
111 is_gpu: true,
112 device_id: None,
113 details: Some("Apple CoreML (Neural Engine/GPU)".to_string()),
114 }
115 }
116
117 pub fn description(&self) -> String {
119 if self.is_gpu {
120 format!("{} (GPU accelerated)", self.name)
121 } else {
122 format!("{} (no GPU)", self.name)
123 }
124 }
125}
126
127pub struct OnnxEmbedder {
143 session: Arc<Mutex<Session>>,
144 tokenizer: Tokenizer,
145 dimension: usize,
146 max_length: usize,
147 execution_provider: ExecutionProviderInfo,
148}
149
150impl OnnxEmbedder {
151 pub fn from_directory<P: AsRef<Path>>(model_dir: P) -> Result<Self> {
157 let model_dir = model_dir.as_ref();
158
159 let model_names = [
161 "model.onnx",
162 "model_optimized.onnx",
163 "model-w-mean-pooling.onnx",
164 "model_quantized.onnx",
165 "encoder_model.onnx",
166 ];
167
168 let model_path = model_names
169 .iter()
170 .map(|name| model_dir.join(name))
171 .find(|p| p.exists())
172 .ok_or_else(|| {
173 Error::model_load(format!(
174 "No ONNX model file found in {}. Expected one of: {:?}",
175 model_dir.display(),
176 model_names
177 ))
178 })?;
179
180 let tokenizer_path = model_dir.join("tokenizer.json");
182 if !tokenizer_path.exists() {
183 return Err(Error::model_load(format!(
184 "tokenizer.json not found in {}",
185 model_dir.display()
186 )));
187 }
188
189 Self::new(&model_path, &tokenizer_path, None, 512)
190 }
191
192 pub fn new<P: AsRef<Path>>(
194 model_path: P,
195 tokenizer_path: P,
196 dimension: Option<usize>,
197 max_length: usize,
198 ) -> Result<Self> {
199 let model_path = model_path.as_ref();
200 let tokenizer_path = tokenizer_path.as_ref();
201
202 let tokenizer = Tokenizer::from_file(tokenizer_path)
204 .map_err(|e| Error::model_load(format!("Failed to load tokenizer: {}", e)))?;
205
206 tracing::info!("Loading ONNX model from: {}", model_path.display());
208
209 #[allow(unused_mut)]
210 let mut builder = Session::builder()
211 .map_err(|e| Error::model_load(format!("Failed to create session builder: {}", e)))?;
212
213 #[allow(unused_mut)]
215 let mut execution_provider = ExecutionProviderInfo::cpu();
216
217 #[cfg(feature = "cuda")]
219 {
220 use ort::execution_providers::CUDAExecutionProvider;
221 tracing::info!("CUDA support enabled, attempting GPU acceleration");
222 builder = builder
223 .with_execution_providers([CUDAExecutionProvider::default().build()])
224 .map_err(|e| Error::model_load(format!("Failed to configure CUDA: {}", e)))?;
225 execution_provider = ExecutionProviderInfo::cuda(0);
226 }
227
228 #[cfg(feature = "tensorrt")]
229 {
230 use ort::execution_providers::TensorRTExecutionProvider;
231 tracing::info!("TensorRT support enabled, attempting GPU acceleration");
232 builder = builder
233 .with_execution_providers([TensorRTExecutionProvider::default().build()])
234 .map_err(|e| Error::model_load(format!("Failed to configure TensorRT: {}", e)))?;
235 execution_provider = ExecutionProviderInfo::tensorrt(0);
236 }
237
238 #[cfg(feature = "directml")]
239 {
240 use ort::execution_providers::DirectMLExecutionProvider;
241 tracing::info!("DirectML support enabled, attempting GPU acceleration");
242 builder = builder
243 .with_execution_providers([DirectMLExecutionProvider::default().build()])
244 .map_err(|e| Error::model_load(format!("Failed to configure DirectML: {}", e)))?;
245 execution_provider = ExecutionProviderInfo::directml(0);
246 }
247
248 #[cfg(feature = "coreml")]
249 {
250 use ort::execution_providers::CoreMLExecutionProvider;
251 tracing::info!("CoreML support enabled, attempting GPU acceleration");
252 builder = builder
253 .with_execution_providers([CoreMLExecutionProvider::default().build()])
254 .map_err(|e| Error::model_load(format!("Failed to configure CoreML: {}", e)))?;
255 execution_provider = ExecutionProviderInfo::coreml();
256 }
257
258 let session = builder
259 .with_intra_threads(4)
260 .map_err(|e| Error::model_load(format!("Failed to set threads: {}", e)))?
261 .commit_from_file(model_path)
262 .map_err(|e| Error::model_load(format!("Failed to load ONNX model: {}", e)))?;
263
264 let dimension = dimension.unwrap_or(768);
266
267 tracing::info!(
268 "Loaded ONNX model (dim={}, max_len={}, provider={})",
269 dimension,
270 max_length,
271 execution_provider.description()
272 );
273
274 Ok(Self {
275 session: Arc::new(Mutex::new(session)),
276 tokenizer,
277 dimension,
278 max_length,
279 execution_provider,
280 })
281 }
282
283 pub fn with_max_length(mut self, max_length: usize) -> Self {
285 self.max_length = max_length;
286 self
287 }
288
289 pub fn execution_provider(&self) -> &ExecutionProviderInfo {
291 &self.execution_provider
292 }
293
294 pub fn is_gpu_accelerated(&self) -> bool {
296 self.execution_provider.is_gpu
297 }
298
299 fn mean_pooling(&self, data: &[f32], shape: &[i64], attention_mask: &[i64]) -> Vec<f32> {
301 if shape.len() != 3 {
302 return data.to_vec();
304 }
305
306 let seq_len = shape[1] as usize;
307 let dim = shape[2] as usize;
308 let mut result = vec![0.0f32; dim];
309 let mut count = 0.0f32;
310
311 for i in 0..seq_len {
312 if i < attention_mask.len() && attention_mask[i] == 1 {
313 for j in 0..dim {
314 result[j] += data[i * dim + j];
315 }
316 count += 1.0;
317 }
318 }
319
320 if count > 0.0 {
321 for val in &mut result {
322 *val /= count;
323 }
324 }
325
326 let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
328 if norm > 0.0 {
329 for val in &mut result {
330 *val /= norm;
331 }
332 }
333
334 result
335 }
336}
337
338impl Embedder for OnnxEmbedder {
339 fn embed(&self, text: &str) -> Result<Vec<f32>> {
340 let results = self.embed_batch(&[text])?;
341 Ok(results.into_iter().next().unwrap_or_default())
342 }
343
344 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
345 if texts.is_empty() {
346 return Ok(Vec::new());
347 }
348
349 let mut all_embeddings = Vec::new();
350
351 for text in texts {
353 let encoding = self
355 .tokenizer
356 .encode(*text, true)
357 .map_err(|e| Error::embedding(format!("Tokenization failed: {}", e)))?;
358
359 let ids = encoding.get_ids();
360 let mask = encoding.get_attention_mask();
361 let seq_len = ids.len().min(self.max_length);
362
363 let input_ids: Vec<i64> = ids.iter().take(seq_len).map(|&id| id as i64).collect();
365 let attention_mask: Vec<i64> = mask.iter().take(seq_len).map(|&m| m as i64).collect();
366 let token_type_ids: Vec<i64> = vec![0i64; seq_len];
367
368 let input_ids_array = ndarray::Array::from_shape_vec(IxDyn(&[1, seq_len]), input_ids.clone())
370 .map_err(|e| Error::embedding(format!("Failed to create input tensor: {}", e)))?;
371 let input_ids_tensor = Tensor::from_array(input_ids_array)
372 .map_err(|e| Error::embedding(format!("Failed to create input tensor: {}", e)))?;
373
374 let attention_mask_array = ndarray::Array::from_shape_vec(IxDyn(&[1, seq_len]), attention_mask.clone())
375 .map_err(|e| Error::embedding(format!("Failed to create mask tensor: {}", e)))?;
376 let attention_mask_tensor = Tensor::from_array(attention_mask_array)
377 .map_err(|e| Error::embedding(format!("Failed to create mask tensor: {}", e)))?;
378
379 let token_type_ids_array = ndarray::Array::from_shape_vec(IxDyn(&[1, seq_len]), token_type_ids)
380 .map_err(|e| Error::embedding(format!("Failed to create token_type tensor: {}", e)))?;
381 let token_type_ids_tensor = Tensor::from_array(token_type_ids_array)
382 .map_err(|e| Error::embedding(format!("Failed to create token_type tensor: {}", e)))?;
383
384 let mut session = self.session.lock();
386
387 let first_output_name = session.outputs.first()
389 .map(|o| o.name.clone())
390 .unwrap_or_else(|| "output".to_string());
391
392 let outputs = session
393 .run(ort::inputs![
394 "input_ids" => input_ids_tensor,
395 "attention_mask" => attention_mask_tensor,
396 "token_type_ids" => token_type_ids_tensor,
397 ])
398 .map_err(|e| Error::embedding(format!("Inference failed: {}", e)))?;
399
400 let output = if let Some(val) = outputs.get("last_hidden_state") {
402 val
403 } else if let Some(val) = outputs.get("sentence_embedding") {
404 val
405 } else {
406 outputs.get(&first_output_name)
407 .ok_or_else(|| Error::embedding("No output found".to_string()))?
408 };
409
410 let (output_shape, output_data) = output
411 .try_extract_tensor::<f32>()
412 .map_err(|e| Error::embedding(format!("Failed to extract output: {}", e)))?;
413
414 let shape_vec: Vec<i64> = output_shape.iter().map(|&d| d as i64).collect();
415
416 let embedding = if shape_vec.len() == 2 {
417 let emb: Vec<f32> = output_data.to_vec();
419 normalize_vector(emb)
420 } else if shape_vec.len() == 3 {
421 self.mean_pooling(output_data, &shape_vec, &attention_mask)
423 } else {
424 return Err(Error::embedding(format!(
425 "Unexpected output shape: {:?}",
426 shape_vec
427 )));
428 };
429
430 all_embeddings.push(embedding);
431 }
432
433 Ok(all_embeddings)
434 }
435
436 fn dimension(&self) -> usize {
437 self.dimension
438 }
439
440 fn name(&self) -> &'static str {
441 "onnx-runtime"
442 }
443
444 fn max_sequence_length(&self) -> usize {
445 self.max_length
446 }
447}
448
449fn normalize_vector(mut v: Vec<f32>) -> Vec<f32> {
451 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
452 if norm > 0.0 {
453 for x in &mut v {
454 *x /= norm;
455 }
456 }
457 v
458}
459
460pub struct HashEmbedder {
465 dimension: usize,
466}
467
468impl HashEmbedder {
469 pub fn new(dimension: usize) -> Self {
471 Self { dimension }
472 }
473
474 fn hash_to_embedding(&self, text: &str) -> Vec<f32> {
475 use std::collections::hash_map::DefaultHasher;
476 use std::hash::{Hash, Hasher};
477
478 let mut result = vec![0.0f32; self.dimension];
479 let hash = {
480 let mut hasher = DefaultHasher::new();
481 text.hash(&mut hasher);
482 hasher.finish()
483 };
484
485 let mut seed = hash;
487 for val in result.iter_mut() {
488 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
489 *val = ((seed >> 32) as f32 / u32::MAX as f32) * 2.0 - 1.0;
490 }
491
492 normalize_vector(result)
493 }
494}
495
496impl Embedder for HashEmbedder {
497 fn embed(&self, text: &str) -> Result<Vec<f32>> {
498 Ok(self.hash_to_embedding(text))
499 }
500
501 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
502 Ok(texts.iter().map(|t| self.hash_to_embedding(t)).collect())
503 }
504
505 fn dimension(&self) -> usize {
506 self.dimension
507 }
508
509 fn name(&self) -> &'static str {
510 "hash"
511 }
512}
513
514#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
516pub struct ModelConfig {
517 pub model_path: std::path::PathBuf,
519 pub tokenizer_path: Option<std::path::PathBuf>,
521 pub dimension: Option<usize>,
523 pub max_length: usize,
525}
526
527impl Default for ModelConfig {
528 fn default() -> Self {
529 Self {
530 model_path: std::path::PathBuf::new(),
531 tokenizer_path: None,
532 dimension: None,
533 max_length: 512,
534 }
535 }
536}
537
538impl ModelConfig {
539 pub fn from_directory<P: AsRef<Path>>(path: P) -> Self {
541 Self {
542 model_path: path.as_ref().to_path_buf(),
543 tokenizer_path: None,
544 dimension: None,
545 max_length: 512,
546 }
547 }
548
549 pub fn with_max_length(mut self, max_length: usize) -> Self {
551 self.max_length = max_length;
552 self
553 }
554
555 pub fn with_dimension(mut self, dimension: usize) -> Self {
557 self.dimension = Some(dimension);
558 self
559 }
560
561 pub fn load(&self) -> Result<OnnxEmbedder> {
563 if self.model_path.is_dir() {
564 let mut embedder = OnnxEmbedder::from_directory(&self.model_path)?;
565 embedder.max_length = self.max_length;
566 if let Some(dim) = self.dimension {
567 embedder.dimension = dim;
568 }
569 Ok(embedder)
570 } else {
571 let tokenizer_path = self
572 .tokenizer_path
573 .clone()
574 .unwrap_or_else(|| self.model_path.with_file_name("tokenizer.json"));
575
576 OnnxEmbedder::new(
577 &self.model_path,
578 &tokenizer_path,
579 self.dimension,
580 self.max_length,
581 )
582 }
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn test_hash_embedder() {
592 let embedder = HashEmbedder::new(384);
593
594 let embedding = embedder.embed("test code").unwrap();
595 assert_eq!(embedding.len(), 384);
596
597 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
599 assert!((norm - 1.0).abs() < 0.01);
600
601 let embedding2 = embedder.embed("test code").unwrap();
603 assert_eq!(embedding, embedding2);
604
605 let embedding3 = embedder.embed("other code").unwrap();
607 assert_ne!(embedding, embedding3);
608 }
609
610 #[test]
611 fn test_batch_embedding() {
612 let embedder = HashEmbedder::new(128);
613
614 let texts = vec!["hello", "world", "test"];
615 let embeddings = embedder.embed_batch(&texts).unwrap();
616
617 assert_eq!(embeddings.len(), 3);
618 for emb in &embeddings {
619 assert_eq!(emb.len(), 128);
620 }
621 }
622}