use burn::prelude::*;
use safetensors::SafeTensors;
use crate::config::DataConfig;
pub fn discretize_chan_pos<B: Backend>(
chan_pos: Tensor<B, 2>,
cfg: &DataConfig,
device: &B::Device,
) -> Tensor<B, 2, Int> {
let [_c, _] = chan_pos.dims();
let xyz_min = Tensor::<B, 2>::from_data(
TensorData::new(cfg.xyz_min.to_vec(), vec![1, 3]), device,
);
let xyz_max = Tensor::<B, 2>::from_data(
TensorData::new(cfg.xyz_max.to_vec(), vec![1, 3]), device,
);
let norm = (chan_pos - xyz_min.clone()) / (xyz_max - xyz_min); let bins = cfg.num_bins as f32;
norm.mul_scalar(bins)
.int()
.clamp(0i32, cfg.num_bins as i32 - 1)
}
pub fn chop_and_reshape<B: Backend>(
eeg: Tensor<B, 2>, chan_pos: Tensor<B, 2>, chan_pos_disc: Tensor<B, 2, Int>, tf: usize,
) -> (Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 2, Int>, Tensor<B, 2, Int>) {
let [c, t_total] = eeg.dims();
assert_eq!(t_total % tf, 0, "T must be divisible by tf");
let tc = t_total / tf;
let s = c * tc;
let device = eeg.device();
let eeg_tokens = eeg.reshape([c, tc, tf]).reshape([s, tf]);
let pos = repeat_interleave_rows_f(chan_pos, tc);
let posd = repeat_interleave_rows_i(chan_pos_disc, tc);
let tc_vals: Vec<i32> = (0..tc as i32)
.cycle()
.take(s)
.collect();
let t_coarse = Tensor::<B, 1, Int>::from_data(
TensorData::new(tc_vals, vec![s]),
&device,
)
.reshape([s, 1]);
(eeg_tokens, pos, posd, t_coarse)
}
pub fn build_tok_idx<B: Backend>(
chan_pos_disc: Tensor<B, 2, Int>, t_coarse: Tensor<B, 2, Int>, ) -> Tensor<B, 2, Int> {
Tensor::cat(vec![chan_pos_disc, t_coarse], 1) }
pub struct InputBatch<B: Backend> {
pub encoder_input: Tensor<B, 3>,
pub tok_idx: Tensor<B, 2, Int>,
pub chan_pos: Tensor<B, 2>,
pub n_channels: usize,
pub tc: usize,
}
pub fn load_batch<B: Backend>(
path: &str,
cfg: &DataConfig,
device: &B::Device,
) -> anyhow::Result<Vec<InputBatch<B>>> {
let bytes = std::fs::read(path)?;
let st = SafeTensors::deserialize(&bytes)?;
let n_samples = {
let v = st.tensor("n_samples")?;
match v.dtype() {
safetensors::Dtype::I32 =>
i32::from_le_bytes(v.data()[..4].try_into().unwrap()) as usize,
safetensors::Dtype::F32 =>
f32::from_le_bytes(v.data()[..4].try_into().unwrap()) as usize,
other => anyhow::bail!("unexpected dtype for n_samples: {:?}", other),
}
};
let mut batches = Vec::with_capacity(n_samples);
for i in 0..n_samples {
let eeg_view = st.tensor(&format!("eeg_{i}"))?;
let [c, t]: [usize; 2] = eeg_view.shape().try_into()
.map_err(|_| anyhow::anyhow!("eeg_{i} must be 2-D"))?;
let eeg_f32 = bytes_to_f32(eeg_view.data(), eeg_view.dtype())?;
let eeg = Tensor::<B, 2>::from_data(TensorData::new(eeg_f32, vec![c, t]), device);
let pos_view = st.tensor(&format!("chan_pos_{i}"))?;
let pos_f32 = bytes_to_f32(pos_view.data(), pos_view.dtype())?;
let chan_pos = Tensor::<B, 2>::from_data(TensorData::new(pos_f32, vec![c, 3]), device);
let chan_pos_disc = discretize_chan_pos(chan_pos.clone(), cfg, device);
let tc = t / cfg.num_fine_time_pts;
let (eeg_tokens, _, posd, t_coarse) =
chop_and_reshape(eeg.clone(), chan_pos.clone(), chan_pos_disc, cfg.num_fine_time_pts);
let tok_idx = build_tok_idx(posd, t_coarse);
let encoder_input = eeg_tokens.unsqueeze_dim::<3>(0);
batches.push(InputBatch { encoder_input, tok_idx, chan_pos, n_channels: c, tc });
}
Ok(batches)
}
pub fn invert_reshape<B: Backend>(
tokens: Tensor<B, 2>,
n_channels: usize,
tc: usize,
tf: usize,
) -> Tensor<B, 2> {
tokens.reshape([n_channels, tc, tf]).reshape([n_channels, tc * tf])
}
pub struct FifInfo {
pub ch_names: Vec<String>,
pub ch_pos_mm: Vec<[f32; 3]>,
pub sfreq: f32,
pub n_times_raw: usize,
pub duration_s: f32,
pub n_epochs: usize,
pub target_sfreq: f32,
pub epoch_dur_s: f32,
}
pub fn load_from_fif<B: Backend>(
path: &std::path::Path,
data_cfg: &DataConfig,
data_norm: f32,
device: &B::Device,
) -> anyhow::Result<(Vec<InputBatch<B>>, FifInfo)> {
use exg::{
fiff::raw::open_raw,
PipelineConfig,
};
use ndarray::Array2;
let raw_fif = open_raw(path)?;
let src_sfreq = raw_fif.info.sfreq as f32;
let n_ch = raw_fif.info.n_chan;
let n_times_raw = raw_fif.n_times();
let duration_s = n_times_raw as f32 / src_sfreq;
let ch_names: Vec<String> = raw_fif.info.chs.iter()
.map(|ch| ch.name.clone())
.collect();
let ch_pos_mm: Vec<[f32; 3]> = raw_fif.info.chs.iter()
.map(|ch| [ch.loc[0] * 1000.0, ch.loc[1] * 1000.0, ch.loc[2] * 1000.0])
.collect();
let pos_flat: Vec<f32> = raw_fif.info.chs.iter()
.flat_map(|ch| [ch.loc[0], ch.loc[1], ch.loc[2]])
.collect();
let chan_pos_arr = Array2::from_shape_vec((n_ch, 3), pos_flat)?;
let data_f64 = raw_fif.read_all_data()?;
let data_f32: Array2<f32> = data_f64.mapv(|v| v as f32);
let preproc_cfg = PipelineConfig {
data_norm,
..PipelineConfig::default()
};
let epochs = exg::preprocess(data_f32, chan_pos_arr, src_sfreq, &preproc_cfg)?;
let n_epochs = epochs.len();
let mut batches = Vec::with_capacity(n_epochs);
for (eeg_arr, pos_arr) in epochs {
let (c, t) = eeg_arr.dim();
let eeg_data: Vec<f32> = eeg_arr.iter().copied().collect();
let eeg = Tensor::<B, 2>::from_data(TensorData::new(eeg_data, vec![c, t]), device);
let pos_data: Vec<f32> = pos_arr.iter().copied().collect();
let chan_pos_t = Tensor::<B, 2>::from_data(TensorData::new(pos_data, vec![c, 3]), device);
let chan_pos_disc = discretize_chan_pos(chan_pos_t.clone(), data_cfg, device);
let tc = t / data_cfg.num_fine_time_pts;
let (eeg_tokens, _, posd, t_coarse) = chop_and_reshape(
eeg,
chan_pos_t.clone(),
chan_pos_disc,
data_cfg.num_fine_time_pts,
);
let tok_idx = build_tok_idx(posd, t_coarse);
let encoder_input = eeg_tokens.unsqueeze_dim::<3>(0);
batches.push(InputBatch {
encoder_input,
tok_idx,
chan_pos: chan_pos_t,
n_channels: c,
tc,
});
}
let info = FifInfo {
ch_names,
ch_pos_mm,
sfreq: src_sfreq,
n_times_raw,
duration_s,
n_epochs,
target_sfreq: preproc_cfg.target_sfreq,
epoch_dur_s: preproc_cfg.epoch_dur,
};
Ok((batches, info))
}
fn repeat_interleave_rows_f<B: Backend>(t: Tensor<B, 2>, repeats: usize) -> Tensor<B, 2> {
let [s, c] = t.dims();
t.unsqueeze_dim::<3>(1).expand([s, repeats, c]).reshape([s * repeats, c])
}
fn repeat_interleave_rows_i<B: Backend>(
t: Tensor<B, 2, Int>,
repeats: usize,
) -> Tensor<B, 2, Int> {
let [s, c] = t.dims();
t.unsqueeze_dim::<3>(1).expand([s, repeats, c]).reshape([s * repeats, c])
}
fn bytes_to_f32(data: &[u8], dtype: safetensors::Dtype) -> anyhow::Result<Vec<f32>> {
match dtype {
safetensors::Dtype::F32 =>
Ok(data.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect()),
safetensors::Dtype::BF16 =>
Ok(data.chunks_exact(2)
.map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
.collect()),
other => anyhow::bail!("unsupported dtype {:?}", other),
}
}