aurora_semantic/embeddings/
mod.rs

1//! Embeddings module for generating semantic vectors.
2//!
3//! This module provides functionality to generate embedding vectors
4//! from source code chunks using various embedding providers:
5//!
6//! - **API-based**: Jina AI, OpenAI, Cohere, Voyage AI
7//! - **Local**: HuggingFace models via Candle
8//! - **Custom**: Implement the `Embedder` trait
9
10pub mod pooling;
11mod providers;
12
13pub use pooling::cosine_similarity;
14pub use providers::*;
15
16use crate::error::Result;
17use crate::types::Chunk;
18
19/// Trait for embedding generators.
20///
21/// Implement this trait to add support for custom embedding providers.
22///
23/// # Example
24///
25/// ```rust,ignore
26/// use aurora_semantic::{Embedder, Result};
27///
28/// struct MyEmbedder {
29///     dimension: usize,
30/// }
31///
32/// impl Embedder for MyEmbedder {
33///     fn embed(&self, text: &str) -> Result<Vec<f32>> {
34///         // Your embedding logic here
35///         Ok(vec![0.0; self.dimension])
36///     }
37///
38///     fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
39///         texts.iter().map(|t| self.embed(t)).collect()
40///     }
41///
42///     fn dimension(&self) -> usize {
43///         self.dimension
44///     }
45///
46///     fn name(&self) -> &'static str {
47///         "my-embedder"
48///     }
49/// }
50/// ```
51pub trait Embedder: Send + Sync {
52    /// Generate an embedding for a single text.
53    fn embed(&self, text: &str) -> Result<Vec<f32>>;
54
55    /// Generate embeddings for multiple texts in batch.
56    ///
57    /// Default implementation calls `embed` for each text.
58    /// Override for more efficient batch processing.
59    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
60        texts.iter().map(|t| self.embed(t)).collect()
61    }
62
63    /// Generate an embedding for a search query.
64    ///
65    /// For asymmetric retrieval models (like Jina Code 1.5B), this uses
66    /// query-specific instruction prefixes. For symmetric models, this
67    /// just calls `embed()`.
68    ///
69    /// Use this when embedding user search queries.
70    fn embed_for_query(&self, text: &str) -> Result<Vec<f32>> {
71        // Default: same as embed()
72        // JinaCodeEmbedder overrides this to use query prefix
73        self.embed(text)
74    }
75
76    /// Get the embedding dimension.
77    fn dimension(&self) -> usize;
78
79    /// Get the name of this embedder.
80    fn name(&self) -> &'static str;
81
82    /// Get the maximum sequence length supported.
83    fn max_sequence_length(&self) -> usize {
84        512 // Default, override for models with different limits
85    }
86}
87
88/// Embed chunks and return (chunk_id, embedding) pairs.
89#[allow(dead_code)]
90pub fn embed_chunks<E: Embedder>(
91    embedder: &E,
92    chunks: &[Chunk],
93    batch_size: usize,
94) -> Result<Vec<(String, Vec<f32>)>> {
95    let mut results = Vec::with_capacity(chunks.len());
96
97    for batch in chunks.chunks(batch_size) {
98        let texts: Vec<&str> = batch.iter().map(|c| c.content.as_str()).collect();
99        let embeddings = embedder.embed_batch(&texts)?;
100
101        for (chunk, embedding) in batch.iter().zip(embeddings.into_iter()) {
102            results.push((chunk.id.0.to_string(), embedding));
103        }
104    }
105
106    Ok(results)
107}
108
109/// Preprocess code for embedding.
110///
111/// This normalizes code to improve embedding quality.
112pub fn preprocess_code(content: &str) -> String {
113    let mut result = String::with_capacity(content.len());
114
115    // Remove excessive whitespace while preserving structure
116    let mut prev_was_space = false;
117    let mut prev_was_newline = false;
118
119    for c in content.chars() {
120        if c == '\n' {
121            if !prev_was_newline {
122                result.push('\n');
123                prev_was_newline = true;
124            }
125            prev_was_space = false;
126        } else if c.is_whitespace() {
127            if !prev_was_space && !prev_was_newline {
128                result.push(' ');
129                prev_was_space = true;
130            }
131        } else {
132            result.push(c);
133            prev_was_space = false;
134            prev_was_newline = false;
135        }
136    }
137
138    result.trim().to_string()
139}
140
141/// Create a context-enriched text for embedding.
142///
143/// This adds metadata to help the embedding model understand the code context.
144pub fn create_embedding_text(chunk: &Chunk) -> String {
145    let mut text = String::new();
146
147    // Add chunk type context
148    let type_name = match chunk.chunk_type {
149        crate::types::ChunkType::Function => "function",
150        crate::types::ChunkType::Class => "class",
151        crate::types::ChunkType::Struct => "struct",
152        crate::types::ChunkType::Enum => "enum",
153        crate::types::ChunkType::Interface => "interface",
154        crate::types::ChunkType::Implementation => "implementation",
155        crate::types::ChunkType::Module => "module",
156        crate::types::ChunkType::Imports => "imports",
157        crate::types::ChunkType::Constant => "constant",
158        crate::types::ChunkType::TypeDef => "type definition",
159        crate::types::ChunkType::Block => "code block",
160        crate::types::ChunkType::Comment => "documentation",
161    };
162
163    // Add symbol name if available
164    if let Some(ref name) = chunk.symbol_name {
165        text.push_str(&format!("{} {} ", type_name, name));
166    } else {
167        text.push_str(&format!("{} ", type_name));
168    }
169
170    // Add parent context if available
171    if let Some(ref parent) = chunk.parent_symbol {
172        text.push_str(&format!("in {} ", parent));
173    }
174
175    // Add the actual code
176    text.push_str(&preprocess_code(&chunk.content));
177
178    text
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use crate::types::{ChunkId, ChunkType, DocumentId};
185
186    #[test]
187    fn test_preprocess_code() {
188        let input = "fn   foo()  {\n\n\n    bar()\n}";
189        let output = preprocess_code(input);
190        assert!(!output.contains("  ")); // No double spaces
191        assert!(!output.contains("\n\n")); // No double newlines
192    }
193
194    #[test]
195    fn test_create_embedding_text() {
196        let chunk = Chunk {
197            id: ChunkId::new(),
198            document_id: DocumentId::new(),
199            content: "fn add(a: i32, b: i32) -> i32 { a + b }".to_string(),
200            chunk_type: ChunkType::Function,
201            start_line: 1,
202            end_line: 1,
203            start_byte: 0,
204            end_byte: 38,
205            symbol_name: Some("add".to_string()),
206            parent_symbol: None,
207        };
208
209        let text = create_embedding_text(&chunk);
210        assert!(text.starts_with("function add"));
211        assert!(text.contains("fn add"));
212    }
213}