1#![allow(unsafe_code)]
33
34use std::{
35 fs::File,
36 path::{Path, PathBuf},
37 sync::Arc,
38};
39
40use arrow::{array::RecordBatch, datatypes::SchemaRef};
41use memmap2::Mmap;
42use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
43
44use crate::{
45 dataset::Dataset,
46 error::{Error, Result},
47};
48
49#[derive(Debug)]
68pub struct MmapDataset {
69 #[allow(dead_code)]
71 mmap: Mmap,
72 path: PathBuf,
74 schema: SchemaRef,
76 batches: Vec<RecordBatch>,
78 row_count: usize,
80}
81
82impl MmapDataset {
83 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
105 let path = path.as_ref();
106 let file = File::open(path).map_err(|e| Error::io(e, path))?;
107
108 let mmap = unsafe { Mmap::map(&file) }.map_err(|e| Error::io(e, path))?;
112
113 let bytes = bytes::Bytes::copy_from_slice(&mmap[..]);
115 let builder = ParquetRecordBatchReaderBuilder::try_new(bytes).map_err(Error::Parquet)?;
116
117 let schema = builder.schema().clone();
118 let reader = builder.build().map_err(Error::Parquet)?;
119
120 let batches: Vec<RecordBatch> = reader
121 .collect::<std::result::Result<Vec<_>, _>>()
122 .map_err(Error::Arrow)?;
123
124 if batches.is_empty() {
125 return Err(Error::EmptyDataset);
126 }
127
128 let row_count = batches.iter().map(|b| b.num_rows()).sum();
129
130 Ok(Self {
131 mmap,
132 path: path.to_path_buf(),
133 schema,
134 batches,
135 row_count,
136 })
137 }
138
139 pub fn open_with_batch_size(path: impl AsRef<Path>, batch_size: usize) -> Result<Self> {
150 let path = path.as_ref();
151 let file = File::open(path).map_err(|e| Error::io(e, path))?;
152
153 let mmap = unsafe { Mmap::map(&file) }.map_err(|e| Error::io(e, path))?;
155
156 let bytes = bytes::Bytes::copy_from_slice(&mmap[..]);
157 let builder = ParquetRecordBatchReaderBuilder::try_new(bytes).map_err(Error::Parquet)?;
158
159 let schema = builder.schema().clone();
160 let reader = builder
161 .with_batch_size(batch_size)
162 .build()
163 .map_err(Error::Parquet)?;
164
165 let batches: Vec<RecordBatch> = reader
166 .collect::<std::result::Result<Vec<_>, _>>()
167 .map_err(Error::Arrow)?;
168
169 if batches.is_empty() {
170 return Err(Error::EmptyDataset);
171 }
172
173 let row_count = batches.iter().map(|b| b.num_rows()).sum();
174
175 Ok(Self {
176 mmap,
177 path: path.to_path_buf(),
178 schema,
179 batches,
180 row_count,
181 })
182 }
183
184 pub fn path(&self) -> &Path {
186 &self.path
187 }
188
189 pub fn mmap_size(&self) -> usize {
191 self.mmap.len()
192 }
193
194 pub fn to_arrow_dataset(&self) -> Result<crate::ArrowDataset> {
203 crate::ArrowDataset::new(self.batches.clone())
204 }
205
206 fn find_row(&self, global_index: usize) -> Option<(usize, usize)> {
208 if global_index >= self.row_count {
209 return None;
210 }
211
212 let mut remaining = global_index;
213 for (batch_idx, batch) in self.batches.iter().enumerate() {
214 let batch_rows = batch.num_rows();
215 if remaining < batch_rows {
216 return Some((batch_idx, remaining));
217 }
218 remaining -= batch_rows;
219 }
220
221 None
222 }
223}
224
225impl MmapDataset {
226 pub fn try_clone(&self) -> crate::Result<Self> {
230 Self::open(&self.path)
231 }
232}
233
234impl Dataset for MmapDataset {
235 fn len(&self) -> usize {
236 self.row_count
237 }
238
239 fn get(&self, index: usize) -> Option<RecordBatch> {
240 let (batch_idx, local_idx) = self.find_row(index)?;
241 let batch = &self.batches[batch_idx];
242 Some(batch.slice(local_idx, 1))
243 }
244
245 fn schema(&self) -> SchemaRef {
246 Arc::clone(&self.schema)
247 }
248
249 fn iter(&self) -> Box<dyn Iterator<Item = RecordBatch> + Send + '_> {
250 Box::new(self.batches.iter().cloned())
251 }
252
253 fn num_batches(&self) -> usize {
254 self.batches.len()
255 }
256
257 fn get_batch(&self, index: usize) -> Option<&RecordBatch> {
258 self.batches.get(index)
259 }
260}
261
262#[derive(Debug, Default)]
264pub struct MmapDatasetBuilder {
265 batch_size: Option<usize>,
266 columns: Option<Vec<String>>,
267}
268
269impl MmapDatasetBuilder {
270 pub fn new() -> Self {
272 Self::default()
273 }
274
275 #[must_use]
277 pub fn batch_size(mut self, size: usize) -> Self {
278 self.batch_size = Some(size);
279 self
280 }
281
282 #[must_use]
284 pub fn columns(mut self, cols: Vec<String>) -> Self {
285 self.columns = Some(cols);
286 self
287 }
288
289 pub fn open(self, path: impl AsRef<Path>) -> Result<MmapDataset> {
295 let path = path.as_ref();
296 let file = File::open(path).map_err(|e| Error::io(e, path))?;
297
298 let mmap = unsafe { Mmap::map(&file) }.map_err(|e| Error::io(e, path))?;
300
301 let bytes = bytes::Bytes::copy_from_slice(&mmap[..]);
302 let mut builder =
303 ParquetRecordBatchReaderBuilder::try_new(bytes).map_err(Error::Parquet)?;
304
305 if let Some(batch_size) = self.batch_size {
306 builder = builder.with_batch_size(batch_size);
307 }
308
309 if let Some(ref cols) = self.columns {
311 let indices: Vec<usize> = {
313 let parquet_schema = builder.parquet_schema();
314 cols.iter()
315 .filter_map(|name| {
316 parquet_schema
317 .columns()
318 .iter()
319 .position(|col| col.name() == name)
320 })
321 .collect()
322 };
323
324 if !indices.is_empty() {
325 let mask = parquet::arrow::ProjectionMask::roots(builder.parquet_schema(), indices);
326 builder = builder.with_projection(mask);
327 }
328 }
329
330 let schema = builder.schema().clone();
331 let reader = builder.build().map_err(Error::Parquet)?;
332
333 let batches: Vec<RecordBatch> = reader
334 .collect::<std::result::Result<Vec<_>, _>>()
335 .map_err(Error::Arrow)?;
336
337 if batches.is_empty() {
338 return Err(Error::EmptyDataset);
339 }
340
341 let row_count = batches.iter().map(|b| b.num_rows()).sum();
342
343 Ok(MmapDataset {
344 mmap,
345 path: path.to_path_buf(),
346 schema,
347 batches,
348 row_count,
349 })
350 }
351}
352
353#[cfg(test)]
354#[allow(
355 clippy::cast_possible_truncation,
356 clippy::cast_possible_wrap,
357 clippy::uninlined_format_args,
358 clippy::unwrap_used,
359 clippy::expect_used
360)]
361mod tests {
362 use std::sync::Arc;
363
364 use arrow::{
365 array::{Float64Array, Int32Array, StringArray},
366 datatypes::{DataType, Field, Schema},
367 };
368 use parquet::{arrow::ArrowWriter, file::properties::WriterProperties};
369
370 use super::*;
371
372 fn create_test_parquet(path: &Path, rows: usize) {
373 let schema = Arc::new(Schema::new(vec![
374 Field::new("id", DataType::Int32, false),
375 Field::new("value", DataType::Float64, false),
376 Field::new("name", DataType::Utf8, false),
377 ]));
378
379 let ids: Vec<i32> = (0..rows as i32).collect();
380 let values: Vec<f64> = ids.iter().map(|i| *i as f64 * 1.5).collect();
381 let names: Vec<String> = ids.iter().map(|i| format!("item_{}", i)).collect();
382
383 let batch = RecordBatch::try_new(
384 schema.clone(),
385 vec![
386 Arc::new(Int32Array::from(ids)),
387 Arc::new(Float64Array::from(values)),
388 Arc::new(StringArray::from(names)),
389 ],
390 )
391 .unwrap();
392
393 let file = File::create(path).unwrap();
394 let props = WriterProperties::builder().build();
395 let mut writer = ArrowWriter::try_new(file, schema, Some(props)).unwrap();
396 writer.write(&batch).unwrap();
397 writer.close().unwrap();
398 }
399
400 #[test]
401 fn test_mmap_dataset_open() {
402 let temp_dir = tempfile::tempdir().unwrap();
403 let path = temp_dir.path().join("test.parquet");
404 create_test_parquet(&path, 100);
405
406 let dataset = MmapDataset::open(&path).unwrap();
407 assert_eq!(dataset.len(), 100);
408 assert!(!dataset.is_empty());
409 }
410
411 #[test]
412 fn test_mmap_dataset_schema() {
413 let temp_dir = tempfile::tempdir().unwrap();
414 let path = temp_dir.path().join("test.parquet");
415 create_test_parquet(&path, 50);
416
417 let dataset = MmapDataset::open(&path).unwrap();
418 let schema = dataset.schema();
419
420 assert_eq!(schema.fields().len(), 3);
421 assert_eq!(schema.field(0).name(), "id");
422 assert_eq!(schema.field(1).name(), "value");
423 assert_eq!(schema.field(2).name(), "name");
424 }
425
426 #[test]
427 fn test_mmap_dataset_get_row() {
428 let temp_dir = tempfile::tempdir().unwrap();
429 let path = temp_dir.path().join("test.parquet");
430 create_test_parquet(&path, 100);
431
432 let dataset = MmapDataset::open(&path).unwrap();
433
434 let row = dataset.get(0).unwrap();
436 assert_eq!(row.num_rows(), 1);
437 let ids = row.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
438 assert_eq!(ids.value(0), 0);
439
440 let row = dataset.get(50).unwrap();
442 let ids = row.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
443 assert_eq!(ids.value(0), 50);
444
445 let row = dataset.get(99).unwrap();
447 let ids = row.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
448 assert_eq!(ids.value(0), 99);
449
450 assert!(dataset.get(100).is_none());
452 assert!(dataset.get(1000).is_none());
453 }
454
455 #[test]
456 fn test_mmap_dataset_iter() {
457 let temp_dir = tempfile::tempdir().unwrap();
458 let path = temp_dir.path().join("test.parquet");
459 create_test_parquet(&path, 100);
460
461 let dataset = MmapDataset::open(&path).unwrap();
462
463 let total_rows: usize = dataset.iter().map(|b| b.num_rows()).sum();
464 assert_eq!(total_rows, 100);
465 }
466
467 #[test]
468 fn test_mmap_dataset_num_batches() {
469 let temp_dir = tempfile::tempdir().unwrap();
470 let path = temp_dir.path().join("test.parquet");
471 create_test_parquet(&path, 100);
472
473 let dataset = MmapDataset::open(&path).unwrap();
474 assert!(dataset.num_batches() >= 1);
475 }
476
477 #[test]
478 fn test_mmap_dataset_get_batch() {
479 let temp_dir = tempfile::tempdir().unwrap();
480 let path = temp_dir.path().join("test.parquet");
481 create_test_parquet(&path, 100);
482
483 let dataset = MmapDataset::open(&path).unwrap();
484
485 let batch = dataset.get_batch(0);
486 assert!(batch.is_some());
487
488 let out_of_bounds = dataset.get_batch(1000);
489 assert!(out_of_bounds.is_none());
490 }
491
492 #[test]
493 fn test_mmap_dataset_path() {
494 let temp_dir = tempfile::tempdir().unwrap();
495 let path = temp_dir.path().join("test.parquet");
496 create_test_parquet(&path, 100);
497
498 let dataset = MmapDataset::open(&path).unwrap();
499 assert_eq!(dataset.path(), path);
500 }
501
502 #[test]
503 fn test_mmap_dataset_mmap_size() {
504 let temp_dir = tempfile::tempdir().unwrap();
505 let path = temp_dir.path().join("test.parquet");
506 create_test_parquet(&path, 100);
507
508 let dataset = MmapDataset::open(&path).unwrap();
509 assert!(dataset.mmap_size() > 0);
510 }
511
512 #[test]
513 fn test_mmap_dataset_to_arrow() {
514 let temp_dir = tempfile::tempdir().unwrap();
515 let path = temp_dir.path().join("test.parquet");
516 create_test_parquet(&path, 100);
517
518 let mmap_dataset = MmapDataset::open(&path).unwrap();
519 let arrow_dataset = mmap_dataset.to_arrow_dataset().unwrap();
520
521 assert_eq!(arrow_dataset.len(), mmap_dataset.len());
522 assert_eq!(arrow_dataset.schema(), mmap_dataset.schema());
523 }
524
525 #[test]
526 fn test_mmap_dataset_with_batch_size() {
527 let temp_dir = tempfile::tempdir().unwrap();
528 let path = temp_dir.path().join("test.parquet");
529 create_test_parquet(&path, 100);
530
531 let dataset = MmapDataset::open_with_batch_size(&path, 10).unwrap();
532 assert_eq!(dataset.len(), 100);
533 }
534
535 #[test]
536 fn test_mmap_dataset_clone() {
537 let temp_dir = tempfile::tempdir().unwrap();
538 let path = temp_dir.path().join("test.parquet");
539 create_test_parquet(&path, 50);
540
541 let dataset = MmapDataset::open(&path).unwrap();
542 let cloned = dataset.try_clone().unwrap();
543
544 assert_eq!(cloned.len(), dataset.len());
545 assert_eq!(cloned.schema(), dataset.schema());
546 assert_eq!(cloned.path(), dataset.path());
547 }
548
549 #[test]
550 fn test_mmap_dataset_debug() {
551 let temp_dir = tempfile::tempdir().unwrap();
552 let path = temp_dir.path().join("test.parquet");
553 create_test_parquet(&path, 50);
554
555 let dataset = MmapDataset::open(&path).unwrap();
556 let debug_str = format!("{:?}", dataset);
557 assert!(debug_str.contains("MmapDataset"));
558 }
559
560 #[test]
561 fn test_mmap_dataset_open_nonexistent() {
562 let result = MmapDataset::open("/nonexistent/path/to/file.parquet");
563 assert!(result.is_err());
564 }
565
566 #[test]
567 fn test_mmap_dataset_open_invalid_file() {
568 let temp_dir = tempfile::tempdir().unwrap();
569 let path = temp_dir.path().join("not_parquet.txt");
570 std::fs::write(&path, "this is not parquet data").unwrap();
571
572 let result = MmapDataset::open(&path);
573 assert!(result.is_err());
574 }
575
576 #[test]
577 fn test_mmap_builder_basic() {
578 let temp_dir = tempfile::tempdir().unwrap();
579 let path = temp_dir.path().join("test.parquet");
580 create_test_parquet(&path, 100);
581
582 let dataset = MmapDatasetBuilder::new().open(&path).unwrap();
583
584 assert_eq!(dataset.len(), 100);
585 }
586
587 #[test]
588 fn test_mmap_builder_with_batch_size() {
589 let temp_dir = tempfile::tempdir().unwrap();
590 let path = temp_dir.path().join("test.parquet");
591 create_test_parquet(&path, 100);
592
593 let dataset = MmapDatasetBuilder::new()
594 .batch_size(10)
595 .open(&path)
596 .unwrap();
597
598 assert_eq!(dataset.len(), 100);
599 }
600
601 #[test]
602 fn test_mmap_builder_with_columns() {
603 let temp_dir = tempfile::tempdir().unwrap();
604 let path = temp_dir.path().join("test.parquet");
605 create_test_parquet(&path, 100);
606
607 let dataset = MmapDatasetBuilder::new()
608 .columns(vec!["id".to_string(), "name".to_string()])
609 .open(&path)
610 .unwrap();
611
612 assert_eq!(dataset.len(), 100);
613 let schema = dataset.schema();
616 assert!(schema.field_with_name("id").is_ok());
618 assert!(schema.field_with_name("name").is_ok());
619 }
620
621 #[test]
622 fn test_mmap_builder_debug() {
623 let builder = MmapDatasetBuilder::new().batch_size(100);
624 let debug_str = format!("{:?}", builder);
625 assert!(debug_str.contains("MmapDatasetBuilder"));
626 }
627
628 #[test]
629 fn test_mmap_builder_default() {
630 let builder = MmapDatasetBuilder::default();
631 assert!(builder.batch_size.is_none());
632 assert!(builder.columns.is_none());
633 }
634
635 #[test]
636 fn test_mmap_dataset_large_file() {
637 let temp_dir = tempfile::tempdir().unwrap();
638 let path = temp_dir.path().join("large.parquet");
639 create_test_parquet(&path, 10000);
640
641 let dataset = MmapDataset::open(&path).unwrap();
642 assert_eq!(dataset.len(), 10000);
643
644 assert!(dataset.get(0).is_some());
646 assert!(dataset.get(5000).is_some());
647 assert!(dataset.get(9999).is_some());
648 }
649
650 #[test]
651 fn test_mmap_dataset_with_dataloader() {
652 use crate::DataLoader;
653
654 let temp_dir = tempfile::tempdir().unwrap();
655 let path = temp_dir.path().join("test.parquet");
656 create_test_parquet(&path, 100);
657
658 let dataset = MmapDataset::open(&path).unwrap();
659 let loader = DataLoader::new(dataset).batch_size(10);
660
661 let batches: Vec<RecordBatch> = loader.into_iter().collect();
662 assert_eq!(batches.len(), 10);
663
664 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
665 assert_eq!(total_rows, 100);
666 }
667
668 #[test]
669 fn test_mmap_builder_nonexistent_columns() {
670 let temp_dir = tempfile::tempdir().unwrap();
671 let path = temp_dir.path().join("test.parquet");
672 create_test_parquet(&path, 100);
673
674 let dataset = MmapDatasetBuilder::new()
676 .columns(vec!["nonexistent_col".to_string()])
677 .open(&path)
678 .unwrap();
679
680 assert_eq!(dataset.len(), 100);
682 }
683}