1use anyhow::{anyhow, Result};
5use scirs2_core::ndarray_ext::{Array1, Array2};
6#[allow(unused_imports)]
7use scirs2_core::random::{Random, Rng};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::fs;
11use std::path::Path;
12
13pub mod data_loader {
15 use super::*;
16 use std::io::{BufRead, BufReader};
17
18 pub fn load_triples_from_tsv<P: AsRef<Path>>(
20 file_path: P,
21 ) -> Result<Vec<(String, String, String)>> {
22 let file = fs::File::open(file_path)?;
23 let reader = BufReader::new(file);
24 let mut triples = Vec::new();
25
26 for (line_num, line) in reader.lines().enumerate() {
27 let line = line?;
28 if line.trim().is_empty() || line.starts_with('#') {
29 continue; }
31
32 if line_num == 0
34 && (line.contains("subject")
35 || line.contains("predicate")
36 || line.contains("object"))
37 {
38 continue;
39 }
40
41 let parts: Vec<&str> = line.split('\t').collect();
42 if parts.len() >= 3 {
43 let subject = parts[0].trim().to_string();
44 let predicate = parts[1].trim().to_string();
45 let object = parts[2].trim().to_string();
46 triples.push((subject, predicate, object));
47 } else {
48 eprintln!(
49 "Warning: Invalid triple format at line {}: {}",
50 line_num + 1,
51 line
52 );
53 }
54 }
55
56 Ok(triples)
57 }
58
59 pub fn load_triples_from_csv<P: AsRef<Path>>(
61 file_path: P,
62 ) -> Result<Vec<(String, String, String)>> {
63 let file = fs::File::open(file_path)?;
64 let reader = BufReader::new(file);
65 let mut triples = Vec::new();
66 let mut is_first_line = true;
67
68 for (line_num, line) in reader.lines().enumerate() {
69 let line = line?;
70 if is_first_line {
71 is_first_line = false;
72 if line.to_lowercase().contains("subject")
74 && line.to_lowercase().contains("predicate")
75 {
76 continue;
77 }
78 }
79
80 if line.trim().is_empty() {
81 continue;
82 }
83
84 let parts: Vec<&str> = line.split(',').collect();
85 if parts.len() >= 3 {
86 let subject = parts[0].trim().trim_matches('"').to_string();
87 let predicate = parts[1].trim().trim_matches('"').to_string();
88 let object = parts[2].trim().trim_matches('"').to_string();
89 triples.push((subject, predicate, object));
90 } else {
91 eprintln!(
92 "Warning: Invalid triple format at line {}: {}",
93 line_num + 1,
94 line
95 );
96 }
97 }
98
99 Ok(triples)
100 }
101
102 pub fn load_triples_from_ntriples<P: AsRef<Path>>(
104 file_path: P,
105 ) -> Result<Vec<(String, String, String)>> {
106 let file = fs::File::open(file_path)?;
107 let reader = BufReader::new(file);
108 let mut triples = Vec::new();
109
110 for (line_num, line) in reader.lines().enumerate() {
111 let line = line?;
112 let line = line.trim();
113
114 if line.is_empty() || line.starts_with('#') {
115 continue;
116 }
117
118 if let Some(triple) = parse_ntriple_line(line) {
120 triples.push(triple);
121 } else {
122 eprintln!(
123 "Warning: Failed to parse N-Triple at line {}: {}",
124 line_num + 1,
125 line
126 );
127 }
128 }
129
130 Ok(triples)
131 }
132
133 fn parse_ntriple_line(line: &str) -> Option<(String, String, String)> {
135 let line = line.trim_end_matches(" .");
136 let parts: Vec<&str> = line.split_whitespace().collect();
137
138 if parts.len() >= 3 {
139 let subject = clean_uri_or_literal(parts[0]);
140 let predicate = clean_uri_or_literal(parts[1]);
141 let object = clean_uri_or_literal(&parts[2..].join(" "));
142
143 Some((subject, predicate, object))
144 } else {
145 None
146 }
147 }
148
149 fn clean_uri_or_literal(term: &str) -> String {
151 if term.starts_with('<') && term.ends_with('>') {
152 term[1..term.len() - 1].to_string()
153 } else if term.starts_with('"') && term.contains('"') {
154 let end_quote = term.rfind('"').unwrap_or(term.len());
156 term[1..end_quote].to_string()
157 } else {
158 term.to_string()
159 }
160 }
161
162 pub fn load_triples_from_jsonl<P: AsRef<Path>>(
164 file_path: P,
165 ) -> Result<Vec<(String, String, String)>> {
166 let file = fs::File::open(file_path)?;
167 let reader = BufReader::new(file);
168 let mut triples = Vec::new();
169
170 for (line_num, line) in reader.lines().enumerate() {
171 let line = line?;
172 if line.trim().is_empty() {
173 continue;
174 }
175
176 match serde_json::from_str::<serde_json::Value>(&line) {
177 Ok(json) => {
178 if let (Some(subject), Some(predicate), Some(object)) = (
179 json["subject"].as_str(),
180 json["predicate"].as_str(),
181 json["object"].as_str(),
182 ) {
183 triples.push((
184 subject.to_string(),
185 predicate.to_string(),
186 object.to_string(),
187 ));
188 } else {
189 eprintln!(
190 "Warning: Invalid JSON structure at line {}: {}",
191 line_num + 1,
192 line
193 );
194 }
195 }
196 Err(e) => {
197 eprintln!(
198 "Warning: Failed to parse JSON at line {}: {} - Error: {}",
199 line_num + 1,
200 line,
201 e
202 );
203 }
204 }
205 }
206
207 Ok(triples)
208 }
209
210 pub fn save_triples_to_tsv<P: AsRef<Path>>(
212 triples: &[(String, String, String)],
213 file_path: P,
214 ) -> Result<()> {
215 let mut content = String::new();
216 content.push_str("subject\tpredicate\tobject\n");
217
218 for (subject, predicate, object) in triples {
219 content.push_str(&format!("{subject}\t{predicate}\t{object}\n"));
220 }
221
222 fs::write(file_path, content)?;
223 Ok(())
224 }
225
226 pub fn save_triples_to_jsonl<P: AsRef<Path>>(
228 triples: &[(String, String, String)],
229 file_path: P,
230 ) -> Result<()> {
231 use std::io::Write;
232 let mut file = fs::File::create(file_path)?;
233
234 for (subject, predicate, object) in triples {
235 let json = serde_json::json!({
236 "subject": subject,
237 "predicate": predicate,
238 "object": object
239 });
240 writeln!(file, "{json}")?;
241 }
242
243 Ok(())
244 }
245
246 pub fn load_triples_auto_detect<P: AsRef<Path>>(
248 file_path: P,
249 ) -> Result<Vec<(String, String, String)>> {
250 let path = file_path.as_ref();
251 let extension = path
252 .extension()
253 .and_then(|ext| ext.to_str())
254 .unwrap_or("")
255 .to_lowercase();
256
257 match extension.as_str() {
258 "tsv" => load_triples_from_tsv(path),
259 "csv" => load_triples_from_csv(path),
260 "nt" | "ntriples" => load_triples_from_ntriples(path),
261 "jsonl" | "ndjson" => load_triples_from_jsonl(path),
262 _ => {
263 eprintln!(
265 "Warning: Unknown file extension '{extension}', attempting auto-detection"
266 );
267
268 if let Ok(triples) = load_triples_from_tsv(path) {
270 if !triples.is_empty() {
271 return Ok(triples);
272 }
273 }
274
275 if let Ok(triples) = load_triples_from_ntriples(path) {
277 if !triples.is_empty() {
278 return Ok(triples);
279 }
280 }
281
282 if let Ok(triples) = load_triples_from_jsonl(path) {
284 if !triples.is_empty() {
285 return Ok(triples);
286 }
287 }
288
289 load_triples_from_csv(path)
291 }
292 }
293 }
294}
295
296pub mod dataset_splitter {
298 use super::*;
299
300 pub fn split_dataset(
302 triples: Vec<(String, String, String)>,
303 train_ratio: f64,
304 val_ratio: f64,
305 seed: Option<u64>,
306 ) -> Result<DatasetSplit> {
307 if train_ratio + val_ratio >= 1.0 {
308 return Err(anyhow!(
309 "Train and validation ratios must sum to less than 1.0"
310 ));
311 }
312
313 let mut rng = if let Some(s) = seed {
314 Random::seed(s)
315 } else {
316 Random::seed(42) };
318
319 let mut shuffled_triples = triples;
320 for i in (1..shuffled_triples.len()).rev() {
322 let j = rng.random_range(0..i + 1);
323 shuffled_triples.swap(i, j);
324 }
325
326 let total = shuffled_triples.len();
327 let train_end = (total as f64 * train_ratio) as usize;
328 let val_end = train_end + (total as f64 * val_ratio) as usize;
329
330 let train_triples = shuffled_triples[..train_end].to_vec();
331 let val_triples = shuffled_triples[train_end..val_end].to_vec();
332 let test_triples = shuffled_triples[val_end..].to_vec();
333
334 Ok(DatasetSplit {
335 train: train_triples,
336 validation: val_triples,
337 test: test_triples,
338 })
339 }
340
341 pub fn split_dataset_no_leakage(
343 triples: Vec<(String, String, String)>,
344 train_ratio: f64,
345 val_ratio: f64,
346 seed: Option<u64>,
347 ) -> Result<DatasetSplit> {
348 let mut entity_triples: HashMap<String, Vec<(String, String, String)>> =
350 HashMap::with_capacity(triples.len() / 2); for triple in &triples {
353 let entities = [&triple.0, &triple.2];
354 for entity in entities {
355 entity_triples
356 .entry(entity.clone())
357 .or_default()
358 .push(triple.clone());
359 }
360 }
361
362 let entities: Vec<String> = entity_triples.keys().cloned().collect();
364 let dummy_string = "dummy".to_string();
365 let entity_split = split_dataset(
366 entities
367 .into_iter()
368 .map(|e| (e, dummy_string.clone(), dummy_string.clone()))
369 .collect(),
370 train_ratio,
371 val_ratio,
372 seed,
373 )?;
374
375 let train_entities: HashSet<String> =
376 entity_split.train.into_iter().map(|(e, _, _)| e).collect();
377 let val_entities: HashSet<String> = entity_split
378 .validation
379 .into_iter()
380 .map(|(e, _, _)| e)
381 .collect();
382 let test_entities: HashSet<String> =
383 entity_split.test.into_iter().map(|(e, _, _)| e).collect();
384
385 let estimated_capacity = triples.len() / 3;
387 let mut train_triples = Vec::with_capacity(estimated_capacity);
388 let mut val_triples = Vec::with_capacity(estimated_capacity);
389 let mut test_triples = Vec::with_capacity(estimated_capacity);
390
391 for (entity, entity_triple_list) in entity_triples {
392 if train_entities.contains(&entity) {
393 train_triples.extend(entity_triple_list);
394 } else if val_entities.contains(&entity) {
395 val_triples.extend(entity_triple_list);
396 } else if test_entities.contains(&entity) {
397 test_triples.extend(entity_triple_list);
398 }
399 }
400
401 train_triples.sort();
403 train_triples.dedup();
404 val_triples.sort();
405 val_triples.dedup();
406 test_triples.sort();
407 test_triples.dedup();
408
409 Ok(DatasetSplit {
410 train: train_triples,
411 validation: val_triples,
412 test: test_triples,
413 })
414 }
415}
416
417#[derive(Debug, Clone)]
419pub struct DatasetSplit {
420 pub train: Vec<(String, String, String)>,
421 pub validation: Vec<(String, String, String)>,
422 pub test: Vec<(String, String, String)>,
423}
424
425#[derive(Debug, Clone, Serialize, Deserialize)]
427pub struct DatasetStatistics {
428 pub num_triples: usize,
429 pub num_entities: usize,
430 pub num_relations: usize,
431 pub entity_frequency: HashMap<String, usize>,
432 pub relation_frequency: HashMap<String, usize>,
433 pub avg_degree: f64,
434 pub density: f64,
435}
436
437pub fn compute_dataset_statistics(triples: &[(String, String, String)]) -> DatasetStatistics {
439 let mut entities = HashSet::new();
440 let mut relations = HashSet::new();
441 let mut entity_frequency = HashMap::new();
442 let mut relation_frequency = HashMap::new();
443
444 for (subject, predicate, object) in triples {
445 entities.insert(subject.clone());
446 entities.insert(object.clone());
447 relations.insert(predicate.clone());
448
449 *entity_frequency.entry(subject.clone()).or_insert(0) += 1;
450 *entity_frequency.entry(object.clone()).or_insert(0) += 1;
451 *relation_frequency.entry(predicate.clone()).or_insert(0) += 1;
452 }
453
454 let num_entities = entities.len();
455 let num_relations = relations.len();
456 let num_triples = triples.len();
457
458 let avg_degree = if num_entities > 0 {
459 (num_triples * 2) as f64 / num_entities as f64
460 } else {
461 0.0
462 };
463
464 let max_possible_edges = num_entities * num_entities;
465 let density = if max_possible_edges > 0 {
466 num_triples as f64 / max_possible_edges as f64
467 } else {
468 0.0
469 };
470
471 DatasetStatistics {
472 num_triples,
473 num_entities,
474 num_relations,
475 entity_frequency,
476 relation_frequency,
477 avg_degree,
478 density,
479 }
480}
481
482pub mod embedding_analysis {
484 use super::*;
485
486 pub fn analyze_embedding_distribution(embeddings: &Array2<f64>) -> EmbeddingDistributionStats {
488 let flat_values: Vec<f64> = embeddings.iter().cloned().collect();
489
490 let mean = flat_values.iter().sum::<f64>() / flat_values.len() as f64;
491 let variance =
492 flat_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / flat_values.len() as f64;
493 let std_dev = variance.sqrt();
494
495 let mut sorted_values = flat_values.clone();
496 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
497
498 let min_val = sorted_values[0];
499 let max_val = sorted_values[sorted_values.len() - 1];
500 let median = sorted_values[sorted_values.len() / 2];
501
502 EmbeddingDistributionStats {
503 mean,
504 std_dev,
505 variance,
506 min: min_val,
507 max: max_val,
508 median,
509 num_parameters: embeddings.len(),
510 }
511 }
512
513 pub fn compute_embedding_norms(embeddings: &Array2<f64>) -> Vec<f64> {
515 embeddings
516 .rows()
517 .into_iter()
518 .map(|row| row.dot(&row).sqrt())
519 .collect()
520 }
521
522 pub fn analyze_embedding_similarities(
524 embeddings: &Array2<f64>,
525 sample_size: usize,
526 ) -> SimilarityStats {
527 let num_embeddings = embeddings.nrows();
528 let mut similarities = Vec::new();
529
530 let sample_size = sample_size.min(num_embeddings * (num_embeddings - 1) / 2);
531 let mut rng = Random::default();
532
533 for _ in 0..sample_size {
534 let i = rng.random_range(0..num_embeddings);
535 let j = rng.random_range(0..num_embeddings);
536
537 if i != j {
538 let emb_i = embeddings.row(i);
539 let emb_j = embeddings.row(j);
540 let similarity = cosine_similarity(&emb_i.to_owned(), &emb_j.to_owned());
541 similarities.push(similarity);
542 }
543 }
544
545 similarities.sort_by(|a, b| a.partial_cmp(b).unwrap());
546
547 let mean_similarity = similarities.iter().sum::<f64>() / similarities.len() as f64;
548 let min_similarity = similarities[0];
549 let max_similarity = similarities[similarities.len() - 1];
550 let median_similarity = similarities[similarities.len() / 2];
551
552 SimilarityStats {
553 mean_similarity,
554 min_similarity,
555 max_similarity,
556 median_similarity,
557 num_comparisons: similarities.len(),
558 }
559 }
560
561 fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
563 let dot_product = a.dot(b);
564 let norm_a = a.dot(a).sqrt();
565 let norm_b = b.dot(b).sqrt();
566
567 if norm_a > 1e-10 && norm_b > 1e-10 {
568 dot_product / (norm_a * norm_b)
569 } else {
570 0.0
571 }
572 }
573}
574
575#[derive(Debug, Clone)]
577pub struct EmbeddingDistributionStats {
578 pub mean: f64,
579 pub std_dev: f64,
580 pub variance: f64,
581 pub min: f64,
582 pub max: f64,
583 pub median: f64,
584 pub num_parameters: usize,
585}
586
587#[derive(Debug, Clone)]
589pub struct SimilarityStats {
590 pub mean_similarity: f64,
591 pub min_similarity: f64,
592 pub max_similarity: f64,
593 pub median_similarity: f64,
594 pub num_comparisons: usize,
595}
596
597pub mod graph_analysis {
599 use super::*;
600
601 pub fn compute_graph_metrics(triples: &[(String, String, String)]) -> GraphMetrics {
603 let estimated_entities = triples.len(); let estimated_relations = triples.len() / 10; let mut entity_degrees: HashMap<String, usize> = HashMap::with_capacity(estimated_entities);
608 let mut relation_counts: HashMap<String, usize> =
609 HashMap::with_capacity(estimated_relations);
610 let mut entities = HashSet::with_capacity(estimated_entities);
611
612 for (subject, predicate, object) in triples {
613 entities.insert(subject.clone());
614 entities.insert(object.clone());
615
616 *entity_degrees.entry(subject.clone()).or_insert(0) += 1;
617 *entity_degrees.entry(object.clone()).or_insert(0) += 1;
618 *relation_counts.entry(predicate.clone()).or_insert(0) += 1;
619 }
620
621 let num_entities = entities.len();
622 let num_relations = relation_counts.len();
623 let num_triples = triples.len();
624
625 let degrees: Vec<usize> = entity_degrees.values().cloned().collect();
626 let avg_degree = degrees.iter().sum::<usize>() as f64 / degrees.len() as f64;
627 let max_degree = degrees.iter().max().cloned().unwrap_or(0);
628 let min_degree = degrees.iter().min().cloned().unwrap_or(0);
629
630 GraphMetrics {
631 num_entities,
632 num_relations,
633 num_triples,
634 avg_degree,
635 max_degree,
636 min_degree,
637 density: num_triples as f64 / (num_entities * num_entities) as f64,
638 }
639 }
640}
641
642#[derive(Debug, Clone, Serialize, Deserialize)]
644pub struct GraphMetrics {
645 pub num_entities: usize,
646 pub num_relations: usize,
647 pub num_triples: usize,
648 pub avg_degree: f64,
649 pub max_degree: usize,
650 pub min_degree: usize,
651 pub density: f64,
652}
653
654#[derive(Debug)]
656pub struct ProgressTracker {
657 total: usize,
658 current: usize,
659 start_time: std::time::Instant,
660 last_update: std::time::Instant,
661 update_interval: std::time::Duration,
662}
663
664impl ProgressTracker {
665 pub fn new(total: usize) -> Self {
667 let now = std::time::Instant::now();
668 Self {
669 total,
670 current: 0,
671 start_time: now,
672 last_update: now,
673 update_interval: std::time::Duration::from_secs(1),
674 }
675 }
676
677 pub fn update(&mut self, current: usize) {
679 self.current = current;
680 let now = std::time::Instant::now();
681
682 if now.duration_since(self.last_update) >= self.update_interval {
683 self.print_progress();
684 self.last_update = now;
685 }
686 }
687
688 fn print_progress(&self) {
690 let percentage = (self.current as f64 / self.total as f64) * 100.0;
691 let elapsed = self.start_time.elapsed().as_secs_f64();
692 let rate = self.current as f64 / elapsed;
693
694 println!(
695 "Progress: {}/{} ({:.1}%) - {:.1} items/sec",
696 self.current, self.total, percentage, rate
697 );
698 }
699
700 pub fn finish(&self) {
702 let elapsed = self.start_time.elapsed().as_secs_f64();
703 let rate = self.total as f64 / elapsed;
704
705 println!(
706 "Completed: {} items in {:.2}s ({:.1} items/sec)",
707 self.total, elapsed, rate
708 );
709 }
710}
711
712pub mod performance_benchmark {
714 use super::*;
715 use std::collections::BTreeMap;
716 use std::time::{Duration, Instant};
717
718 #[derive(Debug, Clone, Serialize, Deserialize)]
720 pub struct BenchmarkSuite {
721 pub results: BTreeMap<String, BenchmarkResult>,
723 pub summary: BenchmarkSummary,
725 pub config: BenchmarkConfig,
727 }
728
729 #[derive(Debug, Clone, Serialize, Deserialize)]
731 pub struct BenchmarkResult {
732 pub operation: String,
734 pub iterations: usize,
736 pub total_duration: Duration,
738 pub avg_duration: Duration,
740 pub min_duration: Duration,
742 pub max_duration: Duration,
744 pub std_deviation: Duration,
746 pub ops_per_second: f64,
748 pub memory_stats: MemoryStats,
750 pub custom_metrics: HashMap<String, f64>,
752 }
753
754 #[derive(Debug, Clone, Serialize, Deserialize)]
756 pub struct MemoryStats {
757 pub peak_memory_bytes: usize,
759 pub avg_memory_bytes: usize,
761 pub allocations: usize,
763 pub deallocations: usize,
765 }
766
767 #[derive(Debug, Clone, Serialize, Deserialize)]
769 pub struct BenchmarkSummary {
770 pub total_duration: Duration,
772 pub total_operations: usize,
774 pub overall_throughput: f64,
776 pub efficiency_score: f64,
778 pub bottlenecks: Vec<String>,
780 }
781
782 #[derive(Debug, Clone, Serialize, Deserialize)]
784 pub struct BenchmarkConfig {
785 pub warmup_iterations: usize,
787 pub measurement_iterations: usize,
789 pub confidence_level: f64,
791 pub enable_memory_profiling: bool,
793 pub enable_detailed_timing: bool,
795 }
796
797 impl Default for BenchmarkConfig {
798 fn default() -> Self {
799 Self {
800 warmup_iterations: 100,
801 measurement_iterations: 1000,
802 confidence_level: 0.95,
803 enable_memory_profiling: true,
804 enable_detailed_timing: true,
805 }
806 }
807 }
808
809 pub struct PrecisionTimer {
811 start_time: Instant,
812 lap_times: Vec<Duration>,
813 }
814
815 impl Default for PrecisionTimer {
816 fn default() -> Self {
817 Self::new()
818 }
819 }
820
821 impl PrecisionTimer {
822 pub fn new() -> Self {
823 Self {
824 start_time: Instant::now(),
825 lap_times: Vec::new(),
826 }
827 }
828
829 pub fn start(&mut self) {
831 self.start_time = Instant::now();
832 self.lap_times.clear();
833 }
834
835 pub fn lap(&mut self) -> Duration {
837 let lap_duration = self.start_time.elapsed();
838 self.lap_times.push(lap_duration);
839 lap_duration
840 }
841
842 pub fn stop(&self) -> Duration {
844 self.start_time.elapsed()
845 }
846
847 pub fn lap_times(&self) -> &[Duration] {
849 &self.lap_times
850 }
851 }
852
853 pub struct EmbeddingBenchmark {
855 config: BenchmarkConfig,
856 results: BTreeMap<String, BenchmarkResult>,
857 }
858
859 impl EmbeddingBenchmark {
860 pub fn new(config: BenchmarkConfig) -> Self {
861 Self {
862 config,
863 results: BTreeMap::new(),
864 }
865 }
866
867 pub fn benchmark<F, T>(&mut self, name: &str, mut operation: F) -> Result<T>
869 where
870 F: FnMut() -> Result<T>,
871 {
872 for _ in 0..self.config.warmup_iterations {
874 let _ = operation()?;
875 }
876
877 let mut durations = Vec::with_capacity(self.config.measurement_iterations);
878 let mut memory_snapshots = Vec::new();
879 let mut result = None;
880
881 for i in 0..self.config.measurement_iterations {
883 let memory_before = self.get_memory_usage();
884 let start = Instant::now();
885
886 let op_result = operation()?;
887
888 let duration = start.elapsed();
889 let memory_after = self.get_memory_usage();
890
891 durations.push(duration);
892
893 if self.config.enable_memory_profiling {
894 memory_snapshots.push((memory_before, memory_after));
895 }
896
897 if i == 0 {
899 result = Some(op_result);
900 }
901 }
902
903 let total_duration: Duration = durations.iter().sum();
905 let avg_duration = total_duration / durations.len() as u32;
906 let min_duration = *durations.iter().min().unwrap();
907 let max_duration = *durations.iter().max().unwrap();
908
909 let variance: f64 = durations
911 .iter()
912 .map(|d| {
913 let diff = d.as_nanos() as f64 - avg_duration.as_nanos() as f64;
914 diff * diff
915 })
916 .sum::<f64>()
917 / durations.len() as f64;
918 let std_deviation = Duration::from_nanos(variance.sqrt() as u64);
919
920 let ops_per_second = 1_000_000_000.0 / avg_duration.as_nanos() as f64;
921
922 let memory_stats = if self.config.enable_memory_profiling
924 && !memory_snapshots.is_empty()
925 {
926 let peak_memory = memory_snapshots
927 .iter()
928 .map(|(_, after)| after.peak_memory_bytes)
929 .max()
930 .unwrap_or(0);
931
932 let avg_memory = memory_snapshots
933 .iter()
934 .map(|(before, after)| (before.avg_memory_bytes + after.avg_memory_bytes) / 2)
935 .sum::<usize>()
936 / memory_snapshots.len();
937
938 MemoryStats {
939 peak_memory_bytes: peak_memory,
940 avg_memory_bytes: avg_memory,
941 allocations: memory_snapshots.len(),
942 deallocations: 0, }
944 } else {
945 MemoryStats {
946 peak_memory_bytes: 0,
947 avg_memory_bytes: 0,
948 allocations: 0,
949 deallocations: 0,
950 }
951 };
952
953 let benchmark_result = BenchmarkResult {
954 operation: name.to_string(),
955 iterations: self.config.measurement_iterations,
956 total_duration,
957 avg_duration,
958 min_duration,
959 max_duration,
960 std_deviation,
961 ops_per_second,
962 memory_stats,
963 custom_metrics: HashMap::new(),
964 };
965
966 self.results.insert(name.to_string(), benchmark_result);
967
968 result.ok_or_else(|| anyhow!("Failed to capture benchmark result"))
969 }
970
971 pub fn generate_report(&self) -> BenchmarkSuite {
973 let total_duration = self.results.values().map(|r| r.total_duration).sum();
974
975 let total_operations = self.results.len();
976
977 let overall_throughput = self.results.values().map(|r| r.ops_per_second).sum::<f64>()
978 / total_operations as f64;
979
980 let efficiency_score = self.calculate_efficiency_score();
982
983 let bottlenecks = self.identify_bottlenecks();
985
986 let summary = BenchmarkSummary {
987 total_duration,
988 total_operations,
989 overall_throughput,
990 efficiency_score,
991 bottlenecks,
992 };
993
994 BenchmarkSuite {
995 results: self.results.clone(),
996 summary,
997 config: self.config.clone(),
998 }
999 }
1000
1001 fn calculate_efficiency_score(&self) -> f64 {
1003 if self.results.is_empty() {
1004 return 0.0;
1005 }
1006
1007 let consistency_scores: Vec<f64> = self
1008 .results
1009 .values()
1010 .map(|result| {
1011 let cv = result.std_deviation.as_nanos() as f64
1013 / result.avg_duration.as_nanos() as f64;
1014 1.0 / (1.0 + cv)
1016 })
1017 .collect();
1018
1019 consistency_scores.iter().sum::<f64>() / consistency_scores.len() as f64
1020 }
1021
1022 fn identify_bottlenecks(&self) -> Vec<String> {
1024 let mut bottlenecks = Vec::new();
1025
1026 for (name, result) in &self.results {
1028 let cv =
1029 result.std_deviation.as_nanos() as f64 / result.avg_duration.as_nanos() as f64;
1030 if cv > 0.2 {
1031 bottlenecks.push(format!("High variance in {}: {:.2}% CV", name, cv * 100.0));
1033 }
1034 }
1035
1036 let avg_throughput = self.results.values().map(|r| r.ops_per_second).sum::<f64>()
1038 / self.results.len() as f64;
1039
1040 for (name, result) in &self.results {
1041 if result.ops_per_second < avg_throughput * 0.5 {
1042 bottlenecks.push(format!(
1044 "Slow operation {}: {:.0} ops/sec",
1045 name, result.ops_per_second
1046 ));
1047 }
1048 }
1049
1050 bottlenecks
1051 }
1052
1053 fn get_memory_usage(&self) -> MemoryStats {
1055 MemoryStats {
1058 peak_memory_bytes: 0,
1059 avg_memory_bytes: 0,
1060 allocations: 0,
1061 deallocations: 0,
1062 }
1063 }
1064 }
1065
1066 pub mod analysis {
1068 use super::*;
1069
1070 pub fn compare_benchmarks(
1072 baseline: &BenchmarkResult,
1073 comparison: &BenchmarkResult,
1074 ) -> BenchmarkComparison {
1075 let throughput_improvement =
1076 (comparison.ops_per_second - baseline.ops_per_second) / baseline.ops_per_second;
1077
1078 let latency_improvement = (baseline.avg_duration.as_nanos() as f64
1079 - comparison.avg_duration.as_nanos() as f64)
1080 / baseline.avg_duration.as_nanos() as f64;
1081
1082 let consistency_improvement = {
1083 let baseline_cv = baseline.std_deviation.as_nanos() as f64
1084 / baseline.avg_duration.as_nanos() as f64;
1085 let comparison_cv = comparison.std_deviation.as_nanos() as f64
1086 / comparison.avg_duration.as_nanos() as f64;
1087 (baseline_cv - comparison_cv) / baseline_cv
1088 };
1089
1090 BenchmarkComparison {
1091 baseline_name: baseline.operation.clone(),
1092 comparison_name: comparison.operation.clone(),
1093 throughput_improvement,
1094 latency_improvement,
1095 consistency_improvement,
1096 is_improvement: throughput_improvement > 0.0 && latency_improvement > 0.0,
1097 }
1098 }
1099
1100 pub fn analyze_regression(
1102 historical_results: &[BenchmarkResult],
1103 current_result: &BenchmarkResult,
1104 ) -> RegressionAnalysis {
1105 if historical_results.is_empty() {
1106 return RegressionAnalysis::default();
1107 }
1108
1109 let historical_avg_throughput = historical_results
1110 .iter()
1111 .map(|r| r.ops_per_second)
1112 .sum::<f64>()
1113 / historical_results.len() as f64;
1114
1115 let throughput_change = (current_result.ops_per_second - historical_avg_throughput)
1116 / historical_avg_throughput;
1117
1118 let is_regression = throughput_change < -0.05; RegressionAnalysis {
1121 throughput_change,
1122 is_regression,
1123 confidence_level: 0.95, analysis_notes: if is_regression {
1125 vec!["Performance regression detected".to_string()]
1126 } else {
1127 vec!["Performance within expected range".to_string()]
1128 },
1129 }
1130 }
1131 }
1132
1133 #[derive(Debug, Clone, Serialize, Deserialize)]
1135 pub struct BenchmarkComparison {
1136 pub baseline_name: String,
1137 pub comparison_name: String,
1138 pub throughput_improvement: f64,
1139 pub latency_improvement: f64,
1140 pub consistency_improvement: f64,
1141 pub is_improvement: bool,
1142 }
1143
1144 #[derive(Debug, Clone, Serialize, Deserialize)]
1146 pub struct RegressionAnalysis {
1147 pub throughput_change: f64,
1148 pub is_regression: bool,
1149 pub confidence_level: f64,
1150 pub analysis_notes: Vec<String>,
1151 }
1152
1153 impl Default for RegressionAnalysis {
1154 fn default() -> Self {
1155 Self {
1156 throughput_change: 0.0,
1157 is_regression: false,
1158 confidence_level: 0.0,
1159 analysis_notes: vec!["No historical data available".to_string()],
1160 }
1161 }
1162 }
1163}
1164
1165pub mod convenience {
1167 use super::*;
1168 use crate::{EmbeddingModel, ModelConfig, NamedNode, TransE, Triple};
1169
1170 pub fn create_simple_transe_model() -> TransE {
1172 let config = ModelConfig::default()
1173 .with_dimensions(128)
1174 .with_learning_rate(0.01)
1175 .with_max_epochs(100);
1176 TransE::new(config)
1177 }
1178
1179 pub fn parse_triple_from_string(triple_str: &str) -> Result<Triple> {
1181 let parts: Vec<&str> = triple_str.split_whitespace().collect();
1182 if parts.len() != 3 {
1183 return Err(anyhow!(
1184 "Invalid triple format. Expected 'subject predicate object', got: '{}'",
1185 triple_str
1186 ));
1187 }
1188
1189 let subject = if parts[0].starts_with("http") {
1190 NamedNode::new(parts[0])?
1191 } else {
1192 NamedNode::new(&format!("http://example.org/{}", parts[0]))?
1193 };
1194
1195 let predicate = if parts[1].starts_with("http") {
1196 NamedNode::new(parts[1])?
1197 } else {
1198 NamedNode::new(&format!("http://example.org/{}", parts[1]))?
1199 };
1200
1201 let object = if parts[2].starts_with("http") {
1202 NamedNode::new(parts[2])?
1203 } else {
1204 NamedNode::new(&format!("http://example.org/{}", parts[2]))?
1205 };
1206
1207 Ok(Triple::new(subject, predicate, object))
1208 }
1209
1210 pub fn add_triples_from_strings(
1212 model: &mut dyn EmbeddingModel,
1213 triple_strings: &[&str],
1214 ) -> Result<usize> {
1215 let mut added_count = 0;
1216 for triple_str in triple_strings {
1217 match parse_triple_from_string(triple_str) {
1218 Ok(triple) => {
1219 model.add_triple(triple)?;
1220 added_count += 1;
1221 }
1222 Err(e) => {
1223 eprintln!("Warning: Failed to parse triple '{triple_str}': {e}");
1224 }
1225 }
1226 }
1227 Ok(added_count)
1228 }
1229
1230 pub fn cosine_similarity(a: &[f64], b: &[f64]) -> Result<f64> {
1232 if a.len() != b.len() {
1233 return Err(anyhow!(
1234 "Vector dimensions don't match: {} vs {}",
1235 a.len(),
1236 b.len()
1237 ));
1238 }
1239
1240 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
1241 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
1242 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
1243
1244 if norm_a == 0.0 || norm_b == 0.0 {
1245 return Ok(0.0);
1246 }
1247
1248 Ok(dot_product / (norm_a * norm_b))
1249 }
1250
1251 pub fn generate_sample_kg_data(
1253 num_entities: usize,
1254 num_relations: usize,
1255 ) -> Vec<(String, String, String)> {
1256 let mut rng = Random::default();
1257 let mut triples = Vec::new();
1258
1259 let entities: Vec<String> = (0..num_entities).map(|i| format!("entity_{i}")).collect();
1260
1261 let relations: Vec<String> = (0..num_relations)
1262 .map(|i| format!("relation_{i}"))
1263 .collect();
1264
1265 for _ in 0..(num_entities * 2) {
1267 let subject = entities[rng.random_range(0..entities.len())].clone();
1268 let relation = relations[rng.random_range(0..relations.len())].clone();
1269 let object = entities[rng.random_range(0..entities.len())].clone();
1270
1271 if subject != object {
1272 triples.push((subject, relation, object));
1273 }
1274 }
1275
1276 triples
1277 }
1278
1279 pub fn quick_performance_test<F>(
1281 name: &str,
1282 iterations: usize,
1283 operation: F,
1284 ) -> std::time::Duration
1285 where
1286 F: Fn(),
1287 {
1288 let start = std::time::Instant::now();
1289 for _ in 0..iterations {
1290 operation();
1291 }
1292 let duration = start.elapsed();
1293
1294 println!(
1295 "Performance test '{}': {} iterations in {:?} ({:.2} ops/sec)",
1296 name,
1297 iterations,
1298 duration,
1299 iterations as f64 / duration.as_secs_f64()
1300 );
1301
1302 duration
1303 }
1304}
1305
1306pub mod performance_utils {
1308 use super::*;
1309
1310 type ProcessorFn<T> = Box<dyn Fn(&[T]) -> Result<()> + Send + Sync>;
1312
1313 pub struct BatchProcessor<T> {
1315 batch_size: usize,
1316 current_batch: Vec<T>,
1317 processor_fn: ProcessorFn<T>,
1318 }
1319
1320 impl<T> BatchProcessor<T> {
1321 pub fn new<F>(batch_size: usize, processor_fn: F) -> Self
1322 where
1323 F: Fn(&[T]) -> Result<()> + Send + Sync + 'static,
1324 {
1325 Self {
1326 batch_size,
1327 current_batch: Vec::with_capacity(batch_size),
1328 processor_fn: Box::new(processor_fn),
1329 }
1330 }
1331
1332 pub fn add(&mut self, item: T) -> Result<()> {
1333 self.current_batch.push(item);
1334 if self.current_batch.len() >= self.batch_size {
1335 return self.flush();
1336 }
1337 Ok(())
1338 }
1339
1340 pub fn flush(&mut self) -> Result<()> {
1341 if !self.current_batch.is_empty() {
1342 (self.processor_fn)(&self.current_batch)?;
1343 self.current_batch.clear();
1344 }
1345 Ok(())
1346 }
1347 }
1348
1349 #[derive(Debug, Clone)]
1351 pub struct MemoryMonitor {
1352 peak_usage: usize,
1353 current_usage: usize,
1354 allocations: usize,
1355 deallocations: usize,
1356 }
1357
1358 impl MemoryMonitor {
1359 pub fn new() -> Self {
1360 Self {
1361 peak_usage: 0,
1362 current_usage: 0,
1363 allocations: 0,
1364 deallocations: 0,
1365 }
1366 }
1367
1368 pub fn record_allocation(&mut self, size: usize) {
1369 self.current_usage += size;
1370 self.allocations += 1;
1371 if self.current_usage > self.peak_usage {
1372 self.peak_usage = self.current_usage;
1373 }
1374 }
1375
1376 pub fn record_deallocation(&mut self, size: usize) {
1377 self.current_usage = self.current_usage.saturating_sub(size);
1378 self.deallocations += 1;
1379 }
1380
1381 pub fn peak_usage(&self) -> usize {
1382 self.peak_usage
1383 }
1384
1385 pub fn current_usage(&self) -> usize {
1386 self.current_usage
1387 }
1388
1389 pub fn allocation_count(&self) -> usize {
1390 self.allocations
1391 }
1392 }
1393
1394 impl Default for MemoryMonitor {
1395 fn default() -> Self {
1396 Self::new()
1397 }
1398 }
1399}
1400
1401pub mod parallel_utils {
1403 use super::*;
1404 use rayon::prelude::*;
1405
1406 pub fn parallel_cosine_similarities(
1408 query_embedding: &[f32],
1409 candidate_embeddings: &[Vec<f32>],
1410 ) -> Result<Vec<f32>> {
1411 let similarities: Vec<f32> = candidate_embeddings
1412 .par_iter()
1413 .map(|embedding| oxirs_vec::similarity::cosine_similarity(query_embedding, embedding))
1414 .collect();
1415 Ok(similarities)
1416 }
1417
1418 pub fn parallel_batch_process<T, R, F>(
1420 items: &[T],
1421 batch_size: usize,
1422 processor: F,
1423 ) -> Result<Vec<R>>
1424 where
1425 T: Sync,
1426 R: Send,
1427 F: Fn(&[T]) -> Result<Vec<R>> + Sync + Send,
1428 {
1429 let results: Result<Vec<Vec<R>>> = items.par_chunks(batch_size).map(processor).collect();
1430
1431 Ok(results?.into_iter().flatten().collect())
1432 }
1433
1434 pub fn parallel_entity_frequencies(
1436 triples: &[(String, String, String)],
1437 ) -> HashMap<String, usize> {
1438 let entity_counts: HashMap<String, usize> = triples
1439 .par_iter()
1440 .fold(HashMap::new, |mut acc, (subject, _predicate, object)| {
1441 *acc.entry(subject.clone()).or_insert(0) += 1;
1442 *acc.entry(object.clone()).or_insert(0) += 1;
1443 acc
1444 })
1445 .reduce(HashMap::new, |mut acc1, acc2| {
1446 for (entity, count) in acc2 {
1447 *acc1.entry(entity).or_insert(0) += count;
1448 }
1449 acc1
1450 });
1451 entity_counts
1452 }
1453}
1454
1455#[cfg(test)]
1456mod tests {
1457 use super::*;
1458 use crate::quick_start::{
1459 add_triples_from_strings, cosine_similarity, create_simple_transe_model,
1460 generate_sample_kg_data, parse_triple_from_string, quick_performance_test,
1461 };
1462 use crate::EmbeddingModel;
1463 use std::io::Write;
1464 use tempfile::NamedTempFile;
1465
1466 #[test]
1467 fn test_load_triples_from_tsv() -> Result<()> {
1468 let mut temp_file = NamedTempFile::new()?;
1469 writeln!(temp_file, "subject\tpredicate\tobject")?;
1470 writeln!(temp_file, "alice\tknows\tbob")?;
1471 writeln!(temp_file, "bob\tlikes\tcharlie")?;
1472
1473 let triples = data_loader::load_triples_from_tsv(temp_file.path())?;
1474 assert_eq!(triples.len(), 2);
1475 assert_eq!(
1476 triples[0],
1477 ("alice".to_string(), "knows".to_string(), "bob".to_string())
1478 );
1479
1480 Ok(())
1481 }
1482
1483 #[test]
1484 fn test_dataset_split() -> Result<()> {
1485 let triples = vec![
1486 ("a".to_string(), "r1".to_string(), "b".to_string()),
1487 ("b".to_string(), "r2".to_string(), "c".to_string()),
1488 ("c".to_string(), "r3".to_string(), "d".to_string()),
1489 ("d".to_string(), "r4".to_string(), "e".to_string()),
1490 ];
1491
1492 let split = dataset_splitter::split_dataset(triples, 0.6, 0.2, Some(42))?;
1493
1494 assert_eq!(split.train.len(), 2);
1495 assert_eq!(split.validation.len(), 0); assert_eq!(split.test.len(), 2);
1497
1498 Ok(())
1499 }
1500
1501 #[test]
1502 fn test_load_triples_from_jsonl() -> Result<()> {
1503 let mut temp_file = NamedTempFile::new()?;
1504 writeln!(
1505 temp_file,
1506 r#"{{"subject": "alice", "predicate": "knows", "object": "bob"}}"#
1507 )?;
1508 writeln!(
1509 temp_file,
1510 r#"{{"subject": "bob", "predicate": "likes", "object": "charlie"}}"#
1511 )?;
1512
1513 let triples = data_loader::load_triples_from_jsonl(temp_file.path())?;
1514 assert_eq!(triples.len(), 2);
1515 assert_eq!(
1516 triples[0],
1517 ("alice".to_string(), "knows".to_string(), "bob".to_string())
1518 );
1519
1520 Ok(())
1521 }
1522
1523 #[test]
1524 fn test_save_triples_to_jsonl() -> Result<()> {
1525 let triples = vec![
1526 ("alice".to_string(), "knows".to_string(), "bob".to_string()),
1527 (
1528 "bob".to_string(),
1529 "likes".to_string(),
1530 "charlie".to_string(),
1531 ),
1532 ];
1533
1534 let temp_file = NamedTempFile::new()?;
1535 data_loader::save_triples_to_jsonl(&triples, temp_file.path())?;
1536
1537 let loaded_triples = data_loader::load_triples_from_jsonl(temp_file.path())?;
1538 assert_eq!(loaded_triples, triples);
1539
1540 Ok(())
1541 }
1542
1543 #[test]
1544 fn test_load_triples_auto_detect() -> Result<()> {
1545 let mut tsv_file = NamedTempFile::with_suffix(".tsv")?;
1547 writeln!(tsv_file, "subject\tpredicate\tobject")?;
1548 writeln!(tsv_file, "alice\tknows\tbob")?;
1549
1550 let triples = data_loader::load_triples_auto_detect(tsv_file.path())?;
1551 assert_eq!(triples.len(), 1);
1552
1553 let mut jsonl_file = NamedTempFile::with_suffix(".jsonl")?;
1555 writeln!(
1556 jsonl_file,
1557 r#"{{"subject": "alice", "predicate": "knows", "object": "bob"}}"#
1558 )?;
1559
1560 let triples = data_loader::load_triples_auto_detect(jsonl_file.path())?;
1561 assert_eq!(triples.len(), 1);
1562 assert_eq!(
1563 triples[0],
1564 ("alice".to_string(), "knows".to_string(), "bob".to_string())
1565 );
1566
1567 Ok(())
1568 }
1569
1570 #[test]
1571 fn test_dataset_statistics() {
1572 let triples = vec![
1573 ("alice".to_string(), "knows".to_string(), "bob".to_string()),
1574 (
1575 "bob".to_string(),
1576 "knows".to_string(),
1577 "charlie".to_string(),
1578 ),
1579 (
1580 "alice".to_string(),
1581 "likes".to_string(),
1582 "charlie".to_string(),
1583 ),
1584 ];
1585
1586 let stats = compute_dataset_statistics(&triples);
1587
1588 assert_eq!(stats.num_triples, 3);
1589 assert_eq!(stats.num_entities, 3); assert_eq!(stats.num_relations, 2); assert!(stats.avg_degree > 0.0);
1592 }
1593
1594 #[test]
1596 fn test_create_simple_transe_model() {
1597 let model = create_simple_transe_model();
1598 assert_eq!(model.config().dimensions, 128);
1599 assert_eq!(model.config().learning_rate, 0.01);
1600 assert_eq!(model.config().max_epochs, 100);
1601 }
1602
1603 #[test]
1604 fn test_parse_triple_from_string() -> Result<()> {
1605 let triple = parse_triple_from_string("alice knows bob")?;
1606 assert_eq!(triple.subject.iri.as_str(), "http://example.org/alice");
1607 assert_eq!(triple.predicate.iri.as_str(), "http://example.org/knows");
1608 assert_eq!(triple.object.iri.as_str(), "http://example.org/bob");
1609
1610 let triple2 = parse_triple_from_string(
1612 "http://example.org/alice http://example.org/knows http://example.org/bob",
1613 )?;
1614 assert_eq!(triple2.subject.iri.as_str(), "http://example.org/alice");
1615
1616 assert!(parse_triple_from_string("alice knows").is_err());
1618
1619 Ok(())
1620 }
1621
1622 #[test]
1623 fn test_add_triples_from_strings() -> Result<()> {
1624 let mut model = create_simple_transe_model();
1625 let triple_strings = &[
1626 "alice knows bob",
1627 "bob likes charlie",
1628 "charlie follows alice",
1629 ];
1630
1631 let added_count = add_triples_from_strings(&mut model, triple_strings)?;
1632 assert_eq!(added_count, 3);
1633
1634 Ok(())
1635 }
1636
1637 #[test]
1638 fn test_cosine_similarity() -> Result<()> {
1639 let a = vec![1.0, 0.0, 0.0];
1640 let b = vec![1.0, 0.0, 0.0];
1641 let similarity = cosine_similarity(&a, &b)?;
1642 assert!((similarity - 1.0).abs() < 1e-10);
1643
1644 let c = vec![0.0, 1.0, 0.0];
1645 let similarity2 = cosine_similarity(&a, &c)?;
1646 assert!((similarity2 - 0.0).abs() < 1e-10);
1647
1648 let d = vec![1.0, 0.0];
1650 assert!(cosine_similarity(&a, &d).is_err());
1651
1652 Ok(())
1653 }
1654
1655 #[test]
1656 fn test_generate_sample_kg_data() {
1657 let triples = generate_sample_kg_data(5, 3);
1658 assert!(!triples.is_empty());
1659
1660 for (subject, relation, object) in &triples {
1662 assert!(subject.starts_with("http://example.org/entity_"));
1663 assert!(relation.starts_with("http://example.org/relation_"));
1664 assert!(object.starts_with("http://example.org/entity_"));
1665 assert_ne!(subject, object); }
1667 }
1668
1669 #[test]
1670 fn test_quick_performance_test() {
1671 let duration = quick_performance_test("test_operation", 100, || {
1672 let _sum: i32 = (1..10).sum();
1674 });
1675
1676 let _nanos = duration.as_nanos();
1679 }
1680}