Skip to main content

oxirs_embed/
utils_math.rs

1//! Numeric/vector math utilities: dot products, norms, distances, dataset statistics,
2//! embedding analysis, graph analysis, progress tracking, and performance utilities.
3
4use crate::utils_types::{
5    BenchmarkComparison, BenchmarkConfig, BenchmarkResult, BenchmarkSummary, DatasetStatistics,
6    EmbeddingDistributionStats, GraphMetrics, MemoryStats, RegressionAnalysis, SimilarityStats,
7};
8use anyhow::{anyhow, Result};
9use scirs2_core::ndarray_ext::{Array1, Array2};
10use scirs2_core::random::Random;
11use std::collections::{BTreeMap, HashMap, HashSet};
12use std::time::{Duration, Instant};
13
14/// Compute dataset statistics
15pub fn compute_dataset_statistics(triples: &[(String, String, String)]) -> DatasetStatistics {
16    let mut entities = HashSet::new();
17    let mut relations = HashSet::new();
18    let mut entity_frequency = HashMap::new();
19    let mut relation_frequency = HashMap::new();
20
21    for (subject, predicate, object) in triples {
22        entities.insert(subject.clone());
23        entities.insert(object.clone());
24        relations.insert(predicate.clone());
25
26        *entity_frequency.entry(subject.clone()).or_insert(0) += 1;
27        *entity_frequency.entry(object.clone()).or_insert(0) += 1;
28        *relation_frequency.entry(predicate.clone()).or_insert(0) += 1;
29    }
30
31    let num_entities = entities.len();
32    let num_relations = relations.len();
33    let num_triples = triples.len();
34
35    let avg_degree = if num_entities > 0 {
36        (num_triples * 2) as f64 / num_entities as f64
37    } else {
38        0.0
39    };
40
41    let max_possible_edges = num_entities * num_entities;
42    let density = if max_possible_edges > 0 {
43        num_triples as f64 / max_possible_edges as f64
44    } else {
45        0.0
46    };
47
48    DatasetStatistics {
49        num_triples,
50        num_entities,
51        num_relations,
52        entity_frequency,
53        relation_frequency,
54        avg_degree,
55        density,
56    }
57}
58
59/// Embedding dimension analysis utilities
60pub mod embedding_analysis {
61    use super::*;
62
63    /// Analyze embedding distribution
64    pub fn analyze_embedding_distribution(embeddings: &Array2<f64>) -> EmbeddingDistributionStats {
65        let flat_values: Vec<f64> = embeddings.iter().cloned().collect();
66
67        let mean = flat_values.iter().sum::<f64>() / flat_values.len() as f64;
68        let variance =
69            flat_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / flat_values.len() as f64;
70        let std_dev = variance.sqrt();
71
72        let mut sorted_values = flat_values.clone();
73        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
74
75        let min_val = sorted_values[0];
76        let max_val = sorted_values[sorted_values.len() - 1];
77        let median = sorted_values[sorted_values.len() / 2];
78
79        EmbeddingDistributionStats {
80            mean,
81            std_dev,
82            variance,
83            min: min_val,
84            max: max_val,
85            median,
86            num_parameters: embeddings.len(),
87        }
88    }
89
90    /// Compute embedding norms
91    pub fn compute_embedding_norms(embeddings: &Array2<f64>) -> Vec<f64> {
92        embeddings
93            .rows()
94            .into_iter()
95            .map(|row| row.dot(&row).sqrt())
96            .collect()
97    }
98
99    /// Analyze embedding similarities
100    pub fn analyze_embedding_similarities(
101        embeddings: &Array2<f64>,
102        sample_size: usize,
103    ) -> SimilarityStats {
104        let num_embeddings = embeddings.nrows();
105        let mut similarities = Vec::new();
106
107        let sample_size = sample_size.min(num_embeddings * (num_embeddings - 1) / 2);
108        let mut rng = Random::default();
109
110        for _ in 0..sample_size {
111            let i = rng.random_range(0..num_embeddings);
112            let j = rng.random_range(0..num_embeddings);
113
114            if i != j {
115                let emb_i = embeddings.row(i);
116                let emb_j = embeddings.row(j);
117                let similarity = cosine_similarity_array(&emb_i.to_owned(), &emb_j.to_owned());
118                similarities.push(similarity);
119            }
120        }
121
122        similarities.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
123
124        let mean_similarity = similarities.iter().sum::<f64>() / similarities.len() as f64;
125        let min_similarity = similarities[0];
126        let max_similarity = similarities[similarities.len() - 1];
127        let median_similarity = similarities[similarities.len() / 2];
128
129        SimilarityStats {
130            mean_similarity,
131            min_similarity,
132            max_similarity,
133            median_similarity,
134            num_comparisons: similarities.len(),
135        }
136    }
137
138    /// Cosine similarity between two ndarray vectors
139    fn cosine_similarity_array(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
140        let dot_product = a.dot(b);
141        let norm_a = a.dot(a).sqrt();
142        let norm_b = b.dot(b).sqrt();
143
144        if norm_a > 1e-10 && norm_b > 1e-10 {
145            dot_product / (norm_a * norm_b)
146        } else {
147            0.0
148        }
149    }
150}
151
152/// Graph analysis utilities
153pub mod graph_analysis {
154    use super::*;
155
156    /// Compute graph metrics for knowledge graph
157    pub fn compute_graph_metrics(triples: &[(String, String, String)]) -> GraphMetrics {
158        let estimated_entities = triples.len();
159        let estimated_relations = triples.len() / 10;
160
161        let mut entity_degrees: HashMap<String, usize> = HashMap::with_capacity(estimated_entities);
162        let mut relation_counts: HashMap<String, usize> =
163            HashMap::with_capacity(estimated_relations);
164        let mut entities = HashSet::with_capacity(estimated_entities);
165
166        for (subject, predicate, object) in triples {
167            entities.insert(subject.clone());
168            entities.insert(object.clone());
169
170            *entity_degrees.entry(subject.clone()).or_insert(0) += 1;
171            *entity_degrees.entry(object.clone()).or_insert(0) += 1;
172            *relation_counts.entry(predicate.clone()).or_insert(0) += 1;
173        }
174
175        let num_entities = entities.len();
176        let num_relations = relation_counts.len();
177        let num_triples = triples.len();
178
179        let degrees: Vec<usize> = entity_degrees.values().cloned().collect();
180        let avg_degree = degrees.iter().sum::<usize>() as f64 / degrees.len() as f64;
181        let max_degree = degrees.iter().max().cloned().unwrap_or(0);
182        let min_degree = degrees.iter().min().cloned().unwrap_or(0);
183
184        GraphMetrics {
185            num_entities,
186            num_relations,
187            num_triples,
188            avg_degree,
189            max_degree,
190            min_degree,
191            density: num_triples as f64 / (num_entities * num_entities) as f64,
192        }
193    }
194}
195
196/// Progress tracking utility
197#[derive(Debug)]
198pub struct ProgressTracker {
199    total: usize,
200    current: usize,
201    start_time: Instant,
202    last_update: Instant,
203    update_interval: Duration,
204}
205
206impl ProgressTracker {
207    /// Create a new progress tracker
208    pub fn new(total: usize) -> Self {
209        let now = Instant::now();
210        Self {
211            total,
212            current: 0,
213            start_time: now,
214            last_update: now,
215            update_interval: Duration::from_secs(1),
216        }
217    }
218
219    /// Update progress
220    pub fn update(&mut self, current: usize) {
221        self.current = current;
222        let now = Instant::now();
223        if now.duration_since(self.last_update) >= self.update_interval {
224            self.print_progress();
225            self.last_update = now;
226        }
227    }
228
229    fn print_progress(&self) {
230        let percentage = (self.current as f64 / self.total as f64) * 100.0;
231        let elapsed = self.start_time.elapsed().as_secs_f64();
232        let rate = self.current as f64 / elapsed;
233        println!(
234            "Progress: {}/{} ({:.1}%) - {:.1} items/sec",
235            self.current, self.total, percentage, rate
236        );
237    }
238
239    /// Finish and print final statistics
240    pub fn finish(&self) {
241        let elapsed = self.start_time.elapsed().as_secs_f64();
242        let rate = self.total as f64 / elapsed;
243        println!(
244            "Completed: {} items in {:.2}s ({:.1} items/sec)",
245            self.total, elapsed, rate
246        );
247    }
248}
249
250/// High-precision timer for micro-benchmarking
251pub struct PrecisionTimer {
252    start_time: Instant,
253    lap_times: Vec<Duration>,
254}
255
256impl Default for PrecisionTimer {
257    fn default() -> Self {
258        Self::new()
259    }
260}
261
262impl PrecisionTimer {
263    pub fn new() -> Self {
264        Self {
265            start_time: Instant::now(),
266            lap_times: Vec::new(),
267        }
268    }
269
270    /// Start timing
271    pub fn start(&mut self) {
272        self.start_time = Instant::now();
273        self.lap_times.clear();
274    }
275
276    /// Record a lap time
277    pub fn lap(&mut self) -> Duration {
278        let lap_duration = self.start_time.elapsed();
279        self.lap_times.push(lap_duration);
280        lap_duration
281    }
282
283    /// Stop timing and return final duration
284    pub fn stop(&self) -> Duration {
285        self.start_time.elapsed()
286    }
287
288    /// Get all recorded lap times
289    pub fn lap_times(&self) -> &[Duration] {
290        &self.lap_times
291    }
292}
293
294/// Benchmarking framework for embedding operations
295pub struct EmbeddingBenchmark {
296    config: BenchmarkConfig,
297    results: BTreeMap<String, BenchmarkResult>,
298}
299
300impl EmbeddingBenchmark {
301    pub fn new(config: BenchmarkConfig) -> Self {
302        Self {
303            config,
304            results: BTreeMap::new(),
305        }
306    }
307
308    /// Benchmark a function with comprehensive timing and memory analysis
309    pub fn benchmark<F, T>(&mut self, name: &str, mut operation: F) -> Result<T>
310    where
311        F: FnMut() -> Result<T>,
312    {
313        // Warmup phase
314        for _ in 0..self.config.warmup_iterations {
315            let _ = operation()?;
316        }
317
318        let mut durations = Vec::with_capacity(self.config.measurement_iterations);
319        let mut memory_snapshots = Vec::new();
320        let mut result = None;
321
322        for i in 0..self.config.measurement_iterations {
323            let memory_before = self.get_memory_usage();
324            let start = Instant::now();
325
326            let op_result = operation()?;
327
328            let duration = start.elapsed();
329            let memory_after = self.get_memory_usage();
330
331            durations.push(duration);
332
333            if self.config.enable_memory_profiling {
334                memory_snapshots.push((memory_before, memory_after));
335            }
336
337            if i == 0 {
338                result = Some(op_result);
339            }
340        }
341
342        let total_duration: Duration = durations.iter().sum();
343        let avg_duration = total_duration / durations.len() as u32;
344        let min_duration = *durations
345            .iter()
346            .min()
347            .expect("durations should not be empty");
348        let max_duration = *durations
349            .iter()
350            .max()
351            .expect("durations should not be empty");
352
353        let variance: f64 = durations
354            .iter()
355            .map(|d| {
356                let diff = d.as_nanos() as f64 - avg_duration.as_nanos() as f64;
357                diff * diff
358            })
359            .sum::<f64>()
360            / durations.len() as f64;
361        let std_deviation = Duration::from_nanos(variance.sqrt() as u64);
362
363        let ops_per_second = 1_000_000_000.0 / avg_duration.as_nanos() as f64;
364
365        let memory_stats = if self.config.enable_memory_profiling && !memory_snapshots.is_empty() {
366            let peak_memory = memory_snapshots
367                .iter()
368                .map(|(_, after)| after.peak_memory_bytes)
369                .max()
370                .unwrap_or(0);
371
372            let avg_memory = memory_snapshots
373                .iter()
374                .map(|(before, after)| (before.avg_memory_bytes + after.avg_memory_bytes) / 2)
375                .sum::<usize>()
376                / memory_snapshots.len();
377
378            MemoryStats {
379                peak_memory_bytes: peak_memory,
380                avg_memory_bytes: avg_memory,
381                allocations: memory_snapshots.len(),
382                deallocations: 0,
383            }
384        } else {
385            MemoryStats {
386                peak_memory_bytes: 0,
387                avg_memory_bytes: 0,
388                allocations: 0,
389                deallocations: 0,
390            }
391        };
392
393        let benchmark_result = BenchmarkResult {
394            operation: name.to_string(),
395            iterations: self.config.measurement_iterations,
396            total_duration,
397            avg_duration,
398            min_duration,
399            max_duration,
400            std_deviation,
401            ops_per_second,
402            memory_stats,
403            custom_metrics: HashMap::new(),
404        };
405
406        self.results.insert(name.to_string(), benchmark_result);
407        result.ok_or_else(|| anyhow!("Failed to capture benchmark result"))
408    }
409
410    /// Generate comprehensive benchmark report
411    pub fn generate_report(&self) -> BenchmarkSuite {
412        let total_duration = self.results.values().map(|r| r.total_duration).sum();
413        let total_operations = self.results.len();
414        let overall_throughput =
415            self.results.values().map(|r| r.ops_per_second).sum::<f64>() / total_operations as f64;
416        let efficiency_score = self.calculate_efficiency_score();
417        let bottlenecks = self.identify_bottlenecks();
418
419        let summary = BenchmarkSummary {
420            total_duration,
421            total_operations,
422            overall_throughput,
423            efficiency_score,
424            bottlenecks,
425        };
426
427        BenchmarkSuite {
428            results: self.results.clone(),
429            summary,
430            config: self.config.clone(),
431        }
432    }
433
434    fn calculate_efficiency_score(&self) -> f64 {
435        if self.results.is_empty() {
436            return 0.0;
437        }
438        let consistency_scores: Vec<f64> = self
439            .results
440            .values()
441            .map(|result| {
442                let cv =
443                    result.std_deviation.as_nanos() as f64 / result.avg_duration.as_nanos() as f64;
444                1.0 / (1.0 + cv)
445            })
446            .collect();
447        consistency_scores.iter().sum::<f64>() / consistency_scores.len() as f64
448    }
449
450    fn identify_bottlenecks(&self) -> Vec<String> {
451        let mut bottlenecks = Vec::new();
452        for (name, result) in &self.results {
453            let cv = result.std_deviation.as_nanos() as f64 / result.avg_duration.as_nanos() as f64;
454            if cv > 0.2 {
455                bottlenecks.push(format!("High variance in {}: {:.2}% CV", name, cv * 100.0));
456            }
457        }
458
459        let avg_throughput = self.results.values().map(|r| r.ops_per_second).sum::<f64>()
460            / self.results.len() as f64;
461
462        for (name, result) in &self.results {
463            if result.ops_per_second < avg_throughput * 0.5 {
464                bottlenecks.push(format!(
465                    "Slow operation {}: {:.0} ops/sec",
466                    name, result.ops_per_second
467                ));
468            }
469        }
470        bottlenecks
471    }
472
473    fn get_memory_usage(&self) -> MemoryStats {
474        MemoryStats {
475            peak_memory_bytes: 0,
476            avg_memory_bytes: 0,
477            allocations: 0,
478            deallocations: 0,
479        }
480    }
481}
482
483/// Benchmark suite result
484#[derive(Debug, Clone)]
485pub struct BenchmarkSuite {
486    pub results: BTreeMap<String, BenchmarkResult>,
487    pub summary: BenchmarkSummary,
488    pub config: BenchmarkConfig,
489}
490
491/// Utility functions for performance analysis
492pub mod analysis {
493    use super::*;
494
495    /// Compare two benchmark results
496    pub fn compare_benchmarks(
497        baseline: &BenchmarkResult,
498        comparison: &BenchmarkResult,
499    ) -> BenchmarkComparison {
500        let throughput_improvement =
501            (comparison.ops_per_second - baseline.ops_per_second) / baseline.ops_per_second;
502
503        let latency_improvement = (baseline.avg_duration.as_nanos() as f64
504            - comparison.avg_duration.as_nanos() as f64)
505            / baseline.avg_duration.as_nanos() as f64;
506
507        let consistency_improvement = {
508            let baseline_cv =
509                baseline.std_deviation.as_nanos() as f64 / baseline.avg_duration.as_nanos() as f64;
510            let comparison_cv = comparison.std_deviation.as_nanos() as f64
511                / comparison.avg_duration.as_nanos() as f64;
512            (baseline_cv - comparison_cv) / baseline_cv
513        };
514
515        BenchmarkComparison {
516            baseline_name: baseline.operation.clone(),
517            comparison_name: comparison.operation.clone(),
518            throughput_improvement,
519            latency_improvement,
520            consistency_improvement,
521            is_improvement: throughput_improvement > 0.0 && latency_improvement > 0.0,
522        }
523    }
524
525    /// Generate performance regression analysis
526    pub fn analyze_regression(
527        historical_results: &[BenchmarkResult],
528        current_result: &BenchmarkResult,
529    ) -> RegressionAnalysis {
530        if historical_results.is_empty() {
531            return RegressionAnalysis::default();
532        }
533
534        let historical_avg_throughput = historical_results
535            .iter()
536            .map(|r| r.ops_per_second)
537            .sum::<f64>()
538            / historical_results.len() as f64;
539
540        let throughput_change =
541            (current_result.ops_per_second - historical_avg_throughput) / historical_avg_throughput;
542        let is_regression = throughput_change < -0.05;
543
544        RegressionAnalysis {
545            throughput_change,
546            is_regression,
547            confidence_level: 0.95,
548            analysis_notes: if is_regression {
549                vec!["Performance regression detected".to_string()]
550            } else {
551                vec!["Performance within expected range".to_string()]
552            },
553        }
554    }
555}
556
557/// Type alias for batch processor function
558type ProcessorFn<T> = Box<dyn Fn(&[T]) -> Result<()> + Send + Sync>;
559
560/// Memory-efficient batch processor for large datasets
561pub struct BatchProcessor<T> {
562    batch_size: usize,
563    current_batch: Vec<T>,
564    processor_fn: ProcessorFn<T>,
565}
566
567impl<T> BatchProcessor<T> {
568    pub fn new<F>(batch_size: usize, processor_fn: F) -> Self
569    where
570        F: Fn(&[T]) -> Result<()> + Send + Sync + 'static,
571    {
572        Self {
573            batch_size,
574            current_batch: Vec::with_capacity(batch_size),
575            processor_fn: Box::new(processor_fn),
576        }
577    }
578
579    pub fn add(&mut self, item: T) -> Result<()> {
580        self.current_batch.push(item);
581        if self.current_batch.len() >= self.batch_size {
582            return self.flush();
583        }
584        Ok(())
585    }
586
587    pub fn flush(&mut self) -> Result<()> {
588        if !self.current_batch.is_empty() {
589            (self.processor_fn)(&self.current_batch)?;
590            self.current_batch.clear();
591        }
592        Ok(())
593    }
594}
595
596/// Enhanced memory monitoring for embedding operations
597#[derive(Debug, Clone)]
598pub struct MemoryMonitor {
599    peak_usage: usize,
600    current_usage: usize,
601    allocations: usize,
602    deallocations: usize,
603}
604
605impl MemoryMonitor {
606    pub fn new() -> Self {
607        Self {
608            peak_usage: 0,
609            current_usage: 0,
610            allocations: 0,
611            deallocations: 0,
612        }
613    }
614
615    pub fn record_allocation(&mut self, size: usize) {
616        self.current_usage += size;
617        self.allocations += 1;
618        if self.current_usage > self.peak_usage {
619            self.peak_usage = self.current_usage;
620        }
621    }
622
623    pub fn record_deallocation(&mut self, size: usize) {
624        self.current_usage = self.current_usage.saturating_sub(size);
625        self.deallocations += 1;
626    }
627
628    pub fn peak_usage(&self) -> usize {
629        self.peak_usage
630    }
631
632    pub fn current_usage(&self) -> usize {
633        self.current_usage
634    }
635
636    pub fn allocation_count(&self) -> usize {
637        self.allocations
638    }
639}
640
641impl Default for MemoryMonitor {
642    fn default() -> Self {
643        Self::new()
644    }
645}