Skip to main content

tenflowers_dataset/
zero_copy.rs

1//! Zero-copy operations for memory-efficient dataset loading
2//!
3//! This module provides utilities for zero-copy data access and tensor views
4//! that avoid unnecessary data copying during dataset operations.
5
6use crate::Dataset;
7use std::ops::Range;
8use std::sync::Arc;
9use tenflowers_core::{Result, Shape, Tensor, TensorError};
10
11#[cfg(feature = "mmap")]
12use memmap2::{Mmap, MmapOptions};
13#[cfg(feature = "mmap")]
14use std::fs::File;
15#[cfg(feature = "mmap")]
16use std::marker::PhantomData;
17#[cfg(feature = "mmap")]
18use std::path::Path;
19
20/// A zero-copy view into a tensor that shares memory with the original tensor
21#[derive(Debug, Clone)]
22pub struct TensorView<T> {
23    /// Reference to the original tensor
24    source: Arc<Tensor<T>>,
25    /// Offset into the source tensor data
26    offset: usize,
27    /// Shape of this view
28    shape: Shape,
29    /// Strides for accessing data
30    strides: Vec<usize>,
31}
32
33impl<T> TensorView<T>
34where
35    T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
36{
37    /// Create a new tensor view from a source tensor
38    pub fn new(
39        source: Arc<Tensor<T>>,
40        offset: usize,
41        shape: Vec<usize>,
42        strides: Vec<usize>,
43    ) -> Result<Self> {
44        if shape.len() != strides.len() {
45            return Err(TensorError::invalid_argument(
46                "Shape and strides must have the same length".to_string(),
47            ));
48        }
49
50        let shape = Shape::new(shape);
51
52        Ok(Self {
53            source,
54            offset,
55            shape,
56            strides,
57        })
58    }
59
60    /// Create a view that slices the tensor along specified dimensions
61    pub fn slice(source: Arc<Tensor<T>>, ranges: &[Range<usize>]) -> Result<Self> {
62        let source_shape = source.shape();
63
64        if ranges.len() != source_shape.rank() {
65            return Err(TensorError::invalid_argument(format!(
66                "Number of ranges ({}) must match tensor rank ({})",
67                ranges.len(),
68                source_shape.rank()
69            )));
70        }
71
72        // Calculate new shape and offset
73        let mut new_shape = Vec::new();
74        let mut offset = 0;
75        let mut stride = 1;
76
77        // Calculate strides (row-major order)
78        let mut strides = vec![1; ranges.len()];
79        for i in (0..ranges.len()).rev() {
80            strides[i] = stride;
81            stride *= source_shape.dims()[i];
82        }
83
84        // Calculate offset and new shape
85        for (i, range) in ranges.iter().enumerate() {
86            if range.end > source_shape.dims()[i] {
87                return Err(TensorError::invalid_argument(format!(
88                    "Range end {} exceeds dimension size {}",
89                    range.end,
90                    source_shape.dims()[i]
91                )));
92            }
93
94            offset += range.start * strides[i];
95            new_shape.push(range.end - range.start);
96        }
97
98        Self::new(source, offset, new_shape, strides)
99    }
100
101    /// Create a view that reshapes the tensor without copying data
102    pub fn reshape(source: Arc<Tensor<T>>, new_shape: Vec<usize>) -> Result<Self> {
103        let total_elements = new_shape.iter().product::<usize>();
104        let source_elements = source.shape().size();
105
106        if total_elements != source_elements {
107            return Err(TensorError::invalid_argument(
108                format!("Cannot reshape tensor with {source_elements} elements to shape with {total_elements} elements")
109            ));
110        }
111
112        // Calculate row-major strides for new shape
113        let mut strides = vec![1; new_shape.len()];
114        let mut stride = 1;
115        for i in (0..new_shape.len()).rev() {
116            strides[i] = stride;
117            stride *= new_shape[i];
118        }
119
120        Self::new(source, 0, new_shape, strides)
121    }
122
123    /// Get the shape of this view
124    pub fn shape(&self) -> &Shape {
125        &self.shape
126    }
127
128    /// Get the strides of this view
129    pub fn strides(&self) -> &[usize] {
130        &self.strides
131    }
132
133    /// Get the offset into the source tensor
134    pub fn offset(&self) -> usize {
135        self.offset
136    }
137
138    /// Check if this view is contiguous in memory
139    pub fn is_contiguous(&self) -> bool {
140        let mut expected_stride = 1;
141        for i in (0..self.shape.rank()).rev() {
142            if self.strides[i] != expected_stride {
143                return false;
144            }
145            expected_stride *= self.shape.dims()[i];
146        }
147        true
148    }
149
150    /// Materialize this view into a concrete tensor (performs copy)
151    pub fn materialize(&self) -> Result<Tensor<T>>
152    where
153        T: bytemuck::Pod + bytemuck::Zeroable,
154    {
155        if self.is_contiguous() {
156            // If contiguous, we can potentially slice directly
157            if let Some(slice) = self.source.as_slice() {
158                let start = self.offset;
159                let end = start + self.shape.size();
160                let data = slice[start..end].to_vec();
161                return Tensor::from_vec(data, self.shape.dims());
162            }
163        }
164
165        // Non-contiguous case: need to copy elements individually
166        let mut data = Vec::with_capacity(self.shape.size());
167        let indices = self.iter_indices();
168
169        if let Some(slice) = self.source.as_slice() {
170            for linear_idx in indices {
171                let source_idx = self.offset + linear_idx;
172                data.push(slice[source_idx]);
173            }
174        } else {
175            return Err(TensorError::invalid_argument(
176                "Cannot access GPU tensor data for materialization".to_string(),
177            ));
178        }
179
180        Tensor::from_vec(data, self.shape.dims())
181    }
182
183    /// Iterate over linear indices for this view
184    fn iter_indices(&self) -> LinearIndexIterator {
185        LinearIndexIterator::new(&self.shape, &self.strides)
186    }
187}
188
189/// Iterator over linear indices for a tensor view
190struct LinearIndexIterator {
191    shape: Vec<usize>,
192    strides: Vec<usize>,
193    current: Vec<usize>,
194    done: bool,
195}
196
197impl LinearIndexIterator {
198    fn new(shape: &Shape, strides: &[usize]) -> Self {
199        Self {
200            shape: shape.dims().to_vec(),
201            strides: strides.to_vec(),
202            current: vec![0; shape.rank()],
203            done: shape.size() == 0,
204        }
205    }
206}
207
208impl Iterator for LinearIndexIterator {
209    type Item = usize;
210
211    fn next(&mut self) -> Option<Self::Item> {
212        if self.done {
213            return None;
214        }
215
216        // Calculate current linear index
217        let linear_idx = self
218            .current
219            .iter()
220            .zip(&self.strides)
221            .map(|(&idx, &stride)| idx * stride)
222            .sum();
223
224        // Advance to next index
225        let mut carry = 1;
226        for i in (0..self.current.len()).rev() {
227            self.current[i] += carry;
228            if self.current[i] < self.shape[i] {
229                carry = 0;
230                break;
231            } else {
232                self.current[i] = 0;
233            }
234        }
235
236        if carry == 1 {
237            self.done = true;
238        }
239
240        Some(linear_idx)
241    }
242}
243
244/// Zero-copy dataset wrapper that provides views into a large tensor
245pub struct ZeroCopyDataset<T> {
246    /// Source tensor containing all data
247    source: Arc<Tensor<T>>,
248    /// Number of samples
249    num_samples: usize,
250    /// Size of each sample (number of elements)
251    sample_size: usize,
252    /// Features shape (without batch dimension)
253    feature_shape: Vec<usize>,
254    /// Labels shape (without batch dimension)  
255    label_shape: Vec<usize>,
256    /// Offset to labels in the source tensor
257    labels_offset: usize,
258}
259
260impl<T> ZeroCopyDataset<T>
261where
262    T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
263{
264    /// Create a new zero-copy dataset from features and labels tensors
265    pub fn new(features: Tensor<T>, labels: Tensor<T>) -> Result<Self> {
266        let features_shape = features.shape();
267        let labels_shape = labels.shape();
268
269        if features_shape.dims()[0] != labels_shape.dims()[0] {
270            return Err(TensorError::invalid_argument(
271                "Features and labels must have same batch size".to_string(),
272            ));
273        }
274
275        let num_samples = features_shape.dims()[0];
276        let feature_elements = features_shape.size() / num_samples;
277        let label_elements = labels_shape.size() / num_samples;
278
279        // Concatenate features and labels into a single tensor for zero-copy access
280        let mut combined_data = Vec::new();
281
282        if let (Some(feat_slice), Some(label_slice)) = (features.as_slice(), labels.as_slice()) {
283            // Interleave features and labels for each sample
284            for i in 0..num_samples {
285                let feat_start = i * feature_elements;
286                let feat_end = feat_start + feature_elements;
287                combined_data.extend_from_slice(&feat_slice[feat_start..feat_end]);
288
289                let label_start = i * label_elements;
290                let label_end = label_start + label_elements;
291                combined_data.extend_from_slice(&label_slice[label_start..label_end]);
292            }
293        } else {
294            return Err(TensorError::invalid_argument(
295                "Cannot access tensor data (GPU tensors not supported for zero-copy dataset)"
296                    .to_string(),
297            ));
298        }
299
300        let sample_size = feature_elements + label_elements;
301        let combined_shape = vec![num_samples * sample_size];
302        let source = Arc::new(Tensor::from_vec(combined_data, &combined_shape)?);
303
304        Ok(Self {
305            source,
306            num_samples,
307            sample_size,
308            feature_shape: features_shape.dims()[1..].to_vec(),
309            label_shape: labels_shape.dims()[1..].to_vec(),
310            labels_offset: feature_elements,
311        })
312    }
313
314    /// Get a zero-copy view of a sample
315    pub fn get_view(&self, index: usize) -> Result<(TensorView<T>, TensorView<T>)> {
316        if index >= self.num_samples {
317            return Err(TensorError::invalid_argument(format!(
318                "Index {} out of bounds for dataset with {} samples",
319                index, self.num_samples
320            )));
321        }
322
323        let sample_offset = index * self.sample_size;
324
325        // Create feature view
326        let feature_strides = self.calculate_strides(&self.feature_shape);
327        let feature_view = TensorView::new(
328            Arc::clone(&self.source),
329            sample_offset,
330            self.feature_shape.clone(),
331            feature_strides,
332        )?;
333
334        // Create label view
335        let label_strides = self.calculate_strides(&self.label_shape);
336        let label_view = TensorView::new(
337            Arc::clone(&self.source),
338            sample_offset + self.labels_offset,
339            self.label_shape.clone(),
340            label_strides,
341        )?;
342
343        Ok((feature_view, label_view))
344    }
345
346    fn calculate_strides(&self, shape: &[usize]) -> Vec<usize> {
347        let mut strides = vec![1; shape.len()];
348        let mut stride = 1;
349        for i in (0..shape.len()).rev() {
350            strides[i] = stride;
351            stride *= shape[i];
352        }
353        strides
354    }
355}
356
357impl<T> Dataset<T> for ZeroCopyDataset<T>
358where
359    T: Clone
360        + Default
361        + scirs2_core::numeric::Zero
362        + Send
363        + Sync
364        + 'static
365        + bytemuck::Pod
366        + bytemuck::Zeroable,
367{
368    fn len(&self) -> usize {
369        self.num_samples
370    }
371
372    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
373        let (feature_view, label_view) = self.get_view(index)?;
374
375        // Materialize views into concrete tensors
376        let features = feature_view.materialize()?;
377        let labels = label_view.materialize()?;
378
379        Ok((features, labels))
380    }
381}
382
383/// Memory-mapped zero-copy dataset for large datasets
384pub struct MemoryMappedDataset<T> {
385    /// Memory-mapped data
386    data: Arc<[T]>,
387    /// Number of samples
388    num_samples: usize,
389    /// Feature size per sample
390    feature_size: usize,
391    /// Label size per sample
392    label_size: usize,
393    /// Feature shape (without batch dimension)
394    feature_shape: Vec<usize>,
395    /// Label shape (without batch dimension)
396    label_shape: Vec<usize>,
397}
398
399impl<T> MemoryMappedDataset<T>
400where
401    T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
402{
403    /// Create a memory-mapped dataset from raw data
404    /// Data layout: [sample0_features, sample0_labels, sample1_features, sample1_labels, ...]
405    pub fn new(
406        data: Arc<[T]>,
407        num_samples: usize,
408        feature_shape: Vec<usize>,
409        label_shape: Vec<usize>,
410    ) -> Result<Self> {
411        let feature_size = feature_shape.iter().product();
412        let label_size = label_shape.iter().product();
413        let expected_size = num_samples * (feature_size + label_size);
414
415        if data.len() != expected_size {
416            return Err(TensorError::invalid_argument(format!(
417                "Data size {} doesn't match expected size {} for {} samples",
418                data.len(),
419                expected_size,
420                num_samples
421            )));
422        }
423
424        Ok(Self {
425            data,
426            num_samples,
427            feature_size,
428            label_size,
429            feature_shape,
430            label_shape,
431        })
432    }
433
434    /// Get a zero-copy slice for a sample
435    pub fn get_raw_sample(&self, index: usize) -> Result<(&[T], &[T])> {
436        if index >= self.num_samples {
437            return Err(TensorError::invalid_argument(format!(
438                "Index {index} out of bounds"
439            )));
440        }
441
442        let sample_size = self.feature_size + self.label_size;
443        let start = index * sample_size;
444
445        let features = &self.data[start..start + self.feature_size];
446        let labels = &self.data[start + self.feature_size..start + sample_size];
447
448        Ok((features, labels))
449    }
450}
451
452impl<T> Dataset<T> for MemoryMappedDataset<T>
453where
454    T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
455{
456    fn len(&self) -> usize {
457        self.num_samples
458    }
459
460    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
461        let (feat_slice, label_slice) = self.get_raw_sample(index)?;
462
463        // Create tensors from slices (this involves copying)
464        let features = Tensor::from_vec(feat_slice.to_vec(), &self.feature_shape)?;
465        let labels = Tensor::from_vec(label_slice.to_vec(), &self.label_shape)?;
466
467        Ok((features, labels))
468    }
469}
470
471/// Enhanced memory-mapped dataset that loads data directly from files
472/// This provides true zero-copy access to very large datasets that exceed available RAM
473#[cfg(feature = "mmap")]
474#[allow(unsafe_code)] // Required for memory mapping
475pub struct MemoryMappedFileDataset<T> {
476    /// Memory-mapped file
477    mmap: Mmap,
478    /// Number of samples  
479    num_samples: usize,
480    /// Feature size per sample (in bytes)
481    feature_size_bytes: usize,
482    /// Label size per sample (in bytes)
483    label_size_bytes: usize,
484    /// Feature shape (without batch dimension)
485    feature_shape: Vec<usize>,
486    /// Label shape (without batch dimension)
487    label_shape: Vec<usize>,
488    /// Size of each element in bytes
489    #[allow(dead_code)] // Used for future validation features
490    element_size: usize,
491    /// Total size per sample (features + labels) in bytes
492    sample_size_bytes: usize,
493    /// File path for debugging
494    file_path: String,
495    /// Phantom data for type parameter
496    _phantom: PhantomData<T>,
497}
498
499#[cfg(feature = "mmap")]
500impl<T> MemoryMappedFileDataset<T>
501where
502    T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
503{
504    /// Create a memory-mapped dataset directly from a file
505    /// The file should contain binary data in the format:
506    /// [sample0_features, sample0_labels, sample1_features, sample1_labels, ...]
507    #[allow(unsafe_code)] // Required for memory mapping
508    pub fn from_file<P: AsRef<Path>>(
509        file_path: P,
510        num_samples: usize,
511        feature_shape: Vec<usize>,
512        label_shape: Vec<usize>,
513    ) -> Result<Self> {
514        let file_path = file_path.as_ref();
515        let file = File::open(file_path).map_err(|e| {
516            TensorError::io_error_simple(format!(
517                "Failed to open file {}: {}",
518                file_path.display(),
519                e
520            ))
521        })?;
522
523        let mmap = unsafe {
524            MmapOptions::new().map(&file).map_err(|e| {
525                TensorError::io_error_simple(format!(
526                    "Failed to memory map file {}: {}",
527                    file_path.display(),
528                    e
529                ))
530            })?
531        };
532
533        let element_size = std::mem::size_of::<T>();
534        let feature_size = feature_shape.iter().product::<usize>();
535        let label_size = label_shape.iter().product::<usize>();
536        let feature_size_bytes = feature_size * element_size;
537        let label_size_bytes = label_size * element_size;
538        let sample_size_bytes = feature_size_bytes + label_size_bytes;
539        let expected_file_size = num_samples * sample_size_bytes;
540
541        if mmap.len() < expected_file_size {
542            return Err(TensorError::invalid_argument(format!(
543                "File {} size {} is smaller than expected size {} for {} samples",
544                file_path.display(),
545                mmap.len(),
546                expected_file_size,
547                num_samples
548            )));
549        }
550
551        Ok(Self {
552            mmap,
553            num_samples,
554            feature_size_bytes,
555            label_size_bytes,
556            feature_shape,
557            label_shape,
558            element_size,
559            sample_size_bytes,
560            file_path: file_path.display().to_string(),
561            _phantom: PhantomData,
562        })
563    }
564
565    /// Create a memory-mapped dataset from an existing file with automatic shape detection
566    /// Assumes all samples have the same shape
567    pub fn auto_detect<P: AsRef<Path>>(
568        file_path: P,
569        feature_shape: Vec<usize>,
570        label_shape: Vec<usize>,
571    ) -> Result<Self> {
572        let file_path_ref = file_path.as_ref();
573        let metadata = std::fs::metadata(file_path_ref).map_err(|e| {
574            TensorError::io_error_simple(format!(
575                "Failed to get metadata for {}: {}",
576                file_path_ref.display(),
577                e
578            ))
579        })?;
580
581        let element_size = std::mem::size_of::<T>();
582        let feature_size = feature_shape.iter().product::<usize>();
583        let label_size = label_shape.iter().product::<usize>();
584        let sample_size_bytes = (feature_size + label_size) * element_size;
585
586        let num_samples = metadata.len() as usize / sample_size_bytes;
587
588        if metadata.len() as usize % sample_size_bytes != 0 {
589            return Err(TensorError::invalid_argument(format!(
590                "File {} size {} is not evenly divisible by sample size {}",
591                file_path_ref.display(),
592                metadata.len(),
593                sample_size_bytes
594            )));
595        }
596
597        Self::from_file(file_path, num_samples, feature_shape, label_shape)
598    }
599
600    /// Get raw byte slices for a sample (features and labels)
601    fn get_raw_sample_bytes(&self, index: usize) -> Result<(&[u8], &[u8])> {
602        if index >= self.num_samples {
603            return Err(TensorError::invalid_argument(format!(
604                "Index {} out of bounds for dataset with {} samples",
605                index, self.num_samples
606            )));
607        }
608
609        let sample_offset = index * self.sample_size_bytes;
610        let feature_start = sample_offset;
611        let feature_end = feature_start + self.feature_size_bytes;
612        let label_start = feature_end;
613        let label_end = label_start + self.label_size_bytes;
614
615        let feature_bytes = &self.mmap[feature_start..feature_end];
616        let label_bytes = &self.mmap[label_start..label_end];
617
618        Ok((feature_bytes, label_bytes))
619    }
620
621    /// Get file statistics for monitoring
622    pub fn file_stats(&self) -> MemoryMappedFileStats {
623        MemoryMappedFileStats {
624            file_path: self.file_path.clone(),
625            file_size: self.mmap.len(),
626            num_samples: self.num_samples,
627            sample_size_bytes: self.sample_size_bytes,
628            feature_shape: self.feature_shape.clone(),
629            label_shape: self.label_shape.clone(),
630        }
631    }
632
633    /// Create a view that spans multiple samples for batch processing
634    pub fn get_batch_view(&self, start_index: usize, batch_size: usize) -> Result<&[u8]> {
635        if start_index + batch_size > self.num_samples {
636            return Err(TensorError::invalid_argument(format!(
637                "Batch {}..{} out of bounds for dataset with {} samples",
638                start_index,
639                start_index + batch_size,
640                self.num_samples
641            )));
642        }
643
644        let start_offset = start_index * self.sample_size_bytes;
645        let end_offset = (start_index + batch_size) * self.sample_size_bytes;
646
647        Ok(&self.mmap[start_offset..end_offset])
648    }
649}
650
651#[cfg(feature = "mmap")]
652impl<T> Dataset<T> for MemoryMappedFileDataset<T>
653where
654    T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + bytemuck::Pod + 'static,
655{
656    fn len(&self) -> usize {
657        self.num_samples
658    }
659
660    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
661        let (feature_bytes, label_bytes) = self.get_raw_sample_bytes(index)?;
662
663        // Convert bytes to typed slices safely using bytemuck
664        let feature_data: &[T] = bytemuck::cast_slice(feature_bytes);
665        let label_data: &[T] = bytemuck::cast_slice(label_bytes);
666
667        // Create tensors from slices
668        let features = Tensor::from_vec(feature_data.to_vec(), &self.feature_shape)?;
669        let labels = Tensor::from_vec(label_data.to_vec(), &self.label_shape)?;
670
671        Ok((features, labels))
672    }
673}
674
675/// Statistics for memory-mapped file datasets
676#[cfg(feature = "mmap")]
677#[derive(Debug, Clone)]
678pub struct MemoryMappedFileStats {
679    pub file_path: String,
680    pub file_size: usize,
681    pub num_samples: usize,
682    pub sample_size_bytes: usize,
683    pub feature_shape: Vec<usize>,
684    pub label_shape: Vec<usize>,
685}
686
687#[cfg(feature = "mmap")]
688impl MemoryMappedFileStats {
689    /// Calculate memory efficiency (how much of mapped memory is actually used)
690    pub fn memory_efficiency(&self) -> f64 {
691        let used_size = self.num_samples * self.sample_size_bytes;
692        used_size as f64 / self.file_size as f64
693    }
694
695    /// Get human-readable file size
696    pub fn human_readable_size(&self) -> String {
697        let size = self.file_size as f64;
698        if size < 1024.0 {
699            format!("{size:.1} B")
700        } else if size < 1024.0 * 1024.0 {
701            let kb = size / 1024.0;
702            format!("{kb:.1} KB")
703        } else if size < 1024.0 * 1024.0 * 1024.0 {
704            let mb = size / (1024.0 * 1024.0);
705            format!("{mb:.1} MB")
706        } else {
707            let gb = size / (1024.0 * 1024.0 * 1024.0);
708            format!("{gb:.1} GB")
709        }
710    }
711}
712
713#[cfg(test)]
714mod tests {
715    use super::*;
716
717    #[test]
718    fn test_tensor_view_slice() {
719        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
720        let tensor = Arc::new(
721            Tensor::from_vec(data, &[2, 3]).expect("test: tensor creation should succeed"),
722        );
723
724        // Slice first row: [1, 2, 3]
725        let view =
726            TensorView::slice(tensor, &[0..1, 0..3]).expect("test: operation should succeed");
727        assert_eq!(view.shape().dims(), &[1, 3]);
728        assert_eq!(view.offset(), 0);
729        assert!(view.is_contiguous());
730    }
731
732    #[test]
733    fn test_tensor_view_reshape() {
734        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
735        let tensor = Arc::new(
736            Tensor::from_vec(data, &[2, 3]).expect("test: tensor creation should succeed"),
737        );
738
739        // Reshape to [6, 1]
740        let view = TensorView::reshape(tensor, vec![6, 1]).expect("test: operation should succeed");
741        assert_eq!(view.shape().dims(), &[6, 1]);
742        assert_eq!(view.offset(), 0);
743    }
744
745    #[test]
746    fn test_zero_copy_dataset() {
747        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
748            .expect("test: tensor creation should succeed");
749        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
750            .expect("test: tensor creation should succeed");
751
752        let dataset =
753            ZeroCopyDataset::new(features, labels).expect("test: operation should succeed");
754        assert_eq!(dataset.len(), 2);
755
756        let (feat, label) = dataset.get(0).expect("index should be in bounds");
757        assert_eq!(feat.shape().dims(), &[2]);
758        assert_eq!(label.shape().dims(), &[] as &[usize]);
759    }
760
761    #[test]
762    fn test_memory_mapped_dataset() {
763        // Data layout: [feat0_0, feat0_1, label0, feat1_0, feat1_1, label1]
764        let data: Arc<[f32]> = Arc::from(vec![1.0, 2.0, 0.0, 3.0, 4.0, 1.0]);
765
766        let dataset = MemoryMappedDataset::new(
767            data,
768            2,       // 2 samples
769            vec![2], // 2 features per sample
770            vec![],  // scalar labels
771        )
772        .expect("test: operation should succeed");
773
774        assert_eq!(dataset.len(), 2);
775
776        let (feat0, label0) = dataset.get(0).expect("index should be in bounds");
777        assert_eq!(feat0.shape().dims(), &[2]);
778        assert_eq!(label0.shape().dims(), &[] as &[usize]);
779
780        let (feat1, label1) = dataset.get(1).expect("index should be in bounds");
781        assert_eq!(feat1.shape().dims(), &[2]);
782        assert_eq!(label1.shape().dims(), &[] as &[usize]);
783    }
784}