oxirs_embed/
utils.rs

1//! Utility functions and helpers for embedding operations
2
3// Removed unused imports
4use anyhow::{anyhow, Result};
5use scirs2_core::ndarray_ext::{Array1, Array2};
6#[allow(unused_imports)]
7use scirs2_core::random::{Random, Rng};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::fs;
11use std::path::Path;
12
13/// Data loading utilities
14pub mod data_loader {
15    use super::*;
16    use std::io::{BufRead, BufReader};
17
18    /// Load triples from TSV file format
19    pub fn load_triples_from_tsv<P: AsRef<Path>>(
20        file_path: P,
21    ) -> Result<Vec<(String, String, String)>> {
22        let file = fs::File::open(file_path)?;
23        let reader = BufReader::new(file);
24        let mut triples = Vec::new();
25
26        for (line_num, line) in reader.lines().enumerate() {
27            let line = line?;
28            if line.trim().is_empty() || line.starts_with('#') {
29                continue; // Skip empty lines and comments
30            }
31
32            // Skip header line (first line that contains common header terms)
33            if line_num == 0
34                && (line.contains("subject")
35                    || line.contains("predicate")
36                    || line.contains("object"))
37            {
38                continue;
39            }
40
41            let parts: Vec<&str> = line.split('\t').collect();
42            if parts.len() >= 3 {
43                let subject = parts[0].trim().to_string();
44                let predicate = parts[1].trim().to_string();
45                let object = parts[2].trim().to_string();
46                triples.push((subject, predicate, object));
47            } else {
48                eprintln!(
49                    "Warning: Invalid triple format at line {}: {}",
50                    line_num + 1,
51                    line
52                );
53            }
54        }
55
56        Ok(triples)
57    }
58
59    /// Load triples from CSV file format
60    pub fn load_triples_from_csv<P: AsRef<Path>>(
61        file_path: P,
62    ) -> Result<Vec<(String, String, String)>> {
63        let file = fs::File::open(file_path)?;
64        let reader = BufReader::new(file);
65        let mut triples = Vec::new();
66        let mut is_first_line = true;
67
68        for (line_num, line) in reader.lines().enumerate() {
69            let line = line?;
70            if is_first_line {
71                is_first_line = false;
72                // Skip header if it looks like one
73                if line.to_lowercase().contains("subject")
74                    && line.to_lowercase().contains("predicate")
75                {
76                    continue;
77                }
78            }
79
80            if line.trim().is_empty() {
81                continue;
82            }
83
84            let parts: Vec<&str> = line.split(',').collect();
85            if parts.len() >= 3 {
86                let subject = parts[0].trim().trim_matches('"').to_string();
87                let predicate = parts[1].trim().trim_matches('"').to_string();
88                let object = parts[2].trim().trim_matches('"').to_string();
89                triples.push((subject, predicate, object));
90            } else {
91                eprintln!(
92                    "Warning: Invalid triple format at line {}: {}",
93                    line_num + 1,
94                    line
95                );
96            }
97        }
98
99        Ok(triples)
100    }
101
102    /// Load triples from N-Triples format
103    pub fn load_triples_from_ntriples<P: AsRef<Path>>(
104        file_path: P,
105    ) -> Result<Vec<(String, String, String)>> {
106        let file = fs::File::open(file_path)?;
107        let reader = BufReader::new(file);
108        let mut triples = Vec::new();
109
110        for (line_num, line) in reader.lines().enumerate() {
111            let line = line?;
112            let line = line.trim();
113
114            if line.is_empty() || line.starts_with('#') {
115                continue;
116            }
117
118            // Simple N-Triples parser (very basic)
119            if let Some(triple) = parse_ntriple_line(line) {
120                triples.push(triple);
121            } else {
122                eprintln!(
123                    "Warning: Failed to parse N-Triple at line {}: {}",
124                    line_num + 1,
125                    line
126                );
127            }
128        }
129
130        Ok(triples)
131    }
132
133    /// Parse a single N-Triple line
134    fn parse_ntriple_line(line: &str) -> Option<(String, String, String)> {
135        let line = line.trim_end_matches(" .");
136        let parts: Vec<&str> = line.split_whitespace().collect();
137
138        if parts.len() >= 3 {
139            let subject = clean_uri_or_literal(parts[0]);
140            let predicate = clean_uri_or_literal(parts[1]);
141            let object = clean_uri_or_literal(&parts[2..].join(" "));
142
143            Some((subject, predicate, object))
144        } else {
145            None
146        }
147    }
148
149    /// Clean URI or literal from N-Triple format
150    fn clean_uri_or_literal(term: &str) -> String {
151        if term.starts_with('<') && term.ends_with('>') {
152            term[1..term.len() - 1].to_string()
153        } else if term.starts_with('"') && term.contains('"') {
154            // Handle literals - just take the string part for now
155            let end_quote = term.rfind('"').unwrap_or(term.len());
156            term[1..end_quote].to_string()
157        } else {
158            term.to_string()
159        }
160    }
161
162    /// Load triples from JSON Lines format (one triple per line as JSON)
163    pub fn load_triples_from_jsonl<P: AsRef<Path>>(
164        file_path: P,
165    ) -> Result<Vec<(String, String, String)>> {
166        let file = fs::File::open(file_path)?;
167        let reader = BufReader::new(file);
168        let mut triples = Vec::new();
169
170        for (line_num, line) in reader.lines().enumerate() {
171            let line = line?;
172            if line.trim().is_empty() {
173                continue;
174            }
175
176            match serde_json::from_str::<serde_json::Value>(&line) {
177                Ok(json) => {
178                    if let (Some(subject), Some(predicate), Some(object)) = (
179                        json["subject"].as_str(),
180                        json["predicate"].as_str(),
181                        json["object"].as_str(),
182                    ) {
183                        triples.push((
184                            subject.to_string(),
185                            predicate.to_string(),
186                            object.to_string(),
187                        ));
188                    } else {
189                        eprintln!(
190                            "Warning: Invalid JSON structure at line {}: {}",
191                            line_num + 1,
192                            line
193                        );
194                    }
195                }
196                Err(e) => {
197                    eprintln!(
198                        "Warning: Failed to parse JSON at line {}: {} - Error: {}",
199                        line_num + 1,
200                        line,
201                        e
202                    );
203                }
204            }
205        }
206
207        Ok(triples)
208    }
209
210    /// Save triples to TSV format
211    pub fn save_triples_to_tsv<P: AsRef<Path>>(
212        triples: &[(String, String, String)],
213        file_path: P,
214    ) -> Result<()> {
215        let mut content = String::new();
216        content.push_str("subject\tpredicate\tobject\n");
217
218        for (subject, predicate, object) in triples {
219            content.push_str(&format!("{subject}\t{predicate}\t{object}\n"));
220        }
221
222        fs::write(file_path, content)?;
223        Ok(())
224    }
225
226    /// Save triples to JSON Lines format
227    pub fn save_triples_to_jsonl<P: AsRef<Path>>(
228        triples: &[(String, String, String)],
229        file_path: P,
230    ) -> Result<()> {
231        use std::io::Write;
232        let mut file = fs::File::create(file_path)?;
233
234        for (subject, predicate, object) in triples {
235            let json = serde_json::json!({
236                "subject": subject,
237                "predicate": predicate,
238                "object": object
239            });
240            writeln!(file, "{json}")?;
241        }
242
243        Ok(())
244    }
245
246    /// Auto-detect file format and load triples accordingly
247    pub fn load_triples_auto_detect<P: AsRef<Path>>(
248        file_path: P,
249    ) -> Result<Vec<(String, String, String)>> {
250        let path = file_path.as_ref();
251        let extension = path
252            .extension()
253            .and_then(|ext| ext.to_str())
254            .unwrap_or("")
255            .to_lowercase();
256
257        match extension.as_str() {
258            "tsv" => load_triples_from_tsv(path),
259            "csv" => load_triples_from_csv(path),
260            "nt" | "ntriples" => load_triples_from_ntriples(path),
261            "jsonl" | "ndjson" => load_triples_from_jsonl(path),
262            _ => {
263                // Try to auto-detect based on content
264                eprintln!(
265                    "Warning: Unknown file extension '{extension}', attempting auto-detection"
266                );
267
268                // Try TSV first (most common)
269                if let Ok(triples) = load_triples_from_tsv(path) {
270                    if !triples.is_empty() {
271                        return Ok(triples);
272                    }
273                }
274
275                // Try N-Triples
276                if let Ok(triples) = load_triples_from_ntriples(path) {
277                    if !triples.is_empty() {
278                        return Ok(triples);
279                    }
280                }
281
282                // Try JSON Lines
283                if let Ok(triples) = load_triples_from_jsonl(path) {
284                    if !triples.is_empty() {
285                        return Ok(triples);
286                    }
287                }
288
289                // Finally try CSV
290                load_triples_from_csv(path)
291            }
292        }
293    }
294}
295
296/// Dataset splitting utilities
297pub mod dataset_splitter {
298    use super::*;
299
300    /// Split dataset into train/validation/test sets
301    pub fn split_dataset(
302        triples: Vec<(String, String, String)>,
303        train_ratio: f64,
304        val_ratio: f64,
305        seed: Option<u64>,
306    ) -> Result<DatasetSplit> {
307        if train_ratio + val_ratio >= 1.0 {
308            return Err(anyhow!(
309                "Train and validation ratios must sum to less than 1.0"
310            ));
311        }
312
313        let mut rng = if let Some(s) = seed {
314            Random::seed(s)
315        } else {
316            Random::seed(42) // Use a default seed for consistency
317        };
318
319        let mut shuffled_triples = triples;
320        // Manual Fisher-Yates shuffle
321        for i in (1..shuffled_triples.len()).rev() {
322            let j = rng.random_range(0..i + 1);
323            shuffled_triples.swap(i, j);
324        }
325
326        let total = shuffled_triples.len();
327        let train_end = (total as f64 * train_ratio) as usize;
328        let val_end = train_end + (total as f64 * val_ratio) as usize;
329
330        let train_triples = shuffled_triples[..train_end].to_vec();
331        let val_triples = shuffled_triples[train_end..val_end].to_vec();
332        let test_triples = shuffled_triples[val_end..].to_vec();
333
334        Ok(DatasetSplit {
335            train: train_triples,
336            validation: val_triples,
337            test: test_triples,
338        })
339    }
340
341    /// Split dataset ensuring no entity leakage between splits
342    pub fn split_dataset_no_leakage(
343        triples: Vec<(String, String, String)>,
344        train_ratio: f64,
345        val_ratio: f64,
346        seed: Option<u64>,
347    ) -> Result<DatasetSplit> {
348        // Group triples by entities with pre-allocated capacity for better performance
349        let mut entity_triples: HashMap<String, Vec<(String, String, String)>> =
350            HashMap::with_capacity(triples.len() / 2); // Estimate capacity
351
352        for triple in &triples {
353            let entities = [&triple.0, &triple.2];
354            for entity in entities {
355                entity_triples
356                    .entry(entity.clone())
357                    .or_default()
358                    .push(triple.clone());
359            }
360        }
361
362        // Split entities first, then assign triples - optimized allocation
363        let entities: Vec<String> = entity_triples.keys().cloned().collect();
364        let dummy_string = "dummy".to_string();
365        let entity_split = split_dataset(
366            entities
367                .into_iter()
368                .map(|e| (e, dummy_string.clone(), dummy_string.clone()))
369                .collect(),
370            train_ratio,
371            val_ratio,
372            seed,
373        )?;
374
375        let train_entities: HashSet<String> =
376            entity_split.train.into_iter().map(|(e, _, _)| e).collect();
377        let val_entities: HashSet<String> = entity_split
378            .validation
379            .into_iter()
380            .map(|(e, _, _)| e)
381            .collect();
382        let test_entities: HashSet<String> =
383            entity_split.test.into_iter().map(|(e, _, _)| e).collect();
384
385        // Assign triples based on entity membership with pre-allocated capacity
386        let estimated_capacity = triples.len() / 3;
387        let mut train_triples = Vec::with_capacity(estimated_capacity);
388        let mut val_triples = Vec::with_capacity(estimated_capacity);
389        let mut test_triples = Vec::with_capacity(estimated_capacity);
390
391        for (entity, entity_triple_list) in entity_triples {
392            if train_entities.contains(&entity) {
393                train_triples.extend(entity_triple_list);
394            } else if val_entities.contains(&entity) {
395                val_triples.extend(entity_triple_list);
396            } else if test_entities.contains(&entity) {
397                test_triples.extend(entity_triple_list);
398            }
399        }
400
401        // Remove duplicates
402        train_triples.sort();
403        train_triples.dedup();
404        val_triples.sort();
405        val_triples.dedup();
406        test_triples.sort();
407        test_triples.dedup();
408
409        Ok(DatasetSplit {
410            train: train_triples,
411            validation: val_triples,
412            test: test_triples,
413        })
414    }
415}
416
417/// Dataset split result
418#[derive(Debug, Clone)]
419pub struct DatasetSplit {
420    pub train: Vec<(String, String, String)>,
421    pub validation: Vec<(String, String, String)>,
422    pub test: Vec<(String, String, String)>,
423}
424
425/// Statistics about a dataset
426#[derive(Debug, Clone, Serialize, Deserialize)]
427pub struct DatasetStatistics {
428    pub num_triples: usize,
429    pub num_entities: usize,
430    pub num_relations: usize,
431    pub entity_frequency: HashMap<String, usize>,
432    pub relation_frequency: HashMap<String, usize>,
433    pub avg_degree: f64,
434    pub density: f64,
435}
436
437/// Compute dataset statistics
438pub fn compute_dataset_statistics(triples: &[(String, String, String)]) -> DatasetStatistics {
439    let mut entities = HashSet::new();
440    let mut relations = HashSet::new();
441    let mut entity_frequency = HashMap::new();
442    let mut relation_frequency = HashMap::new();
443
444    for (subject, predicate, object) in triples {
445        entities.insert(subject.clone());
446        entities.insert(object.clone());
447        relations.insert(predicate.clone());
448
449        *entity_frequency.entry(subject.clone()).or_insert(0) += 1;
450        *entity_frequency.entry(object.clone()).or_insert(0) += 1;
451        *relation_frequency.entry(predicate.clone()).or_insert(0) += 1;
452    }
453
454    let num_entities = entities.len();
455    let num_relations = relations.len();
456    let num_triples = triples.len();
457
458    let avg_degree = if num_entities > 0 {
459        (num_triples * 2) as f64 / num_entities as f64
460    } else {
461        0.0
462    };
463
464    let max_possible_edges = num_entities * num_entities;
465    let density = if max_possible_edges > 0 {
466        num_triples as f64 / max_possible_edges as f64
467    } else {
468        0.0
469    };
470
471    DatasetStatistics {
472        num_triples,
473        num_entities,
474        num_relations,
475        entity_frequency,
476        relation_frequency,
477        avg_degree,
478        density,
479    }
480}
481
482/// Embedding dimension analysis utilities
483pub mod embedding_analysis {
484    use super::*;
485
486    /// Analyze embedding distribution
487    pub fn analyze_embedding_distribution(embeddings: &Array2<f64>) -> EmbeddingDistributionStats {
488        let flat_values: Vec<f64> = embeddings.iter().cloned().collect();
489
490        let mean = flat_values.iter().sum::<f64>() / flat_values.len() as f64;
491        let variance =
492            flat_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / flat_values.len() as f64;
493        let std_dev = variance.sqrt();
494
495        let mut sorted_values = flat_values.clone();
496        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
497
498        let min_val = sorted_values[0];
499        let max_val = sorted_values[sorted_values.len() - 1];
500        let median = sorted_values[sorted_values.len() / 2];
501
502        EmbeddingDistributionStats {
503            mean,
504            std_dev,
505            variance,
506            min: min_val,
507            max: max_val,
508            median,
509            num_parameters: embeddings.len(),
510        }
511    }
512
513    /// Compute embedding norms
514    pub fn compute_embedding_norms(embeddings: &Array2<f64>) -> Vec<f64> {
515        embeddings
516            .rows()
517            .into_iter()
518            .map(|row| row.dot(&row).sqrt())
519            .collect()
520    }
521
522    /// Analyze embedding similarities
523    pub fn analyze_embedding_similarities(
524        embeddings: &Array2<f64>,
525        sample_size: usize,
526    ) -> SimilarityStats {
527        let num_embeddings = embeddings.nrows();
528        let mut similarities = Vec::new();
529
530        let sample_size = sample_size.min(num_embeddings * (num_embeddings - 1) / 2);
531        let mut rng = Random::default();
532
533        for _ in 0..sample_size {
534            let i = rng.random_range(0..num_embeddings);
535            let j = rng.random_range(0..num_embeddings);
536
537            if i != j {
538                let emb_i = embeddings.row(i);
539                let emb_j = embeddings.row(j);
540                let similarity = cosine_similarity(&emb_i.to_owned(), &emb_j.to_owned());
541                similarities.push(similarity);
542            }
543        }
544
545        similarities.sort_by(|a, b| a.partial_cmp(b).unwrap());
546
547        let mean_similarity = similarities.iter().sum::<f64>() / similarities.len() as f64;
548        let min_similarity = similarities[0];
549        let max_similarity = similarities[similarities.len() - 1];
550        let median_similarity = similarities[similarities.len() / 2];
551
552        SimilarityStats {
553            mean_similarity,
554            min_similarity,
555            max_similarity,
556            median_similarity,
557            num_comparisons: similarities.len(),
558        }
559    }
560
561    /// Cosine similarity between two vectors
562    fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
563        let dot_product = a.dot(b);
564        let norm_a = a.dot(a).sqrt();
565        let norm_b = b.dot(b).sqrt();
566
567        if norm_a > 1e-10 && norm_b > 1e-10 {
568            dot_product / (norm_a * norm_b)
569        } else {
570            0.0
571        }
572    }
573}
574
575/// Embedding distribution statistics
576#[derive(Debug, Clone)]
577pub struct EmbeddingDistributionStats {
578    pub mean: f64,
579    pub std_dev: f64,
580    pub variance: f64,
581    pub min: f64,
582    pub max: f64,
583    pub median: f64,
584    pub num_parameters: usize,
585}
586
587/// Similarity statistics
588#[derive(Debug, Clone)]
589pub struct SimilarityStats {
590    pub mean_similarity: f64,
591    pub min_similarity: f64,
592    pub max_similarity: f64,
593    pub median_similarity: f64,
594    pub num_comparisons: usize,
595}
596
597/// Graph analysis utilities
598pub mod graph_analysis {
599    use super::*;
600
601    /// Compute graph metrics for knowledge graph - optimized for performance
602    pub fn compute_graph_metrics(triples: &[(String, String, String)]) -> GraphMetrics {
603        // Pre-allocate with estimated capacity for better performance
604        let estimated_entities = triples.len(); // Conservative estimate
605        let estimated_relations = triples.len() / 10; // Rough estimate
606
607        let mut entity_degrees: HashMap<String, usize> = HashMap::with_capacity(estimated_entities);
608        let mut relation_counts: HashMap<String, usize> =
609            HashMap::with_capacity(estimated_relations);
610        let mut entities = HashSet::with_capacity(estimated_entities);
611
612        for (subject, predicate, object) in triples {
613            entities.insert(subject.clone());
614            entities.insert(object.clone());
615
616            *entity_degrees.entry(subject.clone()).or_insert(0) += 1;
617            *entity_degrees.entry(object.clone()).or_insert(0) += 1;
618            *relation_counts.entry(predicate.clone()).or_insert(0) += 1;
619        }
620
621        let num_entities = entities.len();
622        let num_relations = relation_counts.len();
623        let num_triples = triples.len();
624
625        let degrees: Vec<usize> = entity_degrees.values().cloned().collect();
626        let avg_degree = degrees.iter().sum::<usize>() as f64 / degrees.len() as f64;
627        let max_degree = degrees.iter().max().cloned().unwrap_or(0);
628        let min_degree = degrees.iter().min().cloned().unwrap_or(0);
629
630        GraphMetrics {
631            num_entities,
632            num_relations,
633            num_triples,
634            avg_degree,
635            max_degree,
636            min_degree,
637            density: num_triples as f64 / (num_entities * num_entities) as f64,
638        }
639    }
640}
641
642/// Graph metrics
643#[derive(Debug, Clone, Serialize, Deserialize)]
644pub struct GraphMetrics {
645    pub num_entities: usize,
646    pub num_relations: usize,
647    pub num_triples: usize,
648    pub avg_degree: f64,
649    pub max_degree: usize,
650    pub min_degree: usize,
651    pub density: f64,
652}
653
654/// Progress tracking utilities
655#[derive(Debug)]
656pub struct ProgressTracker {
657    total: usize,
658    current: usize,
659    start_time: std::time::Instant,
660    last_update: std::time::Instant,
661    update_interval: std::time::Duration,
662}
663
664impl ProgressTracker {
665    /// Create a new progress tracker
666    pub fn new(total: usize) -> Self {
667        let now = std::time::Instant::now();
668        Self {
669            total,
670            current: 0,
671            start_time: now,
672            last_update: now,
673            update_interval: std::time::Duration::from_secs(1),
674        }
675    }
676
677    /// Update progress
678    pub fn update(&mut self, current: usize) {
679        self.current = current;
680        let now = std::time::Instant::now();
681
682        if now.duration_since(self.last_update) >= self.update_interval {
683            self.print_progress();
684            self.last_update = now;
685        }
686    }
687
688    /// Print current progress
689    fn print_progress(&self) {
690        let percentage = (self.current as f64 / self.total as f64) * 100.0;
691        let elapsed = self.start_time.elapsed().as_secs_f64();
692        let rate = self.current as f64 / elapsed;
693
694        println!(
695            "Progress: {}/{} ({:.1}%) - {:.1} items/sec",
696            self.current, self.total, percentage, rate
697        );
698    }
699
700    /// Finish and print final statistics
701    pub fn finish(&self) {
702        let elapsed = self.start_time.elapsed().as_secs_f64();
703        let rate = self.total as f64 / elapsed;
704
705        println!(
706            "Completed: {} items in {:.2}s ({:.1} items/sec)",
707            self.total, elapsed, rate
708        );
709    }
710}
711
712/// Performance benchmarking and profiling utilities
713pub mod performance_benchmark {
714    use super::*;
715    use std::collections::BTreeMap;
716    use std::time::{Duration, Instant};
717
718    /// Comprehensive performance benchmarking for embedding operations
719    #[derive(Debug, Clone, Serialize, Deserialize)]
720    pub struct BenchmarkSuite {
721        /// Results organized by operation type
722        pub results: BTreeMap<String, BenchmarkResult>,
723        /// Overall benchmark statistics
724        pub summary: BenchmarkSummary,
725        /// Benchmark configuration
726        pub config: BenchmarkConfig,
727    }
728
729    /// Individual benchmark result for a specific operation
730    #[derive(Debug, Clone, Serialize, Deserialize)]
731    pub struct BenchmarkResult {
732        /// Operation name
733        pub operation: String,
734        /// Total number of iterations
735        pub iterations: usize,
736        /// Total elapsed time
737        pub total_duration: Duration,
738        /// Average time per operation
739        pub avg_duration: Duration,
740        /// Minimum time observed
741        pub min_duration: Duration,
742        /// Maximum time observed
743        pub max_duration: Duration,
744        /// Standard deviation of durations
745        pub std_deviation: Duration,
746        /// Operations per second
747        pub ops_per_second: f64,
748        /// Memory usage statistics
749        pub memory_stats: MemoryStats,
750        /// Additional metrics
751        pub custom_metrics: HashMap<String, f64>,
752    }
753
754    /// Memory usage statistics
755    #[derive(Debug, Clone, Serialize, Deserialize)]
756    pub struct MemoryStats {
757        /// Peak memory usage (bytes)
758        pub peak_memory_bytes: usize,
759        /// Average memory usage (bytes)
760        pub avg_memory_bytes: usize,
761        /// Memory allocations count
762        pub allocations: usize,
763        /// Memory deallocations count
764        pub deallocations: usize,
765    }
766
767    /// Overall benchmark summary
768    #[derive(Debug, Clone, Serialize, Deserialize)]
769    pub struct BenchmarkSummary {
770        /// Total benchmark duration
771        pub total_duration: Duration,
772        /// Number of operations benchmarked
773        pub total_operations: usize,
774        /// Overall throughput (ops/sec)
775        pub overall_throughput: f64,
776        /// Performance efficiency score (0.0-1.0)
777        pub efficiency_score: f64,
778        /// Bottleneck analysis
779        pub bottlenecks: Vec<String>,
780    }
781
782    /// Benchmarking configuration
783    #[derive(Debug, Clone, Serialize, Deserialize)]
784    pub struct BenchmarkConfig {
785        /// Number of warmup iterations
786        pub warmup_iterations: usize,
787        /// Number of measurement iterations
788        pub measurement_iterations: usize,
789        /// Target confidence level (0.0-1.0)
790        pub confidence_level: f64,
791        /// Enable memory profiling
792        pub enable_memory_profiling: bool,
793        /// Enable detailed timing analysis
794        pub enable_detailed_timing: bool,
795    }
796
797    impl Default for BenchmarkConfig {
798        fn default() -> Self {
799            Self {
800                warmup_iterations: 100,
801                measurement_iterations: 1000,
802                confidence_level: 0.95,
803                enable_memory_profiling: true,
804                enable_detailed_timing: true,
805            }
806        }
807    }
808
809    /// High-precision timer for micro-benchmarking
810    pub struct PrecisionTimer {
811        start_time: Instant,
812        lap_times: Vec<Duration>,
813    }
814
815    impl Default for PrecisionTimer {
816        fn default() -> Self {
817            Self::new()
818        }
819    }
820
821    impl PrecisionTimer {
822        pub fn new() -> Self {
823            Self {
824                start_time: Instant::now(),
825                lap_times: Vec::new(),
826            }
827        }
828
829        /// Start timing
830        pub fn start(&mut self) {
831            self.start_time = Instant::now();
832            self.lap_times.clear();
833        }
834
835        /// Record a lap time
836        pub fn lap(&mut self) -> Duration {
837            let lap_duration = self.start_time.elapsed();
838            self.lap_times.push(lap_duration);
839            lap_duration
840        }
841
842        /// Stop timing and return final duration
843        pub fn stop(&self) -> Duration {
844            self.start_time.elapsed()
845        }
846
847        /// Get all recorded lap times
848        pub fn lap_times(&self) -> &[Duration] {
849            &self.lap_times
850        }
851    }
852
853    /// Benchmarking framework for embedding operations
854    pub struct EmbeddingBenchmark {
855        config: BenchmarkConfig,
856        results: BTreeMap<String, BenchmarkResult>,
857    }
858
859    impl EmbeddingBenchmark {
860        pub fn new(config: BenchmarkConfig) -> Self {
861            Self {
862                config,
863                results: BTreeMap::new(),
864            }
865        }
866
867        /// Benchmark a function with comprehensive timing and memory analysis
868        pub fn benchmark<F, T>(&mut self, name: &str, mut operation: F) -> Result<T>
869        where
870            F: FnMut() -> Result<T>,
871        {
872            // Warmup phase
873            for _ in 0..self.config.warmup_iterations {
874                let _ = operation()?;
875            }
876
877            let mut durations = Vec::with_capacity(self.config.measurement_iterations);
878            let mut memory_snapshots = Vec::new();
879            let mut result = None;
880
881            // Measurement phase
882            for i in 0..self.config.measurement_iterations {
883                let memory_before = self.get_memory_usage();
884                let start = Instant::now();
885
886                let op_result = operation()?;
887
888                let duration = start.elapsed();
889                let memory_after = self.get_memory_usage();
890
891                durations.push(duration);
892
893                if self.config.enable_memory_profiling {
894                    memory_snapshots.push((memory_before, memory_after));
895                }
896
897                // Store result from the first iteration
898                if i == 0 {
899                    result = Some(op_result);
900                }
901            }
902
903            // Calculate statistics
904            let total_duration: Duration = durations.iter().sum();
905            let avg_duration = total_duration / durations.len() as u32;
906            let min_duration = *durations.iter().min().unwrap();
907            let max_duration = *durations.iter().max().unwrap();
908
909            // Calculate standard deviation
910            let variance: f64 = durations
911                .iter()
912                .map(|d| {
913                    let diff = d.as_nanos() as f64 - avg_duration.as_nanos() as f64;
914                    diff * diff
915                })
916                .sum::<f64>()
917                / durations.len() as f64;
918            let std_deviation = Duration::from_nanos(variance.sqrt() as u64);
919
920            let ops_per_second = 1_000_000_000.0 / avg_duration.as_nanos() as f64;
921
922            // Memory statistics
923            let memory_stats = if self.config.enable_memory_profiling
924                && !memory_snapshots.is_empty()
925            {
926                let peak_memory = memory_snapshots
927                    .iter()
928                    .map(|(_, after)| after.peak_memory_bytes)
929                    .max()
930                    .unwrap_or(0);
931
932                let avg_memory = memory_snapshots
933                    .iter()
934                    .map(|(before, after)| (before.avg_memory_bytes + after.avg_memory_bytes) / 2)
935                    .sum::<usize>()
936                    / memory_snapshots.len();
937
938                MemoryStats {
939                    peak_memory_bytes: peak_memory,
940                    avg_memory_bytes: avg_memory,
941                    allocations: memory_snapshots.len(),
942                    deallocations: 0, // Simplified for now
943                }
944            } else {
945                MemoryStats {
946                    peak_memory_bytes: 0,
947                    avg_memory_bytes: 0,
948                    allocations: 0,
949                    deallocations: 0,
950                }
951            };
952
953            let benchmark_result = BenchmarkResult {
954                operation: name.to_string(),
955                iterations: self.config.measurement_iterations,
956                total_duration,
957                avg_duration,
958                min_duration,
959                max_duration,
960                std_deviation,
961                ops_per_second,
962                memory_stats,
963                custom_metrics: HashMap::new(),
964            };
965
966            self.results.insert(name.to_string(), benchmark_result);
967
968            result.ok_or_else(|| anyhow!("Failed to capture benchmark result"))
969        }
970
971        /// Generate comprehensive benchmark report
972        pub fn generate_report(&self) -> BenchmarkSuite {
973            let total_duration = self.results.values().map(|r| r.total_duration).sum();
974
975            let total_operations = self.results.len();
976
977            let overall_throughput = self.results.values().map(|r| r.ops_per_second).sum::<f64>()
978                / total_operations as f64;
979
980            // Calculate efficiency score based on consistency and performance
981            let efficiency_score = self.calculate_efficiency_score();
982
983            // Identify bottlenecks
984            let bottlenecks = self.identify_bottlenecks();
985
986            let summary = BenchmarkSummary {
987                total_duration,
988                total_operations,
989                overall_throughput,
990                efficiency_score,
991                bottlenecks,
992            };
993
994            BenchmarkSuite {
995                results: self.results.clone(),
996                summary,
997                config: self.config.clone(),
998            }
999        }
1000
1001        /// Calculate efficiency score based on performance consistency
1002        fn calculate_efficiency_score(&self) -> f64 {
1003            if self.results.is_empty() {
1004                return 0.0;
1005            }
1006
1007            let consistency_scores: Vec<f64> = self
1008                .results
1009                .values()
1010                .map(|result| {
1011                    // Calculate coefficient of variation (std_dev / mean)
1012                    let cv = result.std_deviation.as_nanos() as f64
1013                        / result.avg_duration.as_nanos() as f64;
1014                    // Convert to consistency score (lower CV = higher consistency)
1015                    1.0 / (1.0 + cv)
1016                })
1017                .collect();
1018
1019            consistency_scores.iter().sum::<f64>() / consistency_scores.len() as f64
1020        }
1021
1022        /// Identify performance bottlenecks
1023        fn identify_bottlenecks(&self) -> Vec<String> {
1024            let mut bottlenecks = Vec::new();
1025
1026            // Find operations with high standard deviation (inconsistent performance)
1027            for (name, result) in &self.results {
1028                let cv =
1029                    result.std_deviation.as_nanos() as f64 / result.avg_duration.as_nanos() as f64;
1030                if cv > 0.2 {
1031                    // 20% coefficient of variation threshold
1032                    bottlenecks.push(format!("High variance in {}: {:.2}% CV", name, cv * 100.0));
1033                }
1034            }
1035
1036            // Find slow operations (below average throughput)
1037            let avg_throughput = self.results.values().map(|r| r.ops_per_second).sum::<f64>()
1038                / self.results.len() as f64;
1039
1040            for (name, result) in &self.results {
1041                if result.ops_per_second < avg_throughput * 0.5 {
1042                    // 50% below average
1043                    bottlenecks.push(format!(
1044                        "Slow operation {}: {:.0} ops/sec",
1045                        name, result.ops_per_second
1046                    ));
1047                }
1048            }
1049
1050            bottlenecks
1051        }
1052
1053        /// Get current memory usage (simplified implementation)
1054        fn get_memory_usage(&self) -> MemoryStats {
1055            // This is a simplified implementation
1056            // In a real-world scenario, you'd use proper memory profiling tools
1057            MemoryStats {
1058                peak_memory_bytes: 0,
1059                avg_memory_bytes: 0,
1060                allocations: 0,
1061                deallocations: 0,
1062            }
1063        }
1064    }
1065
1066    /// Utility functions for performance analysis
1067    pub mod analysis {
1068        use super::*;
1069
1070        /// Compare two benchmark results
1071        pub fn compare_benchmarks(
1072            baseline: &BenchmarkResult,
1073            comparison: &BenchmarkResult,
1074        ) -> BenchmarkComparison {
1075            let throughput_improvement =
1076                (comparison.ops_per_second - baseline.ops_per_second) / baseline.ops_per_second;
1077
1078            let latency_improvement = (baseline.avg_duration.as_nanos() as f64
1079                - comparison.avg_duration.as_nanos() as f64)
1080                / baseline.avg_duration.as_nanos() as f64;
1081
1082            let consistency_improvement = {
1083                let baseline_cv = baseline.std_deviation.as_nanos() as f64
1084                    / baseline.avg_duration.as_nanos() as f64;
1085                let comparison_cv = comparison.std_deviation.as_nanos() as f64
1086                    / comparison.avg_duration.as_nanos() as f64;
1087                (baseline_cv - comparison_cv) / baseline_cv
1088            };
1089
1090            BenchmarkComparison {
1091                baseline_name: baseline.operation.clone(),
1092                comparison_name: comparison.operation.clone(),
1093                throughput_improvement,
1094                latency_improvement,
1095                consistency_improvement,
1096                is_improvement: throughput_improvement > 0.0 && latency_improvement > 0.0,
1097            }
1098        }
1099
1100        /// Generate performance regression analysis
1101        pub fn analyze_regression(
1102            historical_results: &[BenchmarkResult],
1103            current_result: &BenchmarkResult,
1104        ) -> RegressionAnalysis {
1105            if historical_results.is_empty() {
1106                return RegressionAnalysis::default();
1107            }
1108
1109            let historical_avg_throughput = historical_results
1110                .iter()
1111                .map(|r| r.ops_per_second)
1112                .sum::<f64>()
1113                / historical_results.len() as f64;
1114
1115            let throughput_change = (current_result.ops_per_second - historical_avg_throughput)
1116                / historical_avg_throughput;
1117
1118            let is_regression = throughput_change < -0.05; // 5% threshold
1119
1120            RegressionAnalysis {
1121                throughput_change,
1122                is_regression,
1123                confidence_level: 0.95, // Simplified
1124                analysis_notes: if is_regression {
1125                    vec!["Performance regression detected".to_string()]
1126                } else {
1127                    vec!["Performance within expected range".to_string()]
1128                },
1129            }
1130        }
1131    }
1132
1133    /// Benchmark comparison result
1134    #[derive(Debug, Clone, Serialize, Deserialize)]
1135    pub struct BenchmarkComparison {
1136        pub baseline_name: String,
1137        pub comparison_name: String,
1138        pub throughput_improvement: f64,
1139        pub latency_improvement: f64,
1140        pub consistency_improvement: f64,
1141        pub is_improvement: bool,
1142    }
1143
1144    /// Performance regression analysis
1145    #[derive(Debug, Clone, Serialize, Deserialize)]
1146    pub struct RegressionAnalysis {
1147        pub throughput_change: f64,
1148        pub is_regression: bool,
1149        pub confidence_level: f64,
1150        pub analysis_notes: Vec<String>,
1151    }
1152
1153    impl Default for RegressionAnalysis {
1154        fn default() -> Self {
1155            Self {
1156                throughput_change: 0.0,
1157                is_regression: false,
1158                confidence_level: 0.0,
1159                analysis_notes: vec!["No historical data available".to_string()],
1160            }
1161        }
1162    }
1163}
1164
1165/// Convenience functions for quick operations
1166pub mod convenience {
1167    use super::*;
1168    use crate::{EmbeddingModel, ModelConfig, NamedNode, TransE, Triple};
1169
1170    /// Create a simple TransE model with sensible defaults for quick prototyping
1171    pub fn create_simple_transe_model() -> TransE {
1172        let config = ModelConfig::default()
1173            .with_dimensions(128)
1174            .with_learning_rate(0.01)
1175            .with_max_epochs(100);
1176        TransE::new(config)
1177    }
1178
1179    /// Parse a triple from a simple string format "subject predicate object"
1180    pub fn parse_triple_from_string(triple_str: &str) -> Result<Triple> {
1181        let parts: Vec<&str> = triple_str.split_whitespace().collect();
1182        if parts.len() != 3 {
1183            return Err(anyhow!(
1184                "Invalid triple format. Expected 'subject predicate object', got: '{}'",
1185                triple_str
1186            ));
1187        }
1188
1189        let subject = if parts[0].starts_with("http") {
1190            NamedNode::new(parts[0])?
1191        } else {
1192            NamedNode::new(&format!("http://example.org/{}", parts[0]))?
1193        };
1194
1195        let predicate = if parts[1].starts_with("http") {
1196            NamedNode::new(parts[1])?
1197        } else {
1198            NamedNode::new(&format!("http://example.org/{}", parts[1]))?
1199        };
1200
1201        let object = if parts[2].starts_with("http") {
1202            NamedNode::new(parts[2])?
1203        } else {
1204            NamedNode::new(&format!("http://example.org/{}", parts[2]))?
1205        };
1206
1207        Ok(Triple::new(subject, predicate, object))
1208    }
1209
1210    /// Add multiple triples from string array to a model
1211    pub fn add_triples_from_strings(
1212        model: &mut dyn EmbeddingModel,
1213        triple_strings: &[&str],
1214    ) -> Result<usize> {
1215        let mut added_count = 0;
1216        for triple_str in triple_strings {
1217            match parse_triple_from_string(triple_str) {
1218                Ok(triple) => {
1219                    model.add_triple(triple)?;
1220                    added_count += 1;
1221                }
1222                Err(e) => {
1223                    eprintln!("Warning: Failed to parse triple '{triple_str}': {e}");
1224                }
1225            }
1226        }
1227        Ok(added_count)
1228    }
1229
1230    /// Quick function to compute similarity between two embedding vectors
1231    pub fn cosine_similarity(a: &[f64], b: &[f64]) -> Result<f64> {
1232        if a.len() != b.len() {
1233            return Err(anyhow!(
1234                "Vector dimensions don't match: {} vs {}",
1235                a.len(),
1236                b.len()
1237            ));
1238        }
1239
1240        let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
1241        let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
1242        let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
1243
1244        if norm_a == 0.0 || norm_b == 0.0 {
1245            return Ok(0.0);
1246        }
1247
1248        Ok(dot_product / (norm_a * norm_b))
1249    }
1250
1251    /// Generate sample knowledge graph data for testing
1252    pub fn generate_sample_kg_data(
1253        num_entities: usize,
1254        num_relations: usize,
1255    ) -> Vec<(String, String, String)> {
1256        let mut rng = Random::default();
1257        let mut triples = Vec::new();
1258
1259        let entities: Vec<String> = (0..num_entities).map(|i| format!("entity_{i}")).collect();
1260
1261        let relations: Vec<String> = (0..num_relations)
1262            .map(|i| format!("relation_{i}"))
1263            .collect();
1264
1265        // Generate random triples
1266        for _ in 0..(num_entities * 2) {
1267            let subject = entities[rng.random_range(0..entities.len())].clone();
1268            let relation = relations[rng.random_range(0..relations.len())].clone();
1269            let object = entities[rng.random_range(0..entities.len())].clone();
1270
1271            if subject != object {
1272                triples.push((subject, relation, object));
1273            }
1274        }
1275
1276        triples
1277    }
1278
1279    /// Quick performance test function
1280    pub fn quick_performance_test<F>(
1281        name: &str,
1282        iterations: usize,
1283        operation: F,
1284    ) -> std::time::Duration
1285    where
1286        F: Fn(),
1287    {
1288        let start = std::time::Instant::now();
1289        for _ in 0..iterations {
1290            operation();
1291        }
1292        let duration = start.elapsed();
1293
1294        println!(
1295            "Performance test '{}': {} iterations in {:?} ({:.2} ops/sec)",
1296            name,
1297            iterations,
1298            duration,
1299            iterations as f64 / duration.as_secs_f64()
1300        );
1301
1302        duration
1303    }
1304}
1305
1306/// Advanced performance utilities for embedding operations
1307pub mod performance_utils {
1308    use super::*;
1309
1310    /// Type alias for batch processor function
1311    type ProcessorFn<T> = Box<dyn Fn(&[T]) -> Result<()> + Send + Sync>;
1312
1313    /// Memory-efficient batch processor for large datasets
1314    pub struct BatchProcessor<T> {
1315        batch_size: usize,
1316        current_batch: Vec<T>,
1317        processor_fn: ProcessorFn<T>,
1318    }
1319
1320    impl<T> BatchProcessor<T> {
1321        pub fn new<F>(batch_size: usize, processor_fn: F) -> Self
1322        where
1323            F: Fn(&[T]) -> Result<()> + Send + Sync + 'static,
1324        {
1325            Self {
1326                batch_size,
1327                current_batch: Vec::with_capacity(batch_size),
1328                processor_fn: Box::new(processor_fn),
1329            }
1330        }
1331
1332        pub fn add(&mut self, item: T) -> Result<()> {
1333            self.current_batch.push(item);
1334            if self.current_batch.len() >= self.batch_size {
1335                return self.flush();
1336            }
1337            Ok(())
1338        }
1339
1340        pub fn flush(&mut self) -> Result<()> {
1341            if !self.current_batch.is_empty() {
1342                (self.processor_fn)(&self.current_batch)?;
1343                self.current_batch.clear();
1344            }
1345            Ok(())
1346        }
1347    }
1348
1349    /// Enhanced memory monitoring for embedding operations
1350    #[derive(Debug, Clone)]
1351    pub struct MemoryMonitor {
1352        peak_usage: usize,
1353        current_usage: usize,
1354        allocations: usize,
1355        deallocations: usize,
1356    }
1357
1358    impl MemoryMonitor {
1359        pub fn new() -> Self {
1360            Self {
1361                peak_usage: 0,
1362                current_usage: 0,
1363                allocations: 0,
1364                deallocations: 0,
1365            }
1366        }
1367
1368        pub fn record_allocation(&mut self, size: usize) {
1369            self.current_usage += size;
1370            self.allocations += 1;
1371            if self.current_usage > self.peak_usage {
1372                self.peak_usage = self.current_usage;
1373            }
1374        }
1375
1376        pub fn record_deallocation(&mut self, size: usize) {
1377            self.current_usage = self.current_usage.saturating_sub(size);
1378            self.deallocations += 1;
1379        }
1380
1381        pub fn peak_usage(&self) -> usize {
1382            self.peak_usage
1383        }
1384
1385        pub fn current_usage(&self) -> usize {
1386            self.current_usage
1387        }
1388
1389        pub fn allocation_count(&self) -> usize {
1390            self.allocations
1391        }
1392    }
1393
1394    impl Default for MemoryMonitor {
1395        fn default() -> Self {
1396            Self::new()
1397        }
1398    }
1399}
1400
1401/// Parallel processing utilities for embedding operations
1402pub mod parallel_utils {
1403    use super::*;
1404    use rayon::prelude::*;
1405
1406    /// Parallel computation of embedding similarities
1407    pub fn parallel_cosine_similarities(
1408        query_embedding: &[f32],
1409        candidate_embeddings: &[Vec<f32>],
1410    ) -> Result<Vec<f32>> {
1411        let similarities: Vec<f32> = candidate_embeddings
1412            .par_iter()
1413            .map(|embedding| oxirs_vec::similarity::cosine_similarity(query_embedding, embedding))
1414            .collect();
1415        Ok(similarities)
1416    }
1417
1418    /// Parallel batch processing with configurable thread pool
1419    pub fn parallel_batch_process<T, R, F>(
1420        items: &[T],
1421        batch_size: usize,
1422        processor: F,
1423    ) -> Result<Vec<R>>
1424    where
1425        T: Sync,
1426        R: Send,
1427        F: Fn(&[T]) -> Result<Vec<R>> + Sync + Send,
1428    {
1429        let results: Result<Vec<Vec<R>>> = items.par_chunks(batch_size).map(processor).collect();
1430
1431        Ok(results?.into_iter().flatten().collect())
1432    }
1433
1434    /// Parallel graph analysis with optimized memory usage
1435    pub fn parallel_entity_frequencies(
1436        triples: &[(String, String, String)],
1437    ) -> HashMap<String, usize> {
1438        let entity_counts: HashMap<String, usize> = triples
1439            .par_iter()
1440            .fold(HashMap::new, |mut acc, (subject, _predicate, object)| {
1441                *acc.entry(subject.clone()).or_insert(0) += 1;
1442                *acc.entry(object.clone()).or_insert(0) += 1;
1443                acc
1444            })
1445            .reduce(HashMap::new, |mut acc1, acc2| {
1446                for (entity, count) in acc2 {
1447                    *acc1.entry(entity).or_insert(0) += count;
1448                }
1449                acc1
1450            });
1451        entity_counts
1452    }
1453}
1454
1455#[cfg(test)]
1456mod tests {
1457    use super::*;
1458    use crate::quick_start::{
1459        add_triples_from_strings, cosine_similarity, create_simple_transe_model,
1460        generate_sample_kg_data, parse_triple_from_string, quick_performance_test,
1461    };
1462    use crate::EmbeddingModel;
1463    use std::io::Write;
1464    use tempfile::NamedTempFile;
1465
1466    #[test]
1467    fn test_load_triples_from_tsv() -> Result<()> {
1468        let mut temp_file = NamedTempFile::new()?;
1469        writeln!(temp_file, "subject\tpredicate\tobject")?;
1470        writeln!(temp_file, "alice\tknows\tbob")?;
1471        writeln!(temp_file, "bob\tlikes\tcharlie")?;
1472
1473        let triples = data_loader::load_triples_from_tsv(temp_file.path())?;
1474        assert_eq!(triples.len(), 2);
1475        assert_eq!(
1476            triples[0],
1477            ("alice".to_string(), "knows".to_string(), "bob".to_string())
1478        );
1479
1480        Ok(())
1481    }
1482
1483    #[test]
1484    fn test_dataset_split() -> Result<()> {
1485        let triples = vec![
1486            ("a".to_string(), "r1".to_string(), "b".to_string()),
1487            ("b".to_string(), "r2".to_string(), "c".to_string()),
1488            ("c".to_string(), "r3".to_string(), "d".to_string()),
1489            ("d".to_string(), "r4".to_string(), "e".to_string()),
1490        ];
1491
1492        let split = dataset_splitter::split_dataset(triples, 0.6, 0.2, Some(42))?;
1493
1494        assert_eq!(split.train.len(), 2);
1495        assert_eq!(split.validation.len(), 0); // 0.2 * 4 = 0.8, rounded down
1496        assert_eq!(split.test.len(), 2);
1497
1498        Ok(())
1499    }
1500
1501    #[test]
1502    fn test_load_triples_from_jsonl() -> Result<()> {
1503        let mut temp_file = NamedTempFile::new()?;
1504        writeln!(
1505            temp_file,
1506            r#"{{"subject": "alice", "predicate": "knows", "object": "bob"}}"#
1507        )?;
1508        writeln!(
1509            temp_file,
1510            r#"{{"subject": "bob", "predicate": "likes", "object": "charlie"}}"#
1511        )?;
1512
1513        let triples = data_loader::load_triples_from_jsonl(temp_file.path())?;
1514        assert_eq!(triples.len(), 2);
1515        assert_eq!(
1516            triples[0],
1517            ("alice".to_string(), "knows".to_string(), "bob".to_string())
1518        );
1519
1520        Ok(())
1521    }
1522
1523    #[test]
1524    fn test_save_triples_to_jsonl() -> Result<()> {
1525        let triples = vec![
1526            ("alice".to_string(), "knows".to_string(), "bob".to_string()),
1527            (
1528                "bob".to_string(),
1529                "likes".to_string(),
1530                "charlie".to_string(),
1531            ),
1532        ];
1533
1534        let temp_file = NamedTempFile::new()?;
1535        data_loader::save_triples_to_jsonl(&triples, temp_file.path())?;
1536
1537        let loaded_triples = data_loader::load_triples_from_jsonl(temp_file.path())?;
1538        assert_eq!(loaded_triples, triples);
1539
1540        Ok(())
1541    }
1542
1543    #[test]
1544    fn test_load_triples_auto_detect() -> Result<()> {
1545        // Test TSV auto-detection
1546        let mut tsv_file = NamedTempFile::with_suffix(".tsv")?;
1547        writeln!(tsv_file, "subject\tpredicate\tobject")?;
1548        writeln!(tsv_file, "alice\tknows\tbob")?;
1549
1550        let triples = data_loader::load_triples_auto_detect(tsv_file.path())?;
1551        assert_eq!(triples.len(), 1);
1552
1553        // Test JSON Lines auto-detection
1554        let mut jsonl_file = NamedTempFile::with_suffix(".jsonl")?;
1555        writeln!(
1556            jsonl_file,
1557            r#"{{"subject": "alice", "predicate": "knows", "object": "bob"}}"#
1558        )?;
1559
1560        let triples = data_loader::load_triples_auto_detect(jsonl_file.path())?;
1561        assert_eq!(triples.len(), 1);
1562        assert_eq!(
1563            triples[0],
1564            ("alice".to_string(), "knows".to_string(), "bob".to_string())
1565        );
1566
1567        Ok(())
1568    }
1569
1570    #[test]
1571    fn test_dataset_statistics() {
1572        let triples = vec![
1573            ("alice".to_string(), "knows".to_string(), "bob".to_string()),
1574            (
1575                "bob".to_string(),
1576                "knows".to_string(),
1577                "charlie".to_string(),
1578            ),
1579            (
1580                "alice".to_string(),
1581                "likes".to_string(),
1582                "charlie".to_string(),
1583            ),
1584        ];
1585
1586        let stats = compute_dataset_statistics(&triples);
1587
1588        assert_eq!(stats.num_triples, 3);
1589        assert_eq!(stats.num_entities, 3); // alice, bob, charlie
1590        assert_eq!(stats.num_relations, 2); // knows, likes
1591        assert!(stats.avg_degree > 0.0);
1592    }
1593
1594    // Tests for convenience functions
1595    #[test]
1596    fn test_create_simple_transe_model() {
1597        let model = create_simple_transe_model();
1598        assert_eq!(model.config().dimensions, 128);
1599        assert_eq!(model.config().learning_rate, 0.01);
1600        assert_eq!(model.config().max_epochs, 100);
1601    }
1602
1603    #[test]
1604    fn test_parse_triple_from_string() -> Result<()> {
1605        let triple = parse_triple_from_string("alice knows bob")?;
1606        assert_eq!(triple.subject.iri.as_str(), "http://example.org/alice");
1607        assert_eq!(triple.predicate.iri.as_str(), "http://example.org/knows");
1608        assert_eq!(triple.object.iri.as_str(), "http://example.org/bob");
1609
1610        // Test with full URIs
1611        let triple2 = parse_triple_from_string(
1612            "http://example.org/alice http://example.org/knows http://example.org/bob",
1613        )?;
1614        assert_eq!(triple2.subject.iri.as_str(), "http://example.org/alice");
1615
1616        // Test invalid format
1617        assert!(parse_triple_from_string("alice knows").is_err());
1618
1619        Ok(())
1620    }
1621
1622    #[test]
1623    fn test_add_triples_from_strings() -> Result<()> {
1624        let mut model = create_simple_transe_model();
1625        let triple_strings = &[
1626            "alice knows bob",
1627            "bob likes charlie",
1628            "charlie follows alice",
1629        ];
1630
1631        let added_count = add_triples_from_strings(&mut model, triple_strings)?;
1632        assert_eq!(added_count, 3);
1633
1634        Ok(())
1635    }
1636
1637    #[test]
1638    fn test_cosine_similarity() -> Result<()> {
1639        let a = vec![1.0, 0.0, 0.0];
1640        let b = vec![1.0, 0.0, 0.0];
1641        let similarity = cosine_similarity(&a, &b)?;
1642        assert!((similarity - 1.0).abs() < 1e-10);
1643
1644        let c = vec![0.0, 1.0, 0.0];
1645        let similarity2 = cosine_similarity(&a, &c)?;
1646        assert!((similarity2 - 0.0).abs() < 1e-10);
1647
1648        // Test different dimensions
1649        let d = vec![1.0, 0.0];
1650        assert!(cosine_similarity(&a, &d).is_err());
1651
1652        Ok(())
1653    }
1654
1655    #[test]
1656    fn test_generate_sample_kg_data() {
1657        let triples = generate_sample_kg_data(5, 3);
1658        assert!(!triples.is_empty());
1659
1660        // Check that all subjects and objects are in the expected format
1661        for (subject, relation, object) in &triples {
1662            assert!(subject.starts_with("http://example.org/entity_"));
1663            assert!(relation.starts_with("http://example.org/relation_"));
1664            assert!(object.starts_with("http://example.org/entity_"));
1665            assert_ne!(subject, object); // No self-loops
1666        }
1667    }
1668
1669    #[test]
1670    fn test_quick_performance_test() {
1671        let duration = quick_performance_test("test_operation", 100, || {
1672            // Simple operation for testing
1673            let _sum: i32 = (1..10).sum();
1674        });
1675
1676        // In release mode, operations can be extremely fast
1677        // Just verify the function completes and returns a valid duration
1678        let _nanos = duration.as_nanos();
1679    }
1680}