1#![allow(dead_code)] use crate::rag::providers::EmbeddingProvider as ProviderTrait;
9use anyhow::Result;
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone)]
14pub struct BenchmarkResults {
15 pub provider_name: String,
16 pub total_texts: usize,
17 pub total_duration: Duration,
18 pub avg_embedding_time: Duration,
19 pub embeddings_per_second: f64,
20 pub dimension: usize,
21 pub memory_usage_mb: Option<f64>,
22 pub semantic_quality_score: Option<f64>,
23}
24
25pub struct BenchmarkTestData {
27 pub texts: Vec<&'static str>,
28 pub semantic_pairs: Vec<(usize, usize, f32)>, }
30
31impl BenchmarkTestData {
32 pub fn new_default() -> Self {
34 let texts = vec![
35 "React hooks useState for state management",
36 "useState React hook for managing component state",
37 "Python Django models for database operations",
38 "Django Python framework for web development",
39 "JavaScript async await for asynchronous programming",
40 "Node.js Express framework for web servers",
41 "Rust memory safety without garbage collection",
42 "C++ manual memory management with pointers",
43 "Machine learning with neural networks",
44 "Deep learning artificial intelligence models",
45 ];
46
47 let semantic_pairs = vec![
49 (0, 1, 0.8), (2, 3, 0.7), (8, 9, 0.8), (0, 2, 0.2), (4, 6, 0.1), ];
55
56 Self {
57 texts,
58 semantic_pairs,
59 }
60 }
61
62 pub fn extended() -> Self {
64 let texts = vec![
65 "React hooks useState for state management in functional components",
67 "useState React hook manages local component state efficiently",
68 "Vue.js reactive data binding with computed properties",
69 "Angular component lifecycle hooks and state management",
70 "PostgreSQL relational database with ACID transactions",
72 "MongoDB document database with flexible schema design",
73 "Redis in-memory data structure store for caching",
74 "SQLite lightweight embedded database for applications",
75 "Deep neural networks for computer vision tasks",
77 "Convolutional neural networks process image data effectively",
78 "Natural language processing with transformer models",
79 "BERT transformer model for text understanding tasks",
80 "RESTful API design principles and best practices",
82 "GraphQL flexible query language for API development",
83 "Microservices architecture pattern for scalable systems",
84 "Docker containerization for application deployment",
85 "Cooking pasta requires boiling water and salt",
87 "Weather forecast shows rain tomorrow afternoon",
88 "Basketball game ended with a score of 95-87",
89 "Garden flowers bloom beautifully in spring season",
90 ];
91
92 let semantic_pairs = vec![
93 (0, 1, 0.85), (9, 10, 0.80), (11, 12, 0.75), (4, 5, 0.65), (0, 2, 0.45), (13, 14, 0.60), (0, 16, 0.05), (4, 17, 0.05), (9, 18, 0.05), ];
106
107 Self {
108 texts,
109 semantic_pairs,
110 }
111 }
112}
113
114pub async fn benchmark_provider<T: ProviderTrait + Send + Sync + ?Sized>(
116 provider: &T,
117 test_data: &BenchmarkTestData,
118) -> Result<BenchmarkResults> {
119 let provider_info = provider.get_info();
120 let dimension = provider.get_dimension().await?;
121
122 let mut embeddings = Vec::new();
123 let start_time = Instant::now();
124
125 for text in &test_data.texts {
127 let embedding = provider.embed_text(text).await?;
128 embeddings.push(embedding);
129 }
130
131 let total_duration = start_time.elapsed();
132 let avg_embedding_time = total_duration / test_data.texts.len() as u32;
133 let embeddings_per_second = test_data.texts.len() as f64 / total_duration.as_secs_f64();
134
135 let semantic_quality_score = calculate_semantic_quality(&embeddings, test_data);
137
138 Ok(BenchmarkResults {
139 provider_name: provider_info.name,
140 total_texts: test_data.texts.len(),
141 total_duration,
142 avg_embedding_time,
143 embeddings_per_second,
144 dimension,
145 memory_usage_mb: Some(get_process_memory_mb()),
146 semantic_quality_score: Some(semantic_quality_score),
147 })
148}
149
150fn calculate_semantic_quality(embeddings: &[Vec<f32>], test_data: &BenchmarkTestData) -> f64 {
152 if test_data.semantic_pairs.is_empty() {
153 return 0.0;
154 }
155
156 let mut total_error = 0.0;
157
158 for (idx1, idx2, expected_sim) in &test_data.semantic_pairs {
159 if *idx1 < embeddings.len() && *idx2 < embeddings.len() {
160 let actual_sim = cosine_similarity(&embeddings[*idx1], &embeddings[*idx2]);
161 let error = (actual_sim - expected_sim).abs();
162 total_error += error;
163 }
164 }
165
166 let avg_error = total_error / test_data.semantic_pairs.len() as f32;
167 (1.0 - avg_error.min(1.0)) as f64
169}
170
171fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
173 if a.len() != b.len() {
174 return 0.0;
175 }
176
177 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
178 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
179 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
180
181 if norm_a == 0.0 || norm_b == 0.0 {
182 0.0
183 } else {
184 dot_product / (norm_a * norm_b)
185 }
186}
187
188pub async fn compare_providers(
190 providers: Vec<(&str, Box<dyn ProviderTrait + Send + Sync>)>,
191 test_data: &BenchmarkTestData,
192) -> Result<Vec<BenchmarkResults>> {
193 let mut results = Vec::new();
194
195 for (name, provider) in providers {
196 println!("Benchmarking provider: {}", name);
197 match benchmark_provider(provider.as_ref(), test_data).await {
198 Ok(result) => {
199 println!("ā
{} completed", name);
200 results.push(result);
201 }
202 Err(e) => {
203 println!("ā {} failed: {}", name, e);
204 }
205 }
206 }
207
208 Ok(results)
209}
210
211pub fn print_benchmark_results(results: &[BenchmarkResults]) {
213 println!("\nš Embedding Provider Benchmark Results");
214 println!("{}", "=".repeat(80));
215
216 for result in results {
217 println!("\nš§ Provider: {}", result.provider_name);
218 println!(" Texts processed: {}", result.total_texts);
219 println!(" Total time: {:.2}ms", result.total_duration.as_millis());
220 println!(
221 " Avg per embedding: {:.2}ms",
222 result.avg_embedding_time.as_millis()
223 );
224 println!(
225 " Throughput: {:.1} embeddings/sec",
226 result.embeddings_per_second
227 );
228 println!(" Embedding dimension: {}", result.dimension);
229
230 if let Some(quality) = result.semantic_quality_score {
231 println!(" Semantic quality: {:.3} (0.0-1.0)", quality);
232 }
233
234 if let Some(memory) = result.memory_usage_mb {
235 println!(" Memory usage: {:.1}MB", memory);
236 }
237 }
238
239 println!("\n{}", "=".repeat(80));
240}
241
242fn get_process_memory_mb() -> f64 {
244 #[cfg(target_os = "linux")]
245 {
246 use std::fs;
247 if let Ok(status) = fs::read_to_string("/proc/self/status") {
249 for line in status.lines() {
250 if line.starts_with("VmRSS:") {
251 if let Some(kb_str) = line.split_whitespace().nth(1) {
253 if let Ok(kb) = kb_str.parse::<f64>() {
254 return kb / 1024.0; }
256 }
257 break;
258 }
259 }
260 }
261 0.0
263 }
264
265 #[cfg(target_os = "macos")]
266 {
267 0.0
271 }
272
273 #[cfg(target_os = "windows")]
274 {
275 0.0
277 }
278
279 #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
280 {
281 0.0
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use crate::rag::providers::hash::HashProvider;
290
291 #[tokio::test]
292 async fn test_hash_provider_benchmark() {
293 let provider = HashProvider::new(384);
294 let test_data = BenchmarkTestData::new_default();
295
296 let result = benchmark_provider(&provider, &test_data).await.unwrap();
297
298 assert_eq!(result.total_texts, test_data.texts.len());
299 assert_eq!(result.dimension, 384);
300 assert!(result.embeddings_per_second > 0.0);
301 assert!(result.semantic_quality_score.is_some());
302 }
303
304 #[test]
305 fn test_cosine_similarity() {
306 let a = vec![1.0, 0.0, 0.0];
307 let b = vec![1.0, 0.0, 0.0];
308 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
309
310 let c = vec![0.0, 1.0, 0.0];
311 assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.001);
312 }
313}