Skip to main content

oxirs_embed/
utils_io.rs

1//! Serialization/deserialization helpers and file I/O utilities for embeddings.
2
3use crate::utils_types::DatasetSplit;
4use anyhow::{anyhow, Result};
5use scirs2_core::random::Random;
6use std::collections::{HashMap, HashSet};
7use std::fs;
8use std::io::{BufRead, BufReader};
9use std::path::Path;
10
11/// Data loading utilities
12pub mod data_loader {
13    use super::*;
14
15    /// Load triples from TSV file format
16    pub fn load_triples_from_tsv<P: AsRef<Path>>(
17        file_path: P,
18    ) -> Result<Vec<(String, String, String)>> {
19        let file = fs::File::open(file_path)?;
20        let reader = BufReader::new(file);
21        let mut triples = Vec::new();
22
23        for (line_num, line) in reader.lines().enumerate() {
24            let line = line?;
25            if line.trim().is_empty() || line.starts_with('#') {
26                continue;
27            }
28
29            if line_num == 0
30                && (line.contains("subject")
31                    || line.contains("predicate")
32                    || line.contains("object"))
33            {
34                continue;
35            }
36
37            let parts: Vec<&str> = line.split('\t').collect();
38            if parts.len() >= 3 {
39                let subject = parts[0].trim().to_string();
40                let predicate = parts[1].trim().to_string();
41                let object = parts[2].trim().to_string();
42                triples.push((subject, predicate, object));
43            } else {
44                eprintln!(
45                    "Warning: Invalid triple format at line {}: {}",
46                    line_num + 1,
47                    line
48                );
49            }
50        }
51
52        Ok(triples)
53    }
54
55    /// Load triples from CSV file format
56    pub fn load_triples_from_csv<P: AsRef<Path>>(
57        file_path: P,
58    ) -> Result<Vec<(String, String, String)>> {
59        let file = fs::File::open(file_path)?;
60        let reader = BufReader::new(file);
61        let mut triples = Vec::new();
62        let mut is_first_line = true;
63
64        for (line_num, line) in reader.lines().enumerate() {
65            let line = line?;
66            if is_first_line {
67                is_first_line = false;
68                if line.to_lowercase().contains("subject")
69                    && line.to_lowercase().contains("predicate")
70                {
71                    continue;
72                }
73            }
74
75            if line.trim().is_empty() {
76                continue;
77            }
78
79            let parts: Vec<&str> = line.split(',').collect();
80            if parts.len() >= 3 {
81                let subject = parts[0].trim().trim_matches('"').to_string();
82                let predicate = parts[1].trim().trim_matches('"').to_string();
83                let object = parts[2].trim().trim_matches('"').to_string();
84                triples.push((subject, predicate, object));
85            } else {
86                eprintln!(
87                    "Warning: Invalid triple format at line {}: {}",
88                    line_num + 1,
89                    line
90                );
91            }
92        }
93
94        Ok(triples)
95    }
96
97    /// Load triples from N-Triples format
98    pub fn load_triples_from_ntriples<P: AsRef<Path>>(
99        file_path: P,
100    ) -> Result<Vec<(String, String, String)>> {
101        let file = fs::File::open(file_path)?;
102        let reader = BufReader::new(file);
103        let mut triples = Vec::new();
104
105        for (line_num, line) in reader.lines().enumerate() {
106            let line = line?;
107            let line = line.trim();
108
109            if line.is_empty() || line.starts_with('#') {
110                continue;
111            }
112
113            if let Some(triple) = parse_ntriple_line(line) {
114                triples.push(triple);
115            } else {
116                eprintln!(
117                    "Warning: Failed to parse N-Triple at line {}: {}",
118                    line_num + 1,
119                    line
120                );
121            }
122        }
123
124        Ok(triples)
125    }
126
127    fn parse_ntriple_line(line: &str) -> Option<(String, String, String)> {
128        let line = line.trim_end_matches(" .");
129        let parts: Vec<&str> = line.split_whitespace().collect();
130
131        if parts.len() >= 3 {
132            let subject = clean_uri_or_literal(parts[0]);
133            let predicate = clean_uri_or_literal(parts[1]);
134            let object = clean_uri_or_literal(&parts[2..].join(" "));
135            Some((subject, predicate, object))
136        } else {
137            None
138        }
139    }
140
141    fn clean_uri_or_literal(term: &str) -> String {
142        if term.starts_with('<') && term.ends_with('>') {
143            term[1..term.len() - 1].to_string()
144        } else if term.starts_with('"') && term.contains('"') {
145            let end_quote = term.rfind('"').unwrap_or(term.len());
146            term[1..end_quote].to_string()
147        } else {
148            term.to_string()
149        }
150    }
151
152    /// Load triples from JSON Lines format
153    pub fn load_triples_from_jsonl<P: AsRef<Path>>(
154        file_path: P,
155    ) -> Result<Vec<(String, String, String)>> {
156        let file = fs::File::open(file_path)?;
157        let reader = BufReader::new(file);
158        let mut triples = Vec::new();
159
160        for (line_num, line) in reader.lines().enumerate() {
161            let line = line?;
162            if line.trim().is_empty() {
163                continue;
164            }
165
166            match serde_json::from_str::<serde_json::Value>(&line) {
167                Ok(json) => {
168                    if let (Some(subject), Some(predicate), Some(object)) = (
169                        json["subject"].as_str(),
170                        json["predicate"].as_str(),
171                        json["object"].as_str(),
172                    ) {
173                        triples.push((
174                            subject.to_string(),
175                            predicate.to_string(),
176                            object.to_string(),
177                        ));
178                    } else {
179                        eprintln!(
180                            "Warning: Invalid JSON structure at line {}: {}",
181                            line_num + 1,
182                            line
183                        );
184                    }
185                }
186                Err(e) => {
187                    eprintln!(
188                        "Warning: Failed to parse JSON at line {}: {} - Error: {}",
189                        line_num + 1,
190                        line,
191                        e
192                    );
193                }
194            }
195        }
196
197        Ok(triples)
198    }
199
200    /// Save triples to TSV format
201    pub fn save_triples_to_tsv<P: AsRef<Path>>(
202        triples: &[(String, String, String)],
203        file_path: P,
204    ) -> Result<()> {
205        let mut content = String::new();
206        content.push_str("subject\tpredicate\tobject\n");
207
208        for (subject, predicate, object) in triples {
209            content.push_str(&format!("{subject}\t{predicate}\t{object}\n"));
210        }
211
212        fs::write(file_path, content)?;
213        Ok(())
214    }
215
216    /// Save triples to JSON Lines format
217    pub fn save_triples_to_jsonl<P: AsRef<Path>>(
218        triples: &[(String, String, String)],
219        file_path: P,
220    ) -> Result<()> {
221        use std::io::Write;
222        let mut file = fs::File::create(file_path)?;
223
224        for (subject, predicate, object) in triples {
225            let json = serde_json::json!({
226                "subject": subject,
227                "predicate": predicate,
228                "object": object
229            });
230            writeln!(file, "{json}")?;
231        }
232
233        Ok(())
234    }
235
236    /// Auto-detect file format and load triples accordingly
237    pub fn load_triples_auto_detect<P: AsRef<Path>>(
238        file_path: P,
239    ) -> Result<Vec<(String, String, String)>> {
240        let path = file_path.as_ref();
241        let extension = path
242            .extension()
243            .and_then(|ext| ext.to_str())
244            .unwrap_or("")
245            .to_lowercase();
246
247        match extension.as_str() {
248            "tsv" => load_triples_from_tsv(path),
249            "csv" => load_triples_from_csv(path),
250            "nt" | "ntriples" => load_triples_from_ntriples(path),
251            "jsonl" | "ndjson" => load_triples_from_jsonl(path),
252            _ => {
253                eprintln!(
254                    "Warning: Unknown file extension '{extension}', attempting auto-detection"
255                );
256
257                if let Ok(triples) = load_triples_from_tsv(path) {
258                    if !triples.is_empty() {
259                        return Ok(triples);
260                    }
261                }
262
263                if let Ok(triples) = load_triples_from_ntriples(path) {
264                    if !triples.is_empty() {
265                        return Ok(triples);
266                    }
267                }
268
269                if let Ok(triples) = load_triples_from_jsonl(path) {
270                    if !triples.is_empty() {
271                        return Ok(triples);
272                    }
273                }
274
275                load_triples_from_csv(path)
276            }
277        }
278    }
279}
280
281/// Dataset splitting utilities
282pub mod dataset_splitter {
283    use super::*;
284
285    /// Split dataset into train/validation/test sets
286    pub fn split_dataset(
287        triples: Vec<(String, String, String)>,
288        train_ratio: f64,
289        val_ratio: f64,
290        seed: Option<u64>,
291    ) -> Result<DatasetSplit> {
292        if train_ratio + val_ratio >= 1.0 {
293            return Err(anyhow!(
294                "Train and validation ratios must sum to less than 1.0"
295            ));
296        }
297
298        let mut rng = if let Some(s) = seed {
299            Random::seed(s)
300        } else {
301            Random::seed(42)
302        };
303
304        let mut shuffled_triples = triples;
305        for i in (1..shuffled_triples.len()).rev() {
306            let j = rng.random_range(0..i + 1);
307            shuffled_triples.swap(i, j);
308        }
309
310        let total = shuffled_triples.len();
311        let train_end = (total as f64 * train_ratio) as usize;
312        let val_end = train_end + (total as f64 * val_ratio) as usize;
313
314        let train_triples = shuffled_triples[..train_end].to_vec();
315        let val_triples = shuffled_triples[train_end..val_end].to_vec();
316        let test_triples = shuffled_triples[val_end..].to_vec();
317
318        Ok(DatasetSplit {
319            train: train_triples,
320            validation: val_triples,
321            test: test_triples,
322        })
323    }
324
325    /// Split dataset ensuring no entity leakage between splits
326    pub fn split_dataset_no_leakage(
327        triples: Vec<(String, String, String)>,
328        train_ratio: f64,
329        val_ratio: f64,
330        seed: Option<u64>,
331    ) -> Result<DatasetSplit> {
332        let mut entity_triples: HashMap<String, Vec<(String, String, String)>> =
333            HashMap::with_capacity(triples.len() / 2);
334
335        for triple in &triples {
336            let entities = [&triple.0, &triple.2];
337            for entity in entities {
338                entity_triples
339                    .entry(entity.clone())
340                    .or_default()
341                    .push(triple.clone());
342            }
343        }
344
345        let entities: Vec<String> = entity_triples.keys().cloned().collect();
346        let dummy_string = "dummy".to_string();
347        let entity_split = split_dataset(
348            entities
349                .into_iter()
350                .map(|e| (e, dummy_string.clone(), dummy_string.clone()))
351                .collect(),
352            train_ratio,
353            val_ratio,
354            seed,
355        )?;
356
357        let train_entities: HashSet<String> =
358            entity_split.train.into_iter().map(|(e, _, _)| e).collect();
359        let val_entities: HashSet<String> = entity_split
360            .validation
361            .into_iter()
362            .map(|(e, _, _)| e)
363            .collect();
364        let test_entities: HashSet<String> =
365            entity_split.test.into_iter().map(|(e, _, _)| e).collect();
366
367        let estimated_capacity = triples.len() / 3;
368        let mut train_triples = Vec::with_capacity(estimated_capacity);
369        let mut val_triples = Vec::with_capacity(estimated_capacity);
370        let mut test_triples = Vec::with_capacity(estimated_capacity);
371
372        for (entity, entity_triple_list) in entity_triples {
373            if train_entities.contains(&entity) {
374                train_triples.extend(entity_triple_list);
375            } else if val_entities.contains(&entity) {
376                val_triples.extend(entity_triple_list);
377            } else if test_entities.contains(&entity) {
378                test_triples.extend(entity_triple_list);
379            }
380        }
381
382        train_triples.sort();
383        train_triples.dedup();
384        val_triples.sort();
385        val_triples.dedup();
386        test_triples.sort();
387        test_triples.dedup();
388
389        Ok(DatasetSplit {
390            train: train_triples,
391            validation: val_triples,
392            test: test_triples,
393        })
394    }
395}
396
397/// Parallel processing utilities for embedding operations
398pub mod parallel_utils {
399    use anyhow::Result;
400    use rayon::prelude::*;
401    use std::collections::HashMap;
402
403    /// Parallel computation of embedding similarities
404    pub fn parallel_cosine_similarities(
405        query_embedding: &[f32],
406        candidate_embeddings: &[Vec<f32>],
407    ) -> Result<Vec<f32>> {
408        let similarities: Vec<f32> = candidate_embeddings
409            .par_iter()
410            .map(|embedding| oxirs_vec::similarity::cosine_similarity(query_embedding, embedding))
411            .collect();
412        Ok(similarities)
413    }
414
415    /// Parallel batch processing with configurable thread pool
416    pub fn parallel_batch_process<T, R, F>(
417        items: &[T],
418        batch_size: usize,
419        processor: F,
420    ) -> Result<Vec<R>>
421    where
422        T: Sync,
423        R: Send,
424        F: Fn(&[T]) -> Result<Vec<R>> + Sync + Send,
425    {
426        let results: Result<Vec<Vec<R>>> = items.par_chunks(batch_size).map(processor).collect();
427        Ok(results?.into_iter().flatten().collect())
428    }
429
430    /// Parallel graph analysis with optimized memory usage
431    pub fn parallel_entity_frequencies(
432        triples: &[(String, String, String)],
433    ) -> HashMap<String, usize> {
434        let entity_counts: HashMap<String, usize> = triples
435            .par_iter()
436            .fold(HashMap::new, |mut acc, (subject, _predicate, object)| {
437                *acc.entry(subject.clone()).or_insert(0) += 1;
438                *acc.entry(object.clone()).or_insert(0) += 1;
439                acc
440            })
441            .reduce(HashMap::new, |mut acc1, acc2| {
442                for (entity, count) in acc2 {
443                    *acc1.entry(entity).or_insert(0) += count;
444                }
445                acc1
446            });
447        entity_counts
448    }
449}