1use anyhow::{anyhow, Result};
5use scirs2_core::ndarray_ext::{Array1, Array2};
6#[allow(unused_imports)]
7use scirs2_core::random::{Random, Rng};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::fs;
11use std::path::Path;
12
13pub mod data_loader {
15 use super::*;
16 use std::io::{BufRead, BufReader};
17
18 pub fn load_triples_from_tsv<P: AsRef<Path>>(
20 file_path: P,
21 ) -> Result<Vec<(String, String, String)>> {
22 let file = fs::File::open(file_path)?;
23 let reader = BufReader::new(file);
24 let mut triples = Vec::new();
25
26 for (line_num, line) in reader.lines().enumerate() {
27 let line = line?;
28 if line.trim().is_empty() || line.starts_with('#') {
29 continue; }
31
32 if line_num == 0
34 && (line.contains("subject")
35 || line.contains("predicate")
36 || line.contains("object"))
37 {
38 continue;
39 }
40
41 let parts: Vec<&str> = line.split('\t').collect();
42 if parts.len() >= 3 {
43 let subject = parts[0].trim().to_string();
44 let predicate = parts[1].trim().to_string();
45 let object = parts[2].trim().to_string();
46 triples.push((subject, predicate, object));
47 } else {
48 eprintln!(
49 "Warning: Invalid triple format at line {}: {}",
50 line_num + 1,
51 line
52 );
53 }
54 }
55
56 Ok(triples)
57 }
58
59 pub fn load_triples_from_csv<P: AsRef<Path>>(
61 file_path: P,
62 ) -> Result<Vec<(String, String, String)>> {
63 let file = fs::File::open(file_path)?;
64 let reader = BufReader::new(file);
65 let mut triples = Vec::new();
66 let mut is_first_line = true;
67
68 for (line_num, line) in reader.lines().enumerate() {
69 let line = line?;
70 if is_first_line {
71 is_first_line = false;
72 if line.to_lowercase().contains("subject")
74 && line.to_lowercase().contains("predicate")
75 {
76 continue;
77 }
78 }
79
80 if line.trim().is_empty() {
81 continue;
82 }
83
84 let parts: Vec<&str> = line.split(',').collect();
85 if parts.len() >= 3 {
86 let subject = parts[0].trim().trim_matches('"').to_string();
87 let predicate = parts[1].trim().trim_matches('"').to_string();
88 let object = parts[2].trim().trim_matches('"').to_string();
89 triples.push((subject, predicate, object));
90 } else {
91 eprintln!(
92 "Warning: Invalid triple format at line {}: {}",
93 line_num + 1,
94 line
95 );
96 }
97 }
98
99 Ok(triples)
100 }
101
102 pub fn load_triples_from_ntriples<P: AsRef<Path>>(
104 file_path: P,
105 ) -> Result<Vec<(String, String, String)>> {
106 let file = fs::File::open(file_path)?;
107 let reader = BufReader::new(file);
108 let mut triples = Vec::new();
109
110 for (line_num, line) in reader.lines().enumerate() {
111 let line = line?;
112 let line = line.trim();
113
114 if line.is_empty() || line.starts_with('#') {
115 continue;
116 }
117
118 if let Some(triple) = parse_ntriple_line(line) {
120 triples.push(triple);
121 } else {
122 eprintln!(
123 "Warning: Failed to parse N-Triple at line {}: {}",
124 line_num + 1,
125 line
126 );
127 }
128 }
129
130 Ok(triples)
131 }
132
133 fn parse_ntriple_line(line: &str) -> Option<(String, String, String)> {
135 let line = line.trim_end_matches(" .");
136 let parts: Vec<&str> = line.split_whitespace().collect();
137
138 if parts.len() >= 3 {
139 let subject = clean_uri_or_literal(parts[0]);
140 let predicate = clean_uri_or_literal(parts[1]);
141 let object = clean_uri_or_literal(&parts[2..].join(" "));
142
143 Some((subject, predicate, object))
144 } else {
145 None
146 }
147 }
148
149 fn clean_uri_or_literal(term: &str) -> String {
151 if term.starts_with('<') && term.ends_with('>') {
152 term[1..term.len() - 1].to_string()
153 } else if term.starts_with('"') && term.contains('"') {
154 let end_quote = term.rfind('"').unwrap_or(term.len());
156 term[1..end_quote].to_string()
157 } else {
158 term.to_string()
159 }
160 }
161
162 pub fn load_triples_from_jsonl<P: AsRef<Path>>(
164 file_path: P,
165 ) -> Result<Vec<(String, String, String)>> {
166 let file = fs::File::open(file_path)?;
167 let reader = BufReader::new(file);
168 let mut triples = Vec::new();
169
170 for (line_num, line) in reader.lines().enumerate() {
171 let line = line?;
172 if line.trim().is_empty() {
173 continue;
174 }
175
176 match serde_json::from_str::<serde_json::Value>(&line) {
177 Ok(json) => {
178 if let (Some(subject), Some(predicate), Some(object)) = (
179 json["subject"].as_str(),
180 json["predicate"].as_str(),
181 json["object"].as_str(),
182 ) {
183 triples.push((
184 subject.to_string(),
185 predicate.to_string(),
186 object.to_string(),
187 ));
188 } else {
189 eprintln!(
190 "Warning: Invalid JSON structure at line {}: {}",
191 line_num + 1,
192 line
193 );
194 }
195 }
196 Err(e) => {
197 eprintln!(
198 "Warning: Failed to parse JSON at line {}: {} - Error: {}",
199 line_num + 1,
200 line,
201 e
202 );
203 }
204 }
205 }
206
207 Ok(triples)
208 }
209
210 pub fn save_triples_to_tsv<P: AsRef<Path>>(
212 triples: &[(String, String, String)],
213 file_path: P,
214 ) -> Result<()> {
215 let mut content = String::new();
216 content.push_str("subject\tpredicate\tobject\n");
217
218 for (subject, predicate, object) in triples {
219 content.push_str(&format!("{subject}\t{predicate}\t{object}\n"));
220 }
221
222 fs::write(file_path, content)?;
223 Ok(())
224 }
225
226 pub fn save_triples_to_jsonl<P: AsRef<Path>>(
228 triples: &[(String, String, String)],
229 file_path: P,
230 ) -> Result<()> {
231 use std::io::Write;
232 let mut file = fs::File::create(file_path)?;
233
234 for (subject, predicate, object) in triples {
235 let json = serde_json::json!({
236 "subject": subject,
237 "predicate": predicate,
238 "object": object
239 });
240 writeln!(file, "{json}")?;
241 }
242
243 Ok(())
244 }
245
246 pub fn load_triples_auto_detect<P: AsRef<Path>>(
248 file_path: P,
249 ) -> Result<Vec<(String, String, String)>> {
250 let path = file_path.as_ref();
251 let extension = path
252 .extension()
253 .and_then(|ext| ext.to_str())
254 .unwrap_or("")
255 .to_lowercase();
256
257 match extension.as_str() {
258 "tsv" => load_triples_from_tsv(path),
259 "csv" => load_triples_from_csv(path),
260 "nt" | "ntriples" => load_triples_from_ntriples(path),
261 "jsonl" | "ndjson" => load_triples_from_jsonl(path),
262 _ => {
263 eprintln!(
265 "Warning: Unknown file extension '{extension}', attempting auto-detection"
266 );
267
268 if let Ok(triples) = load_triples_from_tsv(path) {
270 if !triples.is_empty() {
271 return Ok(triples);
272 }
273 }
274
275 if let Ok(triples) = load_triples_from_ntriples(path) {
277 if !triples.is_empty() {
278 return Ok(triples);
279 }
280 }
281
282 if let Ok(triples) = load_triples_from_jsonl(path) {
284 if !triples.is_empty() {
285 return Ok(triples);
286 }
287 }
288
289 load_triples_from_csv(path)
291 }
292 }
293 }
294}
295
296pub mod dataset_splitter {
298 use super::*;
299
300 pub fn split_dataset(
302 triples: Vec<(String, String, String)>,
303 train_ratio: f64,
304 val_ratio: f64,
305 seed: Option<u64>,
306 ) -> Result<DatasetSplit> {
307 if train_ratio + val_ratio >= 1.0 {
308 return Err(anyhow!(
309 "Train and validation ratios must sum to less than 1.0"
310 ));
311 }
312
313 let mut rng = if let Some(s) = seed {
314 Random::seed(s)
315 } else {
316 Random::seed(42) };
318
319 let mut shuffled_triples = triples;
320 for i in (1..shuffled_triples.len()).rev() {
322 let j = rng.random_range(0..i + 1);
323 shuffled_triples.swap(i, j);
324 }
325
326 let total = shuffled_triples.len();
327 let train_end = (total as f64 * train_ratio) as usize;
328 let val_end = train_end + (total as f64 * val_ratio) as usize;
329
330 let train_triples = shuffled_triples[..train_end].to_vec();
331 let val_triples = shuffled_triples[train_end..val_end].to_vec();
332 let test_triples = shuffled_triples[val_end..].to_vec();
333
334 Ok(DatasetSplit {
335 train: train_triples,
336 validation: val_triples,
337 test: test_triples,
338 })
339 }
340
341 pub fn split_dataset_no_leakage(
343 triples: Vec<(String, String, String)>,
344 train_ratio: f64,
345 val_ratio: f64,
346 seed: Option<u64>,
347 ) -> Result<DatasetSplit> {
348 let mut entity_triples: HashMap<String, Vec<(String, String, String)>> =
350 HashMap::with_capacity(triples.len() / 2); for triple in &triples {
353 let entities = [&triple.0, &triple.2];
354 for entity in entities {
355 entity_triples
356 .entry(entity.clone())
357 .or_default()
358 .push(triple.clone());
359 }
360 }
361
362 let entities: Vec<String> = entity_triples.keys().cloned().collect();
364 let dummy_string = "dummy".to_string();
365 let entity_split = split_dataset(
366 entities
367 .into_iter()
368 .map(|e| (e, dummy_string.clone(), dummy_string.clone()))
369 .collect(),
370 train_ratio,
371 val_ratio,
372 seed,
373 )?;
374
375 let train_entities: HashSet<String> =
376 entity_split.train.into_iter().map(|(e, _, _)| e).collect();
377 let val_entities: HashSet<String> = entity_split
378 .validation
379 .into_iter()
380 .map(|(e, _, _)| e)
381 .collect();
382 let test_entities: HashSet<String> =
383 entity_split.test.into_iter().map(|(e, _, _)| e).collect();
384
385 let estimated_capacity = triples.len() / 3;
387 let mut train_triples = Vec::with_capacity(estimated_capacity);
388 let mut val_triples = Vec::with_capacity(estimated_capacity);
389 let mut test_triples = Vec::with_capacity(estimated_capacity);
390
391 for (entity, entity_triple_list) in entity_triples {
392 if train_entities.contains(&entity) {
393 train_triples.extend(entity_triple_list);
394 } else if val_entities.contains(&entity) {
395 val_triples.extend(entity_triple_list);
396 } else if test_entities.contains(&entity) {
397 test_triples.extend(entity_triple_list);
398 }
399 }
400
401 train_triples.sort();
403 train_triples.dedup();
404 val_triples.sort();
405 val_triples.dedup();
406 test_triples.sort();
407 test_triples.dedup();
408
409 Ok(DatasetSplit {
410 train: train_triples,
411 validation: val_triples,
412 test: test_triples,
413 })
414 }
415}
416
417#[derive(Debug, Clone)]
419pub struct DatasetSplit {
420 pub train: Vec<(String, String, String)>,
421 pub validation: Vec<(String, String, String)>,
422 pub test: Vec<(String, String, String)>,
423}
424
425#[derive(Debug, Clone, Serialize, Deserialize)]
427pub struct DatasetStatistics {
428 pub num_triples: usize,
429 pub num_entities: usize,
430 pub num_relations: usize,
431 pub entity_frequency: HashMap<String, usize>,
432 pub relation_frequency: HashMap<String, usize>,
433 pub avg_degree: f64,
434 pub density: f64,
435}
436
437pub fn compute_dataset_statistics(triples: &[(String, String, String)]) -> DatasetStatistics {
439 let mut entities = HashSet::new();
440 let mut relations = HashSet::new();
441 let mut entity_frequency = HashMap::new();
442 let mut relation_frequency = HashMap::new();
443
444 for (subject, predicate, object) in triples {
445 entities.insert(subject.clone());
446 entities.insert(object.clone());
447 relations.insert(predicate.clone());
448
449 *entity_frequency.entry(subject.clone()).or_insert(0) += 1;
450 *entity_frequency.entry(object.clone()).or_insert(0) += 1;
451 *relation_frequency.entry(predicate.clone()).or_insert(0) += 1;
452 }
453
454 let num_entities = entities.len();
455 let num_relations = relations.len();
456 let num_triples = triples.len();
457
458 let avg_degree = if num_entities > 0 {
459 (num_triples * 2) as f64 / num_entities as f64
460 } else {
461 0.0
462 };
463
464 let max_possible_edges = num_entities * num_entities;
465 let density = if max_possible_edges > 0 {
466 num_triples as f64 / max_possible_edges as f64
467 } else {
468 0.0
469 };
470
471 DatasetStatistics {
472 num_triples,
473 num_entities,
474 num_relations,
475 entity_frequency,
476 relation_frequency,
477 avg_degree,
478 density,
479 }
480}
481
482pub mod embedding_analysis {
484 use super::*;
485
486 pub fn analyze_embedding_distribution(embeddings: &Array2<f64>) -> EmbeddingDistributionStats {
488 let flat_values: Vec<f64> = embeddings.iter().cloned().collect();
489
490 let mean = flat_values.iter().sum::<f64>() / flat_values.len() as f64;
491 let variance =
492 flat_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / flat_values.len() as f64;
493 let std_dev = variance.sqrt();
494
495 let mut sorted_values = flat_values.clone();
496 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
497
498 let min_val = sorted_values[0];
499 let max_val = sorted_values[sorted_values.len() - 1];
500 let median = sorted_values[sorted_values.len() / 2];
501
502 EmbeddingDistributionStats {
503 mean,
504 std_dev,
505 variance,
506 min: min_val,
507 max: max_val,
508 median,
509 num_parameters: embeddings.len(),
510 }
511 }
512
513 pub fn compute_embedding_norms(embeddings: &Array2<f64>) -> Vec<f64> {
515 embeddings
516 .rows()
517 .into_iter()
518 .map(|row| row.dot(&row).sqrt())
519 .collect()
520 }
521
522 pub fn analyze_embedding_similarities(
524 embeddings: &Array2<f64>,
525 sample_size: usize,
526 ) -> SimilarityStats {
527 let num_embeddings = embeddings.nrows();
528 let mut similarities = Vec::new();
529
530 let sample_size = sample_size.min(num_embeddings * (num_embeddings - 1) / 2);
531 let mut rng = Random::default();
532
533 for _ in 0..sample_size {
534 let i = rng.random_range(0..num_embeddings);
535 let j = rng.random_range(0..num_embeddings);
536
537 if i != j {
538 let emb_i = embeddings.row(i);
539 let emb_j = embeddings.row(j);
540 let similarity = cosine_similarity(&emb_i.to_owned(), &emb_j.to_owned());
541 similarities.push(similarity);
542 }
543 }
544
545 similarities.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
546
547 let mean_similarity = similarities.iter().sum::<f64>() / similarities.len() as f64;
548 let min_similarity = similarities[0];
549 let max_similarity = similarities[similarities.len() - 1];
550 let median_similarity = similarities[similarities.len() / 2];
551
552 SimilarityStats {
553 mean_similarity,
554 min_similarity,
555 max_similarity,
556 median_similarity,
557 num_comparisons: similarities.len(),
558 }
559 }
560
561 fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
563 let dot_product = a.dot(b);
564 let norm_a = a.dot(a).sqrt();
565 let norm_b = b.dot(b).sqrt();
566
567 if norm_a > 1e-10 && norm_b > 1e-10 {
568 dot_product / (norm_a * norm_b)
569 } else {
570 0.0
571 }
572 }
573}
574
575#[derive(Debug, Clone)]
577pub struct EmbeddingDistributionStats {
578 pub mean: f64,
579 pub std_dev: f64,
580 pub variance: f64,
581 pub min: f64,
582 pub max: f64,
583 pub median: f64,
584 pub num_parameters: usize,
585}
586
587#[derive(Debug, Clone)]
589pub struct SimilarityStats {
590 pub mean_similarity: f64,
591 pub min_similarity: f64,
592 pub max_similarity: f64,
593 pub median_similarity: f64,
594 pub num_comparisons: usize,
595}
596
597pub mod graph_analysis {
599 use super::*;
600
601 pub fn compute_graph_metrics(triples: &[(String, String, String)]) -> GraphMetrics {
603 let estimated_entities = triples.len(); let estimated_relations = triples.len() / 10; let mut entity_degrees: HashMap<String, usize> = HashMap::with_capacity(estimated_entities);
608 let mut relation_counts: HashMap<String, usize> =
609 HashMap::with_capacity(estimated_relations);
610 let mut entities = HashSet::with_capacity(estimated_entities);
611
612 for (subject, predicate, object) in triples {
613 entities.insert(subject.clone());
614 entities.insert(object.clone());
615
616 *entity_degrees.entry(subject.clone()).or_insert(0) += 1;
617 *entity_degrees.entry(object.clone()).or_insert(0) += 1;
618 *relation_counts.entry(predicate.clone()).or_insert(0) += 1;
619 }
620
621 let num_entities = entities.len();
622 let num_relations = relation_counts.len();
623 let num_triples = triples.len();
624
625 let degrees: Vec<usize> = entity_degrees.values().cloned().collect();
626 let avg_degree = degrees.iter().sum::<usize>() as f64 / degrees.len() as f64;
627 let max_degree = degrees.iter().max().cloned().unwrap_or(0);
628 let min_degree = degrees.iter().min().cloned().unwrap_or(0);
629
630 GraphMetrics {
631 num_entities,
632 num_relations,
633 num_triples,
634 avg_degree,
635 max_degree,
636 min_degree,
637 density: num_triples as f64 / (num_entities * num_entities) as f64,
638 }
639 }
640}
641
642#[derive(Debug, Clone, Serialize, Deserialize)]
644pub struct GraphMetrics {
645 pub num_entities: usize,
646 pub num_relations: usize,
647 pub num_triples: usize,
648 pub avg_degree: f64,
649 pub max_degree: usize,
650 pub min_degree: usize,
651 pub density: f64,
652}
653
654#[derive(Debug)]
656pub struct ProgressTracker {
657 total: usize,
658 current: usize,
659 start_time: std::time::Instant,
660 last_update: std::time::Instant,
661 update_interval: std::time::Duration,
662}
663
664impl ProgressTracker {
665 pub fn new(total: usize) -> Self {
667 let now = std::time::Instant::now();
668 Self {
669 total,
670 current: 0,
671 start_time: now,
672 last_update: now,
673 update_interval: std::time::Duration::from_secs(1),
674 }
675 }
676
677 pub fn update(&mut self, current: usize) {
679 self.current = current;
680 let now = std::time::Instant::now();
681
682 if now.duration_since(self.last_update) >= self.update_interval {
683 self.print_progress();
684 self.last_update = now;
685 }
686 }
687
688 fn print_progress(&self) {
690 let percentage = (self.current as f64 / self.total as f64) * 100.0;
691 let elapsed = self.start_time.elapsed().as_secs_f64();
692 let rate = self.current as f64 / elapsed;
693
694 println!(
695 "Progress: {}/{} ({:.1}%) - {:.1} items/sec",
696 self.current, self.total, percentage, rate
697 );
698 }
699
700 pub fn finish(&self) {
702 let elapsed = self.start_time.elapsed().as_secs_f64();
703 let rate = self.total as f64 / elapsed;
704
705 println!(
706 "Completed: {} items in {:.2}s ({:.1} items/sec)",
707 self.total, elapsed, rate
708 );
709 }
710}
711
712pub mod performance_benchmark {
714 use super::*;
715 use std::collections::BTreeMap;
716 use std::time::{Duration, Instant};
717
718 #[derive(Debug, Clone, Serialize, Deserialize)]
720 pub struct BenchmarkSuite {
721 pub results: BTreeMap<String, BenchmarkResult>,
723 pub summary: BenchmarkSummary,
725 pub config: BenchmarkConfig,
727 }
728
729 #[derive(Debug, Clone, Serialize, Deserialize)]
731 pub struct BenchmarkResult {
732 pub operation: String,
734 pub iterations: usize,
736 pub total_duration: Duration,
738 pub avg_duration: Duration,
740 pub min_duration: Duration,
742 pub max_duration: Duration,
744 pub std_deviation: Duration,
746 pub ops_per_second: f64,
748 pub memory_stats: MemoryStats,
750 pub custom_metrics: HashMap<String, f64>,
752 }
753
754 #[derive(Debug, Clone, Serialize, Deserialize)]
756 pub struct MemoryStats {
757 pub peak_memory_bytes: usize,
759 pub avg_memory_bytes: usize,
761 pub allocations: usize,
763 pub deallocations: usize,
765 }
766
767 #[derive(Debug, Clone, Serialize, Deserialize)]
769 pub struct BenchmarkSummary {
770 pub total_duration: Duration,
772 pub total_operations: usize,
774 pub overall_throughput: f64,
776 pub efficiency_score: f64,
778 pub bottlenecks: Vec<String>,
780 }
781
782 #[derive(Debug, Clone, Serialize, Deserialize)]
784 pub struct BenchmarkConfig {
785 pub warmup_iterations: usize,
787 pub measurement_iterations: usize,
789 pub confidence_level: f64,
791 pub enable_memory_profiling: bool,
793 pub enable_detailed_timing: bool,
795 }
796
797 impl Default for BenchmarkConfig {
798 fn default() -> Self {
799 Self {
800 warmup_iterations: 100,
801 measurement_iterations: 1000,
802 confidence_level: 0.95,
803 enable_memory_profiling: true,
804 enable_detailed_timing: true,
805 }
806 }
807 }
808
809 pub struct PrecisionTimer {
811 start_time: Instant,
812 lap_times: Vec<Duration>,
813 }
814
815 impl Default for PrecisionTimer {
816 fn default() -> Self {
817 Self::new()
818 }
819 }
820
821 impl PrecisionTimer {
822 pub fn new() -> Self {
823 Self {
824 start_time: Instant::now(),
825 lap_times: Vec::new(),
826 }
827 }
828
829 pub fn start(&mut self) {
831 self.start_time = Instant::now();
832 self.lap_times.clear();
833 }
834
835 pub fn lap(&mut self) -> Duration {
837 let lap_duration = self.start_time.elapsed();
838 self.lap_times.push(lap_duration);
839 lap_duration
840 }
841
842 pub fn stop(&self) -> Duration {
844 self.start_time.elapsed()
845 }
846
847 pub fn lap_times(&self) -> &[Duration] {
849 &self.lap_times
850 }
851 }
852
853 pub struct EmbeddingBenchmark {
855 config: BenchmarkConfig,
856 results: BTreeMap<String, BenchmarkResult>,
857 }
858
859 impl EmbeddingBenchmark {
860 pub fn new(config: BenchmarkConfig) -> Self {
861 Self {
862 config,
863 results: BTreeMap::new(),
864 }
865 }
866
867 pub fn benchmark<F, T>(&mut self, name: &str, mut operation: F) -> Result<T>
869 where
870 F: FnMut() -> Result<T>,
871 {
872 for _ in 0..self.config.warmup_iterations {
874 let _ = operation()?;
875 }
876
877 let mut durations = Vec::with_capacity(self.config.measurement_iterations);
878 let mut memory_snapshots = Vec::new();
879 let mut result = None;
880
881 for i in 0..self.config.measurement_iterations {
883 let memory_before = self.get_memory_usage();
884 let start = Instant::now();
885
886 let op_result = operation()?;
887
888 let duration = start.elapsed();
889 let memory_after = self.get_memory_usage();
890
891 durations.push(duration);
892
893 if self.config.enable_memory_profiling {
894 memory_snapshots.push((memory_before, memory_after));
895 }
896
897 if i == 0 {
899 result = Some(op_result);
900 }
901 }
902
903 let total_duration: Duration = durations.iter().sum();
905 let avg_duration = total_duration / durations.len() as u32;
906 let min_duration = *durations
907 .iter()
908 .min()
909 .expect("durations should not be empty");
910 let max_duration = *durations
911 .iter()
912 .max()
913 .expect("durations should not be empty");
914
915 let variance: f64 = durations
917 .iter()
918 .map(|d| {
919 let diff = d.as_nanos() as f64 - avg_duration.as_nanos() as f64;
920 diff * diff
921 })
922 .sum::<f64>()
923 / durations.len() as f64;
924 let std_deviation = Duration::from_nanos(variance.sqrt() as u64);
925
926 let ops_per_second = 1_000_000_000.0 / avg_duration.as_nanos() as f64;
927
928 let memory_stats = if self.config.enable_memory_profiling
930 && !memory_snapshots.is_empty()
931 {
932 let peak_memory = memory_snapshots
933 .iter()
934 .map(|(_, after)| after.peak_memory_bytes)
935 .max()
936 .unwrap_or(0);
937
938 let avg_memory = memory_snapshots
939 .iter()
940 .map(|(before, after)| (before.avg_memory_bytes + after.avg_memory_bytes) / 2)
941 .sum::<usize>()
942 / memory_snapshots.len();
943
944 MemoryStats {
945 peak_memory_bytes: peak_memory,
946 avg_memory_bytes: avg_memory,
947 allocations: memory_snapshots.len(),
948 deallocations: 0, }
950 } else {
951 MemoryStats {
952 peak_memory_bytes: 0,
953 avg_memory_bytes: 0,
954 allocations: 0,
955 deallocations: 0,
956 }
957 };
958
959 let benchmark_result = BenchmarkResult {
960 operation: name.to_string(),
961 iterations: self.config.measurement_iterations,
962 total_duration,
963 avg_duration,
964 min_duration,
965 max_duration,
966 std_deviation,
967 ops_per_second,
968 memory_stats,
969 custom_metrics: HashMap::new(),
970 };
971
972 self.results.insert(name.to_string(), benchmark_result);
973
974 result.ok_or_else(|| anyhow!("Failed to capture benchmark result"))
975 }
976
977 pub fn generate_report(&self) -> BenchmarkSuite {
979 let total_duration = self.results.values().map(|r| r.total_duration).sum();
980
981 let total_operations = self.results.len();
982
983 let overall_throughput = self.results.values().map(|r| r.ops_per_second).sum::<f64>()
984 / total_operations as f64;
985
986 let efficiency_score = self.calculate_efficiency_score();
988
989 let bottlenecks = self.identify_bottlenecks();
991
992 let summary = BenchmarkSummary {
993 total_duration,
994 total_operations,
995 overall_throughput,
996 efficiency_score,
997 bottlenecks,
998 };
999
1000 BenchmarkSuite {
1001 results: self.results.clone(),
1002 summary,
1003 config: self.config.clone(),
1004 }
1005 }
1006
1007 fn calculate_efficiency_score(&self) -> f64 {
1009 if self.results.is_empty() {
1010 return 0.0;
1011 }
1012
1013 let consistency_scores: Vec<f64> = self
1014 .results
1015 .values()
1016 .map(|result| {
1017 let cv = result.std_deviation.as_nanos() as f64
1019 / result.avg_duration.as_nanos() as f64;
1020 1.0 / (1.0 + cv)
1022 })
1023 .collect();
1024
1025 consistency_scores.iter().sum::<f64>() / consistency_scores.len() as f64
1026 }
1027
1028 fn identify_bottlenecks(&self) -> Vec<String> {
1030 let mut bottlenecks = Vec::new();
1031
1032 for (name, result) in &self.results {
1034 let cv =
1035 result.std_deviation.as_nanos() as f64 / result.avg_duration.as_nanos() as f64;
1036 if cv > 0.2 {
1037 bottlenecks.push(format!("High variance in {}: {:.2}% CV", name, cv * 100.0));
1039 }
1040 }
1041
1042 let avg_throughput = self.results.values().map(|r| r.ops_per_second).sum::<f64>()
1044 / self.results.len() as f64;
1045
1046 for (name, result) in &self.results {
1047 if result.ops_per_second < avg_throughput * 0.5 {
1048 bottlenecks.push(format!(
1050 "Slow operation {}: {:.0} ops/sec",
1051 name, result.ops_per_second
1052 ));
1053 }
1054 }
1055
1056 bottlenecks
1057 }
1058
1059 fn get_memory_usage(&self) -> MemoryStats {
1061 MemoryStats {
1064 peak_memory_bytes: 0,
1065 avg_memory_bytes: 0,
1066 allocations: 0,
1067 deallocations: 0,
1068 }
1069 }
1070 }
1071
1072 pub mod analysis {
1074 use super::*;
1075
1076 pub fn compare_benchmarks(
1078 baseline: &BenchmarkResult,
1079 comparison: &BenchmarkResult,
1080 ) -> BenchmarkComparison {
1081 let throughput_improvement =
1082 (comparison.ops_per_second - baseline.ops_per_second) / baseline.ops_per_second;
1083
1084 let latency_improvement = (baseline.avg_duration.as_nanos() as f64
1085 - comparison.avg_duration.as_nanos() as f64)
1086 / baseline.avg_duration.as_nanos() as f64;
1087
1088 let consistency_improvement = {
1089 let baseline_cv = baseline.std_deviation.as_nanos() as f64
1090 / baseline.avg_duration.as_nanos() as f64;
1091 let comparison_cv = comparison.std_deviation.as_nanos() as f64
1092 / comparison.avg_duration.as_nanos() as f64;
1093 (baseline_cv - comparison_cv) / baseline_cv
1094 };
1095
1096 BenchmarkComparison {
1097 baseline_name: baseline.operation.clone(),
1098 comparison_name: comparison.operation.clone(),
1099 throughput_improvement,
1100 latency_improvement,
1101 consistency_improvement,
1102 is_improvement: throughput_improvement > 0.0 && latency_improvement > 0.0,
1103 }
1104 }
1105
1106 pub fn analyze_regression(
1108 historical_results: &[BenchmarkResult],
1109 current_result: &BenchmarkResult,
1110 ) -> RegressionAnalysis {
1111 if historical_results.is_empty() {
1112 return RegressionAnalysis::default();
1113 }
1114
1115 let historical_avg_throughput = historical_results
1116 .iter()
1117 .map(|r| r.ops_per_second)
1118 .sum::<f64>()
1119 / historical_results.len() as f64;
1120
1121 let throughput_change = (current_result.ops_per_second - historical_avg_throughput)
1122 / historical_avg_throughput;
1123
1124 let is_regression = throughput_change < -0.05; RegressionAnalysis {
1127 throughput_change,
1128 is_regression,
1129 confidence_level: 0.95, analysis_notes: if is_regression {
1131 vec!["Performance regression detected".to_string()]
1132 } else {
1133 vec!["Performance within expected range".to_string()]
1134 },
1135 }
1136 }
1137 }
1138
1139 #[derive(Debug, Clone, Serialize, Deserialize)]
1141 pub struct BenchmarkComparison {
1142 pub baseline_name: String,
1143 pub comparison_name: String,
1144 pub throughput_improvement: f64,
1145 pub latency_improvement: f64,
1146 pub consistency_improvement: f64,
1147 pub is_improvement: bool,
1148 }
1149
1150 #[derive(Debug, Clone, Serialize, Deserialize)]
1152 pub struct RegressionAnalysis {
1153 pub throughput_change: f64,
1154 pub is_regression: bool,
1155 pub confidence_level: f64,
1156 pub analysis_notes: Vec<String>,
1157 }
1158
1159 impl Default for RegressionAnalysis {
1160 fn default() -> Self {
1161 Self {
1162 throughput_change: 0.0,
1163 is_regression: false,
1164 confidence_level: 0.0,
1165 analysis_notes: vec!["No historical data available".to_string()],
1166 }
1167 }
1168 }
1169}
1170
1171pub mod convenience {
1173 use super::*;
1174 use crate::{EmbeddingModel, ModelConfig, NamedNode, TransE, Triple};
1175
1176 pub fn create_simple_transe_model() -> TransE {
1178 let config = ModelConfig::default()
1179 .with_dimensions(128)
1180 .with_learning_rate(0.01)
1181 .with_max_epochs(100);
1182 TransE::new(config)
1183 }
1184
1185 pub fn parse_triple_from_string(triple_str: &str) -> Result<Triple> {
1187 let parts: Vec<&str> = triple_str.split_whitespace().collect();
1188 if parts.len() != 3 {
1189 return Err(anyhow!(
1190 "Invalid triple format. Expected 'subject predicate object', got: '{}'",
1191 triple_str
1192 ));
1193 }
1194
1195 let subject = if parts[0].starts_with("http") {
1196 NamedNode::new(parts[0])?
1197 } else {
1198 NamedNode::new(&format!("http://example.org/{}", parts[0]))?
1199 };
1200
1201 let predicate = if parts[1].starts_with("http") {
1202 NamedNode::new(parts[1])?
1203 } else {
1204 NamedNode::new(&format!("http://example.org/{}", parts[1]))?
1205 };
1206
1207 let object = if parts[2].starts_with("http") {
1208 NamedNode::new(parts[2])?
1209 } else {
1210 NamedNode::new(&format!("http://example.org/{}", parts[2]))?
1211 };
1212
1213 Ok(Triple::new(subject, predicate, object))
1214 }
1215
1216 pub fn add_triples_from_strings(
1218 model: &mut dyn EmbeddingModel,
1219 triple_strings: &[&str],
1220 ) -> Result<usize> {
1221 let mut added_count = 0;
1222 for triple_str in triple_strings {
1223 match parse_triple_from_string(triple_str) {
1224 Ok(triple) => {
1225 model.add_triple(triple)?;
1226 added_count += 1;
1227 }
1228 Err(e) => {
1229 eprintln!("Warning: Failed to parse triple '{triple_str}': {e}");
1230 }
1231 }
1232 }
1233 Ok(added_count)
1234 }
1235
1236 pub fn cosine_similarity(a: &[f64], b: &[f64]) -> Result<f64> {
1238 if a.len() != b.len() {
1239 return Err(anyhow!(
1240 "Vector dimensions don't match: {} vs {}",
1241 a.len(),
1242 b.len()
1243 ));
1244 }
1245
1246 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
1247 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
1248 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
1249
1250 if norm_a == 0.0 || norm_b == 0.0 {
1251 return Ok(0.0);
1252 }
1253
1254 Ok(dot_product / (norm_a * norm_b))
1255 }
1256
1257 pub fn generate_sample_kg_data(
1259 num_entities: usize,
1260 num_relations: usize,
1261 ) -> Vec<(String, String, String)> {
1262 let mut rng = Random::default();
1263 let mut triples = Vec::new();
1264
1265 let entities: Vec<String> = (0..num_entities).map(|i| format!("entity_{i}")).collect();
1266
1267 let relations: Vec<String> = (0..num_relations)
1268 .map(|i| format!("relation_{i}"))
1269 .collect();
1270
1271 for _ in 0..(num_entities * 2) {
1273 let subject = entities[rng.random_range(0..entities.len())].clone();
1274 let relation = relations[rng.random_range(0..relations.len())].clone();
1275 let object = entities[rng.random_range(0..entities.len())].clone();
1276
1277 if subject != object {
1278 triples.push((subject, relation, object));
1279 }
1280 }
1281
1282 triples
1283 }
1284
1285 pub fn quick_performance_test<F>(
1287 name: &str,
1288 iterations: usize,
1289 operation: F,
1290 ) -> std::time::Duration
1291 where
1292 F: Fn(),
1293 {
1294 let start = std::time::Instant::now();
1295 for _ in 0..iterations {
1296 operation();
1297 }
1298 let duration = start.elapsed();
1299
1300 println!(
1301 "Performance test '{}': {} iterations in {:?} ({:.2} ops/sec)",
1302 name,
1303 iterations,
1304 duration,
1305 iterations as f64 / duration.as_secs_f64()
1306 );
1307
1308 duration
1309 }
1310}
1311
1312pub mod performance_utils {
1314 use super::*;
1315
1316 type ProcessorFn<T> = Box<dyn Fn(&[T]) -> Result<()> + Send + Sync>;
1318
1319 pub struct BatchProcessor<T> {
1321 batch_size: usize,
1322 current_batch: Vec<T>,
1323 processor_fn: ProcessorFn<T>,
1324 }
1325
1326 impl<T> BatchProcessor<T> {
1327 pub fn new<F>(batch_size: usize, processor_fn: F) -> Self
1328 where
1329 F: Fn(&[T]) -> Result<()> + Send + Sync + 'static,
1330 {
1331 Self {
1332 batch_size,
1333 current_batch: Vec::with_capacity(batch_size),
1334 processor_fn: Box::new(processor_fn),
1335 }
1336 }
1337
1338 pub fn add(&mut self, item: T) -> Result<()> {
1339 self.current_batch.push(item);
1340 if self.current_batch.len() >= self.batch_size {
1341 return self.flush();
1342 }
1343 Ok(())
1344 }
1345
1346 pub fn flush(&mut self) -> Result<()> {
1347 if !self.current_batch.is_empty() {
1348 (self.processor_fn)(&self.current_batch)?;
1349 self.current_batch.clear();
1350 }
1351 Ok(())
1352 }
1353 }
1354
1355 #[derive(Debug, Clone)]
1357 pub struct MemoryMonitor {
1358 peak_usage: usize,
1359 current_usage: usize,
1360 allocations: usize,
1361 deallocations: usize,
1362 }
1363
1364 impl MemoryMonitor {
1365 pub fn new() -> Self {
1366 Self {
1367 peak_usage: 0,
1368 current_usage: 0,
1369 allocations: 0,
1370 deallocations: 0,
1371 }
1372 }
1373
1374 pub fn record_allocation(&mut self, size: usize) {
1375 self.current_usage += size;
1376 self.allocations += 1;
1377 if self.current_usage > self.peak_usage {
1378 self.peak_usage = self.current_usage;
1379 }
1380 }
1381
1382 pub fn record_deallocation(&mut self, size: usize) {
1383 self.current_usage = self.current_usage.saturating_sub(size);
1384 self.deallocations += 1;
1385 }
1386
1387 pub fn peak_usage(&self) -> usize {
1388 self.peak_usage
1389 }
1390
1391 pub fn current_usage(&self) -> usize {
1392 self.current_usage
1393 }
1394
1395 pub fn allocation_count(&self) -> usize {
1396 self.allocations
1397 }
1398 }
1399
1400 impl Default for MemoryMonitor {
1401 fn default() -> Self {
1402 Self::new()
1403 }
1404 }
1405}
1406
1407pub mod parallel_utils {
1409 use super::*;
1410 use rayon::prelude::*;
1411
1412 pub fn parallel_cosine_similarities(
1414 query_embedding: &[f32],
1415 candidate_embeddings: &[Vec<f32>],
1416 ) -> Result<Vec<f32>> {
1417 let similarities: Vec<f32> = candidate_embeddings
1418 .par_iter()
1419 .map(|embedding| oxirs_vec::similarity::cosine_similarity(query_embedding, embedding))
1420 .collect();
1421 Ok(similarities)
1422 }
1423
1424 pub fn parallel_batch_process<T, R, F>(
1426 items: &[T],
1427 batch_size: usize,
1428 processor: F,
1429 ) -> Result<Vec<R>>
1430 where
1431 T: Sync,
1432 R: Send,
1433 F: Fn(&[T]) -> Result<Vec<R>> + Sync + Send,
1434 {
1435 let results: Result<Vec<Vec<R>>> = items.par_chunks(batch_size).map(processor).collect();
1436
1437 Ok(results?.into_iter().flatten().collect())
1438 }
1439
1440 pub fn parallel_entity_frequencies(
1442 triples: &[(String, String, String)],
1443 ) -> HashMap<String, usize> {
1444 let entity_counts: HashMap<String, usize> = triples
1445 .par_iter()
1446 .fold(HashMap::new, |mut acc, (subject, _predicate, object)| {
1447 *acc.entry(subject.clone()).or_insert(0) += 1;
1448 *acc.entry(object.clone()).or_insert(0) += 1;
1449 acc
1450 })
1451 .reduce(HashMap::new, |mut acc1, acc2| {
1452 for (entity, count) in acc2 {
1453 *acc1.entry(entity).or_insert(0) += count;
1454 }
1455 acc1
1456 });
1457 entity_counts
1458 }
1459}
1460
1461#[cfg(test)]
1462mod tests {
1463 use super::*;
1464 use crate::quick_start::{
1465 add_triples_from_strings, cosine_similarity, create_simple_transe_model,
1466 generate_sample_kg_data, parse_triple_from_string, quick_performance_test,
1467 };
1468 use crate::EmbeddingModel;
1469 use std::io::Write;
1470 use tempfile::NamedTempFile;
1471
1472 #[test]
1473 fn test_load_triples_from_tsv() -> Result<()> {
1474 let mut temp_file = NamedTempFile::new()?;
1475 writeln!(temp_file, "subject\tpredicate\tobject")?;
1476 writeln!(temp_file, "alice\tknows\tbob")?;
1477 writeln!(temp_file, "bob\tlikes\tcharlie")?;
1478
1479 let triples = data_loader::load_triples_from_tsv(temp_file.path())?;
1480 assert_eq!(triples.len(), 2);
1481 assert_eq!(
1482 triples[0],
1483 ("alice".to_string(), "knows".to_string(), "bob".to_string())
1484 );
1485
1486 Ok(())
1487 }
1488
1489 #[test]
1490 fn test_dataset_split() -> Result<()> {
1491 let triples = vec![
1492 ("a".to_string(), "r1".to_string(), "b".to_string()),
1493 ("b".to_string(), "r2".to_string(), "c".to_string()),
1494 ("c".to_string(), "r3".to_string(), "d".to_string()),
1495 ("d".to_string(), "r4".to_string(), "e".to_string()),
1496 ];
1497
1498 let split = dataset_splitter::split_dataset(triples, 0.6, 0.2, Some(42))?;
1499
1500 assert_eq!(split.train.len(), 2);
1501 assert_eq!(split.validation.len(), 0); assert_eq!(split.test.len(), 2);
1503
1504 Ok(())
1505 }
1506
1507 #[test]
1508 fn test_load_triples_from_jsonl() -> Result<()> {
1509 let mut temp_file = NamedTempFile::new()?;
1510 writeln!(
1511 temp_file,
1512 r#"{{"subject": "alice", "predicate": "knows", "object": "bob"}}"#
1513 )?;
1514 writeln!(
1515 temp_file,
1516 r#"{{"subject": "bob", "predicate": "likes", "object": "charlie"}}"#
1517 )?;
1518
1519 let triples = data_loader::load_triples_from_jsonl(temp_file.path())?;
1520 assert_eq!(triples.len(), 2);
1521 assert_eq!(
1522 triples[0],
1523 ("alice".to_string(), "knows".to_string(), "bob".to_string())
1524 );
1525
1526 Ok(())
1527 }
1528
1529 #[test]
1530 fn test_save_triples_to_jsonl() -> Result<()> {
1531 let triples = vec![
1532 ("alice".to_string(), "knows".to_string(), "bob".to_string()),
1533 (
1534 "bob".to_string(),
1535 "likes".to_string(),
1536 "charlie".to_string(),
1537 ),
1538 ];
1539
1540 let temp_file = NamedTempFile::new()?;
1541 data_loader::save_triples_to_jsonl(&triples, temp_file.path())?;
1542
1543 let loaded_triples = data_loader::load_triples_from_jsonl(temp_file.path())?;
1544 assert_eq!(loaded_triples, triples);
1545
1546 Ok(())
1547 }
1548
1549 #[test]
1550 fn test_load_triples_auto_detect() -> Result<()> {
1551 let mut tsv_file = NamedTempFile::with_suffix(".tsv")?;
1553 writeln!(tsv_file, "subject\tpredicate\tobject")?;
1554 writeln!(tsv_file, "alice\tknows\tbob")?;
1555
1556 let triples = data_loader::load_triples_auto_detect(tsv_file.path())?;
1557 assert_eq!(triples.len(), 1);
1558
1559 let mut jsonl_file = NamedTempFile::with_suffix(".jsonl")?;
1561 writeln!(
1562 jsonl_file,
1563 r#"{{"subject": "alice", "predicate": "knows", "object": "bob"}}"#
1564 )?;
1565
1566 let triples = data_loader::load_triples_auto_detect(jsonl_file.path())?;
1567 assert_eq!(triples.len(), 1);
1568 assert_eq!(
1569 triples[0],
1570 ("alice".to_string(), "knows".to_string(), "bob".to_string())
1571 );
1572
1573 Ok(())
1574 }
1575
1576 #[test]
1577 fn test_dataset_statistics() {
1578 let triples = vec![
1579 ("alice".to_string(), "knows".to_string(), "bob".to_string()),
1580 (
1581 "bob".to_string(),
1582 "knows".to_string(),
1583 "charlie".to_string(),
1584 ),
1585 (
1586 "alice".to_string(),
1587 "likes".to_string(),
1588 "charlie".to_string(),
1589 ),
1590 ];
1591
1592 let stats = compute_dataset_statistics(&triples);
1593
1594 assert_eq!(stats.num_triples, 3);
1595 assert_eq!(stats.num_entities, 3); assert_eq!(stats.num_relations, 2); assert!(stats.avg_degree > 0.0);
1598 }
1599
1600 #[test]
1602 fn test_create_simple_transe_model() {
1603 let model = create_simple_transe_model();
1604 assert_eq!(model.config().dimensions, 128);
1605 assert_eq!(model.config().learning_rate, 0.01);
1606 assert_eq!(model.config().max_epochs, 100);
1607 }
1608
1609 #[test]
1610 fn test_parse_triple_from_string() -> Result<()> {
1611 let triple = parse_triple_from_string("alice knows bob")?;
1612 assert_eq!(triple.subject.iri.as_str(), "http://example.org/alice");
1613 assert_eq!(triple.predicate.iri.as_str(), "http://example.org/knows");
1614 assert_eq!(triple.object.iri.as_str(), "http://example.org/bob");
1615
1616 let triple2 = parse_triple_from_string(
1618 "http://example.org/alice http://example.org/knows http://example.org/bob",
1619 )?;
1620 assert_eq!(triple2.subject.iri.as_str(), "http://example.org/alice");
1621
1622 assert!(parse_triple_from_string("alice knows").is_err());
1624
1625 Ok(())
1626 }
1627
1628 #[test]
1629 fn test_add_triples_from_strings() -> Result<()> {
1630 let mut model = create_simple_transe_model();
1631 let triple_strings = &[
1632 "alice knows bob",
1633 "bob likes charlie",
1634 "charlie follows alice",
1635 ];
1636
1637 let added_count = add_triples_from_strings(&mut model, triple_strings)?;
1638 assert_eq!(added_count, 3);
1639
1640 Ok(())
1641 }
1642
1643 #[test]
1644 fn test_cosine_similarity() -> Result<()> {
1645 let a = vec![1.0, 0.0, 0.0];
1646 let b = vec![1.0, 0.0, 0.0];
1647 let similarity = cosine_similarity(&a, &b)?;
1648 assert!((similarity - 1.0).abs() < 1e-10);
1649
1650 let c = vec![0.0, 1.0, 0.0];
1651 let similarity2 = cosine_similarity(&a, &c)?;
1652 assert!((similarity2 - 0.0).abs() < 1e-10);
1653
1654 let d = vec![1.0, 0.0];
1656 assert!(cosine_similarity(&a, &d).is_err());
1657
1658 Ok(())
1659 }
1660
1661 #[test]
1662 fn test_generate_sample_kg_data() {
1663 let triples = generate_sample_kg_data(5, 3);
1664 assert!(!triples.is_empty());
1665
1666 for (subject, relation, object) in &triples {
1668 assert!(subject.starts_with("http://example.org/entity_"));
1669 assert!(relation.starts_with("http://example.org/relation_"));
1670 assert!(object.starts_with("http://example.org/entity_"));
1671 assert_ne!(subject, object); }
1673 }
1674
1675 #[test]
1676 fn test_quick_performance_test() {
1677 let duration = quick_performance_test("test_operation", 100, || {
1678 let _sum: i32 = (1..10).sum();
1680 });
1681
1682 let _nanos = duration.as_nanos();
1685 }
1686}