Skip to main content

brainjepa/
data.rs

1/// fMRI data loading and preprocessing utilities.
2///
3/// Handles loading fMRI time series from CSV/safetensors and
4/// preparing them as model input tensors.
5use std::path::Path;
6
7use crate::error::BrainJepaError;
8
9/// fMRI input: plain f32 buffer, row-major `[1, 1, n_rois, n_time]`.
10#[derive(Debug, Clone)]
11pub struct FmriInputF32 {
12    pub data: Vec<f32>,
13    pub n_rois: usize,
14    pub n_time: usize,
15}
16
17/// Brain gradient coordinates loaded from CSV.
18#[derive(Debug)]
19pub struct GradientData {
20    /// Gradient values: [n_rois, grad_dim] as flat Vec
21    pub values: Vec<f32>,
22    pub n_rois: usize,
23    pub grad_dim: usize,
24}
25
26impl GradientData {
27    /// Load gradient mapping from a CSV file.
28    ///
29    /// Expected format: each row is an ROI, columns are gradient axes.
30    /// All rows must have the same number of columns.
31    pub fn from_csv(path: &str) -> crate::error::Result<Self> {
32        let p = Path::new(path);
33        if !p.exists() {
34            return Err(BrainJepaError::FileNotFound {
35                kind: "gradient CSV",
36                path: p.to_path_buf(),
37            });
38        }
39
40        let content = std::fs::read_to_string(p)?;
41        let mut values = Vec::new();
42        let mut n_rois = 0usize;
43        let mut grad_dim = 0usize;
44
45        for (line_no, line) in content.lines().enumerate() {
46            let line = line.trim();
47            if line.is_empty() || line.starts_with('#') {
48                continue;
49            }
50            let parts: Vec<f32> = line
51                .split(',')
52                .filter_map(|s| s.trim().parse::<f32>().ok())
53                .collect();
54            if parts.is_empty() {
55                continue;
56            }
57            if grad_dim == 0 {
58                grad_dim = parts.len();
59            } else if parts.len() != grad_dim {
60                return Err(BrainJepaError::InconsistentCsvRow {
61                    path: p.to_path_buf(),
62                    row: line_no + 1,
63                    expected: grad_dim,
64                    got: parts.len(),
65                });
66            }
67            values.extend_from_slice(&parts);
68            n_rois += 1;
69        }
70
71        if n_rois == 0 {
72            return Err(BrainJepaError::EmptyCsv {
73                path: p.to_path_buf(),
74            });
75        }
76
77        Ok(Self {
78            values,
79            n_rois,
80            grad_dim,
81        })
82    }
83}
84
85/// Load fMRI data from a safetensors file as a plain f32 buffer.
86///
87/// Accepts shapes `[n_rois, n_time]`, `[1, n_rois, n_time]`, or `[1, 1, n_rois, n_time]`.
88pub fn load_fmri_safetensors_f32(path: &str) -> anyhow::Result<FmriInputF32> {
89    let p = Path::new(path);
90    if !p.exists() {
91        return Err(BrainJepaError::FileNotFound {
92            kind: "fMRI input",
93            path: p.to_path_buf(),
94        }
95        .into());
96    }
97
98    let bytes = std::fs::read(p)?;
99    let st = safetensors::SafeTensors::deserialize(&bytes)?;
100
101    let view = st
102        .tensor("fmri")
103        .map_err(|e| anyhow::anyhow!("missing 'fmri' key: {e}"))?;
104    let shape = view.shape().to_vec();
105    let data_bytes = view.data();
106
107    let f32s: Vec<f32> = match view.dtype() {
108        safetensors::Dtype::F32 => data_bytes
109            .chunks_exact(4)
110            .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
111            .collect(),
112        safetensors::Dtype::BF16 => data_bytes
113            .chunks_exact(2)
114            .map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
115            .collect(),
116        safetensors::Dtype::F16 => data_bytes
117            .chunks_exact(2)
118            .map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
119            .collect(),
120        other => anyhow::bail!("unsupported dtype {:?}", other),
121    };
122
123    let (n_rois, n_time, data) = match shape.len() {
124        2 => {
125            let (h, w) = (shape[0], shape[1]);
126            let mut out = vec![0f32; 1 * 1 * h * w];
127            out.copy_from_slice(&f32s);
128            (h, w, out)
129        }
130        3 => {
131            let (h, w) = (shape[1], shape[2]);
132            let mut out = vec![0f32; 1 * 1 * h * w];
133            out.copy_from_slice(&f32s);
134            (h, w, out)
135        }
136        4 => {
137            let (h, w) = (shape[2], shape[3]);
138            (h, w, f32s)
139        }
140        _ => anyhow::bail!("unexpected fmri tensor rank: {}", shape.len()),
141    };
142
143    Ok(FmriInputF32 {
144        data,
145        n_rois,
146        n_time,
147    })
148}
149
150/// Load fMRI from CSV (rows = ROIs, columns = time points) as f32 buffer.
151pub fn load_fmri_csv_f32(path: &str) -> crate::error::Result<FmriInputF32> {
152    let p = Path::new(path);
153    if !p.exists() {
154        return Err(BrainJepaError::FileNotFound {
155            kind: "fMRI CSV",
156            path: p.to_path_buf(),
157        });
158    }
159
160    let content = std::fs::read_to_string(p)?;
161    let mut values = Vec::new();
162    let mut n_rois = 0usize;
163    let mut n_time = 0usize;
164
165    for (line_no, line) in content.lines().enumerate() {
166        let line = line.trim();
167        if line.is_empty() || line.starts_with('#') {
168            continue;
169        }
170        let parts: Vec<f32> = line
171            .split(',')
172            .filter_map(|s| s.trim().parse::<f32>().ok())
173            .collect();
174        if parts.is_empty() {
175            continue;
176        }
177        if n_time == 0 {
178            n_time = parts.len();
179        } else if parts.len() != n_time {
180            return Err(BrainJepaError::InconsistentCsvRow {
181                path: p.to_path_buf(),
182                row: line_no + 1,
183                expected: n_time,
184                got: parts.len(),
185            });
186        }
187        values.extend_from_slice(&parts);
188        n_rois += 1;
189    }
190
191    if n_rois == 0 {
192        return Err(BrainJepaError::EmptyCsv {
193            path: p.to_path_buf(),
194        });
195    }
196
197    Ok(FmriInputF32 {
198        data: values,
199        n_rois,
200        n_time,
201    })
202}
203
204/// Standardize in place: `(x - mean) / (std + 1e-8)` over all elements.
205pub fn standardize_f32_inplace(x: &mut [f32]) {
206    let n = x.len().max(1) as f32;
207    let mean = x.iter().sum::<f32>() / n;
208    for v in x.iter_mut() {
209        *v -= mean;
210    }
211    let var_sum: f32 = x.iter().map(|v| v * v).sum();
212    let std = (var_sum / n).sqrt() + 1e-8;
213    for v in x.iter_mut() {
214        *v /= std;
215    }
216}
217
218/// Downsample (if needed) + standardize.
219pub fn preprocess_fmri_f32(
220    mut data: Vec<f32>,
221    n_rois: usize,
222    n_time: usize,
223    target_time: usize,
224    downsample: bool,
225) -> crate::error::Result<Vec<f32>> {
226    if n_time != target_time && downsample {
227        data = temporal_downsample_f32(data, n_rois, n_time, target_time)?;
228    }
229    standardize_f32_inplace(&mut data);
230    Ok(data)
231}
232
233/// Temporal downsampling for an f32 NCHW buffer.
234///
235/// `x` is `[1, 1, n_rois, n_time]`. Returns `[1, 1, n_rois, target_frames]`.
236pub fn temporal_downsample_f32(
237    x: Vec<f32>,
238    n_rois: usize,
239    n_time: usize,
240    target_frames: usize,
241) -> crate::error::Result<Vec<f32>> {
242    if n_time == target_frames {
243        return Ok(x);
244    }
245    if target_frames > n_time {
246        return Err(BrainJepaError::DownsampleUpscale {
247            src: n_time,
248            dst: target_frames,
249        });
250    }
251    let step = n_time as f64 / target_frames as f64;
252    let indices: Vec<usize> = (0..target_frames)
253        .map(|i| ((i as f64 * step) as usize).min(n_time - 1))
254        .collect();
255    let mut out = vec![0f32; 1 * 1 * n_rois * target_frames];
256    for roi in 0..n_rois {
257        for (j, &src_t) in indices.iter().enumerate() {
258            out[roi * target_frames + j] = x[roi * n_time + src_t];
259        }
260    }
261    Ok(out)
262}