1use std::path::Path;
6
7use crate::error::BrainJepaError;
8
9#[derive(Debug, Clone)]
11pub struct FmriInputF32 {
12 pub data: Vec<f32>,
13 pub n_rois: usize,
14 pub n_time: usize,
15}
16
17#[derive(Debug)]
19pub struct GradientData {
20 pub values: Vec<f32>,
22 pub n_rois: usize,
23 pub grad_dim: usize,
24}
25
26impl GradientData {
27 pub fn from_csv(path: &str) -> crate::error::Result<Self> {
32 let p = Path::new(path);
33 if !p.exists() {
34 return Err(BrainJepaError::FileNotFound {
35 kind: "gradient CSV",
36 path: p.to_path_buf(),
37 });
38 }
39
40 let content = std::fs::read_to_string(p)?;
41 let mut values = Vec::new();
42 let mut n_rois = 0usize;
43 let mut grad_dim = 0usize;
44
45 for (line_no, line) in content.lines().enumerate() {
46 let line = line.trim();
47 if line.is_empty() || line.starts_with('#') {
48 continue;
49 }
50 let parts: Vec<f32> = line
51 .split(',')
52 .filter_map(|s| s.trim().parse::<f32>().ok())
53 .collect();
54 if parts.is_empty() {
55 continue;
56 }
57 if grad_dim == 0 {
58 grad_dim = parts.len();
59 } else if parts.len() != grad_dim {
60 return Err(BrainJepaError::InconsistentCsvRow {
61 path: p.to_path_buf(),
62 row: line_no + 1,
63 expected: grad_dim,
64 got: parts.len(),
65 });
66 }
67 values.extend_from_slice(&parts);
68 n_rois += 1;
69 }
70
71 if n_rois == 0 {
72 return Err(BrainJepaError::EmptyCsv {
73 path: p.to_path_buf(),
74 });
75 }
76
77 Ok(Self {
78 values,
79 n_rois,
80 grad_dim,
81 })
82 }
83}
84
85pub fn load_fmri_safetensors_f32(path: &str) -> anyhow::Result<FmriInputF32> {
89 let p = Path::new(path);
90 if !p.exists() {
91 return Err(BrainJepaError::FileNotFound {
92 kind: "fMRI input",
93 path: p.to_path_buf(),
94 }
95 .into());
96 }
97
98 let bytes = std::fs::read(p)?;
99 let st = safetensors::SafeTensors::deserialize(&bytes)?;
100
101 let view = st
102 .tensor("fmri")
103 .map_err(|e| anyhow::anyhow!("missing 'fmri' key: {e}"))?;
104 let shape = view.shape().to_vec();
105 let data_bytes = view.data();
106
107 let f32s: Vec<f32> = match view.dtype() {
108 safetensors::Dtype::F32 => data_bytes
109 .chunks_exact(4)
110 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
111 .collect(),
112 safetensors::Dtype::BF16 => data_bytes
113 .chunks_exact(2)
114 .map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
115 .collect(),
116 safetensors::Dtype::F16 => data_bytes
117 .chunks_exact(2)
118 .map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
119 .collect(),
120 other => anyhow::bail!("unsupported dtype {:?}", other),
121 };
122
123 let (n_rois, n_time, data) = match shape.len() {
124 2 => {
125 let (h, w) = (shape[0], shape[1]);
126 let mut out = vec![0f32; 1 * 1 * h * w];
127 out.copy_from_slice(&f32s);
128 (h, w, out)
129 }
130 3 => {
131 let (h, w) = (shape[1], shape[2]);
132 let mut out = vec![0f32; 1 * 1 * h * w];
133 out.copy_from_slice(&f32s);
134 (h, w, out)
135 }
136 4 => {
137 let (h, w) = (shape[2], shape[3]);
138 (h, w, f32s)
139 }
140 _ => anyhow::bail!("unexpected fmri tensor rank: {}", shape.len()),
141 };
142
143 Ok(FmriInputF32 {
144 data,
145 n_rois,
146 n_time,
147 })
148}
149
150pub fn load_fmri_csv_f32(path: &str) -> crate::error::Result<FmriInputF32> {
152 let p = Path::new(path);
153 if !p.exists() {
154 return Err(BrainJepaError::FileNotFound {
155 kind: "fMRI CSV",
156 path: p.to_path_buf(),
157 });
158 }
159
160 let content = std::fs::read_to_string(p)?;
161 let mut values = Vec::new();
162 let mut n_rois = 0usize;
163 let mut n_time = 0usize;
164
165 for (line_no, line) in content.lines().enumerate() {
166 let line = line.trim();
167 if line.is_empty() || line.starts_with('#') {
168 continue;
169 }
170 let parts: Vec<f32> = line
171 .split(',')
172 .filter_map(|s| s.trim().parse::<f32>().ok())
173 .collect();
174 if parts.is_empty() {
175 continue;
176 }
177 if n_time == 0 {
178 n_time = parts.len();
179 } else if parts.len() != n_time {
180 return Err(BrainJepaError::InconsistentCsvRow {
181 path: p.to_path_buf(),
182 row: line_no + 1,
183 expected: n_time,
184 got: parts.len(),
185 });
186 }
187 values.extend_from_slice(&parts);
188 n_rois += 1;
189 }
190
191 if n_rois == 0 {
192 return Err(BrainJepaError::EmptyCsv {
193 path: p.to_path_buf(),
194 });
195 }
196
197 Ok(FmriInputF32 {
198 data: values,
199 n_rois,
200 n_time,
201 })
202}
203
204pub fn standardize_f32_inplace(x: &mut [f32]) {
206 let n = x.len().max(1) as f32;
207 let mean = x.iter().sum::<f32>() / n;
208 for v in x.iter_mut() {
209 *v -= mean;
210 }
211 let var_sum: f32 = x.iter().map(|v| v * v).sum();
212 let std = (var_sum / n).sqrt() + 1e-8;
213 for v in x.iter_mut() {
214 *v /= std;
215 }
216}
217
218pub fn preprocess_fmri_f32(
220 mut data: Vec<f32>,
221 n_rois: usize,
222 n_time: usize,
223 target_time: usize,
224 downsample: bool,
225) -> crate::error::Result<Vec<f32>> {
226 if n_time != target_time && downsample {
227 data = temporal_downsample_f32(data, n_rois, n_time, target_time)?;
228 }
229 standardize_f32_inplace(&mut data);
230 Ok(data)
231}
232
233pub fn temporal_downsample_f32(
237 x: Vec<f32>,
238 n_rois: usize,
239 n_time: usize,
240 target_frames: usize,
241) -> crate::error::Result<Vec<f32>> {
242 if n_time == target_frames {
243 return Ok(x);
244 }
245 if target_frames > n_time {
246 return Err(BrainJepaError::DownsampleUpscale {
247 src: n_time,
248 dst: target_frames,
249 });
250 }
251 let step = n_time as f64 / target_frames as f64;
252 let indices: Vec<usize> = (0..target_frames)
253 .map(|i| ((i as f64 * step) as usize).min(n_time - 1))
254 .collect();
255 let mut out = vec![0f32; 1 * 1 * n_rois * target_frames];
256 for roi in 0..n_rois {
257 for (j, &src_t) in indices.iter().enumerate() {
258 out[roi * target_frames + j] = x[roi * n_time + src_t];
259 }
260 }
261 Ok(out)
262}