use std::path::Path;
use burn::prelude::*;
use crate::error::BrainHarmonyError;
#[derive(Debug)]
pub struct GradientData {
pub values: Vec<f32>,
pub n_rois: usize,
pub grad_dim: usize,
}
impl GradientData {
pub fn from_csv(path: &str) -> crate::error::Result<Self> {
load_csv_data(path, "gradient CSV")
}
pub fn to_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 2> {
Tensor::<B, 2>::from_data(
TensorData::new(self.values.clone(), vec![self.n_rois, self.grad_dim]),
device,
)
}
}
#[derive(Debug)]
pub struct GeohData {
pub values: Vec<f32>,
pub n_rois: usize,
pub geoh_dim: usize,
}
impl GeohData {
pub fn from_csv(path: &str) -> crate::error::Result<Self> {
let p = Path::new(path);
if !p.exists() {
return Err(BrainHarmonyError::FileNotFound {
kind: "geometric harmonics CSV",
path: p.to_path_buf(),
});
}
let content = std::fs::read_to_string(p)?;
let mut values = Vec::new();
let mut n_rois = 0usize;
let mut geoh_dim = 0usize;
for (line_no, line) in content.lines().enumerate() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if line_no == 0 {
continue;
}
let parts: Vec<f32> = line
.split(',')
.skip(1) .filter_map(|s| s.trim().parse::<f32>().ok())
.collect();
if parts.is_empty() {
continue;
}
if geoh_dim == 0 {
geoh_dim = parts.len();
} else if parts.len() != geoh_dim {
return Err(BrainHarmonyError::InconsistentCsvRow {
path: p.to_path_buf(),
row: line_no + 1,
expected: geoh_dim,
got: parts.len(),
});
}
values.extend_from_slice(&parts);
n_rois += 1;
}
if n_rois == 0 {
return Err(BrainHarmonyError::EmptyCsv {
path: p.to_path_buf(),
});
}
Ok(Self {
values,
n_rois,
geoh_dim,
})
}
pub fn to_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 2> {
Tensor::<B, 2>::from_data(
TensorData::new(self.values.clone(), vec![self.n_rois, self.geoh_dim]),
device,
)
}
}
#[derive(Debug)]
pub struct SignalInput<B: Backend> {
pub data: Tensor<B, 4>,
pub n_rois: usize,
pub signal_length: usize,
}
pub fn load_signal_safetensors<B: Backend>(
path: &str,
device: &B::Device,
) -> anyhow::Result<SignalInput<B>> {
let p = Path::new(path);
if !p.exists() {
return Err(BrainHarmonyError::FileNotFound {
kind: "signal input",
path: p.to_path_buf(),
}
.into());
}
let bytes = std::fs::read(p)?;
let st = safetensors::SafeTensors::deserialize(&bytes)?;
let key = if st.tensor("signal").is_ok() {
"signal"
} else {
"fmri"
};
let view = st
.tensor(key)
.map_err(|e| anyhow::anyhow!("missing '{key}' key: {e}"))?;
let shape = view.shape().to_vec();
let data_bytes = view.data();
let f32s: Vec<f32> = match view.dtype() {
safetensors::Dtype::F32 => data_bytes
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect(),
safetensors::Dtype::BF16 => data_bytes
.chunks_exact(2)
.map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
.collect(),
safetensors::Dtype::F16 => data_bytes
.chunks_exact(2)
.map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
.collect(),
other => anyhow::bail!("unsupported dtype {:?}", other),
};
let (n_rois, signal_length, tensor) = match shape.len() {
2 => {
let t = Tensor::<B, 2>::from_data(
TensorData::new(f32s, shape.clone()),
device,
);
(shape[0], shape[1], t.unsqueeze_dim::<3>(0).unsqueeze_dim::<4>(0))
}
3 => {
let t = Tensor::<B, 3>::from_data(
TensorData::new(f32s, shape.clone()),
device,
);
(shape[1], shape[2], t.unsqueeze_dim::<4>(1))
}
4 => {
let t = Tensor::<B, 4>::from_data(
TensorData::new(f32s, shape.clone()),
device,
);
(shape[2], shape[3], t)
}
_ => anyhow::bail!("unexpected signal tensor rank: {}", shape.len()),
};
Ok(SignalInput {
data: tensor,
n_rois,
signal_length,
})
}
pub fn load_signal_csv<B: Backend>(
path: &str,
device: &B::Device,
) -> crate::error::Result<SignalInput<B>> {
let p = Path::new(path);
if !p.exists() {
return Err(BrainHarmonyError::FileNotFound {
kind: "signal CSV",
path: p.to_path_buf(),
});
}
let content = std::fs::read_to_string(p)?;
let mut values = Vec::new();
let mut n_rois = 0usize;
let mut n_time = 0usize;
for (line_no, line) in content.lines().enumerate() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let parts: Vec<f32> = line
.split(',')
.filter_map(|s| s.trim().parse::<f32>().ok())
.collect();
if parts.is_empty() {
continue;
}
if n_time == 0 {
n_time = parts.len();
} else if parts.len() != n_time {
return Err(BrainHarmonyError::InconsistentCsvRow {
path: p.to_path_buf(),
row: line_no + 1,
expected: n_time,
got: parts.len(),
});
}
values.extend_from_slice(&parts);
n_rois += 1;
}
if n_rois == 0 {
return Err(BrainHarmonyError::EmptyCsv {
path: p.to_path_buf(),
});
}
let t = Tensor::<B, 2>::from_data(
TensorData::new(values, vec![n_rois, n_time]),
device,
)
.unsqueeze_dim::<3>(0)
.unsqueeze_dim::<4>(0);
Ok(SignalInput {
data: t,
n_rois,
signal_length: n_time,
})
}
pub fn standardize<B: Backend>(x: Tensor<B, 4>) -> Tensor<B, 4> {
let [b, c, h, w] = x.dims();
let n = (b * c * h * w) as f32;
let sum: f32 = x.clone().sum().into_scalar().elem();
let mean = sum / n;
let centered = x.sub_scalar(mean);
let var_sum: f32 = centered.clone().powf_scalar(2.0f32).sum().into_scalar().elem();
let std = (var_sum / n).sqrt() + 1e-8;
centered.div_scalar(std)
}
fn load_csv_data(path: &str, kind: &'static str) -> crate::error::Result<GradientData> {
let p = Path::new(path);
if !p.exists() {
return Err(BrainHarmonyError::FileNotFound {
kind,
path: p.to_path_buf(),
});
}
let content = std::fs::read_to_string(p)?;
let mut values = Vec::new();
let mut n_rois = 0usize;
let mut grad_dim = 0usize;
for (line_no, line) in content.lines().enumerate() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let parts: Vec<f32> = line
.split(',')
.filter_map(|s| s.trim().parse::<f32>().ok())
.collect();
if parts.is_empty() {
continue;
}
if grad_dim == 0 {
grad_dim = parts.len();
} else if parts.len() != grad_dim {
return Err(BrainHarmonyError::InconsistentCsvRow {
path: p.to_path_buf(),
row: line_no + 1,
expected: grad_dim,
got: parts.len(),
});
}
values.extend_from_slice(&parts);
n_rois += 1;
}
if n_rois == 0 {
return Err(BrainHarmonyError::EmptyCsv {
path: p.to_path_buf(),
});
}
Ok(GradientData {
values,
n_rois,
grad_dim,
})
}