1use crate::error::{DatasetsError, Result};
39use std::collections::HashMap;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42
43const ARROW_MAGIC: &[u8; 6] = b"ARROW1";
45
46#[derive(Debug, Clone)]
52pub enum FeatureType {
53 Value {
55 dtype: String,
57 },
58 Sequence {
60 feature: Box<FeatureType>,
62 },
63 ClassLabel {
65 names: Vec<String>,
67 },
68 Text,
70 Image,
72 Unknown,
74}
75
76#[derive(Debug, Clone)]
78pub struct DatasetInfo {
79 pub dataset_name: String,
81 pub version: String,
83 pub features: HashMap<String, FeatureType>,
85 pub num_rows: Option<usize>,
87 pub split: Option<String>,
89}
90
91impl Default for DatasetInfo {
92 fn default() -> Self {
93 Self {
94 dataset_name: String::new(),
95 version: "0.0.0".to_string(),
96 features: HashMap::new(),
97 num_rows: None,
98 split: None,
99 }
100 }
101}
102
103#[derive(Debug)]
109pub struct ArrowDataset {
110 pub info: Option<DatasetInfo>,
112 pub column_names: Vec<String>,
114 pub num_rows: usize,
116 pub(crate) file_paths: Vec<PathBuf>,
118 columns: HashMap<String, Vec<u8>>,
120}
121
122impl ArrowDataset {
123 pub fn from_directory(dir: impl AsRef<Path>) -> Result<Self> {
136 let dir = dir.as_ref();
137
138 if !dir.exists() {
139 return Err(DatasetsError::NotFound(format!(
140 "Directory not found: {}",
141 dir.display()
142 )));
143 }
144
145 let mut arrow_files: Vec<PathBuf> = Vec::new();
147 for entry in std::fs::read_dir(dir).map_err(DatasetsError::IoError)? {
148 let entry = entry.map_err(DatasetsError::IoError)?;
149 let path = entry.path();
150 if path.is_file() {
151 if path.extension().and_then(|e| e.to_str()) == Some("arrow") {
152 arrow_files.push(path);
153 }
154 } else if path.is_dir() {
155 for sub in std::fs::read_dir(&path).map_err(DatasetsError::IoError)? {
157 let sub = sub.map_err(DatasetsError::IoError)?;
158 let sub_path = sub.path();
159 if sub_path.is_file()
160 && sub_path.extension().and_then(|e| e.to_str()) == Some("arrow")
161 {
162 arrow_files.push(sub_path);
163 }
164 }
165 }
166 }
167
168 if arrow_files.is_empty() {
169 return Err(DatasetsError::NotFound(format!(
170 "No .arrow files found under: {}",
171 dir.display()
172 )));
173 }
174
175 arrow_files.sort();
177
178 let info = Self::try_load_dataset_info(dir).or_else(|_| {
180 if let Some(parent) = arrow_files
182 .first()
183 .and_then(|p| p.parent())
184 .and_then(|p| p.parent())
185 {
186 Self::try_load_dataset_info(parent).ok()
187 } else {
188 None
189 }
190 .ok_or(DatasetsError::NotFound("no dataset_info.json".to_string()))
191 });
192
193 for path in &arrow_files {
195 Self::validate_arrow_magic(path)?;
196 }
197
198 Ok(Self {
199 info: info.ok(),
200 column_names: Vec::new(),
201 num_rows: 0,
202 file_paths: arrow_files,
203 columns: HashMap::new(),
204 })
205 }
206
207 pub fn from_arrow_file(path: impl AsRef<Path>) -> Result<Self> {
216 let path = path.as_ref();
217
218 if !path.exists() {
219 return Err(DatasetsError::NotFound(format!(
220 "Arrow file not found: {}",
221 path.display()
222 )));
223 }
224
225 Self::validate_arrow_magic(path)?;
227
228 #[cfg(feature = "parquet_io")]
229 {
230 Self::from_arrow_file_full(path)
231 }
232
233 #[cfg(not(feature = "parquet_io"))]
234 {
235 Ok(Self {
236 info: None,
237 column_names: Vec::new(),
238 num_rows: 0,
239 file_paths: vec![path.to_path_buf()],
240 columns: HashMap::new(),
241 })
242 }
243 }
244
245 pub fn column_names(&self) -> &[String] {
251 &self.column_names
252 }
253
254 pub fn num_rows(&self) -> usize {
256 self.num_rows
257 }
258
259 pub fn info(&self) -> Option<&DatasetInfo> {
261 self.info.as_ref()
262 }
263
264 pub fn file_paths(&self) -> &[PathBuf] {
266 &self.file_paths
267 }
268
269 pub fn validate_arrow_magic(path: impl AsRef<Path>) -> Result<bool> {
278 let path = path.as_ref();
279 let mut f = std::fs::File::open(path).map_err(DatasetsError::IoError)?;
280 let mut buf = [0u8; 6];
281 f.read_exact(&mut buf).map_err(|e| {
282 DatasetsError::InvalidFormat(format!(
283 "Could not read magic bytes from {}: {}",
284 path.display(),
285 e
286 ))
287 })?;
288 if &buf == ARROW_MAGIC {
289 Ok(true)
290 } else {
291 Err(DatasetsError::InvalidFormat(format!(
292 "Not an Arrow IPC file (bad magic bytes): {}",
293 path.display()
294 )))
295 }
296 }
297
298 fn try_load_dataset_info(dir: &Path) -> Result<DatasetInfo> {
304 let info_path = dir.join("dataset_info.json");
305 if !info_path.exists() {
306 return Err(DatasetsError::NotFound(
307 "dataset_info.json not found".to_string(),
308 ));
309 }
310
311 let content = std::fs::read_to_string(&info_path).map_err(DatasetsError::IoError)?;
312
313 Self::parse_dataset_info_json(&content)
314 }
315
316 fn parse_dataset_info_json(json: &str) -> Result<DatasetInfo> {
320 let value: serde_json::Value =
321 serde_json::from_str(json).map_err(|e| DatasetsError::SerdeError(e.to_string()))?;
322
323 let dataset_name = value
324 .get("dataset_name")
325 .and_then(|v| v.as_str())
326 .unwrap_or("")
327 .to_string();
328
329 let version = value
330 .get("version")
331 .and_then(|v| v.as_str())
332 .unwrap_or("0.0.0")
333 .to_string();
334
335 let split = value
336 .get("split")
337 .and_then(|v| v.as_str())
338 .map(|s| s.to_string());
339
340 let num_rows = value
341 .get("num_rows")
342 .or_else(|| value.get("num_examples"))
343 .and_then(|v| v.as_u64())
344 .map(|n| n as usize);
345
346 let features = if let Some(feat_map) = value.get("features").and_then(|v| v.as_object()) {
348 feat_map
349 .iter()
350 .map(|(k, v)| (k.clone(), Self::parse_feature_type(v)))
351 .collect()
352 } else {
353 HashMap::new()
354 };
355
356 Ok(DatasetInfo {
357 dataset_name,
358 version,
359 features,
360 num_rows,
361 split,
362 })
363 }
364
365 fn parse_feature_type(v: &serde_json::Value) -> FeatureType {
367 if let Some(s) = v.as_str() {
369 return match s {
370 "text" | "string" => FeatureType::Text,
371 "image" => FeatureType::Image,
372 other => FeatureType::Value {
373 dtype: other.to_string(),
374 },
375 };
376 }
377
378 if let Some(obj) = v.as_object() {
379 if let Some(names_val) = obj.get("names") {
381 if let Some(names_arr) = names_val.as_array() {
382 let names: Vec<String> = names_arr
383 .iter()
384 .filter_map(|n| n.as_str().map(|s| s.to_string()))
385 .collect();
386 return FeatureType::ClassLabel { names };
387 }
388 }
389
390 if let Some(inner) = obj.get("feature") {
392 return FeatureType::Sequence {
393 feature: Box::new(Self::parse_feature_type(inner)),
394 };
395 }
396
397 if let Some(dtype) = obj.get("dtype").and_then(|d| d.as_str()) {
399 return FeatureType::Value {
400 dtype: dtype.to_string(),
401 };
402 }
403
404 if obj.get("_type").and_then(|t| t.as_str()) == Some("Value") {
406 let dtype = obj
407 .get("dtype")
408 .and_then(|d| d.as_str())
409 .unwrap_or("unknown")
410 .to_string();
411 return FeatureType::Value { dtype };
412 }
413
414 if obj.get("_type").and_then(|t| t.as_str()) == Some("ClassLabel") {
415 if let Some(names_arr) = obj.get("names").and_then(|n| n.as_array()) {
416 let names: Vec<String> = names_arr
417 .iter()
418 .filter_map(|n| n.as_str().map(|s| s.to_string()))
419 .collect();
420 return FeatureType::ClassLabel { names };
421 }
422 }
423
424 if obj.get("_type").and_then(|t| t.as_str()) == Some("Image") {
425 return FeatureType::Image;
426 }
427
428 if obj.get("_type").and_then(|t| t.as_str()) == Some("Sequence") {
429 if let Some(inner) = obj.get("feature") {
430 return FeatureType::Sequence {
431 feature: Box::new(Self::parse_feature_type(inner)),
432 };
433 }
434 }
435 }
436
437 FeatureType::Unknown
438 }
439
440 #[cfg(feature = "parquet_io")]
445 fn from_arrow_file_full(path: &Path) -> Result<Self> {
446 use arrow::ipc::reader::FileReader;
447 use std::fs::File;
448
449 let file = File::open(path).map_err(DatasetsError::IoError)?;
450 let reader = FileReader::try_new(file, None)
451 .map_err(|e| DatasetsError::InvalidFormat(format!("Arrow IPC read error: {}", e)))?;
452
453 let schema = reader.schema();
454 let column_names: Vec<String> = schema.fields().iter().map(|f| f.name().clone()).collect();
455
456 let mut total_rows = 0usize;
457 let mut columns: HashMap<String, Vec<u8>> = HashMap::new();
458
459 for batch_result in reader {
460 let batch = batch_result.map_err(|e| {
461 DatasetsError::InvalidFormat(format!("Arrow batch read error: {}", e))
462 })?;
463 total_rows += batch.num_rows();
464
465 for (i, field) in schema.fields().iter().enumerate() {
467 let col = batch.column(i);
468 let buffers = col.to_data().buffers().to_vec();
469 let entry = columns.entry(field.name().clone()).or_default();
470 for buf in buffers {
471 entry.extend_from_slice(buf.as_slice());
472 }
473 }
474 }
475
476 Ok(Self {
477 info: None,
478 column_names,
479 num_rows: total_rows,
480 file_paths: vec![path.to_path_buf()],
481 columns,
482 })
483 }
484}
485
486#[cfg(test)]
491mod tests {
492 use super::*;
493 use std::io::Write;
494
495 fn temp_arrow_file(valid: bool) -> std::path::PathBuf {
497 let dir = std::env::temp_dir();
498 let file_name = if valid {
499 "test_valid_arrow.arrow"
500 } else {
501 "test_invalid_arrow.arrow"
502 };
503 let path = dir.join(file_name);
504 let mut f = std::fs::File::create(&path).expect("create temp file");
505 if valid {
506 f.write_all(b"ARROW1\x00\x00some_padding_bytes_for_test")
508 .expect("write magic");
509 } else {
510 f.write_all(b"NOTARROW_FILE_CONTENT")
512 .expect("write wrong magic");
513 }
514 path
515 }
516
517 #[test]
518 fn arrow_dataset_validates_magic_bytes() {
519 let path = temp_arrow_file(true);
520 let result = ArrowDataset::validate_arrow_magic(&path);
521 assert!(result.is_ok(), "valid Arrow magic should succeed");
522 assert!(
523 result.expect("valid arrow result"),
524 "validate_arrow_magic should return true for valid magic"
525 );
526 }
527
528 #[test]
529 fn arrow_dataset_rejects_wrong_magic() {
530 let path = temp_arrow_file(false);
531 let result = ArrowDataset::validate_arrow_magic(&path);
532 assert!(result.is_err(), "wrong magic should return an error");
533 if let Err(DatasetsError::InvalidFormat(msg)) = result {
534 assert!(
535 msg.contains("magic bytes"),
536 "error should mention magic bytes, got: {}",
537 msg
538 );
539 } else {
540 panic!("expected InvalidFormat error");
541 }
542 }
543
544 #[test]
545 #[cfg(not(feature = "parquet_io"))]
546 fn arrow_dataset_from_arrow_file_valid() {
547 let path = temp_arrow_file(true);
548 let result = ArrowDataset::from_arrow_file(&path);
550 assert!(
551 result.is_ok(),
552 "from_arrow_file with valid magic should succeed"
553 );
554 let ds = result.expect("valid arrow dataset");
555 assert_eq!(ds.file_paths().len(), 1);
556 }
557
558 #[test]
561 #[cfg(feature = "parquet_io")]
562 fn arrow_dataset_from_arrow_file_valid_parquet_io() {
563 let path = temp_arrow_file(true);
567 let result = ArrowDataset::from_arrow_file(&path);
570 match result {
572 Ok(_) => {} Err(DatasetsError::InvalidFormat(_)) => {} Err(other) => panic!("unexpected error variant: {:?}", other),
575 }
576 }
577
578 #[test]
579 fn arrow_dataset_from_arrow_file_invalid() {
580 let path = temp_arrow_file(false);
581 let result = ArrowDataset::from_arrow_file(&path);
582 assert!(
583 result.is_err(),
584 "from_arrow_file with bad magic should fail"
585 );
586 }
587
588 #[test]
589 fn arrow_dataset_from_directory_empty_dir() {
590 let dir = std::env::temp_dir().join("test_empty_arrow_dir");
591 std::fs::create_dir_all(&dir).expect("create temp dir");
592 for entry in std::fs::read_dir(&dir).expect("read dir") {
594 let entry = entry.expect("entry");
595 if entry.path().extension().and_then(|e| e.to_str()) == Some("arrow") {
596 std::fs::remove_file(entry.path()).ok();
597 }
598 }
599 let result = ArrowDataset::from_directory(&dir);
600 assert!(result.is_err(), "empty dir should return NotFound");
601 if let Err(DatasetsError::NotFound(_)) = result {
602 } else {
604 panic!("expected NotFound error for empty directory");
605 }
606 }
607
608 #[test]
609 fn arrow_dataset_from_directory_with_arrow_file() {
610 let dir = std::env::temp_dir().join("test_arrow_dir_with_file");
611 std::fs::create_dir_all(&dir).expect("create temp dir");
612 let arrow_path = dir.join("data-00000-of-00001.arrow");
613 {
614 let mut f = std::fs::File::create(&arrow_path).expect("create arrow");
615 f.write_all(b"ARROW1\x00\x00dummy_ipc_content_for_test")
616 .expect("write arrow");
617 }
618 let result = ArrowDataset::from_directory(&dir);
619 assert!(
620 result.is_ok(),
621 "directory with valid arrow file should succeed"
622 );
623 let ds = result.expect("arrow dataset from dir");
624 assert_eq!(ds.file_paths().len(), 1);
625 }
626
627 #[test]
628 fn dataset_info_default() {
629 let info = DatasetInfo::default();
630 assert!(info.dataset_name.is_empty());
631 assert_eq!(info.version, "0.0.0");
632 assert!(info.features.is_empty());
633 assert!(info.num_rows.is_none());
634 assert!(info.split.is_none());
635 }
636
637 #[test]
638 fn dataset_info_parse_json() {
639 let json = r#"{
640 "dataset_name": "my_dataset",
641 "version": "1.0.0",
642 "split": "train",
643 "num_rows": 42,
644 "features": {
645 "text": "text",
646 "label": {"names": ["neg", "pos"]},
647 "score": {"dtype": "float32"}
648 }
649 }"#;
650 let info = ArrowDataset::parse_dataset_info_json(json).expect("parse dataset_info.json");
651 assert_eq!(info.dataset_name, "my_dataset");
652 assert_eq!(info.version, "1.0.0");
653 assert_eq!(info.split.as_deref(), Some("train"));
654 assert_eq!(info.num_rows, Some(42));
655 assert_eq!(info.features.len(), 3);
656 if let FeatureType::ClassLabel { names } = &info.features["label"] {
657 assert_eq!(names, &["neg", "pos"]);
658 } else {
659 panic!("expected ClassLabel for 'label' feature");
660 }
661 }
662
663 #[test]
664 fn arrow_dataset_nonexistent_file() {
665 let result = ArrowDataset::from_arrow_file("/nonexistent/path/data.arrow");
666 assert!(result.is_err());
667 if let Err(DatasetsError::NotFound(_)) = result {
668 } else {
670 panic!("expected NotFound for nonexistent file");
671 }
672 }
673
674 #[test]
675 fn arrow_dataset_nonexistent_directory() {
676 let result = ArrowDataset::from_directory("/nonexistent/arrow_dataset_dir_xyz");
677 assert!(result.is_err());
678 if let Err(DatasetsError::NotFound(_)) = result {
679 } else {
681 panic!("expected NotFound for nonexistent directory");
682 }
683 }
684}