aurora_semantic/embeddings/
mod.rs1pub mod pooling;
11mod providers;
12
13pub use pooling::cosine_similarity;
14pub use providers::*;
15
16use crate::error::Result;
17use crate::types::Chunk;
18
19pub trait Embedder: Send + Sync {
52 fn embed(&self, text: &str) -> Result<Vec<f32>>;
54
55 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
60 texts.iter().map(|t| self.embed(t)).collect()
61 }
62
63 fn embed_for_query(&self, text: &str) -> Result<Vec<f32>> {
71 self.embed(text)
74 }
75
76 fn dimension(&self) -> usize;
78
79 fn name(&self) -> &'static str;
81
82 fn max_sequence_length(&self) -> usize {
84 512 }
86}
87
88#[allow(dead_code)]
90pub fn embed_chunks<E: Embedder>(
91 embedder: &E,
92 chunks: &[Chunk],
93 batch_size: usize,
94) -> Result<Vec<(String, Vec<f32>)>> {
95 let mut results = Vec::with_capacity(chunks.len());
96
97 for batch in chunks.chunks(batch_size) {
98 let texts: Vec<&str> = batch.iter().map(|c| c.content.as_str()).collect();
99 let embeddings = embedder.embed_batch(&texts)?;
100
101 for (chunk, embedding) in batch.iter().zip(embeddings.into_iter()) {
102 results.push((chunk.id.0.to_string(), embedding));
103 }
104 }
105
106 Ok(results)
107}
108
109pub fn preprocess_code(content: &str) -> String {
113 let mut result = String::with_capacity(content.len());
114
115 let mut prev_was_space = false;
117 let mut prev_was_newline = false;
118
119 for c in content.chars() {
120 if c == '\n' {
121 if !prev_was_newline {
122 result.push('\n');
123 prev_was_newline = true;
124 }
125 prev_was_space = false;
126 } else if c.is_whitespace() {
127 if !prev_was_space && !prev_was_newline {
128 result.push(' ');
129 prev_was_space = true;
130 }
131 } else {
132 result.push(c);
133 prev_was_space = false;
134 prev_was_newline = false;
135 }
136 }
137
138 result.trim().to_string()
139}
140
141pub fn create_embedding_text(chunk: &Chunk) -> String {
145 let mut text = String::new();
146
147 let type_name = match chunk.chunk_type {
149 crate::types::ChunkType::Function => "function",
150 crate::types::ChunkType::Class => "class",
151 crate::types::ChunkType::Struct => "struct",
152 crate::types::ChunkType::Enum => "enum",
153 crate::types::ChunkType::Interface => "interface",
154 crate::types::ChunkType::Implementation => "implementation",
155 crate::types::ChunkType::Module => "module",
156 crate::types::ChunkType::Imports => "imports",
157 crate::types::ChunkType::Constant => "constant",
158 crate::types::ChunkType::TypeDef => "type definition",
159 crate::types::ChunkType::Block => "code block",
160 crate::types::ChunkType::Comment => "documentation",
161 };
162
163 if let Some(ref name) = chunk.symbol_name {
165 text.push_str(&format!("{} {} ", type_name, name));
166 } else {
167 text.push_str(&format!("{} ", type_name));
168 }
169
170 if let Some(ref parent) = chunk.parent_symbol {
172 text.push_str(&format!("in {} ", parent));
173 }
174
175 text.push_str(&preprocess_code(&chunk.content));
177
178 text
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::types::{ChunkId, ChunkType, DocumentId};
185
186 #[test]
187 fn test_preprocess_code() {
188 let input = "fn foo() {\n\n\n bar()\n}";
189 let output = preprocess_code(input);
190 assert!(!output.contains(" ")); assert!(!output.contains("\n\n")); }
193
194 #[test]
195 fn test_create_embedding_text() {
196 let chunk = Chunk {
197 id: ChunkId::new(),
198 document_id: DocumentId::new(),
199 content: "fn add(a: i32, b: i32) -> i32 { a + b }".to_string(),
200 chunk_type: ChunkType::Function,
201 start_line: 1,
202 end_line: 1,
203 start_byte: 0,
204 end_byte: 38,
205 symbol_name: Some("add".to_string()),
206 parent_symbol: None,
207 };
208
209 let text = create_embedding_text(&chunk);
210 assert!(text.starts_with("function add"));
211 assert!(text.contains("fn add"));
212 }
213}