1use crate::utils_types::DatasetSplit;
4use anyhow::{anyhow, Result};
5use scirs2_core::random::Random;
6use std::collections::{HashMap, HashSet};
7use std::fs;
8use std::io::{BufRead, BufReader};
9use std::path::Path;
10
11pub mod data_loader {
13 use super::*;
14
15 pub fn load_triples_from_tsv<P: AsRef<Path>>(
17 file_path: P,
18 ) -> Result<Vec<(String, String, String)>> {
19 let file = fs::File::open(file_path)?;
20 let reader = BufReader::new(file);
21 let mut triples = Vec::new();
22
23 for (line_num, line) in reader.lines().enumerate() {
24 let line = line?;
25 if line.trim().is_empty() || line.starts_with('#') {
26 continue;
27 }
28
29 if line_num == 0
30 && (line.contains("subject")
31 || line.contains("predicate")
32 || line.contains("object"))
33 {
34 continue;
35 }
36
37 let parts: Vec<&str> = line.split('\t').collect();
38 if parts.len() >= 3 {
39 let subject = parts[0].trim().to_string();
40 let predicate = parts[1].trim().to_string();
41 let object = parts[2].trim().to_string();
42 triples.push((subject, predicate, object));
43 } else {
44 eprintln!(
45 "Warning: Invalid triple format at line {}: {}",
46 line_num + 1,
47 line
48 );
49 }
50 }
51
52 Ok(triples)
53 }
54
55 pub fn load_triples_from_csv<P: AsRef<Path>>(
57 file_path: P,
58 ) -> Result<Vec<(String, String, String)>> {
59 let file = fs::File::open(file_path)?;
60 let reader = BufReader::new(file);
61 let mut triples = Vec::new();
62 let mut is_first_line = true;
63
64 for (line_num, line) in reader.lines().enumerate() {
65 let line = line?;
66 if is_first_line {
67 is_first_line = false;
68 if line.to_lowercase().contains("subject")
69 && line.to_lowercase().contains("predicate")
70 {
71 continue;
72 }
73 }
74
75 if line.trim().is_empty() {
76 continue;
77 }
78
79 let parts: Vec<&str> = line.split(',').collect();
80 if parts.len() >= 3 {
81 let subject = parts[0].trim().trim_matches('"').to_string();
82 let predicate = parts[1].trim().trim_matches('"').to_string();
83 let object = parts[2].trim().trim_matches('"').to_string();
84 triples.push((subject, predicate, object));
85 } else {
86 eprintln!(
87 "Warning: Invalid triple format at line {}: {}",
88 line_num + 1,
89 line
90 );
91 }
92 }
93
94 Ok(triples)
95 }
96
97 pub fn load_triples_from_ntriples<P: AsRef<Path>>(
99 file_path: P,
100 ) -> Result<Vec<(String, String, String)>> {
101 let file = fs::File::open(file_path)?;
102 let reader = BufReader::new(file);
103 let mut triples = Vec::new();
104
105 for (line_num, line) in reader.lines().enumerate() {
106 let line = line?;
107 let line = line.trim();
108
109 if line.is_empty() || line.starts_with('#') {
110 continue;
111 }
112
113 if let Some(triple) = parse_ntriple_line(line) {
114 triples.push(triple);
115 } else {
116 eprintln!(
117 "Warning: Failed to parse N-Triple at line {}: {}",
118 line_num + 1,
119 line
120 );
121 }
122 }
123
124 Ok(triples)
125 }
126
127 fn parse_ntriple_line(line: &str) -> Option<(String, String, String)> {
128 let line = line.trim_end_matches(" .");
129 let parts: Vec<&str> = line.split_whitespace().collect();
130
131 if parts.len() >= 3 {
132 let subject = clean_uri_or_literal(parts[0]);
133 let predicate = clean_uri_or_literal(parts[1]);
134 let object = clean_uri_or_literal(&parts[2..].join(" "));
135 Some((subject, predicate, object))
136 } else {
137 None
138 }
139 }
140
141 fn clean_uri_or_literal(term: &str) -> String {
142 if term.starts_with('<') && term.ends_with('>') {
143 term[1..term.len() - 1].to_string()
144 } else if term.starts_with('"') && term.contains('"') {
145 let end_quote = term.rfind('"').unwrap_or(term.len());
146 term[1..end_quote].to_string()
147 } else {
148 term.to_string()
149 }
150 }
151
152 pub fn load_triples_from_jsonl<P: AsRef<Path>>(
154 file_path: P,
155 ) -> Result<Vec<(String, String, String)>> {
156 let file = fs::File::open(file_path)?;
157 let reader = BufReader::new(file);
158 let mut triples = Vec::new();
159
160 for (line_num, line) in reader.lines().enumerate() {
161 let line = line?;
162 if line.trim().is_empty() {
163 continue;
164 }
165
166 match serde_json::from_str::<serde_json::Value>(&line) {
167 Ok(json) => {
168 if let (Some(subject), Some(predicate), Some(object)) = (
169 json["subject"].as_str(),
170 json["predicate"].as_str(),
171 json["object"].as_str(),
172 ) {
173 triples.push((
174 subject.to_string(),
175 predicate.to_string(),
176 object.to_string(),
177 ));
178 } else {
179 eprintln!(
180 "Warning: Invalid JSON structure at line {}: {}",
181 line_num + 1,
182 line
183 );
184 }
185 }
186 Err(e) => {
187 eprintln!(
188 "Warning: Failed to parse JSON at line {}: {} - Error: {}",
189 line_num + 1,
190 line,
191 e
192 );
193 }
194 }
195 }
196
197 Ok(triples)
198 }
199
200 pub fn save_triples_to_tsv<P: AsRef<Path>>(
202 triples: &[(String, String, String)],
203 file_path: P,
204 ) -> Result<()> {
205 let mut content = String::new();
206 content.push_str("subject\tpredicate\tobject\n");
207
208 for (subject, predicate, object) in triples {
209 content.push_str(&format!("{subject}\t{predicate}\t{object}\n"));
210 }
211
212 fs::write(file_path, content)?;
213 Ok(())
214 }
215
216 pub fn save_triples_to_jsonl<P: AsRef<Path>>(
218 triples: &[(String, String, String)],
219 file_path: P,
220 ) -> Result<()> {
221 use std::io::Write;
222 let mut file = fs::File::create(file_path)?;
223
224 for (subject, predicate, object) in triples {
225 let json = serde_json::json!({
226 "subject": subject,
227 "predicate": predicate,
228 "object": object
229 });
230 writeln!(file, "{json}")?;
231 }
232
233 Ok(())
234 }
235
236 pub fn load_triples_auto_detect<P: AsRef<Path>>(
238 file_path: P,
239 ) -> Result<Vec<(String, String, String)>> {
240 let path = file_path.as_ref();
241 let extension = path
242 .extension()
243 .and_then(|ext| ext.to_str())
244 .unwrap_or("")
245 .to_lowercase();
246
247 match extension.as_str() {
248 "tsv" => load_triples_from_tsv(path),
249 "csv" => load_triples_from_csv(path),
250 "nt" | "ntriples" => load_triples_from_ntriples(path),
251 "jsonl" | "ndjson" => load_triples_from_jsonl(path),
252 _ => {
253 eprintln!(
254 "Warning: Unknown file extension '{extension}', attempting auto-detection"
255 );
256
257 if let Ok(triples) = load_triples_from_tsv(path) {
258 if !triples.is_empty() {
259 return Ok(triples);
260 }
261 }
262
263 if let Ok(triples) = load_triples_from_ntriples(path) {
264 if !triples.is_empty() {
265 return Ok(triples);
266 }
267 }
268
269 if let Ok(triples) = load_triples_from_jsonl(path) {
270 if !triples.is_empty() {
271 return Ok(triples);
272 }
273 }
274
275 load_triples_from_csv(path)
276 }
277 }
278 }
279}
280
281pub mod dataset_splitter {
283 use super::*;
284
285 pub fn split_dataset(
287 triples: Vec<(String, String, String)>,
288 train_ratio: f64,
289 val_ratio: f64,
290 seed: Option<u64>,
291 ) -> Result<DatasetSplit> {
292 if train_ratio + val_ratio >= 1.0 {
293 return Err(anyhow!(
294 "Train and validation ratios must sum to less than 1.0"
295 ));
296 }
297
298 let mut rng = if let Some(s) = seed {
299 Random::seed(s)
300 } else {
301 Random::seed(42)
302 };
303
304 let mut shuffled_triples = triples;
305 for i in (1..shuffled_triples.len()).rev() {
306 let j = rng.random_range(0..i + 1);
307 shuffled_triples.swap(i, j);
308 }
309
310 let total = shuffled_triples.len();
311 let train_end = (total as f64 * train_ratio) as usize;
312 let val_end = train_end + (total as f64 * val_ratio) as usize;
313
314 let train_triples = shuffled_triples[..train_end].to_vec();
315 let val_triples = shuffled_triples[train_end..val_end].to_vec();
316 let test_triples = shuffled_triples[val_end..].to_vec();
317
318 Ok(DatasetSplit {
319 train: train_triples,
320 validation: val_triples,
321 test: test_triples,
322 })
323 }
324
325 pub fn split_dataset_no_leakage(
327 triples: Vec<(String, String, String)>,
328 train_ratio: f64,
329 val_ratio: f64,
330 seed: Option<u64>,
331 ) -> Result<DatasetSplit> {
332 let mut entity_triples: HashMap<String, Vec<(String, String, String)>> =
333 HashMap::with_capacity(triples.len() / 2);
334
335 for triple in &triples {
336 let entities = [&triple.0, &triple.2];
337 for entity in entities {
338 entity_triples
339 .entry(entity.clone())
340 .or_default()
341 .push(triple.clone());
342 }
343 }
344
345 let entities: Vec<String> = entity_triples.keys().cloned().collect();
346 let dummy_string = "dummy".to_string();
347 let entity_split = split_dataset(
348 entities
349 .into_iter()
350 .map(|e| (e, dummy_string.clone(), dummy_string.clone()))
351 .collect(),
352 train_ratio,
353 val_ratio,
354 seed,
355 )?;
356
357 let train_entities: HashSet<String> =
358 entity_split.train.into_iter().map(|(e, _, _)| e).collect();
359 let val_entities: HashSet<String> = entity_split
360 .validation
361 .into_iter()
362 .map(|(e, _, _)| e)
363 .collect();
364 let test_entities: HashSet<String> =
365 entity_split.test.into_iter().map(|(e, _, _)| e).collect();
366
367 let estimated_capacity = triples.len() / 3;
368 let mut train_triples = Vec::with_capacity(estimated_capacity);
369 let mut val_triples = Vec::with_capacity(estimated_capacity);
370 let mut test_triples = Vec::with_capacity(estimated_capacity);
371
372 for (entity, entity_triple_list) in entity_triples {
373 if train_entities.contains(&entity) {
374 train_triples.extend(entity_triple_list);
375 } else if val_entities.contains(&entity) {
376 val_triples.extend(entity_triple_list);
377 } else if test_entities.contains(&entity) {
378 test_triples.extend(entity_triple_list);
379 }
380 }
381
382 train_triples.sort();
383 train_triples.dedup();
384 val_triples.sort();
385 val_triples.dedup();
386 test_triples.sort();
387 test_triples.dedup();
388
389 Ok(DatasetSplit {
390 train: train_triples,
391 validation: val_triples,
392 test: test_triples,
393 })
394 }
395}
396
397pub mod parallel_utils {
399 use anyhow::Result;
400 use rayon::prelude::*;
401 use std::collections::HashMap;
402
403 pub fn parallel_cosine_similarities(
405 query_embedding: &[f32],
406 candidate_embeddings: &[Vec<f32>],
407 ) -> Result<Vec<f32>> {
408 let similarities: Vec<f32> = candidate_embeddings
409 .par_iter()
410 .map(|embedding| oxirs_vec::similarity::cosine_similarity(query_embedding, embedding))
411 .collect();
412 Ok(similarities)
413 }
414
415 pub fn parallel_batch_process<T, R, F>(
417 items: &[T],
418 batch_size: usize,
419 processor: F,
420 ) -> Result<Vec<R>>
421 where
422 T: Sync,
423 R: Send,
424 F: Fn(&[T]) -> Result<Vec<R>> + Sync + Send,
425 {
426 let results: Result<Vec<Vec<R>>> = items.par_chunks(batch_size).map(processor).collect();
427 Ok(results?.into_iter().flatten().collect())
428 }
429
430 pub fn parallel_entity_frequencies(
432 triples: &[(String, String, String)],
433 ) -> HashMap<String, usize> {
434 let entity_counts: HashMap<String, usize> = triples
435 .par_iter()
436 .fold(HashMap::new, |mut acc, (subject, _predicate, object)| {
437 *acc.entry(subject.clone()).or_insert(0) += 1;
438 *acc.entry(object.clone()).or_insert(0) += 1;
439 acc
440 })
441 .reduce(HashMap::new, |mut acc1, acc2| {
442 for (entity, count) in acc2 {
443 *acc1.entry(entity).or_insert(0) += count;
444 }
445 acc1
446 });
447 entity_counts
448 }
449}