aurora_semantic/embeddings/
mod.rs

1//! Embeddings module for generating semantic vectors.
2//!
3//! This module provides functionality to generate embedding vectors
4//! from source code chunks using various embedding providers:
5//!
6//! - **API-based**: Jina AI, OpenAI, Cohere, Voyage AI
7//! - **Local**: HuggingFace models via Candle
8//! - **Custom**: Implement the `Embedder` trait
9
10pub mod pooling;
11mod providers;
12
13pub use pooling::cosine_similarity;
14pub use providers::*;
15
16use crate::error::Result;
17use crate::types::Chunk;
18
19/// Trait for embedding generators.
20///
21/// Implement this trait to add support for custom embedding providers.
22///
23/// # Example
24///
25/// ```rust,ignore
26/// use aurora_semantic::{Embedder, Result};
27///
28/// struct MyEmbedder {
29///     dimension: usize,
30/// }
31///
32/// impl Embedder for MyEmbedder {
33///     fn embed(&self, text: &str) -> Result<Vec<f32>> {
34///         // Your embedding logic here
35///         Ok(vec![0.0; self.dimension])
36///     }
37///
38///     fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
39///         texts.iter().map(|t| self.embed(t)).collect()
40///     }
41///
42///     fn dimension(&self) -> usize {
43///         self.dimension
44///     }
45///
46///     fn name(&self) -> &'static str {
47///         "my-embedder"
48///     }
49/// }
50/// ```
51pub trait Embedder: Send + Sync {
52    /// Generate an embedding for a single text.
53    fn embed(&self, text: &str) -> Result<Vec<f32>>;
54
55    /// Generate embeddings for multiple texts in batch.
56    ///
57    /// Default implementation calls `embed` for each text.
58    /// Override for more efficient batch processing.
59    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
60        texts.iter().map(|t| self.embed(t)).collect()
61    }
62
63    /// Get the embedding dimension.
64    fn dimension(&self) -> usize;
65
66    /// Get the name of this embedder.
67    fn name(&self) -> &'static str;
68
69    /// Get the maximum sequence length supported.
70    fn max_sequence_length(&self) -> usize {
71        512 // Default, override for models with different limits
72    }
73}
74
75/// Embed chunks and return (chunk_id, embedding) pairs.
76#[allow(dead_code)]
77pub fn embed_chunks<E: Embedder>(
78    embedder: &E,
79    chunks: &[Chunk],
80    batch_size: usize,
81) -> Result<Vec<(String, Vec<f32>)>> {
82    let mut results = Vec::with_capacity(chunks.len());
83
84    for batch in chunks.chunks(batch_size) {
85        let texts: Vec<&str> = batch.iter().map(|c| c.content.as_str()).collect();
86        let embeddings = embedder.embed_batch(&texts)?;
87
88        for (chunk, embedding) in batch.iter().zip(embeddings.into_iter()) {
89            results.push((chunk.id.0.to_string(), embedding));
90        }
91    }
92
93    Ok(results)
94}
95
96/// Preprocess code for embedding.
97///
98/// This normalizes code to improve embedding quality.
99pub fn preprocess_code(content: &str) -> String {
100    let mut result = String::with_capacity(content.len());
101
102    // Remove excessive whitespace while preserving structure
103    let mut prev_was_space = false;
104    let mut prev_was_newline = false;
105
106    for c in content.chars() {
107        if c == '\n' {
108            if !prev_was_newline {
109                result.push('\n');
110                prev_was_newline = true;
111            }
112            prev_was_space = false;
113        } else if c.is_whitespace() {
114            if !prev_was_space && !prev_was_newline {
115                result.push(' ');
116                prev_was_space = true;
117            }
118        } else {
119            result.push(c);
120            prev_was_space = false;
121            prev_was_newline = false;
122        }
123    }
124
125    result.trim().to_string()
126}
127
128/// Create a context-enriched text for embedding.
129///
130/// This adds metadata to help the embedding model understand the code context.
131pub fn create_embedding_text(chunk: &Chunk) -> String {
132    let mut text = String::new();
133
134    // Add chunk type context
135    let type_name = match chunk.chunk_type {
136        crate::types::ChunkType::Function => "function",
137        crate::types::ChunkType::Class => "class",
138        crate::types::ChunkType::Struct => "struct",
139        crate::types::ChunkType::Enum => "enum",
140        crate::types::ChunkType::Interface => "interface",
141        crate::types::ChunkType::Implementation => "implementation",
142        crate::types::ChunkType::Module => "module",
143        crate::types::ChunkType::Imports => "imports",
144        crate::types::ChunkType::Constant => "constant",
145        crate::types::ChunkType::TypeDef => "type definition",
146        crate::types::ChunkType::Block => "code block",
147        crate::types::ChunkType::Comment => "documentation",
148    };
149
150    // Add symbol name if available
151    if let Some(ref name) = chunk.symbol_name {
152        text.push_str(&format!("{} {} ", type_name, name));
153    } else {
154        text.push_str(&format!("{} ", type_name));
155    }
156
157    // Add parent context if available
158    if let Some(ref parent) = chunk.parent_symbol {
159        text.push_str(&format!("in {} ", parent));
160    }
161
162    // Add the actual code
163    text.push_str(&preprocess_code(&chunk.content));
164
165    text
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::types::{ChunkId, ChunkType, DocumentId};
172
173    #[test]
174    fn test_preprocess_code() {
175        let input = "fn   foo()  {\n\n\n    bar()\n}";
176        let output = preprocess_code(input);
177        assert!(!output.contains("  ")); // No double spaces
178        assert!(!output.contains("\n\n")); // No double newlines
179    }
180
181    #[test]
182    fn test_create_embedding_text() {
183        let chunk = Chunk {
184            id: ChunkId::new(),
185            document_id: DocumentId::new(),
186            content: "fn add(a: i32, b: i32) -> i32 { a + b }".to_string(),
187            chunk_type: ChunkType::Function,
188            start_line: 1,
189            end_line: 1,
190            start_byte: 0,
191            end_byte: 38,
192            symbol_name: Some("add".to_string()),
193            parent_symbol: None,
194        };
195
196        let text = create_embedding_text(&chunk);
197        assert!(text.starts_with("function add"));
198        assert!(text.contains("fn add"));
199    }
200}