1use crate::Dataset;
7use std::ops::Range;
8use std::sync::Arc;
9use tenflowers_core::{Result, Shape, Tensor, TensorError};
10
11#[cfg(feature = "mmap")]
12use memmap2::{Mmap, MmapOptions};
13#[cfg(feature = "mmap")]
14use std::fs::File;
15#[cfg(feature = "mmap")]
16use std::marker::PhantomData;
17#[cfg(feature = "mmap")]
18use std::path::Path;
19
20#[derive(Debug, Clone)]
22pub struct TensorView<T> {
23 source: Arc<Tensor<T>>,
25 offset: usize,
27 shape: Shape,
29 strides: Vec<usize>,
31}
32
33impl<T> TensorView<T>
34where
35 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
36{
37 pub fn new(
39 source: Arc<Tensor<T>>,
40 offset: usize,
41 shape: Vec<usize>,
42 strides: Vec<usize>,
43 ) -> Result<Self> {
44 if shape.len() != strides.len() {
45 return Err(TensorError::invalid_argument(
46 "Shape and strides must have the same length".to_string(),
47 ));
48 }
49
50 let shape = Shape::new(shape);
51
52 Ok(Self {
53 source,
54 offset,
55 shape,
56 strides,
57 })
58 }
59
60 pub fn slice(source: Arc<Tensor<T>>, ranges: &[Range<usize>]) -> Result<Self> {
62 let source_shape = source.shape();
63
64 if ranges.len() != source_shape.rank() {
65 return Err(TensorError::invalid_argument(format!(
66 "Number of ranges ({}) must match tensor rank ({})",
67 ranges.len(),
68 source_shape.rank()
69 )));
70 }
71
72 let mut new_shape = Vec::new();
74 let mut offset = 0;
75 let mut stride = 1;
76
77 let mut strides = vec![1; ranges.len()];
79 for i in (0..ranges.len()).rev() {
80 strides[i] = stride;
81 stride *= source_shape.dims()[i];
82 }
83
84 for (i, range) in ranges.iter().enumerate() {
86 if range.end > source_shape.dims()[i] {
87 return Err(TensorError::invalid_argument(format!(
88 "Range end {} exceeds dimension size {}",
89 range.end,
90 source_shape.dims()[i]
91 )));
92 }
93
94 offset += range.start * strides[i];
95 new_shape.push(range.end - range.start);
96 }
97
98 Self::new(source, offset, new_shape, strides)
99 }
100
101 pub fn reshape(source: Arc<Tensor<T>>, new_shape: Vec<usize>) -> Result<Self> {
103 let total_elements = new_shape.iter().product::<usize>();
104 let source_elements = source.shape().size();
105
106 if total_elements != source_elements {
107 return Err(TensorError::invalid_argument(
108 format!("Cannot reshape tensor with {source_elements} elements to shape with {total_elements} elements")
109 ));
110 }
111
112 let mut strides = vec![1; new_shape.len()];
114 let mut stride = 1;
115 for i in (0..new_shape.len()).rev() {
116 strides[i] = stride;
117 stride *= new_shape[i];
118 }
119
120 Self::new(source, 0, new_shape, strides)
121 }
122
123 pub fn shape(&self) -> &Shape {
125 &self.shape
126 }
127
128 pub fn strides(&self) -> &[usize] {
130 &self.strides
131 }
132
133 pub fn offset(&self) -> usize {
135 self.offset
136 }
137
138 pub fn is_contiguous(&self) -> bool {
140 let mut expected_stride = 1;
141 for i in (0..self.shape.rank()).rev() {
142 if self.strides[i] != expected_stride {
143 return false;
144 }
145 expected_stride *= self.shape.dims()[i];
146 }
147 true
148 }
149
150 pub fn materialize(&self) -> Result<Tensor<T>>
152 where
153 T: bytemuck::Pod + bytemuck::Zeroable,
154 {
155 if self.is_contiguous() {
156 if let Some(slice) = self.source.as_slice() {
158 let start = self.offset;
159 let end = start + self.shape.size();
160 let data = slice[start..end].to_vec();
161 return Tensor::from_vec(data, self.shape.dims());
162 }
163 }
164
165 let mut data = Vec::with_capacity(self.shape.size());
167 let indices = self.iter_indices();
168
169 if let Some(slice) = self.source.as_slice() {
170 for linear_idx in indices {
171 let source_idx = self.offset + linear_idx;
172 data.push(slice[source_idx]);
173 }
174 } else {
175 return Err(TensorError::invalid_argument(
176 "Cannot access GPU tensor data for materialization".to_string(),
177 ));
178 }
179
180 Tensor::from_vec(data, self.shape.dims())
181 }
182
183 fn iter_indices(&self) -> LinearIndexIterator {
185 LinearIndexIterator::new(&self.shape, &self.strides)
186 }
187}
188
189struct LinearIndexIterator {
191 shape: Vec<usize>,
192 strides: Vec<usize>,
193 current: Vec<usize>,
194 done: bool,
195}
196
197impl LinearIndexIterator {
198 fn new(shape: &Shape, strides: &[usize]) -> Self {
199 Self {
200 shape: shape.dims().to_vec(),
201 strides: strides.to_vec(),
202 current: vec![0; shape.rank()],
203 done: shape.size() == 0,
204 }
205 }
206}
207
208impl Iterator for LinearIndexIterator {
209 type Item = usize;
210
211 fn next(&mut self) -> Option<Self::Item> {
212 if self.done {
213 return None;
214 }
215
216 let linear_idx = self
218 .current
219 .iter()
220 .zip(&self.strides)
221 .map(|(&idx, &stride)| idx * stride)
222 .sum();
223
224 let mut carry = 1;
226 for i in (0..self.current.len()).rev() {
227 self.current[i] += carry;
228 if self.current[i] < self.shape[i] {
229 carry = 0;
230 break;
231 } else {
232 self.current[i] = 0;
233 }
234 }
235
236 if carry == 1 {
237 self.done = true;
238 }
239
240 Some(linear_idx)
241 }
242}
243
244pub struct ZeroCopyDataset<T> {
246 source: Arc<Tensor<T>>,
248 num_samples: usize,
250 sample_size: usize,
252 feature_shape: Vec<usize>,
254 label_shape: Vec<usize>,
256 labels_offset: usize,
258}
259
260impl<T> ZeroCopyDataset<T>
261where
262 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
263{
264 pub fn new(features: Tensor<T>, labels: Tensor<T>) -> Result<Self> {
266 let features_shape = features.shape();
267 let labels_shape = labels.shape();
268
269 if features_shape.dims()[0] != labels_shape.dims()[0] {
270 return Err(TensorError::invalid_argument(
271 "Features and labels must have same batch size".to_string(),
272 ));
273 }
274
275 let num_samples = features_shape.dims()[0];
276 let feature_elements = features_shape.size() / num_samples;
277 let label_elements = labels_shape.size() / num_samples;
278
279 let mut combined_data = Vec::new();
281
282 if let (Some(feat_slice), Some(label_slice)) = (features.as_slice(), labels.as_slice()) {
283 for i in 0..num_samples {
285 let feat_start = i * feature_elements;
286 let feat_end = feat_start + feature_elements;
287 combined_data.extend_from_slice(&feat_slice[feat_start..feat_end]);
288
289 let label_start = i * label_elements;
290 let label_end = label_start + label_elements;
291 combined_data.extend_from_slice(&label_slice[label_start..label_end]);
292 }
293 } else {
294 return Err(TensorError::invalid_argument(
295 "Cannot access tensor data (GPU tensors not supported for zero-copy dataset)"
296 .to_string(),
297 ));
298 }
299
300 let sample_size = feature_elements + label_elements;
301 let combined_shape = vec![num_samples * sample_size];
302 let source = Arc::new(Tensor::from_vec(combined_data, &combined_shape)?);
303
304 Ok(Self {
305 source,
306 num_samples,
307 sample_size,
308 feature_shape: features_shape.dims()[1..].to_vec(),
309 label_shape: labels_shape.dims()[1..].to_vec(),
310 labels_offset: feature_elements,
311 })
312 }
313
314 pub fn get_view(&self, index: usize) -> Result<(TensorView<T>, TensorView<T>)> {
316 if index >= self.num_samples {
317 return Err(TensorError::invalid_argument(format!(
318 "Index {} out of bounds for dataset with {} samples",
319 index, self.num_samples
320 )));
321 }
322
323 let sample_offset = index * self.sample_size;
324
325 let feature_strides = self.calculate_strides(&self.feature_shape);
327 let feature_view = TensorView::new(
328 Arc::clone(&self.source),
329 sample_offset,
330 self.feature_shape.clone(),
331 feature_strides,
332 )?;
333
334 let label_strides = self.calculate_strides(&self.label_shape);
336 let label_view = TensorView::new(
337 Arc::clone(&self.source),
338 sample_offset + self.labels_offset,
339 self.label_shape.clone(),
340 label_strides,
341 )?;
342
343 Ok((feature_view, label_view))
344 }
345
346 fn calculate_strides(&self, shape: &[usize]) -> Vec<usize> {
347 let mut strides = vec![1; shape.len()];
348 let mut stride = 1;
349 for i in (0..shape.len()).rev() {
350 strides[i] = stride;
351 stride *= shape[i];
352 }
353 strides
354 }
355}
356
357impl<T> Dataset<T> for ZeroCopyDataset<T>
358where
359 T: Clone
360 + Default
361 + scirs2_core::numeric::Zero
362 + Send
363 + Sync
364 + 'static
365 + bytemuck::Pod
366 + bytemuck::Zeroable,
367{
368 fn len(&self) -> usize {
369 self.num_samples
370 }
371
372 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
373 let (feature_view, label_view) = self.get_view(index)?;
374
375 let features = feature_view.materialize()?;
377 let labels = label_view.materialize()?;
378
379 Ok((features, labels))
380 }
381}
382
383pub struct MemoryMappedDataset<T> {
385 data: Arc<[T]>,
387 num_samples: usize,
389 feature_size: usize,
391 label_size: usize,
393 feature_shape: Vec<usize>,
395 label_shape: Vec<usize>,
397}
398
399impl<T> MemoryMappedDataset<T>
400where
401 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
402{
403 pub fn new(
406 data: Arc<[T]>,
407 num_samples: usize,
408 feature_shape: Vec<usize>,
409 label_shape: Vec<usize>,
410 ) -> Result<Self> {
411 let feature_size = feature_shape.iter().product();
412 let label_size = label_shape.iter().product();
413 let expected_size = num_samples * (feature_size + label_size);
414
415 if data.len() != expected_size {
416 return Err(TensorError::invalid_argument(format!(
417 "Data size {} doesn't match expected size {} for {} samples",
418 data.len(),
419 expected_size,
420 num_samples
421 )));
422 }
423
424 Ok(Self {
425 data,
426 num_samples,
427 feature_size,
428 label_size,
429 feature_shape,
430 label_shape,
431 })
432 }
433
434 pub fn get_raw_sample(&self, index: usize) -> Result<(&[T], &[T])> {
436 if index >= self.num_samples {
437 return Err(TensorError::invalid_argument(format!(
438 "Index {index} out of bounds"
439 )));
440 }
441
442 let sample_size = self.feature_size + self.label_size;
443 let start = index * sample_size;
444
445 let features = &self.data[start..start + self.feature_size];
446 let labels = &self.data[start + self.feature_size..start + sample_size];
447
448 Ok((features, labels))
449 }
450}
451
452impl<T> Dataset<T> for MemoryMappedDataset<T>
453where
454 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
455{
456 fn len(&self) -> usize {
457 self.num_samples
458 }
459
460 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
461 let (feat_slice, label_slice) = self.get_raw_sample(index)?;
462
463 let features = Tensor::from_vec(feat_slice.to_vec(), &self.feature_shape)?;
465 let labels = Tensor::from_vec(label_slice.to_vec(), &self.label_shape)?;
466
467 Ok((features, labels))
468 }
469}
470
471#[cfg(feature = "mmap")]
474#[allow(unsafe_code)] pub struct MemoryMappedFileDataset<T> {
476 mmap: Mmap,
478 num_samples: usize,
480 feature_size_bytes: usize,
482 label_size_bytes: usize,
484 feature_shape: Vec<usize>,
486 label_shape: Vec<usize>,
488 #[allow(dead_code)] element_size: usize,
491 sample_size_bytes: usize,
493 file_path: String,
495 _phantom: PhantomData<T>,
497}
498
499#[cfg(feature = "mmap")]
500impl<T> MemoryMappedFileDataset<T>
501where
502 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
503{
504 #[allow(unsafe_code)] pub fn from_file<P: AsRef<Path>>(
509 file_path: P,
510 num_samples: usize,
511 feature_shape: Vec<usize>,
512 label_shape: Vec<usize>,
513 ) -> Result<Self> {
514 let file_path = file_path.as_ref();
515 let file = File::open(file_path).map_err(|e| {
516 TensorError::io_error_simple(format!(
517 "Failed to open file {}: {}",
518 file_path.display(),
519 e
520 ))
521 })?;
522
523 let mmap = unsafe {
524 MmapOptions::new().map(&file).map_err(|e| {
525 TensorError::io_error_simple(format!(
526 "Failed to memory map file {}: {}",
527 file_path.display(),
528 e
529 ))
530 })?
531 };
532
533 let element_size = std::mem::size_of::<T>();
534 let feature_size = feature_shape.iter().product::<usize>();
535 let label_size = label_shape.iter().product::<usize>();
536 let feature_size_bytes = feature_size * element_size;
537 let label_size_bytes = label_size * element_size;
538 let sample_size_bytes = feature_size_bytes + label_size_bytes;
539 let expected_file_size = num_samples * sample_size_bytes;
540
541 if mmap.len() < expected_file_size {
542 return Err(TensorError::invalid_argument(format!(
543 "File {} size {} is smaller than expected size {} for {} samples",
544 file_path.display(),
545 mmap.len(),
546 expected_file_size,
547 num_samples
548 )));
549 }
550
551 Ok(Self {
552 mmap,
553 num_samples,
554 feature_size_bytes,
555 label_size_bytes,
556 feature_shape,
557 label_shape,
558 element_size,
559 sample_size_bytes,
560 file_path: file_path.display().to_string(),
561 _phantom: PhantomData,
562 })
563 }
564
565 pub fn auto_detect<P: AsRef<Path>>(
568 file_path: P,
569 feature_shape: Vec<usize>,
570 label_shape: Vec<usize>,
571 ) -> Result<Self> {
572 let file_path_ref = file_path.as_ref();
573 let metadata = std::fs::metadata(file_path_ref).map_err(|e| {
574 TensorError::io_error_simple(format!(
575 "Failed to get metadata for {}: {}",
576 file_path_ref.display(),
577 e
578 ))
579 })?;
580
581 let element_size = std::mem::size_of::<T>();
582 let feature_size = feature_shape.iter().product::<usize>();
583 let label_size = label_shape.iter().product::<usize>();
584 let sample_size_bytes = (feature_size + label_size) * element_size;
585
586 let num_samples = metadata.len() as usize / sample_size_bytes;
587
588 if metadata.len() as usize % sample_size_bytes != 0 {
589 return Err(TensorError::invalid_argument(format!(
590 "File {} size {} is not evenly divisible by sample size {}",
591 file_path_ref.display(),
592 metadata.len(),
593 sample_size_bytes
594 )));
595 }
596
597 Self::from_file(file_path, num_samples, feature_shape, label_shape)
598 }
599
600 fn get_raw_sample_bytes(&self, index: usize) -> Result<(&[u8], &[u8])> {
602 if index >= self.num_samples {
603 return Err(TensorError::invalid_argument(format!(
604 "Index {} out of bounds for dataset with {} samples",
605 index, self.num_samples
606 )));
607 }
608
609 let sample_offset = index * self.sample_size_bytes;
610 let feature_start = sample_offset;
611 let feature_end = feature_start + self.feature_size_bytes;
612 let label_start = feature_end;
613 let label_end = label_start + self.label_size_bytes;
614
615 let feature_bytes = &self.mmap[feature_start..feature_end];
616 let label_bytes = &self.mmap[label_start..label_end];
617
618 Ok((feature_bytes, label_bytes))
619 }
620
621 pub fn file_stats(&self) -> MemoryMappedFileStats {
623 MemoryMappedFileStats {
624 file_path: self.file_path.clone(),
625 file_size: self.mmap.len(),
626 num_samples: self.num_samples,
627 sample_size_bytes: self.sample_size_bytes,
628 feature_shape: self.feature_shape.clone(),
629 label_shape: self.label_shape.clone(),
630 }
631 }
632
633 pub fn get_batch_view(&self, start_index: usize, batch_size: usize) -> Result<&[u8]> {
635 if start_index + batch_size > self.num_samples {
636 return Err(TensorError::invalid_argument(format!(
637 "Batch {}..{} out of bounds for dataset with {} samples",
638 start_index,
639 start_index + batch_size,
640 self.num_samples
641 )));
642 }
643
644 let start_offset = start_index * self.sample_size_bytes;
645 let end_offset = (start_index + batch_size) * self.sample_size_bytes;
646
647 Ok(&self.mmap[start_offset..end_offset])
648 }
649}
650
651#[cfg(feature = "mmap")]
652impl<T> Dataset<T> for MemoryMappedFileDataset<T>
653where
654 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + bytemuck::Pod + 'static,
655{
656 fn len(&self) -> usize {
657 self.num_samples
658 }
659
660 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
661 let (feature_bytes, label_bytes) = self.get_raw_sample_bytes(index)?;
662
663 let feature_data: &[T] = bytemuck::cast_slice(feature_bytes);
665 let label_data: &[T] = bytemuck::cast_slice(label_bytes);
666
667 let features = Tensor::from_vec(feature_data.to_vec(), &self.feature_shape)?;
669 let labels = Tensor::from_vec(label_data.to_vec(), &self.label_shape)?;
670
671 Ok((features, labels))
672 }
673}
674
675#[cfg(feature = "mmap")]
677#[derive(Debug, Clone)]
678pub struct MemoryMappedFileStats {
679 pub file_path: String,
680 pub file_size: usize,
681 pub num_samples: usize,
682 pub sample_size_bytes: usize,
683 pub feature_shape: Vec<usize>,
684 pub label_shape: Vec<usize>,
685}
686
687#[cfg(feature = "mmap")]
688impl MemoryMappedFileStats {
689 pub fn memory_efficiency(&self) -> f64 {
691 let used_size = self.num_samples * self.sample_size_bytes;
692 used_size as f64 / self.file_size as f64
693 }
694
695 pub fn human_readable_size(&self) -> String {
697 let size = self.file_size as f64;
698 if size < 1024.0 {
699 format!("{size:.1} B")
700 } else if size < 1024.0 * 1024.0 {
701 let kb = size / 1024.0;
702 format!("{kb:.1} KB")
703 } else if size < 1024.0 * 1024.0 * 1024.0 {
704 let mb = size / (1024.0 * 1024.0);
705 format!("{mb:.1} MB")
706 } else {
707 let gb = size / (1024.0 * 1024.0 * 1024.0);
708 format!("{gb:.1} GB")
709 }
710 }
711}
712
713#[cfg(test)]
714mod tests {
715 use super::*;
716
717 #[test]
718 fn test_tensor_view_slice() {
719 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
720 let tensor = Arc::new(
721 Tensor::from_vec(data, &[2, 3]).expect("test: tensor creation should succeed"),
722 );
723
724 let view =
726 TensorView::slice(tensor, &[0..1, 0..3]).expect("test: operation should succeed");
727 assert_eq!(view.shape().dims(), &[1, 3]);
728 assert_eq!(view.offset(), 0);
729 assert!(view.is_contiguous());
730 }
731
732 #[test]
733 fn test_tensor_view_reshape() {
734 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
735 let tensor = Arc::new(
736 Tensor::from_vec(data, &[2, 3]).expect("test: tensor creation should succeed"),
737 );
738
739 let view = TensorView::reshape(tensor, vec![6, 1]).expect("test: operation should succeed");
741 assert_eq!(view.shape().dims(), &[6, 1]);
742 assert_eq!(view.offset(), 0);
743 }
744
745 #[test]
746 fn test_zero_copy_dataset() {
747 let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
748 .expect("test: tensor creation should succeed");
749 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
750 .expect("test: tensor creation should succeed");
751
752 let dataset =
753 ZeroCopyDataset::new(features, labels).expect("test: operation should succeed");
754 assert_eq!(dataset.len(), 2);
755
756 let (feat, label) = dataset.get(0).expect("index should be in bounds");
757 assert_eq!(feat.shape().dims(), &[2]);
758 assert_eq!(label.shape().dims(), &[] as &[usize]);
759 }
760
761 #[test]
762 fn test_memory_mapped_dataset() {
763 let data: Arc<[f32]> = Arc::from(vec![1.0, 2.0, 0.0, 3.0, 4.0, 1.0]);
765
766 let dataset = MemoryMappedDataset::new(
767 data,
768 2, vec![2], vec![], )
772 .expect("test: operation should succeed");
773
774 assert_eq!(dataset.len(), 2);
775
776 let (feat0, label0) = dataset.get(0).expect("index should be in bounds");
777 assert_eq!(feat0.shape().dims(), &[2]);
778 assert_eq!(label0.shape().dims(), &[] as &[usize]);
779
780 let (feat1, label1) = dataset.get(1).expect("index should be in bounds");
781 assert_eq!(feat1.shape().dims(), &[2]);
782 assert_eq!(label1.shape().dims(), &[] as &[usize]);
783 }
784}