use std::path::Path;
use memmap2::Mmap;
use crate::data::dataset::{Batch, Dataset};
use crate::error::{Error, Result};
use numr::dtype::DType;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
pub struct MmapDataset {
mmap: Mmap,
seq_len: usize,
num_tokens: usize,
}
impl MmapDataset {
pub fn open(path: impl AsRef<Path>, seq_len: usize) -> Result<Self> {
let path = path.as_ref();
let file = std::fs::File::open(path).map_err(|e| Error::DataError {
reason: format!("failed to open {}: {e}", path.display()),
})?;
let metadata = file.metadata()?;
let file_len = metadata.len() as usize;
if file_len % 4 != 0 {
return Err(Error::DataError {
reason: format!(
"file size {} is not a multiple of 4 (expected u32 tokens)",
file_len
),
});
}
let num_tokens = file_len / 4;
if num_tokens < seq_len + 1 {
return Err(Error::DataError {
reason: format!(
"file has {num_tokens} tokens but need at least {} for seq_len={seq_len}",
seq_len + 1
),
});
}
let mmap = unsafe { Mmap::map(&file)? };
Ok(Self {
mmap,
seq_len,
num_tokens,
})
}
pub fn num_tokens(&self) -> usize {
self.num_tokens
}
pub fn seq_len(&self) -> usize {
self.seq_len
}
pub fn num_samples(&self) -> usize {
self.num_tokens / (self.seq_len + 1)
}
fn read_tokens(&self, token_offset: usize, count: usize) -> &[u32] {
let byte_offset = token_offset * 4;
let byte_end = byte_offset + count * 4;
let bytes = &self.mmap[byte_offset..byte_end];
unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const u32, count) }
}
}
impl<R: Runtime<DType = DType>> Dataset<R> for MmapDataset {
fn len(&self) -> usize {
self.num_samples()
}
fn get(&self, idx: usize, device: &R::Device) -> Result<Batch<R>> {
let len = self.num_samples();
if idx >= len {
return Err(Error::DataError {
reason: format!("index {idx} out of bounds for dataset of size {len}"),
});
}
let start = idx * (self.seq_len + 1);
let tokens = self.read_tokens(start, self.seq_len + 1);
let input_f32: Vec<f32> = tokens[..self.seq_len].iter().map(|&t| t as f32).collect();
let target_f32: Vec<f32> = tokens[1..=self.seq_len].iter().map(|&t| t as f32).collect();
let inputs = Tensor::<R>::from_slice(&input_f32, &[self.seq_len], device);
let targets = Tensor::<R>::from_slice(&target_f32, &[self.seq_len], device);
Ok(Batch { inputs, targets })
}
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
use std::io::Write;
use tempfile::NamedTempFile;
fn write_token_file(tokens: &[u32]) -> NamedTempFile {
let mut f = NamedTempFile::new().unwrap();
for &t in tokens {
f.write_all(&t.to_le_bytes()).unwrap();
}
f.flush().unwrap();
f
}
#[test]
fn test_mmap_dataset_basic() {
let tokens: Vec<u32> = (0..10).collect();
let f = write_token_file(&tokens);
let ds = MmapDataset::open(f.path(), 3).unwrap();
assert_eq!(ds.num_tokens(), 10);
assert_eq!(ds.num_samples(), 2);
let device = CpuDevice::new();
let ds: &dyn Dataset<CpuRuntime> = &ds;
let batch = ds.get(0, &device).unwrap();
let inputs: Vec<f32> = batch.inputs.to_vec();
let targets: Vec<f32> = batch.targets.to_vec();
assert_eq!(inputs, vec![0.0, 1.0, 2.0]);
assert_eq!(targets, vec![1.0, 2.0, 3.0]);
let batch1 = ds.get(1, &device).unwrap();
let inputs1: Vec<f32> = batch1.inputs.to_vec();
assert_eq!(inputs1, vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_mmap_dataset_out_of_bounds() {
let tokens: Vec<u32> = (0..8).collect();
let f = write_token_file(&tokens);
let ds = MmapDataset::open(f.path(), 3).unwrap();
let device = CpuDevice::new();
let ds: &dyn Dataset<CpuRuntime> = &ds;
assert!(ds.get(2, &device).is_err());
}
#[test]
fn test_mmap_dataset_too_small() {
let tokens: Vec<u32> = vec![0, 1];
let f = write_token_file(&tokens);
assert!(MmapDataset::open(f.path(), 3).is_err());
}
#[test]
fn test_mmap_dataset_bad_alignment() {
let mut f = NamedTempFile::new().unwrap();
f.write_all(&[0u8; 5]).unwrap(); f.flush().unwrap();
assert!(MmapDataset::open(f.path(), 1).is_err());
}
}