aurora_semantic/embeddings/
providers.rs

1//! Embedding providers - ONNX-based local model inference using ONNX Runtime.
2//!
3//! This module provides embedding generation using Microsoft's ONNX Runtime via the `ort` crate.
4//! Users download ONNX models and tokenizers, then point to their location.
5//!
6//! # Supported Models
7//!
8//! Any ONNX model that outputs embeddings can be used:
9//! - `jina-embeddings-v2-base-code` (recommended for code)
10//! - `all-MiniLM-L6-v2`
11//! - `bge-small-en-v1.5`
12//! - Any sentence-transformer model exported to ONNX
13//!
14//! # Model Directory Structure
15//!
16//! ```text
17//! model_dir/
18//! ├── model.onnx          # ONNX model file (or model-w-mean-pooling.onnx)
19//! ├── tokenizer.json      # HuggingFace tokenizer
20//! └── config.json         # Optional: model config
21//! ```
22//!
23//! # GPU Acceleration
24//!
25//! GPU support is available via Cargo features:
26//! - `cuda` - NVIDIA CUDA (requires CUDA toolkit)
27//! - `tensorrt` - NVIDIA TensorRT (requires TensorRT)
28//! - `directml` - DirectML (Windows, AMD/Intel/NVIDIA)
29//! - `coreml` - CoreML (macOS, Apple Silicon)
30//!
31//! Build with GPU support:
32//! ```bash
33//! cargo build --release --features cuda
34//! ```
35
36use std::path::Path;
37use std::sync::Arc;
38
39use ndarray::IxDyn;
40use ort::session::Session;
41use ort::value::Tensor;
42use parking_lot::Mutex;
43use serde::{Deserialize, Serialize};
44use tokenizers::Tokenizer;
45
46use crate::embeddings::Embedder;
47use crate::error::{Error, Result};
48
49// ============================================================================
50// JINA CODE EMBEDDINGS 1.5B SUPPORT
51// ============================================================================
52
53/// Embedding task types for Jina Code Embeddings 1.5B.
54///
55/// Each task type has separate instruction prefixes for queries and passages,
56/// enabling asymmetric retrieval (queries and documents are embedded differently).
57///
58/// # Example
59/// ```rust,ignore
60/// use aurora_semantic::{JinaCodeEmbedder, EmbeddingTask};
61///
62/// // For indexing code (use passage prefix)
63/// let code_embedder = JinaCodeEmbedder::from_directory("./models/jina-code-1.5b")?
64///     .with_task(EmbeddingTask::NL2Code);
65/// let code_embedding = code_embedder.embed_passage("fn parse_json(s: &str) -> Value { }")?;
66///
67/// // For search queries (use query prefix)
68/// let query_embedding = code_embedder.embed_query("function to parse JSON")?;
69/// ```
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
71pub enum EmbeddingTask {
72    /// Natural language to code retrieval (default).
73    /// Use for: finding code snippets from natural language queries.
74    #[default]
75    NL2Code,
76    /// Code to code retrieval.
77    /// Use for: finding similar code snippets.
78    Code2Code,
79    /// Code to natural language.
80    /// Use for: finding comments/documentation for code.
81    Code2NL,
82    /// Code completion retrieval.
83    /// Use for: finding code completions.
84    Code2Completion,
85    /// Technical question answering.
86    /// Use for: finding answers to programming questions.
87    QA,
88}
89
90impl EmbeddingTask {
91    /// Get the instruction prefix for QUERY embeddings.
92    ///
93    /// Use this when embedding search queries.
94    pub fn query_prefix(&self) -> &'static str {
95        match self {
96            Self::NL2Code => "Find the most relevant code snippet given the following query:\n",
97            Self::Code2Code => "Find an equivalent code snippet given the following code snippet:\n",
98            Self::Code2NL => "Find the most relevant comment given the following code snippet:\n",
99            Self::Code2Completion => "Find the most relevant completion given the following start of code snippet:\n",
100            Self::QA => "Find the most relevant answer given the following question:\n",
101        }
102    }
103
104    /// Get the instruction prefix for PASSAGE/DOCUMENT embeddings.
105    ///
106    /// Use this when embedding code or documents to be indexed.
107    pub fn passage_prefix(&self) -> &'static str {
108        match self {
109            Self::NL2Code => "Candidate code snippet:\n",
110            Self::Code2Code => "Candidate code snippet:\n",
111            Self::Code2NL => "Candidate comment:\n",
112            Self::Code2Completion => "Candidate completion:\n",
113            Self::QA => "Candidate answer:\n",
114        }
115    }
116
117    /// Get the default instruction prefix (alias for query_prefix).
118    ///
119    /// For backward compatibility.
120    pub fn instruction_prefix(&self) -> &'static str {
121        self.query_prefix()
122    }
123
124    /// Get a human-readable name for this task.
125    pub fn name(&self) -> &'static str {
126        match self {
127            Self::NL2Code => "nl2code",
128            Self::Code2Code => "code2code",
129            Self::Code2NL => "code2nl",
130            Self::Code2Completion => "code2completion",
131            Self::QA => "qa",
132        }
133    }
134
135    /// Parse task from string name.
136    pub fn from_name(name: &str) -> Option<Self> {
137        match name.to_lowercase().as_str() {
138            "nl2code" | "text2code" | "natural-language-to-code" => Some(Self::NL2Code),
139            "code2code" | "code-to-code" | "similar-code" => Some(Self::Code2Code),
140            "code2nl" | "code-to-text" | "summarize" => Some(Self::Code2NL),
141            "code2completion" | "completion" | "autocomplete" => Some(Self::Code2Completion),
142            "qa" | "question-answering" | "technical-qa" => Some(Self::QA),
143            _ => None,
144        }
145    }
146}
147
148/// Matryoshka embedding dimensions supported by Jina Code 1.5B.
149///
150/// The model supports truncating embeddings to smaller dimensions with
151/// minimal performance loss due to Matryoshka Representation Learning.
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
153pub enum MatryoshkaDimension {
154    /// 128 dimensions - smallest, most storage efficient.
155    D128 = 128,
156    /// 256 dimensions.
157    D256 = 256,
158    /// 512 dimensions.
159    D512 = 512,
160    /// 1024 dimensions.
161    D1024 = 1024,
162    /// 1536 dimensions - full model output (default).
163    D1536 = 1536,
164}
165
166impl Default for MatryoshkaDimension {
167    fn default() -> Self {
168        Self::D1536
169    }
170}
171
172impl MatryoshkaDimension {
173    /// Get the dimension value.
174    pub fn value(&self) -> usize {
175        *self as usize
176    }
177
178    /// Get all available dimensions.
179    pub fn all() -> &'static [MatryoshkaDimension] {
180        &[
181            Self::D128,
182            Self::D256,
183            Self::D512,
184            Self::D1024,
185            Self::D1536,
186        ]
187    }
188
189    /// Create from a dimension value, rounding up to nearest supported dimension.
190    pub fn from_value(dim: usize) -> Self {
191        if dim <= 128 {
192            Self::D128
193        } else if dim <= 256 {
194            Self::D256
195        } else if dim <= 512 {
196            Self::D512
197        } else if dim <= 1024 {
198            Self::D1024
199        } else {
200            Self::D1536
201        }
202    }
203}
204
205// ============================================================================
206// EXECUTION PROVIDER INFO
207// ============================================================================
208
209/// Information about the execution provider (CPU/GPU) being used.
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ExecutionProviderInfo {
212    /// Name of the execution provider (e.g., "CPU", "CUDA", "DirectML", "CoreML").
213    pub name: String,
214    /// Whether GPU acceleration is being used.
215    pub is_gpu: bool,
216    /// Device ID if applicable.
217    pub device_id: Option<u32>,
218    /// Additional details about the provider.
219    pub details: Option<String>,
220}
221
222impl ExecutionProviderInfo {
223    /// Create CPU execution provider info.
224    pub fn cpu() -> Self {
225        Self {
226            name: "CPU".to_string(),
227            is_gpu: false,
228            device_id: None,
229            details: Some("Default CPU execution".to_string()),
230        }
231    }
232
233    /// Create CUDA execution provider info.
234    #[allow(dead_code)]
235    pub fn cuda(device_id: u32) -> Self {
236        Self {
237            name: "CUDA".to_string(),
238            is_gpu: true,
239            device_id: Some(device_id),
240            details: Some(format!("NVIDIA CUDA GPU (device {})", device_id)),
241        }
242    }
243
244    /// Create TensorRT execution provider info.
245    #[allow(dead_code)]
246    pub fn tensorrt(device_id: u32) -> Self {
247        Self {
248            name: "TensorRT".to_string(),
249            is_gpu: true,
250            device_id: Some(device_id),
251            details: Some(format!("NVIDIA TensorRT GPU (device {})", device_id)),
252        }
253    }
254
255    /// Create DirectML execution provider info.
256    #[allow(dead_code)]
257    pub fn directml(device_id: u32) -> Self {
258        Self {
259            name: "DirectML".to_string(),
260            is_gpu: true,
261            device_id: Some(device_id),
262            details: Some(format!("DirectML GPU (device {})", device_id)),
263        }
264    }
265
266    /// Create CoreML execution provider info.
267    #[allow(dead_code)]
268    pub fn coreml() -> Self {
269        Self {
270            name: "CoreML".to_string(),
271            is_gpu: true,
272            device_id: None,
273            details: Some("Apple CoreML (Neural Engine/GPU)".to_string()),
274        }
275    }
276
277    /// Get a human-readable description.
278    pub fn description(&self) -> String {
279        if self.is_gpu {
280            format!("{} (GPU accelerated)", self.name)
281        } else {
282            format!("{} (no GPU)", self.name)
283        }
284    }
285}
286
287/// ONNX-based embedding model using ONNX Runtime.
288///
289/// Loads an ONNX model and tokenizer from a directory and generates embeddings.
290///
291/// # Example
292///
293/// ```rust,ignore
294/// use aurora_semantic::OnnxEmbedder;
295///
296/// // Point to your downloaded model directory
297/// let embedder = OnnxEmbedder::from_directory("./models/jina-code")?;
298///
299/// let embedding = embedder.embed("fn main() { println!(\"Hello\"); }")?;
300/// println!("Embedding dimension: {}", embedding.len());
301/// ```
302pub struct OnnxEmbedder {
303    session: Arc<Mutex<Session>>,
304    tokenizer: Tokenizer,
305    dimension: usize,
306    max_length: usize,
307    execution_provider: ExecutionProviderInfo,
308}
309
310impl OnnxEmbedder {
311    /// Load an ONNX model from a directory.
312    ///
313    /// The directory must contain:
314    /// - An ONNX model file (model.onnx, model-w-mean-pooling.onnx, etc.)
315    /// - `tokenizer.json` - HuggingFace tokenizer file
316    pub fn from_directory<P: AsRef<Path>>(model_dir: P) -> Result<Self> {
317        let model_dir = model_dir.as_ref();
318
319        // Check for model file with various common names
320        let model_names = [
321            "model.onnx",
322            "model_optimized.onnx",
323            "model-w-mean-pooling.onnx",
324            "model_quantized.onnx",
325            "encoder_model.onnx",
326        ];
327
328        let model_path = model_names
329            .iter()
330            .map(|name| model_dir.join(name))
331            .find(|p| p.exists())
332            .ok_or_else(|| {
333                Error::model_load(format!(
334                    "No ONNX model file found in {}. Expected one of: {:?}",
335                    model_dir.display(),
336                    model_names
337                ))
338            })?;
339
340        // Check tokenizer exists
341        let tokenizer_path = model_dir.join("tokenizer.json");
342        if !tokenizer_path.exists() {
343            return Err(Error::model_load(format!(
344                "tokenizer.json not found in {}",
345                model_dir.display()
346            )));
347        }
348
349        Self::new(&model_path, &tokenizer_path, None, 512)
350    }
351
352    /// Create a new ONNX embedder with explicit paths.
353    pub fn new<P: AsRef<Path>>(
354        model_path: P,
355        tokenizer_path: P,
356        dimension: Option<usize>,
357        max_length: usize,
358    ) -> Result<Self> {
359        let model_path = model_path.as_ref();
360        let tokenizer_path = tokenizer_path.as_ref();
361
362        // Load tokenizer
363        let tokenizer = Tokenizer::from_file(tokenizer_path)
364            .map_err(|e| Error::model_load(format!("Failed to load tokenizer: {}", e)))?;
365
366        // Load ONNX model with ort
367        tracing::info!("Loading ONNX model from: {}", model_path.display());
368
369        #[allow(unused_mut)]
370        let mut builder = Session::builder()
371            .map_err(|e| Error::model_load(format!("Failed to create session builder: {}", e)))?;
372
373        // Track which execution provider we're using
374        // Note: This variable is set by feature-gated code below
375        #[allow(unused_mut, unused_variables)]
376        let mut _execution_provider = ExecutionProviderInfo::cpu();
377
378        // Configure execution providers based on available features
379        #[cfg(feature = "cuda")]
380        {
381            use ort::execution_providers::CUDAExecutionProvider;
382            tracing::info!("CUDA support enabled, attempting GPU acceleration");
383            builder = builder
384                .with_execution_providers([CUDAExecutionProvider::default().build()])
385                .map_err(|e| Error::model_load(format!("Failed to configure CUDA: {}", e)))?;
386            _execution_provider = ExecutionProviderInfo::cuda(0);
387        }
388
389        #[cfg(feature = "tensorrt")]
390        {
391            use ort::execution_providers::TensorRTExecutionProvider;
392            tracing::info!("TensorRT support enabled, attempting GPU acceleration");
393            builder = builder
394                .with_execution_providers([TensorRTExecutionProvider::default().build()])
395                .map_err(|e| Error::model_load(format!("Failed to configure TensorRT: {}", e)))?;
396            _execution_provider = ExecutionProviderInfo::tensorrt(0);
397        }
398
399        #[cfg(feature = "directml")]
400        {
401            use ort::execution_providers::DirectMLExecutionProvider;
402            tracing::info!("DirectML support enabled, attempting GPU acceleration");
403            builder = builder
404                .with_execution_providers([DirectMLExecutionProvider::default().build()])
405                .map_err(|e| Error::model_load(format!("Failed to configure DirectML: {}", e)))?;
406            _execution_provider = ExecutionProviderInfo::directml(0);
407        }
408
409        #[cfg(feature = "coreml")]
410        {
411            use ort::execution_providers::CoreMLExecutionProvider;
412            tracing::info!("CoreML support enabled, attempting GPU acceleration");
413            builder = builder
414                .with_execution_providers([CoreMLExecutionProvider::default().build()])
415                .map_err(|e| Error::model_load(format!("Failed to configure CoreML: {}", e)))?;
416            _execution_provider = ExecutionProviderInfo::coreml();
417        }
418
419        let session = builder
420            .with_intra_threads(4)
421            .map_err(|e| Error::model_load(format!("Failed to set threads: {}", e)))?
422            .commit_from_file(model_path)
423            .map_err(|e| Error::model_load(format!("Failed to load ONNX model: {}", e)))?;
424
425        // Detect actual output dimension from model
426        let detected_dimension = if model_path.to_string_lossy().contains("jina-code-embeddings-1.5b-ONNX") {
427            1536 // Known dimension for this model
428        } else {
429            dimension.unwrap_or(768) // fallback for other models
430        };
431
432        let dimension = detected_dimension;
433
434        // IMPORTANT: Verify which execution provider is actually being used
435        // ONNX Runtime may silently fall back to CPU if GPU init fails
436        let actual_provider = detect_actual_execution_provider(&session);
437        let execution_provider = actual_provider;
438
439        tracing::info!(
440            "Loaded ONNX model (dim={}, max_len={}, provider={})",
441            dimension,
442            max_length,
443            execution_provider.description()
444        );
445
446        Ok(Self {
447            session: Arc::new(Mutex::new(session)),
448            tokenizer,
449            dimension,
450            max_length,
451            execution_provider,
452        })
453    }
454
455    /// Set the maximum sequence length.
456    pub fn with_max_length(mut self, max_length: usize) -> Self {
457        self.max_length = max_length;
458        self
459    }
460
461    /// Get information about the execution provider (CPU/GPU).
462    pub fn execution_provider(&self) -> &ExecutionProviderInfo {
463        &self.execution_provider
464    }
465
466    /// Check if GPU acceleration is being used.
467    pub fn is_gpu_accelerated(&self) -> bool {
468        self.execution_provider.is_gpu
469    }
470
471    /// Mean pooling over token embeddings.
472    fn mean_pooling(&self, data: &[f32], shape: &[i64], attention_mask: &[i64]) -> Vec<f32> {
473        if shape.len() != 3 {
474            // Not token embeddings, return as-is
475            return data.to_vec();
476        }
477
478        let seq_len = shape[1] as usize;
479        let dim = shape[2] as usize;
480        let mut result = vec![0.0f32; dim];
481        let mut count = 0.0f32;
482
483        for i in 0..seq_len {
484            if i < attention_mask.len() && attention_mask[i] == 1 {
485                for j in 0..dim {
486                    result[j] += data[i * dim + j];
487                }
488                count += 1.0;
489            }
490        }
491
492        if count > 0.0 {
493            for val in &mut result {
494                *val /= count;
495            }
496        }
497
498        // L2 normalize
499        let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
500        if norm > 0.0 {
501            for val in &mut result {
502                *val /= norm;
503            }
504        }
505
506        result
507    }
508}
509
510impl Embedder for OnnxEmbedder {
511    fn embed(&self, text: &str) -> Result<Vec<f32>> {
512        let results = self.embed_batch(&[text])?;
513        Ok(results.into_iter().next().unwrap_or_default())
514    }
515
516    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
517        if texts.is_empty() {
518            return Ok(Vec::new());
519        }
520
521        let mut all_embeddings = Vec::new();
522
523        // Process one at a time for simplicity
524        for text in texts {
525            // Tokenize
526            let encoding = self
527                .tokenizer
528                .encode(*text, true)
529                .map_err(|e| Error::embedding(format!("Tokenization failed: {}", e)))?;
530
531            let ids = encoding.get_ids();
532            let mask = encoding.get_attention_mask();
533            let seq_len = ids.len().min(self.max_length);
534
535            // Prepare input tensors
536            let input_ids: Vec<i64> = ids.iter().take(seq_len).map(|&id| id as i64).collect();
537            let attention_mask: Vec<i64> = mask.iter().take(seq_len).map(|&m| m as i64).collect();
538            let position_ids: Vec<i64> = (0..seq_len as i64).collect();
539
540            // Create ort tensors
541            let input_ids_array = ndarray::Array::from_shape_vec(IxDyn(&[1, seq_len]), input_ids.clone())
542                .map_err(|e| Error::embedding(format!("Failed to create input tensor: {}", e)))?;
543            let input_ids_tensor = Tensor::from_array(input_ids_array)
544                .map_err(|e| Error::embedding(format!("Failed to create input tensor: {}", e)))?;
545
546            let attention_mask_array = ndarray::Array::from_shape_vec(IxDyn(&[1, seq_len]), attention_mask.clone())
547                .map_err(|e| Error::embedding(format!("Failed to create mask tensor: {}", e)))?;
548            let attention_mask_tensor = Tensor::from_array(attention_mask_array)
549                .map_err(|e| Error::embedding(format!("Failed to create mask tensor: {}", e)))?;
550
551            let position_ids_array = ndarray::Array::from_shape_vec(IxDyn(&[1, seq_len]), position_ids)
552                .map_err(|e| Error::embedding(format!("Failed to create position tensor: {}", e)))?;
553            let position_ids_tensor = Tensor::from_array(position_ids_array)
554                .map_err(|e| Error::embedding(format!("Failed to create position tensor: {}", e)))?;
555
556            // Run inference
557            let mut session = self.session.lock();
558            
559            // Get first output name before running
560            let first_output_name = session.outputs.first()
561                .map(|o| o.name.clone())
562                .unwrap_or_else(|| "output".to_string());
563            
564            // Check model input requirements and run inference accordingly
565            let outputs = if session.inputs.iter().any(|input| input.name == "position_ids") {
566                session
567                    .run(ort::inputs![
568                        "input_ids" => input_ids_tensor,
569                        "attention_mask" => attention_mask_tensor,
570                        "position_ids" => position_ids_tensor,
571                    ])
572                    .map_err(|e| Error::embedding(format!("Inference failed: {}", e)))?
573            } else if session.inputs.iter().any(|input| input.name == "token_type_ids") {
574                session
575                    .run(ort::inputs![
576                        "input_ids" => input_ids_tensor,
577                        "attention_mask" => attention_mask_tensor,
578                        "token_type_ids" => position_ids_tensor, // Reuse position_ids as token_type_ids
579                    ])
580                    .map_err(|e| Error::embedding(format!("Inference failed: {}", e)))?
581            } else {
582                session
583                    .run(ort::inputs![
584                        "input_ids" => input_ids_tensor,
585                        "attention_mask" => attention_mask_tensor,
586                    ])
587                    .map_err(|e| Error::embedding(format!("Inference failed: {}", e)))?
588            };
589
590            // Extract embeddings - try different output names
591            let output = if let Some(val) = outputs.get("last_hidden_state") {
592                val
593            } else if let Some(val) = outputs.get("sentence_embedding") {
594                val
595            } else {
596                outputs.get(&first_output_name)
597                    .ok_or_else(|| Error::embedding("No output found".to_string()))?
598            };
599
600            let (output_shape, output_data) = output
601                .try_extract_tensor::<f32>()
602                .map_err(|e| Error::embedding(format!("Failed to extract output: {}", e)))?;
603
604            let shape_vec: Vec<i64> = output_shape.iter().map(|&d| d as i64).collect();
605
606            let embedding = if shape_vec.len() == 2 {
607                // Direct sentence embedding [1, dim]
608                let emb: Vec<f32> = output_data.to_vec();
609                normalize_vector(emb)
610            } else if shape_vec.len() == 3 {
611                // Token embeddings [1, seq, dim] - need pooling
612                self.mean_pooling(output_data, &shape_vec, &attention_mask)
613            } else {
614                return Err(Error::embedding(format!(
615                    "Unexpected output shape: {:?}",
616                    shape_vec
617                )));
618            };
619
620            all_embeddings.push(embedding);
621        }
622
623        Ok(all_embeddings)
624    }
625
626    fn dimension(&self) -> usize {
627        self.dimension
628    }
629
630    fn name(&self) -> &'static str {
631        "onnx-runtime"
632    }
633
634    fn max_sequence_length(&self) -> usize {
635        self.max_length
636    }
637}
638
639/// Detect which execution provider is actually being used by the session.
640/// ONNX Runtime may silently fall back to CPU if GPU initialization fails.
641fn detect_actual_execution_provider(session: &Session) -> ExecutionProviderInfo {
642    // Get the actual execution providers from the session metadata
643    // Unfortunately, ort doesn't expose this directly, so we check what was requested
644    // and verify by looking at session metadata
645    
646    // Check session metadata for provider info
647    if let Ok(metadata) = session.metadata() {
648        let producer = metadata.producer().unwrap_or_default();
649        let description = metadata.description().unwrap_or_default();
650        
651        tracing::debug!("Session metadata - producer: {}, description: {}", producer, description);
652    }
653    
654    // The best way to detect is to check which providers are actually available
655    // and were successfully initialized
656    
657    #[cfg(feature = "cuda")]
658    {
659        // Check if CUDA is actually available at runtime
660        // This requires checking if the CUDA EP was successfully registered
661        if is_cuda_available() {
662            return ExecutionProviderInfo::cuda(0);
663        } else {
664            tracing::warn!("CUDA feature enabled but CUDA runtime not available - falling back to CPU");
665        }
666    }
667    
668    #[cfg(feature = "directml")]
669    {
670        if is_directml_available() {
671            return ExecutionProviderInfo::directml(0);
672        } else {
673            tracing::warn!("DirectML feature enabled but not available - falling back to CPU");
674        }
675    }
676    
677    #[cfg(feature = "coreml")]
678    {
679        return ExecutionProviderInfo::coreml();
680    }
681    
682    ExecutionProviderInfo::cpu()
683}
684
685/// Check if CUDA is actually available at runtime.
686#[cfg(feature = "cuda")]
687fn is_cuda_available() -> bool {
688    // Try to check if CUDA runtime is available
689    // This is a simple heuristic - check if nvidia-smi works or if CUDA libs are loadable
690    
691    // On Windows, check for CUDA DLLs
692    #[cfg(target_os = "windows")]
693    {
694        use std::path::Path;
695        
696        // Check common CUDA paths
697        let cuda_paths = [
698            "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA",
699            "C:\\CUDA",
700        ];
701        
702        for base in cuda_paths {
703            let path = Path::new(base);
704            if path.exists() {
705                // Check for cudart DLL
706                for entry in std::fs::read_dir(path).ok().into_iter().flatten() {
707                    if let Ok(entry) = entry {
708                        let name = entry.file_name();
709                        let name_str = name.to_string_lossy();
710                        if name_str.starts_with("v") && entry.path().is_dir() {
711                            let cudart = entry.path().join("bin").join("cudart64_12.dll");
712                            let cudart_11 = entry.path().join("bin").join("cudart64_11.dll");
713                            if cudart.exists() || cudart_11.exists() {
714                                tracing::info!("Found CUDA runtime at: {}", entry.path().display());
715                                return true;
716                            }
717                        }
718                    }
719                }
720            }
721        }
722        
723        // Also check PATH
724        if let Ok(path) = std::env::var("PATH") {
725            for dir in path.split(';') {
726                let cudart = Path::new(dir).join("cudart64_12.dll");
727                let cudart_11 = Path::new(dir).join("cudart64_11.dll");
728                if cudart.exists() || cudart_11.exists() {
729                    tracing::info!("Found CUDA runtime in PATH: {}", dir);
730                    return true;
731                }
732            }
733        }
734        
735        // Check if nvidia-smi works
736        if let Ok(output) = std::process::Command::new("nvidia-smi").output() {
737            if output.status.success() {
738                tracing::info!("nvidia-smi available, assuming CUDA works");
739                return true;
740            }
741        }
742        
743        false
744    }
745    
746    #[cfg(not(target_os = "windows"))]
747    {
748        // On Linux/macOS, check for libcudart
749        if let Ok(output) = std::process::Command::new("nvidia-smi").output() {
750            return output.status.success();
751        }
752        false
753    }
754}
755
756#[cfg(feature = "directml")]
757fn is_directml_available() -> bool {
758    // DirectML is usually available on Windows 10+
759    #[cfg(target_os = "windows")]
760    {
761        true
762    }
763    #[cfg(not(target_os = "windows"))]
764    {
765        false
766    }
767}
768
769/// Normalize vector to unit length.
770fn normalize_vector(mut v: Vec<f32>) -> Vec<f32> {
771    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
772    if norm > 0.0 {
773        for x in &mut v {
774            *x /= norm;
775        }
776    }
777    v
778}
779
780/// Simple hash-based embedder for testing (no model required).
781///
782/// Generates deterministic embeddings based on text hash.
783/// Useful for testing the search pipeline without loading a real model.
784pub struct HashEmbedder {
785    dimension: usize,
786}
787
788impl HashEmbedder {
789    /// Create a new hash embedder with the specified dimension.
790    pub fn new(dimension: usize) -> Self {
791        Self { dimension }
792    }
793
794    fn hash_to_embedding(&self, text: &str) -> Vec<f32> {
795        use std::collections::hash_map::DefaultHasher;
796        use std::hash::{Hash, Hasher};
797
798        let mut result = vec![0.0f32; self.dimension];
799        let hash = {
800            let mut hasher = DefaultHasher::new();
801            text.hash(&mut hasher);
802            hasher.finish()
803        };
804
805        // Generate pseudo-random values from hash
806        let mut seed = hash;
807        for val in result.iter_mut() {
808            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
809            *val = ((seed >> 32) as f32 / u32::MAX as f32) * 2.0 - 1.0;
810        }
811
812        normalize_vector(result)
813    }
814}
815
816impl Embedder for HashEmbedder {
817    fn embed(&self, text: &str) -> Result<Vec<f32>> {
818        Ok(self.hash_to_embedding(text))
819    }
820
821    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
822        Ok(texts.iter().map(|t| self.hash_to_embedding(t)).collect())
823    }
824
825    fn dimension(&self) -> usize {
826        self.dimension
827    }
828
829    fn name(&self) -> &'static str {
830        "hash"
831    }
832}
833
834// ============================================================================
835// JINA CODE EMBEDDER - Specialized for jina-code-embeddings-1.5b
836// ============================================================================
837
838/// Jina Code Embeddings 1.5B specialized embedder.
839///
840/// This embedder is optimized for the jina-code-embeddings-1.5b model with:
841/// - Task-specific instruction prefixes (NL2Code, Code2Code, etc.)
842/// - Last-token pooling (required for this decoder-based model)
843/// - Matryoshka dimension truncation (128, 256, 512, 1024, 1536)
844/// - Automatic handling of the 32768 context window
845///
846/// # Asymmetric vs Symmetric Embedding Mode
847///
848/// For optimal retrieval quality, use **different modes** for indexing vs querying:
849///
850/// - **Passage mode** (default): Use for indexing code/documents - adds passage prefix
851/// - **Query mode**: Use for search queries - adds query prefix
852///
853/// # Example
854///
855/// ```rust,ignore
856/// use aurora_semantic::{JinaCodeEmbedder, EmbeddingTask, MatryoshkaDimension, EmbeddingMode};
857///
858/// // For INDEXING: use Passage mode (default)
859/// let indexer = JinaCodeEmbedder::from_directory("./models/jina-code-1.5b")?
860///     .with_task(EmbeddingTask::NL2Code)
861///     .with_mode(EmbeddingMode::Passage);  // Default, can omit
862///
863/// // For SEARCHING: use Query mode
864/// let searcher = JinaCodeEmbedder::from_directory("./models/jina-code-1.5b")?
865///     .with_task(EmbeddingTask::NL2Code)
866///     .with_mode(EmbeddingMode::Query);
867/// ```
868pub struct JinaCodeEmbedder {
869    /// Inner ONNX embedder.
870    inner: OnnxEmbedder,
871    /// Current embedding task.
872    task: EmbeddingTask,
873    /// Output dimension (for Matryoshka truncation).
874    output_dimension: MatryoshkaDimension,
875    /// Embedding mode (Query or Passage).
876    mode: EmbeddingMode,
877}
878
879/// Controls whether `embed()` uses query or passage instruction prefix.
880///
881/// For optimal asymmetric retrieval with Jina Code 1.5B:
882/// - Use **Passage** mode when indexing code/documents
883/// - Use **Query** mode when embedding search queries
884#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
885pub enum EmbeddingMode {
886    /// Use passage prefix for indexing code/documents (default).
887    /// Prefix: "Candidate code snippet:\n"
888    #[default]
889    Passage,
890    /// Use query prefix for search queries.
891    /// Prefix: "Find the most relevant code snippet...\n"
892    Query,
893}
894
895impl JinaCodeEmbedder {
896    /// Default max sequence length for Jina Code 1.5B.
897    pub const DEFAULT_MAX_LENGTH: usize = 32768;
898    /// Default dimension for Jina Code 1.5B.
899    pub const DEFAULT_DIMENSION: usize = 1536;
900
901    /// Load Jina Code Embeddings 1.5B from a model directory.
902    ///
903    /// The directory should contain:
904    /// - `model.onnx` - The ONNX model file
905    /// - `tokenizer.json` - The HuggingFace tokenizer
906    pub fn from_directory<P: AsRef<Path>>(model_dir: P) -> Result<Self> {
907        let inner = OnnxEmbedder::from_directory(&model_dir)?;
908        Ok(Self {
909            inner,
910            task: EmbeddingTask::default(),
911            output_dimension: MatryoshkaDimension::default(),
912            mode: EmbeddingMode::default(),
913        })
914    }
915
916    /// Create from an existing OnnxEmbedder.
917    pub fn from_onnx_embedder(inner: OnnxEmbedder) -> Self {
918        Self {
919            inner,
920            task: EmbeddingTask::default(),
921            output_dimension: MatryoshkaDimension::default(),
922            mode: EmbeddingMode::default(),
923        }
924    }
925
926    /// Set the embedding task (determines instruction prefix).
927    pub fn with_task(mut self, task: EmbeddingTask) -> Self {
928        self.task = task;
929        self
930    }
931
932    /// Set the output dimension (Matryoshka truncation).
933    ///
934    /// Smaller dimensions reduce storage and speed up similarity search
935    /// with minimal quality loss.
936    pub fn with_dimension(mut self, dimension: MatryoshkaDimension) -> Self {
937        self.output_dimension = dimension;
938        self
939    }
940
941    /// Set the maximum sequence length.
942    pub fn with_max_length(mut self, max_length: usize) -> Self {
943        self.inner.max_length = max_length;
944        self
945    }
946
947    /// Set the embedding mode (Query or Passage).
948    ///
949    /// - **Passage** (default): Use for indexing code - adds passage prefix
950    /// - **Query**: Use for search queries - adds query prefix
951    pub fn with_mode(mut self, mode: EmbeddingMode) -> Self {
952        self.mode = mode;
953        self
954    }
955
956    /// Get the current embedding mode.
957    pub fn mode(&self) -> EmbeddingMode {
958        self.mode
959    }
960
961    /// Get the current task.
962    pub fn task(&self) -> EmbeddingTask {
963        self.task
964    }
965
966    /// Get the output dimension.
967    pub fn output_dimension(&self) -> MatryoshkaDimension {
968        self.output_dimension
969    }
970
971    /// Get information about the execution provider (CPU/GPU).
972    pub fn execution_provider(&self) -> &ExecutionProviderInfo {
973        self.inner.execution_provider()
974    }
975
976    /// Check if GPU acceleration is being used.
977    pub fn is_gpu_accelerated(&self) -> bool {
978        self.inner.is_gpu_accelerated()
979    }
980
981    // ========================================================================
982    // ASYMMETRIC EMBEDDING METHODS
983    // These are the recommended methods for Jina Code 1.5B retrieval
984    // ========================================================================
985
986    /// Embed a search query with the query instruction prefix.
987    ///
988    /// Use this for user queries when searching the index.
989    ///
990    /// # Example
991    /// ```rust,ignore
992    /// let query_embedding = embedder.embed_query("function to parse JSON")?;
993    /// ```
994    pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
995        let prefixed = format!("{}{}", self.task.query_prefix(), text);
996        let embedding = self.inner.embed(&prefixed)?;
997        Ok(self.truncate_embedding(embedding))
998    }
999
1000    /// Embed a code snippet/passage with the passage instruction prefix.
1001    ///
1002    /// Use this when indexing code to build the search index.
1003    ///
1004    /// # Example
1005    /// ```rust,ignore
1006    /// let code_embedding = embedder.embed_passage("fn parse_json(s: &str) -> Value { ... }")?;
1007    /// ```
1008    pub fn embed_passage(&self, text: &str) -> Result<Vec<f32>> {
1009        let prefixed = format!("{}{}", self.task.passage_prefix(), text);
1010        let embedding = self.inner.embed(&prefixed)?;
1011        Ok(self.truncate_embedding(embedding))
1012    }
1013
1014    /// Embed multiple queries in batch.
1015    pub fn embed_queries(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
1016        let prefixed: Vec<String> = texts.iter()
1017            .map(|t| format!("{}{}", self.task.query_prefix(), t))
1018            .collect();
1019        let refs: Vec<&str> = prefixed.iter().map(|s| s.as_str()).collect();
1020        
1021        let embeddings = self.inner.embed_batch(&refs)?;
1022        Ok(embeddings.into_iter()
1023            .map(|e| self.truncate_embedding(e))
1024            .collect())
1025    }
1026
1027    /// Embed multiple code passages in batch.
1028    ///
1029    /// Use this for efficient bulk indexing.
1030    pub fn embed_passages(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
1031        let prefixed: Vec<String> = texts.iter()
1032            .map(|t| format!("{}{}", self.task.passage_prefix(), t))
1033            .collect();
1034        let refs: Vec<&str> = prefixed.iter().map(|s| s.as_str()).collect();
1035        
1036        let embeddings = self.inner.embed_batch(&refs)?;
1037        Ok(embeddings.into_iter()
1038            .map(|e| self.truncate_embedding(e))
1039            .collect())
1040    }
1041
1042    // ========================================================================
1043    // INTERNAL HELPERS
1044    // ========================================================================
1045
1046    /// Apply task instruction prefix based on current mode.
1047    /// - Passage mode: uses passage prefix (for indexing)
1048    /// - Query mode: uses query prefix (for searching)
1049    fn apply_task_prefix(&self, text: &str) -> String {
1050        let prefix = match self.mode {
1051            EmbeddingMode::Passage => self.task.passage_prefix(),
1052            EmbeddingMode::Query => self.task.query_prefix(),
1053        };
1054        format!("{}{}", prefix, text)
1055    }
1056
1057    /// Truncate embedding to the configured Matryoshka dimension.
1058    fn truncate_embedding(&self, embedding: Vec<f32>) -> Vec<f32> {
1059        let target_dim = self.output_dimension.value();
1060        if embedding.len() <= target_dim {
1061            normalize_vector(embedding)
1062        } else {
1063            let truncated: Vec<f32> = embedding.into_iter().take(target_dim).collect();
1064            normalize_vector(truncated)
1065        }
1066    }
1067
1068    /// Perform last-token pooling on token embeddings.
1069    fn last_token_pooling(&self, data: &[f32], shape: &[i64], attention_mask: &[i64]) -> Vec<f32> {
1070        if shape.len() != 3 {
1071            // Not token embeddings, return as-is
1072            return data.to_vec();
1073        }
1074
1075        let seq_len = shape[1] as usize;
1076        let dim = shape[2] as usize;
1077
1078        // Find the last valid token (last position where attention_mask == 1)
1079        let last_valid_pos = attention_mask.iter()
1080            .enumerate()
1081            .rev()
1082            .find(|(_, &mask)| mask == 1)
1083            .map(|(i, _)| i)
1084            .unwrap_or(seq_len.saturating_sub(1));
1085
1086        // Extract the embedding at the last valid position
1087        let start = last_valid_pos * dim;
1088        let end = start + dim;
1089
1090        if end <= data.len() {
1091            let result: Vec<f32> = data[start..end].to_vec();
1092            normalize_vector(result)
1093        } else {
1094            // Fallback to mean pooling if something goes wrong
1095            self.inner.mean_pooling(data, shape, attention_mask)
1096        }
1097    }
1098}
1099
1100impl Embedder for JinaCodeEmbedder {
1101    fn embed(&self, text: &str) -> Result<Vec<f32>> {
1102        let prefixed = self.apply_task_prefix(text);
1103        let embedding = self.inner.embed(&prefixed)?;
1104        Ok(self.truncate_embedding(embedding))
1105    }
1106
1107    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
1108        let prefixed: Vec<String> = texts.iter()
1109            .map(|t| self.apply_task_prefix(t))
1110            .collect();
1111        let refs: Vec<&str> = prefixed.iter().map(|s| s.as_str()).collect();
1112        
1113        let embeddings = self.inner.embed_batch(&refs)?;
1114        Ok(embeddings.into_iter()
1115            .map(|e| self.truncate_embedding(e))
1116            .collect())
1117    }
1118
1119    /// Override for asymmetric retrieval - always use query prefix for search queries.
1120    fn embed_for_query(&self, text: &str) -> Result<Vec<f32>> {
1121        // Always use query prefix for search queries, regardless of mode
1122        let prefixed = format!("{}{}", self.task.query_prefix(), text);
1123        let embedding = self.inner.embed(&prefixed)?;
1124        Ok(self.truncate_embedding(embedding))
1125    }
1126
1127    fn dimension(&self) -> usize {
1128        self.output_dimension.value()
1129    }
1130
1131    fn name(&self) -> &'static str {
1132        "jina-code-1.5b"
1133    }
1134
1135    fn max_sequence_length(&self) -> usize {
1136        self.inner.max_sequence_length()
1137    }
1138}
1139
1140// ============================================================================
1141// JINA CODE MODEL CONFIG
1142// ============================================================================
1143
1144/// Configuration for loading a Jina Code Embeddings model.
1145#[derive(Debug, Clone, Serialize, Deserialize)]
1146pub struct JinaCodeConfig {
1147    /// Path to model directory.
1148    pub model_path: std::path::PathBuf,
1149    /// Embedding task.
1150    pub task: EmbeddingTask,
1151    /// Output dimension (Matryoshka).
1152    pub dimension: MatryoshkaDimension,
1153    /// Maximum sequence length.
1154    pub max_length: usize,
1155}
1156
1157impl Default for JinaCodeConfig {
1158    fn default() -> Self {
1159        Self {
1160            model_path: std::path::PathBuf::new(),
1161            task: EmbeddingTask::default(),
1162            dimension: MatryoshkaDimension::default(),
1163            max_length: JinaCodeEmbedder::DEFAULT_MAX_LENGTH,
1164        }
1165    }
1166}
1167
1168impl JinaCodeConfig {
1169    /// Create config for a Jina Code model directory.
1170    pub fn from_directory<P: AsRef<Path>>(path: P) -> Self {
1171        Self {
1172            model_path: path.as_ref().to_path_buf(),
1173            ..Default::default()
1174        }
1175    }
1176
1177    /// Set the embedding task.
1178    pub fn with_task(mut self, task: EmbeddingTask) -> Self {
1179        self.task = task;
1180        self
1181    }
1182
1183    /// Set the output dimension.
1184    pub fn with_dimension(mut self, dimension: MatryoshkaDimension) -> Self {
1185        self.dimension = dimension;
1186        self
1187    }
1188
1189    /// Set the maximum sequence length.
1190    pub fn with_max_length(mut self, max_length: usize) -> Self {
1191        self.max_length = max_length;
1192        self
1193    }
1194
1195    /// Load the embedder from this config.
1196    pub fn load(&self) -> Result<JinaCodeEmbedder> {
1197        JinaCodeEmbedder::from_directory(&self.model_path)
1198            .map(|e| e
1199                .with_task(self.task)
1200                .with_dimension(self.dimension)
1201                .with_max_length(self.max_length))
1202    }
1203}
1204
1205// ============================================================================
1206// MODEL CONFIG
1207// ============================================================================
1208
1209/// Configuration for loading an embedding model.
1210#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1211pub struct ModelConfig {
1212    /// Path to model directory or ONNX file.
1213    pub model_path: std::path::PathBuf,
1214    /// Path to tokenizer.json (optional if in model_path).
1215    pub tokenizer_path: Option<std::path::PathBuf>,
1216    /// Embedding dimension (auto-detect if None).
1217    pub dimension: Option<usize>,
1218    /// Maximum sequence length.
1219    pub max_length: usize,
1220}
1221
1222impl Default for ModelConfig {
1223    fn default() -> Self {
1224        Self {
1225            model_path: std::path::PathBuf::new(),
1226            tokenizer_path: None,
1227            dimension: None,
1228            max_length: 512,
1229        }
1230    }
1231}
1232
1233impl ModelConfig {
1234    /// Create config for a model directory.
1235    pub fn from_directory<P: AsRef<Path>>(path: P) -> Self {
1236        Self {
1237            model_path: path.as_ref().to_path_buf(),
1238            tokenizer_path: None,
1239            dimension: None,
1240            max_length: 512,
1241        }
1242    }
1243
1244    /// Set maximum sequence length.
1245    pub fn with_max_length(mut self, max_length: usize) -> Self {
1246        self.max_length = max_length;
1247        self
1248    }
1249
1250    /// Set embedding dimension.
1251    pub fn with_dimension(mut self, dimension: usize) -> Self {
1252        self.dimension = Some(dimension);
1253        self
1254    }
1255
1256    /// Load the embedder from this config.
1257    pub fn load(&self) -> Result<OnnxEmbedder> {
1258        if self.model_path.is_dir() {
1259            let mut embedder = OnnxEmbedder::from_directory(&self.model_path)?;
1260            embedder.max_length = self.max_length;
1261            if let Some(dim) = self.dimension {
1262                embedder.dimension = dim;
1263            }
1264            Ok(embedder)
1265        } else {
1266            let tokenizer_path = self
1267                .tokenizer_path
1268                .clone()
1269                .unwrap_or_else(|| self.model_path.with_file_name("tokenizer.json"));
1270
1271            OnnxEmbedder::new(
1272                &self.model_path,
1273                &tokenizer_path,
1274                self.dimension,
1275                self.max_length,
1276            )
1277        }
1278    }
1279}
1280
1281#[cfg(test)]
1282mod tests {
1283    use super::*;
1284
1285    #[test]
1286    fn test_hash_embedder() {
1287        let embedder = HashEmbedder::new(384);
1288
1289        let embedding = embedder.embed("test code").unwrap();
1290        assert_eq!(embedding.len(), 384);
1291
1292        // Check normalization
1293        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
1294        assert!((norm - 1.0).abs() < 0.01);
1295
1296        // Same text = same embedding
1297        let embedding2 = embedder.embed("test code").unwrap();
1298        assert_eq!(embedding, embedding2);
1299
1300        // Different text = different embedding
1301        let embedding3 = embedder.embed("other code").unwrap();
1302        assert_ne!(embedding, embedding3);
1303    }
1304
1305    #[test]
1306    fn test_batch_embedding() {
1307        let embedder = HashEmbedder::new(128);
1308
1309        let texts = vec!["hello", "world", "test"];
1310        let embeddings = embedder.embed_batch(&texts).unwrap();
1311
1312        assert_eq!(embeddings.len(), 3);
1313        for emb in &embeddings {
1314            assert_eq!(emb.len(), 128);
1315        }
1316    }
1317}