Skip to main content

oxirs_embed/
utils.rs

1//! Utility functions and helpers for embedding operations
2
3// Removed unused imports
4use anyhow::{anyhow, Result};
5use scirs2_core::ndarray_ext::{Array1, Array2};
6#[allow(unused_imports)]
7use scirs2_core::random::{Random, Rng};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::fs;
11use std::path::Path;
12
13/// Data loading utilities
14pub mod data_loader {
15    use super::*;
16    use std::io::{BufRead, BufReader};
17
18    /// Load triples from TSV file format
19    pub fn load_triples_from_tsv<P: AsRef<Path>>(
20        file_path: P,
21    ) -> Result<Vec<(String, String, String)>> {
22        let file = fs::File::open(file_path)?;
23        let reader = BufReader::new(file);
24        let mut triples = Vec::new();
25
26        for (line_num, line) in reader.lines().enumerate() {
27            let line = line?;
28            if line.trim().is_empty() || line.starts_with('#') {
29                continue; // Skip empty lines and comments
30            }
31
32            // Skip header line (first line that contains common header terms)
33            if line_num == 0
34                && (line.contains("subject")
35                    || line.contains("predicate")
36                    || line.contains("object"))
37            {
38                continue;
39            }
40
41            let parts: Vec<&str> = line.split('\t').collect();
42            if parts.len() >= 3 {
43                let subject = parts[0].trim().to_string();
44                let predicate = parts[1].trim().to_string();
45                let object = parts[2].trim().to_string();
46                triples.push((subject, predicate, object));
47            } else {
48                eprintln!(
49                    "Warning: Invalid triple format at line {}: {}",
50                    line_num + 1,
51                    line
52                );
53            }
54        }
55
56        Ok(triples)
57    }
58
59    /// Load triples from CSV file format
60    pub fn load_triples_from_csv<P: AsRef<Path>>(
61        file_path: P,
62    ) -> Result<Vec<(String, String, String)>> {
63        let file = fs::File::open(file_path)?;
64        let reader = BufReader::new(file);
65        let mut triples = Vec::new();
66        let mut is_first_line = true;
67
68        for (line_num, line) in reader.lines().enumerate() {
69            let line = line?;
70            if is_first_line {
71                is_first_line = false;
72                // Skip header if it looks like one
73                if line.to_lowercase().contains("subject")
74                    && line.to_lowercase().contains("predicate")
75                {
76                    continue;
77                }
78            }
79
80            if line.trim().is_empty() {
81                continue;
82            }
83
84            let parts: Vec<&str> = line.split(',').collect();
85            if parts.len() >= 3 {
86                let subject = parts[0].trim().trim_matches('"').to_string();
87                let predicate = parts[1].trim().trim_matches('"').to_string();
88                let object = parts[2].trim().trim_matches('"').to_string();
89                triples.push((subject, predicate, object));
90            } else {
91                eprintln!(
92                    "Warning: Invalid triple format at line {}: {}",
93                    line_num + 1,
94                    line
95                );
96            }
97        }
98
99        Ok(triples)
100    }
101
102    /// Load triples from N-Triples format
103    pub fn load_triples_from_ntriples<P: AsRef<Path>>(
104        file_path: P,
105    ) -> Result<Vec<(String, String, String)>> {
106        let file = fs::File::open(file_path)?;
107        let reader = BufReader::new(file);
108        let mut triples = Vec::new();
109
110        for (line_num, line) in reader.lines().enumerate() {
111            let line = line?;
112            let line = line.trim();
113
114            if line.is_empty() || line.starts_with('#') {
115                continue;
116            }
117
118            // Simple N-Triples parser (very basic)
119            if let Some(triple) = parse_ntriple_line(line) {
120                triples.push(triple);
121            } else {
122                eprintln!(
123                    "Warning: Failed to parse N-Triple at line {}: {}",
124                    line_num + 1,
125                    line
126                );
127            }
128        }
129
130        Ok(triples)
131    }
132
133    /// Parse a single N-Triple line
134    fn parse_ntriple_line(line: &str) -> Option<(String, String, String)> {
135        let line = line.trim_end_matches(" .");
136        let parts: Vec<&str> = line.split_whitespace().collect();
137
138        if parts.len() >= 3 {
139            let subject = clean_uri_or_literal(parts[0]);
140            let predicate = clean_uri_or_literal(parts[1]);
141            let object = clean_uri_or_literal(&parts[2..].join(" "));
142
143            Some((subject, predicate, object))
144        } else {
145            None
146        }
147    }
148
149    /// Clean URI or literal from N-Triple format
150    fn clean_uri_or_literal(term: &str) -> String {
151        if term.starts_with('<') && term.ends_with('>') {
152            term[1..term.len() - 1].to_string()
153        } else if term.starts_with('"') && term.contains('"') {
154            // Handle literals - just take the string part for now
155            let end_quote = term.rfind('"').unwrap_or(term.len());
156            term[1..end_quote].to_string()
157        } else {
158            term.to_string()
159        }
160    }
161
162    /// Load triples from JSON Lines format (one triple per line as JSON)
163    pub fn load_triples_from_jsonl<P: AsRef<Path>>(
164        file_path: P,
165    ) -> Result<Vec<(String, String, String)>> {
166        let file = fs::File::open(file_path)?;
167        let reader = BufReader::new(file);
168        let mut triples = Vec::new();
169
170        for (line_num, line) in reader.lines().enumerate() {
171            let line = line?;
172            if line.trim().is_empty() {
173                continue;
174            }
175
176            match serde_json::from_str::<serde_json::Value>(&line) {
177                Ok(json) => {
178                    if let (Some(subject), Some(predicate), Some(object)) = (
179                        json["subject"].as_str(),
180                        json["predicate"].as_str(),
181                        json["object"].as_str(),
182                    ) {
183                        triples.push((
184                            subject.to_string(),
185                            predicate.to_string(),
186                            object.to_string(),
187                        ));
188                    } else {
189                        eprintln!(
190                            "Warning: Invalid JSON structure at line {}: {}",
191                            line_num + 1,
192                            line
193                        );
194                    }
195                }
196                Err(e) => {
197                    eprintln!(
198                        "Warning: Failed to parse JSON at line {}: {} - Error: {}",
199                        line_num + 1,
200                        line,
201                        e
202                    );
203                }
204            }
205        }
206
207        Ok(triples)
208    }
209
210    /// Save triples to TSV format
211    pub fn save_triples_to_tsv<P: AsRef<Path>>(
212        triples: &[(String, String, String)],
213        file_path: P,
214    ) -> Result<()> {
215        let mut content = String::new();
216        content.push_str("subject\tpredicate\tobject\n");
217
218        for (subject, predicate, object) in triples {
219            content.push_str(&format!("{subject}\t{predicate}\t{object}\n"));
220        }
221
222        fs::write(file_path, content)?;
223        Ok(())
224    }
225
226    /// Save triples to JSON Lines format
227    pub fn save_triples_to_jsonl<P: AsRef<Path>>(
228        triples: &[(String, String, String)],
229        file_path: P,
230    ) -> Result<()> {
231        use std::io::Write;
232        let mut file = fs::File::create(file_path)?;
233
234        for (subject, predicate, object) in triples {
235            let json = serde_json::json!({
236                "subject": subject,
237                "predicate": predicate,
238                "object": object
239            });
240            writeln!(file, "{json}")?;
241        }
242
243        Ok(())
244    }
245
246    /// Auto-detect file format and load triples accordingly
247    pub fn load_triples_auto_detect<P: AsRef<Path>>(
248        file_path: P,
249    ) -> Result<Vec<(String, String, String)>> {
250        let path = file_path.as_ref();
251        let extension = path
252            .extension()
253            .and_then(|ext| ext.to_str())
254            .unwrap_or("")
255            .to_lowercase();
256
257        match extension.as_str() {
258            "tsv" => load_triples_from_tsv(path),
259            "csv" => load_triples_from_csv(path),
260            "nt" | "ntriples" => load_triples_from_ntriples(path),
261            "jsonl" | "ndjson" => load_triples_from_jsonl(path),
262            _ => {
263                // Try to auto-detect based on content
264                eprintln!(
265                    "Warning: Unknown file extension '{extension}', attempting auto-detection"
266                );
267
268                // Try TSV first (most common)
269                if let Ok(triples) = load_triples_from_tsv(path) {
270                    if !triples.is_empty() {
271                        return Ok(triples);
272                    }
273                }
274
275                // Try N-Triples
276                if let Ok(triples) = load_triples_from_ntriples(path) {
277                    if !triples.is_empty() {
278                        return Ok(triples);
279                    }
280                }
281
282                // Try JSON Lines
283                if let Ok(triples) = load_triples_from_jsonl(path) {
284                    if !triples.is_empty() {
285                        return Ok(triples);
286                    }
287                }
288
289                // Finally try CSV
290                load_triples_from_csv(path)
291            }
292        }
293    }
294}
295
296/// Dataset splitting utilities
297pub mod dataset_splitter {
298    use super::*;
299
300    /// Split dataset into train/validation/test sets
301    pub fn split_dataset(
302        triples: Vec<(String, String, String)>,
303        train_ratio: f64,
304        val_ratio: f64,
305        seed: Option<u64>,
306    ) -> Result<DatasetSplit> {
307        if train_ratio + val_ratio >= 1.0 {
308            return Err(anyhow!(
309                "Train and validation ratios must sum to less than 1.0"
310            ));
311        }
312
313        let mut rng = if let Some(s) = seed {
314            Random::seed(s)
315        } else {
316            Random::seed(42) // Use a default seed for consistency
317        };
318
319        let mut shuffled_triples = triples;
320        // Manual Fisher-Yates shuffle
321        for i in (1..shuffled_triples.len()).rev() {
322            let j = rng.random_range(0..i + 1);
323            shuffled_triples.swap(i, j);
324        }
325
326        let total = shuffled_triples.len();
327        let train_end = (total as f64 * train_ratio) as usize;
328        let val_end = train_end + (total as f64 * val_ratio) as usize;
329
330        let train_triples = shuffled_triples[..train_end].to_vec();
331        let val_triples = shuffled_triples[train_end..val_end].to_vec();
332        let test_triples = shuffled_triples[val_end..].to_vec();
333
334        Ok(DatasetSplit {
335            train: train_triples,
336            validation: val_triples,
337            test: test_triples,
338        })
339    }
340
341    /// Split dataset ensuring no entity leakage between splits
342    pub fn split_dataset_no_leakage(
343        triples: Vec<(String, String, String)>,
344        train_ratio: f64,
345        val_ratio: f64,
346        seed: Option<u64>,
347    ) -> Result<DatasetSplit> {
348        // Group triples by entities with pre-allocated capacity for better performance
349        let mut entity_triples: HashMap<String, Vec<(String, String, String)>> =
350            HashMap::with_capacity(triples.len() / 2); // Estimate capacity
351
352        for triple in &triples {
353            let entities = [&triple.0, &triple.2];
354            for entity in entities {
355                entity_triples
356                    .entry(entity.clone())
357                    .or_default()
358                    .push(triple.clone());
359            }
360        }
361
362        // Split entities first, then assign triples - optimized allocation
363        let entities: Vec<String> = entity_triples.keys().cloned().collect();
364        let dummy_string = "dummy".to_string();
365        let entity_split = split_dataset(
366            entities
367                .into_iter()
368                .map(|e| (e, dummy_string.clone(), dummy_string.clone()))
369                .collect(),
370            train_ratio,
371            val_ratio,
372            seed,
373        )?;
374
375        let train_entities: HashSet<String> =
376            entity_split.train.into_iter().map(|(e, _, _)| e).collect();
377        let val_entities: HashSet<String> = entity_split
378            .validation
379            .into_iter()
380            .map(|(e, _, _)| e)
381            .collect();
382        let test_entities: HashSet<String> =
383            entity_split.test.into_iter().map(|(e, _, _)| e).collect();
384
385        // Assign triples based on entity membership with pre-allocated capacity
386        let estimated_capacity = triples.len() / 3;
387        let mut train_triples = Vec::with_capacity(estimated_capacity);
388        let mut val_triples = Vec::with_capacity(estimated_capacity);
389        let mut test_triples = Vec::with_capacity(estimated_capacity);
390
391        for (entity, entity_triple_list) in entity_triples {
392            if train_entities.contains(&entity) {
393                train_triples.extend(entity_triple_list);
394            } else if val_entities.contains(&entity) {
395                val_triples.extend(entity_triple_list);
396            } else if test_entities.contains(&entity) {
397                test_triples.extend(entity_triple_list);
398            }
399        }
400
401        // Remove duplicates
402        train_triples.sort();
403        train_triples.dedup();
404        val_triples.sort();
405        val_triples.dedup();
406        test_triples.sort();
407        test_triples.dedup();
408
409        Ok(DatasetSplit {
410            train: train_triples,
411            validation: val_triples,
412            test: test_triples,
413        })
414    }
415}
416
417/// Dataset split result
418#[derive(Debug, Clone)]
419pub struct DatasetSplit {
420    pub train: Vec<(String, String, String)>,
421    pub validation: Vec<(String, String, String)>,
422    pub test: Vec<(String, String, String)>,
423}
424
425/// Statistics about a dataset
426#[derive(Debug, Clone, Serialize, Deserialize)]
427pub struct DatasetStatistics {
428    pub num_triples: usize,
429    pub num_entities: usize,
430    pub num_relations: usize,
431    pub entity_frequency: HashMap<String, usize>,
432    pub relation_frequency: HashMap<String, usize>,
433    pub avg_degree: f64,
434    pub density: f64,
435}
436
437/// Compute dataset statistics
438pub fn compute_dataset_statistics(triples: &[(String, String, String)]) -> DatasetStatistics {
439    let mut entities = HashSet::new();
440    let mut relations = HashSet::new();
441    let mut entity_frequency = HashMap::new();
442    let mut relation_frequency = HashMap::new();
443
444    for (subject, predicate, object) in triples {
445        entities.insert(subject.clone());
446        entities.insert(object.clone());
447        relations.insert(predicate.clone());
448
449        *entity_frequency.entry(subject.clone()).or_insert(0) += 1;
450        *entity_frequency.entry(object.clone()).or_insert(0) += 1;
451        *relation_frequency.entry(predicate.clone()).or_insert(0) += 1;
452    }
453
454    let num_entities = entities.len();
455    let num_relations = relations.len();
456    let num_triples = triples.len();
457
458    let avg_degree = if num_entities > 0 {
459        (num_triples * 2) as f64 / num_entities as f64
460    } else {
461        0.0
462    };
463
464    let max_possible_edges = num_entities * num_entities;
465    let density = if max_possible_edges > 0 {
466        num_triples as f64 / max_possible_edges as f64
467    } else {
468        0.0
469    };
470
471    DatasetStatistics {
472        num_triples,
473        num_entities,
474        num_relations,
475        entity_frequency,
476        relation_frequency,
477        avg_degree,
478        density,
479    }
480}
481
482/// Embedding dimension analysis utilities
483pub mod embedding_analysis {
484    use super::*;
485
486    /// Analyze embedding distribution
487    pub fn analyze_embedding_distribution(embeddings: &Array2<f64>) -> EmbeddingDistributionStats {
488        let flat_values: Vec<f64> = embeddings.iter().cloned().collect();
489
490        let mean = flat_values.iter().sum::<f64>() / flat_values.len() as f64;
491        let variance =
492            flat_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / flat_values.len() as f64;
493        let std_dev = variance.sqrt();
494
495        let mut sorted_values = flat_values.clone();
496        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
497
498        let min_val = sorted_values[0];
499        let max_val = sorted_values[sorted_values.len() - 1];
500        let median = sorted_values[sorted_values.len() / 2];
501
502        EmbeddingDistributionStats {
503            mean,
504            std_dev,
505            variance,
506            min: min_val,
507            max: max_val,
508            median,
509            num_parameters: embeddings.len(),
510        }
511    }
512
513    /// Compute embedding norms
514    pub fn compute_embedding_norms(embeddings: &Array2<f64>) -> Vec<f64> {
515        embeddings
516            .rows()
517            .into_iter()
518            .map(|row| row.dot(&row).sqrt())
519            .collect()
520    }
521
522    /// Analyze embedding similarities
523    pub fn analyze_embedding_similarities(
524        embeddings: &Array2<f64>,
525        sample_size: usize,
526    ) -> SimilarityStats {
527        let num_embeddings = embeddings.nrows();
528        let mut similarities = Vec::new();
529
530        let sample_size = sample_size.min(num_embeddings * (num_embeddings - 1) / 2);
531        let mut rng = Random::default();
532
533        for _ in 0..sample_size {
534            let i = rng.random_range(0..num_embeddings);
535            let j = rng.random_range(0..num_embeddings);
536
537            if i != j {
538                let emb_i = embeddings.row(i);
539                let emb_j = embeddings.row(j);
540                let similarity = cosine_similarity(&emb_i.to_owned(), &emb_j.to_owned());
541                similarities.push(similarity);
542            }
543        }
544
545        similarities.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
546
547        let mean_similarity = similarities.iter().sum::<f64>() / similarities.len() as f64;
548        let min_similarity = similarities[0];
549        let max_similarity = similarities[similarities.len() - 1];
550        let median_similarity = similarities[similarities.len() / 2];
551
552        SimilarityStats {
553            mean_similarity,
554            min_similarity,
555            max_similarity,
556            median_similarity,
557            num_comparisons: similarities.len(),
558        }
559    }
560
561    /// Cosine similarity between two vectors
562    fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
563        let dot_product = a.dot(b);
564        let norm_a = a.dot(a).sqrt();
565        let norm_b = b.dot(b).sqrt();
566
567        if norm_a > 1e-10 && norm_b > 1e-10 {
568            dot_product / (norm_a * norm_b)
569        } else {
570            0.0
571        }
572    }
573}
574
575/// Embedding distribution statistics
576#[derive(Debug, Clone)]
577pub struct EmbeddingDistributionStats {
578    pub mean: f64,
579    pub std_dev: f64,
580    pub variance: f64,
581    pub min: f64,
582    pub max: f64,
583    pub median: f64,
584    pub num_parameters: usize,
585}
586
587/// Similarity statistics
588#[derive(Debug, Clone)]
589pub struct SimilarityStats {
590    pub mean_similarity: f64,
591    pub min_similarity: f64,
592    pub max_similarity: f64,
593    pub median_similarity: f64,
594    pub num_comparisons: usize,
595}
596
597/// Graph analysis utilities
598pub mod graph_analysis {
599    use super::*;
600
601    /// Compute graph metrics for knowledge graph - optimized for performance
602    pub fn compute_graph_metrics(triples: &[(String, String, String)]) -> GraphMetrics {
603        // Pre-allocate with estimated capacity for better performance
604        let estimated_entities = triples.len(); // Conservative estimate
605        let estimated_relations = triples.len() / 10; // Rough estimate
606
607        let mut entity_degrees: HashMap<String, usize> = HashMap::with_capacity(estimated_entities);
608        let mut relation_counts: HashMap<String, usize> =
609            HashMap::with_capacity(estimated_relations);
610        let mut entities = HashSet::with_capacity(estimated_entities);
611
612        for (subject, predicate, object) in triples {
613            entities.insert(subject.clone());
614            entities.insert(object.clone());
615
616            *entity_degrees.entry(subject.clone()).or_insert(0) += 1;
617            *entity_degrees.entry(object.clone()).or_insert(0) += 1;
618            *relation_counts.entry(predicate.clone()).or_insert(0) += 1;
619        }
620
621        let num_entities = entities.len();
622        let num_relations = relation_counts.len();
623        let num_triples = triples.len();
624
625        let degrees: Vec<usize> = entity_degrees.values().cloned().collect();
626        let avg_degree = degrees.iter().sum::<usize>() as f64 / degrees.len() as f64;
627        let max_degree = degrees.iter().max().cloned().unwrap_or(0);
628        let min_degree = degrees.iter().min().cloned().unwrap_or(0);
629
630        GraphMetrics {
631            num_entities,
632            num_relations,
633            num_triples,
634            avg_degree,
635            max_degree,
636            min_degree,
637            density: num_triples as f64 / (num_entities * num_entities) as f64,
638        }
639    }
640}
641
642/// Graph metrics
643#[derive(Debug, Clone, Serialize, Deserialize)]
644pub struct GraphMetrics {
645    pub num_entities: usize,
646    pub num_relations: usize,
647    pub num_triples: usize,
648    pub avg_degree: f64,
649    pub max_degree: usize,
650    pub min_degree: usize,
651    pub density: f64,
652}
653
654/// Progress tracking utilities
655#[derive(Debug)]
656pub struct ProgressTracker {
657    total: usize,
658    current: usize,
659    start_time: std::time::Instant,
660    last_update: std::time::Instant,
661    update_interval: std::time::Duration,
662}
663
664impl ProgressTracker {
665    /// Create a new progress tracker
666    pub fn new(total: usize) -> Self {
667        let now = std::time::Instant::now();
668        Self {
669            total,
670            current: 0,
671            start_time: now,
672            last_update: now,
673            update_interval: std::time::Duration::from_secs(1),
674        }
675    }
676
677    /// Update progress
678    pub fn update(&mut self, current: usize) {
679        self.current = current;
680        let now = std::time::Instant::now();
681
682        if now.duration_since(self.last_update) >= self.update_interval {
683            self.print_progress();
684            self.last_update = now;
685        }
686    }
687
688    /// Print current progress
689    fn print_progress(&self) {
690        let percentage = (self.current as f64 / self.total as f64) * 100.0;
691        let elapsed = self.start_time.elapsed().as_secs_f64();
692        let rate = self.current as f64 / elapsed;
693
694        println!(
695            "Progress: {}/{} ({:.1}%) - {:.1} items/sec",
696            self.current, self.total, percentage, rate
697        );
698    }
699
700    /// Finish and print final statistics
701    pub fn finish(&self) {
702        let elapsed = self.start_time.elapsed().as_secs_f64();
703        let rate = self.total as f64 / elapsed;
704
705        println!(
706            "Completed: {} items in {:.2}s ({:.1} items/sec)",
707            self.total, elapsed, rate
708        );
709    }
710}
711
712/// Performance benchmarking and profiling utilities
713pub mod performance_benchmark {
714    use super::*;
715    use std::collections::BTreeMap;
716    use std::time::{Duration, Instant};
717
718    /// Comprehensive performance benchmarking for embedding operations
719    #[derive(Debug, Clone, Serialize, Deserialize)]
720    pub struct BenchmarkSuite {
721        /// Results organized by operation type
722        pub results: BTreeMap<String, BenchmarkResult>,
723        /// Overall benchmark statistics
724        pub summary: BenchmarkSummary,
725        /// Benchmark configuration
726        pub config: BenchmarkConfig,
727    }
728
729    /// Individual benchmark result for a specific operation
730    #[derive(Debug, Clone, Serialize, Deserialize)]
731    pub struct BenchmarkResult {
732        /// Operation name
733        pub operation: String,
734        /// Total number of iterations
735        pub iterations: usize,
736        /// Total elapsed time
737        pub total_duration: Duration,
738        /// Average time per operation
739        pub avg_duration: Duration,
740        /// Minimum time observed
741        pub min_duration: Duration,
742        /// Maximum time observed
743        pub max_duration: Duration,
744        /// Standard deviation of durations
745        pub std_deviation: Duration,
746        /// Operations per second
747        pub ops_per_second: f64,
748        /// Memory usage statistics
749        pub memory_stats: MemoryStats,
750        /// Additional metrics
751        pub custom_metrics: HashMap<String, f64>,
752    }
753
754    /// Memory usage statistics
755    #[derive(Debug, Clone, Serialize, Deserialize)]
756    pub struct MemoryStats {
757        /// Peak memory usage (bytes)
758        pub peak_memory_bytes: usize,
759        /// Average memory usage (bytes)
760        pub avg_memory_bytes: usize,
761        /// Memory allocations count
762        pub allocations: usize,
763        /// Memory deallocations count
764        pub deallocations: usize,
765    }
766
767    /// Overall benchmark summary
768    #[derive(Debug, Clone, Serialize, Deserialize)]
769    pub struct BenchmarkSummary {
770        /// Total benchmark duration
771        pub total_duration: Duration,
772        /// Number of operations benchmarked
773        pub total_operations: usize,
774        /// Overall throughput (ops/sec)
775        pub overall_throughput: f64,
776        /// Performance efficiency score (0.0-1.0)
777        pub efficiency_score: f64,
778        /// Bottleneck analysis
779        pub bottlenecks: Vec<String>,
780    }
781
782    /// Benchmarking configuration
783    #[derive(Debug, Clone, Serialize, Deserialize)]
784    pub struct BenchmarkConfig {
785        /// Number of warmup iterations
786        pub warmup_iterations: usize,
787        /// Number of measurement iterations
788        pub measurement_iterations: usize,
789        /// Target confidence level (0.0-1.0)
790        pub confidence_level: f64,
791        /// Enable memory profiling
792        pub enable_memory_profiling: bool,
793        /// Enable detailed timing analysis
794        pub enable_detailed_timing: bool,
795    }
796
797    impl Default for BenchmarkConfig {
798        fn default() -> Self {
799            Self {
800                warmup_iterations: 100,
801                measurement_iterations: 1000,
802                confidence_level: 0.95,
803                enable_memory_profiling: true,
804                enable_detailed_timing: true,
805            }
806        }
807    }
808
809    /// High-precision timer for micro-benchmarking
810    pub struct PrecisionTimer {
811        start_time: Instant,
812        lap_times: Vec<Duration>,
813    }
814
815    impl Default for PrecisionTimer {
816        fn default() -> Self {
817            Self::new()
818        }
819    }
820
821    impl PrecisionTimer {
822        pub fn new() -> Self {
823            Self {
824                start_time: Instant::now(),
825                lap_times: Vec::new(),
826            }
827        }
828
829        /// Start timing
830        pub fn start(&mut self) {
831            self.start_time = Instant::now();
832            self.lap_times.clear();
833        }
834
835        /// Record a lap time
836        pub fn lap(&mut self) -> Duration {
837            let lap_duration = self.start_time.elapsed();
838            self.lap_times.push(lap_duration);
839            lap_duration
840        }
841
842        /// Stop timing and return final duration
843        pub fn stop(&self) -> Duration {
844            self.start_time.elapsed()
845        }
846
847        /// Get all recorded lap times
848        pub fn lap_times(&self) -> &[Duration] {
849            &self.lap_times
850        }
851    }
852
853    /// Benchmarking framework for embedding operations
854    pub struct EmbeddingBenchmark {
855        config: BenchmarkConfig,
856        results: BTreeMap<String, BenchmarkResult>,
857    }
858
859    impl EmbeddingBenchmark {
860        pub fn new(config: BenchmarkConfig) -> Self {
861            Self {
862                config,
863                results: BTreeMap::new(),
864            }
865        }
866
867        /// Benchmark a function with comprehensive timing and memory analysis
868        pub fn benchmark<F, T>(&mut self, name: &str, mut operation: F) -> Result<T>
869        where
870            F: FnMut() -> Result<T>,
871        {
872            // Warmup phase
873            for _ in 0..self.config.warmup_iterations {
874                let _ = operation()?;
875            }
876
877            let mut durations = Vec::with_capacity(self.config.measurement_iterations);
878            let mut memory_snapshots = Vec::new();
879            let mut result = None;
880
881            // Measurement phase
882            for i in 0..self.config.measurement_iterations {
883                let memory_before = self.get_memory_usage();
884                let start = Instant::now();
885
886                let op_result = operation()?;
887
888                let duration = start.elapsed();
889                let memory_after = self.get_memory_usage();
890
891                durations.push(duration);
892
893                if self.config.enable_memory_profiling {
894                    memory_snapshots.push((memory_before, memory_after));
895                }
896
897                // Store result from the first iteration
898                if i == 0 {
899                    result = Some(op_result);
900                }
901            }
902
903            // Calculate statistics
904            let total_duration: Duration = durations.iter().sum();
905            let avg_duration = total_duration / durations.len() as u32;
906            let min_duration = *durations
907                .iter()
908                .min()
909                .expect("durations should not be empty");
910            let max_duration = *durations
911                .iter()
912                .max()
913                .expect("durations should not be empty");
914
915            // Calculate standard deviation
916            let variance: f64 = durations
917                .iter()
918                .map(|d| {
919                    let diff = d.as_nanos() as f64 - avg_duration.as_nanos() as f64;
920                    diff * diff
921                })
922                .sum::<f64>()
923                / durations.len() as f64;
924            let std_deviation = Duration::from_nanos(variance.sqrt() as u64);
925
926            let ops_per_second = 1_000_000_000.0 / avg_duration.as_nanos() as f64;
927
928            // Memory statistics
929            let memory_stats = if self.config.enable_memory_profiling
930                && !memory_snapshots.is_empty()
931            {
932                let peak_memory = memory_snapshots
933                    .iter()
934                    .map(|(_, after)| after.peak_memory_bytes)
935                    .max()
936                    .unwrap_or(0);
937
938                let avg_memory = memory_snapshots
939                    .iter()
940                    .map(|(before, after)| (before.avg_memory_bytes + after.avg_memory_bytes) / 2)
941                    .sum::<usize>()
942                    / memory_snapshots.len();
943
944                MemoryStats {
945                    peak_memory_bytes: peak_memory,
946                    avg_memory_bytes: avg_memory,
947                    allocations: memory_snapshots.len(),
948                    deallocations: 0, // Simplified for now
949                }
950            } else {
951                MemoryStats {
952                    peak_memory_bytes: 0,
953                    avg_memory_bytes: 0,
954                    allocations: 0,
955                    deallocations: 0,
956                }
957            };
958
959            let benchmark_result = BenchmarkResult {
960                operation: name.to_string(),
961                iterations: self.config.measurement_iterations,
962                total_duration,
963                avg_duration,
964                min_duration,
965                max_duration,
966                std_deviation,
967                ops_per_second,
968                memory_stats,
969                custom_metrics: HashMap::new(),
970            };
971
972            self.results.insert(name.to_string(), benchmark_result);
973
974            result.ok_or_else(|| anyhow!("Failed to capture benchmark result"))
975        }
976
977        /// Generate comprehensive benchmark report
978        pub fn generate_report(&self) -> BenchmarkSuite {
979            let total_duration = self.results.values().map(|r| r.total_duration).sum();
980
981            let total_operations = self.results.len();
982
983            let overall_throughput = self.results.values().map(|r| r.ops_per_second).sum::<f64>()
984                / total_operations as f64;
985
986            // Calculate efficiency score based on consistency and performance
987            let efficiency_score = self.calculate_efficiency_score();
988
989            // Identify bottlenecks
990            let bottlenecks = self.identify_bottlenecks();
991
992            let summary = BenchmarkSummary {
993                total_duration,
994                total_operations,
995                overall_throughput,
996                efficiency_score,
997                bottlenecks,
998            };
999
1000            BenchmarkSuite {
1001                results: self.results.clone(),
1002                summary,
1003                config: self.config.clone(),
1004            }
1005        }
1006
1007        /// Calculate efficiency score based on performance consistency
1008        fn calculate_efficiency_score(&self) -> f64 {
1009            if self.results.is_empty() {
1010                return 0.0;
1011            }
1012
1013            let consistency_scores: Vec<f64> = self
1014                .results
1015                .values()
1016                .map(|result| {
1017                    // Calculate coefficient of variation (std_dev / mean)
1018                    let cv = result.std_deviation.as_nanos() as f64
1019                        / result.avg_duration.as_nanos() as f64;
1020                    // Convert to consistency score (lower CV = higher consistency)
1021                    1.0 / (1.0 + cv)
1022                })
1023                .collect();
1024
1025            consistency_scores.iter().sum::<f64>() / consistency_scores.len() as f64
1026        }
1027
1028        /// Identify performance bottlenecks
1029        fn identify_bottlenecks(&self) -> Vec<String> {
1030            let mut bottlenecks = Vec::new();
1031
1032            // Find operations with high standard deviation (inconsistent performance)
1033            for (name, result) in &self.results {
1034                let cv =
1035                    result.std_deviation.as_nanos() as f64 / result.avg_duration.as_nanos() as f64;
1036                if cv > 0.2 {
1037                    // 20% coefficient of variation threshold
1038                    bottlenecks.push(format!("High variance in {}: {:.2}% CV", name, cv * 100.0));
1039                }
1040            }
1041
1042            // Find slow operations (below average throughput)
1043            let avg_throughput = self.results.values().map(|r| r.ops_per_second).sum::<f64>()
1044                / self.results.len() as f64;
1045
1046            for (name, result) in &self.results {
1047                if result.ops_per_second < avg_throughput * 0.5 {
1048                    // 50% below average
1049                    bottlenecks.push(format!(
1050                        "Slow operation {}: {:.0} ops/sec",
1051                        name, result.ops_per_second
1052                    ));
1053                }
1054            }
1055
1056            bottlenecks
1057        }
1058
1059        /// Get current memory usage (simplified implementation)
1060        fn get_memory_usage(&self) -> MemoryStats {
1061            // This is a simplified implementation
1062            // In a real-world scenario, you'd use proper memory profiling tools
1063            MemoryStats {
1064                peak_memory_bytes: 0,
1065                avg_memory_bytes: 0,
1066                allocations: 0,
1067                deallocations: 0,
1068            }
1069        }
1070    }
1071
1072    /// Utility functions for performance analysis
1073    pub mod analysis {
1074        use super::*;
1075
1076        /// Compare two benchmark results
1077        pub fn compare_benchmarks(
1078            baseline: &BenchmarkResult,
1079            comparison: &BenchmarkResult,
1080        ) -> BenchmarkComparison {
1081            let throughput_improvement =
1082                (comparison.ops_per_second - baseline.ops_per_second) / baseline.ops_per_second;
1083
1084            let latency_improvement = (baseline.avg_duration.as_nanos() as f64
1085                - comparison.avg_duration.as_nanos() as f64)
1086                / baseline.avg_duration.as_nanos() as f64;
1087
1088            let consistency_improvement = {
1089                let baseline_cv = baseline.std_deviation.as_nanos() as f64
1090                    / baseline.avg_duration.as_nanos() as f64;
1091                let comparison_cv = comparison.std_deviation.as_nanos() as f64
1092                    / comparison.avg_duration.as_nanos() as f64;
1093                (baseline_cv - comparison_cv) / baseline_cv
1094            };
1095
1096            BenchmarkComparison {
1097                baseline_name: baseline.operation.clone(),
1098                comparison_name: comparison.operation.clone(),
1099                throughput_improvement,
1100                latency_improvement,
1101                consistency_improvement,
1102                is_improvement: throughput_improvement > 0.0 && latency_improvement > 0.0,
1103            }
1104        }
1105
1106        /// Generate performance regression analysis
1107        pub fn analyze_regression(
1108            historical_results: &[BenchmarkResult],
1109            current_result: &BenchmarkResult,
1110        ) -> RegressionAnalysis {
1111            if historical_results.is_empty() {
1112                return RegressionAnalysis::default();
1113            }
1114
1115            let historical_avg_throughput = historical_results
1116                .iter()
1117                .map(|r| r.ops_per_second)
1118                .sum::<f64>()
1119                / historical_results.len() as f64;
1120
1121            let throughput_change = (current_result.ops_per_second - historical_avg_throughput)
1122                / historical_avg_throughput;
1123
1124            let is_regression = throughput_change < -0.05; // 5% threshold
1125
1126            RegressionAnalysis {
1127                throughput_change,
1128                is_regression,
1129                confidence_level: 0.95, // Simplified
1130                analysis_notes: if is_regression {
1131                    vec!["Performance regression detected".to_string()]
1132                } else {
1133                    vec!["Performance within expected range".to_string()]
1134                },
1135            }
1136        }
1137    }
1138
1139    /// Benchmark comparison result
1140    #[derive(Debug, Clone, Serialize, Deserialize)]
1141    pub struct BenchmarkComparison {
1142        pub baseline_name: String,
1143        pub comparison_name: String,
1144        pub throughput_improvement: f64,
1145        pub latency_improvement: f64,
1146        pub consistency_improvement: f64,
1147        pub is_improvement: bool,
1148    }
1149
1150    /// Performance regression analysis
1151    #[derive(Debug, Clone, Serialize, Deserialize)]
1152    pub struct RegressionAnalysis {
1153        pub throughput_change: f64,
1154        pub is_regression: bool,
1155        pub confidence_level: f64,
1156        pub analysis_notes: Vec<String>,
1157    }
1158
1159    impl Default for RegressionAnalysis {
1160        fn default() -> Self {
1161            Self {
1162                throughput_change: 0.0,
1163                is_regression: false,
1164                confidence_level: 0.0,
1165                analysis_notes: vec!["No historical data available".to_string()],
1166            }
1167        }
1168    }
1169}
1170
1171/// Convenience functions for quick operations
1172pub mod convenience {
1173    use super::*;
1174    use crate::{EmbeddingModel, ModelConfig, NamedNode, TransE, Triple};
1175
1176    /// Create a simple TransE model with sensible defaults for quick prototyping
1177    pub fn create_simple_transe_model() -> TransE {
1178        let config = ModelConfig::default()
1179            .with_dimensions(128)
1180            .with_learning_rate(0.01)
1181            .with_max_epochs(100);
1182        TransE::new(config)
1183    }
1184
1185    /// Parse a triple from a simple string format "subject predicate object"
1186    pub fn parse_triple_from_string(triple_str: &str) -> Result<Triple> {
1187        let parts: Vec<&str> = triple_str.split_whitespace().collect();
1188        if parts.len() != 3 {
1189            return Err(anyhow!(
1190                "Invalid triple format. Expected 'subject predicate object', got: '{}'",
1191                triple_str
1192            ));
1193        }
1194
1195        let subject = if parts[0].starts_with("http") {
1196            NamedNode::new(parts[0])?
1197        } else {
1198            NamedNode::new(&format!("http://example.org/{}", parts[0]))?
1199        };
1200
1201        let predicate = if parts[1].starts_with("http") {
1202            NamedNode::new(parts[1])?
1203        } else {
1204            NamedNode::new(&format!("http://example.org/{}", parts[1]))?
1205        };
1206
1207        let object = if parts[2].starts_with("http") {
1208            NamedNode::new(parts[2])?
1209        } else {
1210            NamedNode::new(&format!("http://example.org/{}", parts[2]))?
1211        };
1212
1213        Ok(Triple::new(subject, predicate, object))
1214    }
1215
1216    /// Add multiple triples from string array to a model
1217    pub fn add_triples_from_strings(
1218        model: &mut dyn EmbeddingModel,
1219        triple_strings: &[&str],
1220    ) -> Result<usize> {
1221        let mut added_count = 0;
1222        for triple_str in triple_strings {
1223            match parse_triple_from_string(triple_str) {
1224                Ok(triple) => {
1225                    model.add_triple(triple)?;
1226                    added_count += 1;
1227                }
1228                Err(e) => {
1229                    eprintln!("Warning: Failed to parse triple '{triple_str}': {e}");
1230                }
1231            }
1232        }
1233        Ok(added_count)
1234    }
1235
1236    /// Quick function to compute similarity between two embedding vectors
1237    pub fn cosine_similarity(a: &[f64], b: &[f64]) -> Result<f64> {
1238        if a.len() != b.len() {
1239            return Err(anyhow!(
1240                "Vector dimensions don't match: {} vs {}",
1241                a.len(),
1242                b.len()
1243            ));
1244        }
1245
1246        let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
1247        let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
1248        let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
1249
1250        if norm_a == 0.0 || norm_b == 0.0 {
1251            return Ok(0.0);
1252        }
1253
1254        Ok(dot_product / (norm_a * norm_b))
1255    }
1256
1257    /// Generate sample knowledge graph data for testing
1258    pub fn generate_sample_kg_data(
1259        num_entities: usize,
1260        num_relations: usize,
1261    ) -> Vec<(String, String, String)> {
1262        let mut rng = Random::default();
1263        let mut triples = Vec::new();
1264
1265        let entities: Vec<String> = (0..num_entities).map(|i| format!("entity_{i}")).collect();
1266
1267        let relations: Vec<String> = (0..num_relations)
1268            .map(|i| format!("relation_{i}"))
1269            .collect();
1270
1271        // Generate random triples
1272        for _ in 0..(num_entities * 2) {
1273            let subject = entities[rng.random_range(0..entities.len())].clone();
1274            let relation = relations[rng.random_range(0..relations.len())].clone();
1275            let object = entities[rng.random_range(0..entities.len())].clone();
1276
1277            if subject != object {
1278                triples.push((subject, relation, object));
1279            }
1280        }
1281
1282        triples
1283    }
1284
1285    /// Quick performance test function
1286    pub fn quick_performance_test<F>(
1287        name: &str,
1288        iterations: usize,
1289        operation: F,
1290    ) -> std::time::Duration
1291    where
1292        F: Fn(),
1293    {
1294        let start = std::time::Instant::now();
1295        for _ in 0..iterations {
1296            operation();
1297        }
1298        let duration = start.elapsed();
1299
1300        println!(
1301            "Performance test '{}': {} iterations in {:?} ({:.2} ops/sec)",
1302            name,
1303            iterations,
1304            duration,
1305            iterations as f64 / duration.as_secs_f64()
1306        );
1307
1308        duration
1309    }
1310}
1311
1312/// Advanced performance utilities for embedding operations
1313pub mod performance_utils {
1314    use super::*;
1315
1316    /// Type alias for batch processor function
1317    type ProcessorFn<T> = Box<dyn Fn(&[T]) -> Result<()> + Send + Sync>;
1318
1319    /// Memory-efficient batch processor for large datasets
1320    pub struct BatchProcessor<T> {
1321        batch_size: usize,
1322        current_batch: Vec<T>,
1323        processor_fn: ProcessorFn<T>,
1324    }
1325
1326    impl<T> BatchProcessor<T> {
1327        pub fn new<F>(batch_size: usize, processor_fn: F) -> Self
1328        where
1329            F: Fn(&[T]) -> Result<()> + Send + Sync + 'static,
1330        {
1331            Self {
1332                batch_size,
1333                current_batch: Vec::with_capacity(batch_size),
1334                processor_fn: Box::new(processor_fn),
1335            }
1336        }
1337
1338        pub fn add(&mut self, item: T) -> Result<()> {
1339            self.current_batch.push(item);
1340            if self.current_batch.len() >= self.batch_size {
1341                return self.flush();
1342            }
1343            Ok(())
1344        }
1345
1346        pub fn flush(&mut self) -> Result<()> {
1347            if !self.current_batch.is_empty() {
1348                (self.processor_fn)(&self.current_batch)?;
1349                self.current_batch.clear();
1350            }
1351            Ok(())
1352        }
1353    }
1354
1355    /// Enhanced memory monitoring for embedding operations
1356    #[derive(Debug, Clone)]
1357    pub struct MemoryMonitor {
1358        peak_usage: usize,
1359        current_usage: usize,
1360        allocations: usize,
1361        deallocations: usize,
1362    }
1363
1364    impl MemoryMonitor {
1365        pub fn new() -> Self {
1366            Self {
1367                peak_usage: 0,
1368                current_usage: 0,
1369                allocations: 0,
1370                deallocations: 0,
1371            }
1372        }
1373
1374        pub fn record_allocation(&mut self, size: usize) {
1375            self.current_usage += size;
1376            self.allocations += 1;
1377            if self.current_usage > self.peak_usage {
1378                self.peak_usage = self.current_usage;
1379            }
1380        }
1381
1382        pub fn record_deallocation(&mut self, size: usize) {
1383            self.current_usage = self.current_usage.saturating_sub(size);
1384            self.deallocations += 1;
1385        }
1386
1387        pub fn peak_usage(&self) -> usize {
1388            self.peak_usage
1389        }
1390
1391        pub fn current_usage(&self) -> usize {
1392            self.current_usage
1393        }
1394
1395        pub fn allocation_count(&self) -> usize {
1396            self.allocations
1397        }
1398    }
1399
1400    impl Default for MemoryMonitor {
1401        fn default() -> Self {
1402            Self::new()
1403        }
1404    }
1405}
1406
1407/// Parallel processing utilities for embedding operations
1408pub mod parallel_utils {
1409    use super::*;
1410    use rayon::prelude::*;
1411
1412    /// Parallel computation of embedding similarities
1413    pub fn parallel_cosine_similarities(
1414        query_embedding: &[f32],
1415        candidate_embeddings: &[Vec<f32>],
1416    ) -> Result<Vec<f32>> {
1417        let similarities: Vec<f32> = candidate_embeddings
1418            .par_iter()
1419            .map(|embedding| oxirs_vec::similarity::cosine_similarity(query_embedding, embedding))
1420            .collect();
1421        Ok(similarities)
1422    }
1423
1424    /// Parallel batch processing with configurable thread pool
1425    pub fn parallel_batch_process<T, R, F>(
1426        items: &[T],
1427        batch_size: usize,
1428        processor: F,
1429    ) -> Result<Vec<R>>
1430    where
1431        T: Sync,
1432        R: Send,
1433        F: Fn(&[T]) -> Result<Vec<R>> + Sync + Send,
1434    {
1435        let results: Result<Vec<Vec<R>>> = items.par_chunks(batch_size).map(processor).collect();
1436
1437        Ok(results?.into_iter().flatten().collect())
1438    }
1439
1440    /// Parallel graph analysis with optimized memory usage
1441    pub fn parallel_entity_frequencies(
1442        triples: &[(String, String, String)],
1443    ) -> HashMap<String, usize> {
1444        let entity_counts: HashMap<String, usize> = triples
1445            .par_iter()
1446            .fold(HashMap::new, |mut acc, (subject, _predicate, object)| {
1447                *acc.entry(subject.clone()).or_insert(0) += 1;
1448                *acc.entry(object.clone()).or_insert(0) += 1;
1449                acc
1450            })
1451            .reduce(HashMap::new, |mut acc1, acc2| {
1452                for (entity, count) in acc2 {
1453                    *acc1.entry(entity).or_insert(0) += count;
1454                }
1455                acc1
1456            });
1457        entity_counts
1458    }
1459}
1460
1461#[cfg(test)]
1462mod tests {
1463    use super::*;
1464    use crate::quick_start::{
1465        add_triples_from_strings, cosine_similarity, create_simple_transe_model,
1466        generate_sample_kg_data, parse_triple_from_string, quick_performance_test,
1467    };
1468    use crate::EmbeddingModel;
1469    use std::io::Write;
1470    use tempfile::NamedTempFile;
1471
1472    #[test]
1473    fn test_load_triples_from_tsv() -> Result<()> {
1474        let mut temp_file = NamedTempFile::new()?;
1475        writeln!(temp_file, "subject\tpredicate\tobject")?;
1476        writeln!(temp_file, "alice\tknows\tbob")?;
1477        writeln!(temp_file, "bob\tlikes\tcharlie")?;
1478
1479        let triples = data_loader::load_triples_from_tsv(temp_file.path())?;
1480        assert_eq!(triples.len(), 2);
1481        assert_eq!(
1482            triples[0],
1483            ("alice".to_string(), "knows".to_string(), "bob".to_string())
1484        );
1485
1486        Ok(())
1487    }
1488
1489    #[test]
1490    fn test_dataset_split() -> Result<()> {
1491        let triples = vec![
1492            ("a".to_string(), "r1".to_string(), "b".to_string()),
1493            ("b".to_string(), "r2".to_string(), "c".to_string()),
1494            ("c".to_string(), "r3".to_string(), "d".to_string()),
1495            ("d".to_string(), "r4".to_string(), "e".to_string()),
1496        ];
1497
1498        let split = dataset_splitter::split_dataset(triples, 0.6, 0.2, Some(42))?;
1499
1500        assert_eq!(split.train.len(), 2);
1501        assert_eq!(split.validation.len(), 0); // 0.2 * 4 = 0.8, rounded down
1502        assert_eq!(split.test.len(), 2);
1503
1504        Ok(())
1505    }
1506
1507    #[test]
1508    fn test_load_triples_from_jsonl() -> Result<()> {
1509        let mut temp_file = NamedTempFile::new()?;
1510        writeln!(
1511            temp_file,
1512            r#"{{"subject": "alice", "predicate": "knows", "object": "bob"}}"#
1513        )?;
1514        writeln!(
1515            temp_file,
1516            r#"{{"subject": "bob", "predicate": "likes", "object": "charlie"}}"#
1517        )?;
1518
1519        let triples = data_loader::load_triples_from_jsonl(temp_file.path())?;
1520        assert_eq!(triples.len(), 2);
1521        assert_eq!(
1522            triples[0],
1523            ("alice".to_string(), "knows".to_string(), "bob".to_string())
1524        );
1525
1526        Ok(())
1527    }
1528
1529    #[test]
1530    fn test_save_triples_to_jsonl() -> Result<()> {
1531        let triples = vec![
1532            ("alice".to_string(), "knows".to_string(), "bob".to_string()),
1533            (
1534                "bob".to_string(),
1535                "likes".to_string(),
1536                "charlie".to_string(),
1537            ),
1538        ];
1539
1540        let temp_file = NamedTempFile::new()?;
1541        data_loader::save_triples_to_jsonl(&triples, temp_file.path())?;
1542
1543        let loaded_triples = data_loader::load_triples_from_jsonl(temp_file.path())?;
1544        assert_eq!(loaded_triples, triples);
1545
1546        Ok(())
1547    }
1548
1549    #[test]
1550    fn test_load_triples_auto_detect() -> Result<()> {
1551        // Test TSV auto-detection
1552        let mut tsv_file = NamedTempFile::with_suffix(".tsv")?;
1553        writeln!(tsv_file, "subject\tpredicate\tobject")?;
1554        writeln!(tsv_file, "alice\tknows\tbob")?;
1555
1556        let triples = data_loader::load_triples_auto_detect(tsv_file.path())?;
1557        assert_eq!(triples.len(), 1);
1558
1559        // Test JSON Lines auto-detection
1560        let mut jsonl_file = NamedTempFile::with_suffix(".jsonl")?;
1561        writeln!(
1562            jsonl_file,
1563            r#"{{"subject": "alice", "predicate": "knows", "object": "bob"}}"#
1564        )?;
1565
1566        let triples = data_loader::load_triples_auto_detect(jsonl_file.path())?;
1567        assert_eq!(triples.len(), 1);
1568        assert_eq!(
1569            triples[0],
1570            ("alice".to_string(), "knows".to_string(), "bob".to_string())
1571        );
1572
1573        Ok(())
1574    }
1575
1576    #[test]
1577    fn test_dataset_statistics() {
1578        let triples = vec![
1579            ("alice".to_string(), "knows".to_string(), "bob".to_string()),
1580            (
1581                "bob".to_string(),
1582                "knows".to_string(),
1583                "charlie".to_string(),
1584            ),
1585            (
1586                "alice".to_string(),
1587                "likes".to_string(),
1588                "charlie".to_string(),
1589            ),
1590        ];
1591
1592        let stats = compute_dataset_statistics(&triples);
1593
1594        assert_eq!(stats.num_triples, 3);
1595        assert_eq!(stats.num_entities, 3); // alice, bob, charlie
1596        assert_eq!(stats.num_relations, 2); // knows, likes
1597        assert!(stats.avg_degree > 0.0);
1598    }
1599
1600    // Tests for convenience functions
1601    #[test]
1602    fn test_create_simple_transe_model() {
1603        let model = create_simple_transe_model();
1604        assert_eq!(model.config().dimensions, 128);
1605        assert_eq!(model.config().learning_rate, 0.01);
1606        assert_eq!(model.config().max_epochs, 100);
1607    }
1608
1609    #[test]
1610    fn test_parse_triple_from_string() -> Result<()> {
1611        let triple = parse_triple_from_string("alice knows bob")?;
1612        assert_eq!(triple.subject.iri.as_str(), "http://example.org/alice");
1613        assert_eq!(triple.predicate.iri.as_str(), "http://example.org/knows");
1614        assert_eq!(triple.object.iri.as_str(), "http://example.org/bob");
1615
1616        // Test with full URIs
1617        let triple2 = parse_triple_from_string(
1618            "http://example.org/alice http://example.org/knows http://example.org/bob",
1619        )?;
1620        assert_eq!(triple2.subject.iri.as_str(), "http://example.org/alice");
1621
1622        // Test invalid format
1623        assert!(parse_triple_from_string("alice knows").is_err());
1624
1625        Ok(())
1626    }
1627
1628    #[test]
1629    fn test_add_triples_from_strings() -> Result<()> {
1630        let mut model = create_simple_transe_model();
1631        let triple_strings = &[
1632            "alice knows bob",
1633            "bob likes charlie",
1634            "charlie follows alice",
1635        ];
1636
1637        let added_count = add_triples_from_strings(&mut model, triple_strings)?;
1638        assert_eq!(added_count, 3);
1639
1640        Ok(())
1641    }
1642
1643    #[test]
1644    fn test_cosine_similarity() -> Result<()> {
1645        let a = vec![1.0, 0.0, 0.0];
1646        let b = vec![1.0, 0.0, 0.0];
1647        let similarity = cosine_similarity(&a, &b)?;
1648        assert!((similarity - 1.0).abs() < 1e-10);
1649
1650        let c = vec![0.0, 1.0, 0.0];
1651        let similarity2 = cosine_similarity(&a, &c)?;
1652        assert!((similarity2 - 0.0).abs() < 1e-10);
1653
1654        // Test different dimensions
1655        let d = vec![1.0, 0.0];
1656        assert!(cosine_similarity(&a, &d).is_err());
1657
1658        Ok(())
1659    }
1660
1661    #[test]
1662    fn test_generate_sample_kg_data() {
1663        let triples = generate_sample_kg_data(5, 3);
1664        assert!(!triples.is_empty());
1665
1666        // Check that all subjects and objects are in the expected format
1667        for (subject, relation, object) in &triples {
1668            assert!(subject.starts_with("http://example.org/entity_"));
1669            assert!(relation.starts_with("http://example.org/relation_"));
1670            assert!(object.starts_with("http://example.org/entity_"));
1671            assert_ne!(subject, object); // No self-loops
1672        }
1673    }
1674
1675    #[test]
1676    fn test_quick_performance_test() {
1677        let duration = quick_performance_test("test_operation", 100, || {
1678            // Simple operation for testing
1679            let _sum: i32 = (1..10).sum();
1680        });
1681
1682        // In release mode, operations can be extremely fast
1683        // Just verify the function completes and returns a valid duration
1684        let _nanos = duration.as_nanos();
1685    }
1686}