aurora_semantic/embeddings/
mod.rs1pub mod pooling;
11mod providers;
12
13pub use pooling::cosine_similarity;
14pub use providers::*;
15
16use crate::error::Result;
17use crate::types::Chunk;
18
19pub trait Embedder: Send + Sync {
52 fn embed(&self, text: &str) -> Result<Vec<f32>>;
54
55 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
60 texts.iter().map(|t| self.embed(t)).collect()
61 }
62
63 fn dimension(&self) -> usize;
65
66 fn name(&self) -> &'static str;
68
69 fn max_sequence_length(&self) -> usize {
71 512 }
73}
74
75#[allow(dead_code)]
77pub fn embed_chunks<E: Embedder>(
78 embedder: &E,
79 chunks: &[Chunk],
80 batch_size: usize,
81) -> Result<Vec<(String, Vec<f32>)>> {
82 let mut results = Vec::with_capacity(chunks.len());
83
84 for batch in chunks.chunks(batch_size) {
85 let texts: Vec<&str> = batch.iter().map(|c| c.content.as_str()).collect();
86 let embeddings = embedder.embed_batch(&texts)?;
87
88 for (chunk, embedding) in batch.iter().zip(embeddings.into_iter()) {
89 results.push((chunk.id.0.to_string(), embedding));
90 }
91 }
92
93 Ok(results)
94}
95
96pub fn preprocess_code(content: &str) -> String {
100 let mut result = String::with_capacity(content.len());
101
102 let mut prev_was_space = false;
104 let mut prev_was_newline = false;
105
106 for c in content.chars() {
107 if c == '\n' {
108 if !prev_was_newline {
109 result.push('\n');
110 prev_was_newline = true;
111 }
112 prev_was_space = false;
113 } else if c.is_whitespace() {
114 if !prev_was_space && !prev_was_newline {
115 result.push(' ');
116 prev_was_space = true;
117 }
118 } else {
119 result.push(c);
120 prev_was_space = false;
121 prev_was_newline = false;
122 }
123 }
124
125 result.trim().to_string()
126}
127
128pub fn create_embedding_text(chunk: &Chunk) -> String {
132 let mut text = String::new();
133
134 let type_name = match chunk.chunk_type {
136 crate::types::ChunkType::Function => "function",
137 crate::types::ChunkType::Class => "class",
138 crate::types::ChunkType::Struct => "struct",
139 crate::types::ChunkType::Enum => "enum",
140 crate::types::ChunkType::Interface => "interface",
141 crate::types::ChunkType::Implementation => "implementation",
142 crate::types::ChunkType::Module => "module",
143 crate::types::ChunkType::Imports => "imports",
144 crate::types::ChunkType::Constant => "constant",
145 crate::types::ChunkType::TypeDef => "type definition",
146 crate::types::ChunkType::Block => "code block",
147 crate::types::ChunkType::Comment => "documentation",
148 };
149
150 if let Some(ref name) = chunk.symbol_name {
152 text.push_str(&format!("{} {} ", type_name, name));
153 } else {
154 text.push_str(&format!("{} ", type_name));
155 }
156
157 if let Some(ref parent) = chunk.parent_symbol {
159 text.push_str(&format!("in {} ", parent));
160 }
161
162 text.push_str(&preprocess_code(&chunk.content));
164
165 text
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::types::{ChunkId, ChunkType, DocumentId};
172
173 #[test]
174 fn test_preprocess_code() {
175 let input = "fn foo() {\n\n\n bar()\n}";
176 let output = preprocess_code(input);
177 assert!(!output.contains(" ")); assert!(!output.contains("\n\n")); }
180
181 #[test]
182 fn test_create_embedding_text() {
183 let chunk = Chunk {
184 id: ChunkId::new(),
185 document_id: DocumentId::new(),
186 content: "fn add(a: i32, b: i32) -> i32 { a + b }".to_string(),
187 chunk_type: ChunkType::Function,
188 start_line: 1,
189 end_line: 1,
190 start_byte: 0,
191 end_byte: 38,
192 symbol_name: Some("add".to_string()),
193 parent_symbol: None,
194 };
195
196 let text = create_embedding_text(&chunk);
197 assert!(text.starts_with("function add"));
198 assert!(text.contains("fn add"));
199 }
200}