use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
pub fn write_npy_header<W: Write>(writer: &mut W, dtype: &str, shape: &str) -> std::io::Result<()> {
writer.write_all(&[0x93])?; writer.write_all(b"NUMPY")?;
writer.write_all(&[0x01, 0x00])?;
let header = format!("{{'descr': '{dtype}', 'fortran_order': False, 'shape': {shape} }}");
let header_len = header.len();
let total_len = 10 + header_len + 1; let padding = (64 - (total_len % 64)) % 64;
let padded_len = header_len + 1 + padding;
writer.write_all(&(padded_len as u16).to_le_bytes())?;
writer.write_all(header.as_bytes())?;
for _ in 0..padding {
writer.write_all(b" ")?;
}
writer.write_all(b"\n")?;
Ok(())
}
pub fn write_npy_1d_i64(path: &Path, data: &[i64]) -> std::io::Result<()> {
let file = File::create(path)?;
let mut writer = BufWriter::with_capacity(256 * 1024, file);
let shape = format!("({},)", data.len());
write_npy_header(&mut writer, "<i8", &shape)?;
for &val in data {
writer.write_all(&val.to_le_bytes())?;
}
Ok(())
}
pub fn write_npy_1d_bool(path: &Path, data: &[bool]) -> std::io::Result<()> {
let file = File::create(path)?;
let mut writer = BufWriter::with_capacity(256 * 1024, file);
let shape = format!("({},)", data.len());
write_npy_header(&mut writer, "|b1", &shape)?;
for &val in data {
writer.write_all(&[if val { 1u8 } else { 0u8 }])?;
}
Ok(())
}
pub fn write_npy_2d_i64(path: &Path, data: &[Vec<i64>]) -> std::io::Result<()> {
let file = File::create(path)?;
let mut writer = BufWriter::with_capacity(256 * 1024, file);
let rows = data.len();
let cols = data.first().map(|r| r.len()).unwrap_or(0);
let shape = format!("({rows}, {cols})");
write_npy_header(&mut writer, "<i8", &shape)?;
for row in data {
for &val in row {
writer.write_all(&val.to_le_bytes())?;
}
for _ in row.len()..cols {
writer.write_all(&0_i64.to_le_bytes())?;
}
}
Ok(())
}
pub fn write_npy_2d_f64(path: &Path, data: &[Vec<f64>]) -> std::io::Result<()> {
let file = File::create(path)?;
let mut writer = BufWriter::with_capacity(256 * 1024, file);
let rows = data.len();
let cols = data.first().map(|r| r.len()).unwrap_or(0);
let shape = format!("({rows}, {cols})");
write_npy_header(&mut writer, "<f8", &shape)?;
for row in data {
for &val in row {
writer.write_all(&val.to_le_bytes())?;
}
for _ in row.len()..cols {
writer.write_all(&0.0_f64.to_le_bytes())?;
}
}
Ok(())
}
pub fn export_masks(
output_dir: &Path,
n: usize,
seed: u64,
train_ratio: f64,
val_ratio: f64,
) -> std::io::Result<()> {
let mut rng = SimpleRng::new(seed);
let train_size = (n as f64 * train_ratio) as usize;
let val_size = (n as f64 * val_ratio) as usize;
let mut indices: Vec<usize> = (0..n).collect();
for i in (1..n).rev() {
let j = (rng.next_u64() % (i as u64 + 1)) as usize;
indices.swap(i, j);
}
let mut train_mask = vec![false; n];
let mut val_mask = vec![false; n];
let mut test_mask = vec![false; n];
for (i, &idx) in indices.iter().enumerate() {
if i < train_size {
train_mask[idx] = true;
} else if i < train_size + val_size {
val_mask[idx] = true;
} else {
test_mask[idx] = true;
}
}
write_npy_1d_bool(&output_dir.join("train_mask.npy"), &train_mask)?;
write_npy_1d_bool(&output_dir.join("val_mask.npy"), &val_mask)?;
write_npy_1d_bool(&output_dir.join("test_mask.npy"), &test_mask)?;
Ok(())
}
pub struct SimpleRng {
state: u64,
}
impl SimpleRng {
pub fn new(seed: u64) -> Self {
Self {
state: if seed == 0 { 1 } else { seed },
}
}
pub fn next_u64(&mut self) -> u64 {
let mut x = self.state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.state = x;
x
}
}