aurora_semantic/embeddings/
providers.rs

1//! Embedding providers - ONNX-based local model inference using ONNX Runtime.
2//!
3//! This module provides embedding generation using Microsoft's ONNX Runtime via the `ort` crate.
4//! Users download ONNX models and tokenizers, then point to their location.
5//!
6//! # Supported Models
7//!
8//! Any ONNX model that outputs embeddings can be used:
9//! - `jina-embeddings-v2-base-code` (recommended for code)
10//! - `all-MiniLM-L6-v2`
11//! - `bge-small-en-v1.5`
12//! - Any sentence-transformer model exported to ONNX
13//!
14//! # Model Directory Structure
15//!
16//! ```text
17//! model_dir/
18//! ├── model.onnx          # ONNX model file (or model-w-mean-pooling.onnx)
19//! ├── tokenizer.json      # HuggingFace tokenizer
20//! └── config.json         # Optional: model config
21//! ```
22//!
23//! # GPU Acceleration
24//!
25//! GPU support is available via Cargo features:
26//! - `cuda` - NVIDIA CUDA (requires CUDA toolkit)
27//! - `tensorrt` - NVIDIA TensorRT (requires TensorRT)
28//! - `directml` - DirectML (Windows, AMD/Intel/NVIDIA)
29//! - `coreml` - CoreML (macOS, Apple Silicon)
30//!
31//! Build with GPU support:
32//! ```bash
33//! cargo build --release --features cuda
34//! ```
35
36use std::path::Path;
37use std::sync::Arc;
38
39use ndarray::IxDyn;
40use ort::session::Session;
41use ort::value::Tensor;
42use parking_lot::Mutex;
43use serde::{Deserialize, Serialize};
44use tokenizers::Tokenizer;
45
46use crate::embeddings::Embedder;
47use crate::error::{Error, Result};
48
49/// Information about the execution provider (CPU/GPU) being used.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ExecutionProviderInfo {
52    /// Name of the execution provider (e.g., "CPU", "CUDA", "DirectML", "CoreML").
53    pub name: String,
54    /// Whether GPU acceleration is being used.
55    pub is_gpu: bool,
56    /// Device ID if applicable.
57    pub device_id: Option<u32>,
58    /// Additional details about the provider.
59    pub details: Option<String>,
60}
61
62impl ExecutionProviderInfo {
63    /// Create CPU execution provider info.
64    pub fn cpu() -> Self {
65        Self {
66            name: "CPU".to_string(),
67            is_gpu: false,
68            device_id: None,
69            details: Some("Default CPU execution".to_string()),
70        }
71    }
72
73    /// Create CUDA execution provider info.
74    #[allow(dead_code)]
75    pub fn cuda(device_id: u32) -> Self {
76        Self {
77            name: "CUDA".to_string(),
78            is_gpu: true,
79            device_id: Some(device_id),
80            details: Some(format!("NVIDIA CUDA GPU (device {})", device_id)),
81        }
82    }
83
84    /// Create TensorRT execution provider info.
85    #[allow(dead_code)]
86    pub fn tensorrt(device_id: u32) -> Self {
87        Self {
88            name: "TensorRT".to_string(),
89            is_gpu: true,
90            device_id: Some(device_id),
91            details: Some(format!("NVIDIA TensorRT GPU (device {})", device_id)),
92        }
93    }
94
95    /// Create DirectML execution provider info.
96    #[allow(dead_code)]
97    pub fn directml(device_id: u32) -> Self {
98        Self {
99            name: "DirectML".to_string(),
100            is_gpu: true,
101            device_id: Some(device_id),
102            details: Some(format!("DirectML GPU (device {})", device_id)),
103        }
104    }
105
106    /// Create CoreML execution provider info.
107    #[allow(dead_code)]
108    pub fn coreml() -> Self {
109        Self {
110            name: "CoreML".to_string(),
111            is_gpu: true,
112            device_id: None,
113            details: Some("Apple CoreML (Neural Engine/GPU)".to_string()),
114        }
115    }
116
117    /// Get a human-readable description.
118    pub fn description(&self) -> String {
119        if self.is_gpu {
120            format!("{} (GPU accelerated)", self.name)
121        } else {
122            format!("{} (no GPU)", self.name)
123        }
124    }
125}
126
127/// ONNX-based embedding model using ONNX Runtime.
128///
129/// Loads an ONNX model and tokenizer from a directory and generates embeddings.
130///
131/// # Example
132///
133/// ```rust,ignore
134/// use aurora_semantic::OnnxEmbedder;
135///
136/// // Point to your downloaded model directory
137/// let embedder = OnnxEmbedder::from_directory("./models/jina-code")?;
138///
139/// let embedding = embedder.embed("fn main() { println!(\"Hello\"); }")?;
140/// println!("Embedding dimension: {}", embedding.len());
141/// ```
142pub struct OnnxEmbedder {
143    session: Arc<Mutex<Session>>,
144    tokenizer: Tokenizer,
145    dimension: usize,
146    max_length: usize,
147    execution_provider: ExecutionProviderInfo,
148}
149
150impl OnnxEmbedder {
151    /// Load an ONNX model from a directory.
152    ///
153    /// The directory must contain:
154    /// - An ONNX model file (model.onnx, model-w-mean-pooling.onnx, etc.)
155    /// - `tokenizer.json` - HuggingFace tokenizer file
156    pub fn from_directory<P: AsRef<Path>>(model_dir: P) -> Result<Self> {
157        let model_dir = model_dir.as_ref();
158
159        // Check for model file with various common names
160        let model_names = [
161            "model.onnx",
162            "model_optimized.onnx",
163            "model-w-mean-pooling.onnx",
164            "model_quantized.onnx",
165            "encoder_model.onnx",
166        ];
167
168        let model_path = model_names
169            .iter()
170            .map(|name| model_dir.join(name))
171            .find(|p| p.exists())
172            .ok_or_else(|| {
173                Error::model_load(format!(
174                    "No ONNX model file found in {}. Expected one of: {:?}",
175                    model_dir.display(),
176                    model_names
177                ))
178            })?;
179
180        // Check tokenizer exists
181        let tokenizer_path = model_dir.join("tokenizer.json");
182        if !tokenizer_path.exists() {
183            return Err(Error::model_load(format!(
184                "tokenizer.json not found in {}",
185                model_dir.display()
186            )));
187        }
188
189        Self::new(&model_path, &tokenizer_path, None, 512)
190    }
191
192    /// Create a new ONNX embedder with explicit paths.
193    pub fn new<P: AsRef<Path>>(
194        model_path: P,
195        tokenizer_path: P,
196        dimension: Option<usize>,
197        max_length: usize,
198    ) -> Result<Self> {
199        let model_path = model_path.as_ref();
200        let tokenizer_path = tokenizer_path.as_ref();
201
202        // Load tokenizer
203        let tokenizer = Tokenizer::from_file(tokenizer_path)
204            .map_err(|e| Error::model_load(format!("Failed to load tokenizer: {}", e)))?;
205
206        // Load ONNX model with ort
207        tracing::info!("Loading ONNX model from: {}", model_path.display());
208
209        #[allow(unused_mut)]
210        let mut builder = Session::builder()
211            .map_err(|e| Error::model_load(format!("Failed to create session builder: {}", e)))?;
212
213        // Track which execution provider we're using
214        #[allow(unused_mut)]
215        let mut execution_provider = ExecutionProviderInfo::cpu();
216
217        // Configure execution providers based on available features
218        #[cfg(feature = "cuda")]
219        {
220            use ort::execution_providers::CUDAExecutionProvider;
221            tracing::info!("CUDA support enabled, attempting GPU acceleration");
222            builder = builder
223                .with_execution_providers([CUDAExecutionProvider::default().build()])
224                .map_err(|e| Error::model_load(format!("Failed to configure CUDA: {}", e)))?;
225            execution_provider = ExecutionProviderInfo::cuda(0);
226        }
227
228        #[cfg(feature = "tensorrt")]
229        {
230            use ort::execution_providers::TensorRTExecutionProvider;
231            tracing::info!("TensorRT support enabled, attempting GPU acceleration");
232            builder = builder
233                .with_execution_providers([TensorRTExecutionProvider::default().build()])
234                .map_err(|e| Error::model_load(format!("Failed to configure TensorRT: {}", e)))?;
235            execution_provider = ExecutionProviderInfo::tensorrt(0);
236        }
237
238        #[cfg(feature = "directml")]
239        {
240            use ort::execution_providers::DirectMLExecutionProvider;
241            tracing::info!("DirectML support enabled, attempting GPU acceleration");
242            builder = builder
243                .with_execution_providers([DirectMLExecutionProvider::default().build()])
244                .map_err(|e| Error::model_load(format!("Failed to configure DirectML: {}", e)))?;
245            execution_provider = ExecutionProviderInfo::directml(0);
246        }
247
248        #[cfg(feature = "coreml")]
249        {
250            use ort::execution_providers::CoreMLExecutionProvider;
251            tracing::info!("CoreML support enabled, attempting GPU acceleration");
252            builder = builder
253                .with_execution_providers([CoreMLExecutionProvider::default().build()])
254                .map_err(|e| Error::model_load(format!("Failed to configure CoreML: {}", e)))?;
255            execution_provider = ExecutionProviderInfo::coreml();
256        }
257
258        let session = builder
259            .with_intra_threads(4)
260            .map_err(|e| Error::model_load(format!("Failed to set threads: {}", e)))?
261            .commit_from_file(model_path)
262            .map_err(|e| Error::model_load(format!("Failed to load ONNX model: {}", e)))?;
263
264        // Default dimension - will be detected from actual output
265        let dimension = dimension.unwrap_or(768);
266
267        tracing::info!(
268            "Loaded ONNX model (dim={}, max_len={}, provider={})",
269            dimension,
270            max_length,
271            execution_provider.description()
272        );
273
274        Ok(Self {
275            session: Arc::new(Mutex::new(session)),
276            tokenizer,
277            dimension,
278            max_length,
279            execution_provider,
280        })
281    }
282
283    /// Set the maximum sequence length.
284    pub fn with_max_length(mut self, max_length: usize) -> Self {
285        self.max_length = max_length;
286        self
287    }
288
289    /// Get information about the execution provider (CPU/GPU).
290    pub fn execution_provider(&self) -> &ExecutionProviderInfo {
291        &self.execution_provider
292    }
293
294    /// Check if GPU acceleration is being used.
295    pub fn is_gpu_accelerated(&self) -> bool {
296        self.execution_provider.is_gpu
297    }
298
299    /// Mean pooling over token embeddings.
300    fn mean_pooling(&self, data: &[f32], shape: &[i64], attention_mask: &[i64]) -> Vec<f32> {
301        if shape.len() != 3 {
302            // Not token embeddings, return as-is
303            return data.to_vec();
304        }
305
306        let seq_len = shape[1] as usize;
307        let dim = shape[2] as usize;
308        let mut result = vec![0.0f32; dim];
309        let mut count = 0.0f32;
310
311        for i in 0..seq_len {
312            if i < attention_mask.len() && attention_mask[i] == 1 {
313                for j in 0..dim {
314                    result[j] += data[i * dim + j];
315                }
316                count += 1.0;
317            }
318        }
319
320        if count > 0.0 {
321            for val in &mut result {
322                *val /= count;
323            }
324        }
325
326        // L2 normalize
327        let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
328        if norm > 0.0 {
329            for val in &mut result {
330                *val /= norm;
331            }
332        }
333
334        result
335    }
336}
337
338impl Embedder for OnnxEmbedder {
339    fn embed(&self, text: &str) -> Result<Vec<f32>> {
340        let results = self.embed_batch(&[text])?;
341        Ok(results.into_iter().next().unwrap_or_default())
342    }
343
344    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
345        if texts.is_empty() {
346            return Ok(Vec::new());
347        }
348
349        let mut all_embeddings = Vec::new();
350
351        // Process one at a time for simplicity
352        for text in texts {
353            // Tokenize
354            let encoding = self
355                .tokenizer
356                .encode(*text, true)
357                .map_err(|e| Error::embedding(format!("Tokenization failed: {}", e)))?;
358
359            let ids = encoding.get_ids();
360            let mask = encoding.get_attention_mask();
361            let seq_len = ids.len().min(self.max_length);
362
363            // Prepare input tensors
364            let input_ids: Vec<i64> = ids.iter().take(seq_len).map(|&id| id as i64).collect();
365            let attention_mask: Vec<i64> = mask.iter().take(seq_len).map(|&m| m as i64).collect();
366            let token_type_ids: Vec<i64> = vec![0i64; seq_len];
367
368            // Create ort tensors
369            let input_ids_array = ndarray::Array::from_shape_vec(IxDyn(&[1, seq_len]), input_ids.clone())
370                .map_err(|e| Error::embedding(format!("Failed to create input tensor: {}", e)))?;
371            let input_ids_tensor = Tensor::from_array(input_ids_array)
372                .map_err(|e| Error::embedding(format!("Failed to create input tensor: {}", e)))?;
373
374            let attention_mask_array = ndarray::Array::from_shape_vec(IxDyn(&[1, seq_len]), attention_mask.clone())
375                .map_err(|e| Error::embedding(format!("Failed to create mask tensor: {}", e)))?;
376            let attention_mask_tensor = Tensor::from_array(attention_mask_array)
377                .map_err(|e| Error::embedding(format!("Failed to create mask tensor: {}", e)))?;
378
379            let token_type_ids_array = ndarray::Array::from_shape_vec(IxDyn(&[1, seq_len]), token_type_ids)
380                .map_err(|e| Error::embedding(format!("Failed to create token_type tensor: {}", e)))?;
381            let token_type_ids_tensor = Tensor::from_array(token_type_ids_array)
382                .map_err(|e| Error::embedding(format!("Failed to create token_type tensor: {}", e)))?;
383
384            // Run inference
385            let mut session = self.session.lock();
386            
387            // Get first output name before running
388            let first_output_name = session.outputs.first()
389                .map(|o| o.name.clone())
390                .unwrap_or_else(|| "output".to_string());
391            
392            let outputs = session
393                .run(ort::inputs![
394                    "input_ids" => input_ids_tensor,
395                    "attention_mask" => attention_mask_tensor,
396                    "token_type_ids" => token_type_ids_tensor,
397                ])
398                .map_err(|e| Error::embedding(format!("Inference failed: {}", e)))?;
399
400            // Extract embeddings - try different output names
401            let output = if let Some(val) = outputs.get("last_hidden_state") {
402                val
403            } else if let Some(val) = outputs.get("sentence_embedding") {
404                val
405            } else {
406                outputs.get(&first_output_name)
407                    .ok_or_else(|| Error::embedding("No output found".to_string()))?
408            };
409
410            let (output_shape, output_data) = output
411                .try_extract_tensor::<f32>()
412                .map_err(|e| Error::embedding(format!("Failed to extract output: {}", e)))?;
413
414            let shape_vec: Vec<i64> = output_shape.iter().map(|&d| d as i64).collect();
415
416            let embedding = if shape_vec.len() == 2 {
417                // Direct sentence embedding [1, dim]
418                let emb: Vec<f32> = output_data.to_vec();
419                normalize_vector(emb)
420            } else if shape_vec.len() == 3 {
421                // Token embeddings [1, seq, dim] - need pooling
422                self.mean_pooling(output_data, &shape_vec, &attention_mask)
423            } else {
424                return Err(Error::embedding(format!(
425                    "Unexpected output shape: {:?}",
426                    shape_vec
427                )));
428            };
429
430            all_embeddings.push(embedding);
431        }
432
433        Ok(all_embeddings)
434    }
435
436    fn dimension(&self) -> usize {
437        self.dimension
438    }
439
440    fn name(&self) -> &'static str {
441        "onnx-runtime"
442    }
443
444    fn max_sequence_length(&self) -> usize {
445        self.max_length
446    }
447}
448
449/// Normalize vector to unit length.
450fn normalize_vector(mut v: Vec<f32>) -> Vec<f32> {
451    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
452    if norm > 0.0 {
453        for x in &mut v {
454            *x /= norm;
455        }
456    }
457    v
458}
459
460/// Simple hash-based embedder for testing (no model required).
461///
462/// Generates deterministic embeddings based on text hash.
463/// Useful for testing the search pipeline without loading a real model.
464pub struct HashEmbedder {
465    dimension: usize,
466}
467
468impl HashEmbedder {
469    /// Create a new hash embedder with the specified dimension.
470    pub fn new(dimension: usize) -> Self {
471        Self { dimension }
472    }
473
474    fn hash_to_embedding(&self, text: &str) -> Vec<f32> {
475        use std::collections::hash_map::DefaultHasher;
476        use std::hash::{Hash, Hasher};
477
478        let mut result = vec![0.0f32; self.dimension];
479        let hash = {
480            let mut hasher = DefaultHasher::new();
481            text.hash(&mut hasher);
482            hasher.finish()
483        };
484
485        // Generate pseudo-random values from hash
486        let mut seed = hash;
487        for val in result.iter_mut() {
488            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
489            *val = ((seed >> 32) as f32 / u32::MAX as f32) * 2.0 - 1.0;
490        }
491
492        normalize_vector(result)
493    }
494}
495
496impl Embedder for HashEmbedder {
497    fn embed(&self, text: &str) -> Result<Vec<f32>> {
498        Ok(self.hash_to_embedding(text))
499    }
500
501    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
502        Ok(texts.iter().map(|t| self.hash_to_embedding(t)).collect())
503    }
504
505    fn dimension(&self) -> usize {
506        self.dimension
507    }
508
509    fn name(&self) -> &'static str {
510        "hash"
511    }
512}
513
514/// Configuration for loading an embedding model.
515#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
516pub struct ModelConfig {
517    /// Path to model directory or ONNX file.
518    pub model_path: std::path::PathBuf,
519    /// Path to tokenizer.json (optional if in model_path).
520    pub tokenizer_path: Option<std::path::PathBuf>,
521    /// Embedding dimension (auto-detect if None).
522    pub dimension: Option<usize>,
523    /// Maximum sequence length.
524    pub max_length: usize,
525}
526
527impl Default for ModelConfig {
528    fn default() -> Self {
529        Self {
530            model_path: std::path::PathBuf::new(),
531            tokenizer_path: None,
532            dimension: None,
533            max_length: 512,
534        }
535    }
536}
537
538impl ModelConfig {
539    /// Create config for a model directory.
540    pub fn from_directory<P: AsRef<Path>>(path: P) -> Self {
541        Self {
542            model_path: path.as_ref().to_path_buf(),
543            tokenizer_path: None,
544            dimension: None,
545            max_length: 512,
546        }
547    }
548
549    /// Set maximum sequence length.
550    pub fn with_max_length(mut self, max_length: usize) -> Self {
551        self.max_length = max_length;
552        self
553    }
554
555    /// Set embedding dimension.
556    pub fn with_dimension(mut self, dimension: usize) -> Self {
557        self.dimension = Some(dimension);
558        self
559    }
560
561    /// Load the embedder from this config.
562    pub fn load(&self) -> Result<OnnxEmbedder> {
563        if self.model_path.is_dir() {
564            let mut embedder = OnnxEmbedder::from_directory(&self.model_path)?;
565            embedder.max_length = self.max_length;
566            if let Some(dim) = self.dimension {
567                embedder.dimension = dim;
568            }
569            Ok(embedder)
570        } else {
571            let tokenizer_path = self
572                .tokenizer_path
573                .clone()
574                .unwrap_or_else(|| self.model_path.with_file_name("tokenizer.json"));
575
576            OnnxEmbedder::new(
577                &self.model_path,
578                &tokenizer_path,
579                self.dimension,
580                self.max_length,
581            )
582        }
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    #[test]
591    fn test_hash_embedder() {
592        let embedder = HashEmbedder::new(384);
593
594        let embedding = embedder.embed("test code").unwrap();
595        assert_eq!(embedding.len(), 384);
596
597        // Check normalization
598        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
599        assert!((norm - 1.0).abs() < 0.01);
600
601        // Same text = same embedding
602        let embedding2 = embedder.embed("test code").unwrap();
603        assert_eq!(embedding, embedding2);
604
605        // Different text = different embedding
606        let embedding3 = embedder.embed("other code").unwrap();
607        assert_ne!(embedding, embedding3);
608    }
609
610    #[test]
611    fn test_batch_embedding() {
612        let embedder = HashEmbedder::new(128);
613
614        let texts = vec!["hello", "world", "test"];
615        let embeddings = embedder.embed_batch(&texts).unwrap();
616
617        assert_eq!(embeddings.len(), 3);
618        for emb in &embeddings {
619            assert_eq!(emb.len(), 128);
620        }
621    }
622}