1use std::path::Path;
4use std::time::Instant;
5
6use anyhow::Context;
7
8use crate::config::{DataConfig, ModelConfig};
9use crate::data::{self, GradientData};
10use crate::error::BrainJepaError;
11
12use super::attn_layout::resolve_attn_layout;
13use super::device::ensure_device;
14use super::graph::{build_encoder_graph, EncoderSpec};
15use super::pos_embed_cpu::build_pos_embed;
16use super::weights::{apply_params, build_encoder_params, load_safetensors, ParamMap};
17
18pub struct EmbeddingResult {
20 pub embeddings: Vec<f32>,
22 pub shape: Vec<usize>,
24 pub n_rois: usize,
26 pub n_time_patches: usize,
28 pub ms_encode: f64,
30}
31
32impl EmbeddingResult {
33 pub fn n_patches(&self) -> usize {
34 self.n_rois * self.n_time_patches
35 }
36 pub fn embed_dim(&self) -> usize {
37 self.shape.get(1).copied().unwrap_or(0)
38 }
39
40 pub fn save_safetensors(&self, path: &str) -> anyhow::Result<()> {
41 use safetensors::{Dtype, View};
42 use std::borrow::Cow;
43
44 struct RawTensor {
45 data: Vec<u8>,
46 shape: Vec<usize>,
47 }
48 impl View for RawTensor {
49 fn dtype(&self) -> Dtype {
50 Dtype::F32
51 }
52 fn shape(&self) -> &[usize] {
53 &self.shape
54 }
55 fn data(&self) -> Cow<'_, [u8]> {
56 Cow::Borrowed(&self.data)
57 }
58 fn data_len(&self) -> usize {
59 self.data.len()
60 }
61 }
62
63 let bytes: Vec<u8> = self
64 .embeddings
65 .iter()
66 .flat_map(|f| f.to_le_bytes())
67 .collect();
68 let tensor = RawTensor {
69 data: bytes,
70 shape: self.shape.clone(),
71 };
72 let pairs: Vec<(&str, RawTensor)> = vec![("embeddings", tensor)];
73 let out = safetensors::serialize(pairs, None)?;
74 std::fs::write(path, out)?;
75 Ok(())
76 }
77}
78
79fn warmup_run(compiled: &mut rlx::CompiledGraph, x: &[f32]) {
81 if compiled.run_slots(&[x]).is_empty() {
82 let _ = compiled.run(&[("x", x)]);
83 }
84}
85
86fn read_output_f32(
88 compiled: &rlx::CompiledGraph,
89 off: usize,
90 len: usize,
91) -> anyhow::Result<Vec<f32>> {
92 let base = compiled.arena_ptr();
93 anyhow::ensure!(len > 0, "encoder output is empty");
94 let out = unsafe { std::slice::from_raw_parts(base.add(off) as *const f32, len) };
95 Ok(out.to_vec())
96}
97
98pub struct BrainJepaEncoder {
99 pub model_cfg: ModelConfig,
100 pub data_cfg: DataConfig,
101 pub device: rlx::Device,
102
103 #[allow(dead_code)]
104 params: ParamMap,
105 compiled: rlx::CompiledGraph,
106
107 n_rois: usize,
108 #[allow(dead_code)]
109 n_time: usize,
110 n_time_patches: usize,
111}
112
113impl BrainJepaEncoder {
114 pub fn from_weights(
115 weights_path: &str,
116 gradient_csv_path: &str,
117 model_cfg: &ModelConfig,
118 data_cfg: &DataConfig,
119 device: &rlx::Device,
120 ) -> anyhow::Result<(Self, f64)> {
121 ensure_device(*device)?;
122
123 if !Path::new(weights_path).exists() {
124 return Err(BrainJepaError::FileNotFound {
125 kind: "weights",
126 path: weights_path.into(),
127 }
128 .into());
129 }
130
131 let grad = GradientData::from_csv(gradient_csv_path)?;
132 let expected_rois = data_cfg.crop_size.0;
133 if grad.n_rois != expected_rois {
134 return Err(BrainJepaError::GradientRoiMismatch {
135 expected: expected_rois,
136 got: grad.n_rois,
137 }
138 .into());
139 }
140
141 let t = Instant::now();
142 let mut raw = load_safetensors(weights_path)?;
143 let (params, grad_proj) = build_encoder_params(&mut raw, model_cfg)?;
144 let ms_weights = t.elapsed().as_secs_f64() * 1000.0;
145
146 let n_rois = data_cfg.crop_size.0;
147 let n_time = data_cfg.crop_size.1;
148 let patch = model_cfg.patch_size;
149 let n_time_patches = n_time / patch;
150 let n = n_rois * n_time_patches;
151
152 let (grad_w, grad_b, grad_dim) = grad_proj
154 .map(|(w, b, gd)| (Some(w), Some(b), gd))
155 .unwrap_or((None, None, grad.grad_dim));
156
157 let pos = build_pos_embed(
158 &model_cfg.pos_mode,
159 n_rois,
160 n_time_patches,
161 model_cfg.embed_dim,
162 &grad.values,
163 grad_dim,
164 grad_w.as_deref(),
165 grad_b.as_deref(),
166 )?;
167
168 let spec = EncoderSpec {
169 b: 1,
170 h: n_rois,
171 w: n_time,
172 patch,
173 w_p: n_time_patches,
174 n,
175 dim: model_cfg.embed_dim,
176 depth: model_cfg.depth,
177 num_heads: model_cfg.num_heads,
178 head_dim: model_cfg.embed_dim / model_cfg.num_heads,
179 hidden_dim: (model_cfg.embed_dim as f64 * model_cfg.mlp_ratio) as usize,
180 norm_eps: model_cfg.norm_eps as f32,
181 };
182
183 let attn_layout = resolve_attn_layout(*device)?;
184 let graph = build_encoder_graph(&spec, attn_layout);
185 let session = rlx::Session::new(*device);
186 let mut compiled = session.compile(graph);
187 apply_params(&mut compiled, ¶ms);
188 compiled.set_param("pos_embed", &pos);
189
190 if !matches!(*device, rlx::Device::Cpu) {
192 let x_warm = vec![0.0f32; 1 * 1 * n_rois * n_time];
193 warmup_run(&mut compiled, &x_warm);
194 }
195
196 Ok((
197 Self {
198 model_cfg: model_cfg.clone(),
199 data_cfg: data_cfg.clone(),
200 device: *device,
201 params,
202 compiled,
203 n_rois,
204 n_time,
205 n_time_patches,
206 },
207 ms_weights,
208 ))
209 }
210
211 pub fn describe(&self) -> String {
212 format!(
213 "Brain-JEPA encoder (RLX, {}) embed_dim={} depth={} heads={} patch={}",
214 super::device::display_name(self.device),
215 self.model_cfg.embed_dim,
216 self.model_cfg.depth,
217 self.model_cfg.num_heads,
218 self.model_cfg.patch_size
219 )
220 }
221
222 pub fn encode_safetensors(&mut self, fmri_path: &str) -> anyhow::Result<EmbeddingResult> {
223 let input = data::load_fmri_safetensors_f32(fmri_path)
224 .with_context(|| format!("loading fmri safetensors: {fmri_path}"))?;
225 self.encode_f32(input.data, input.n_rois, input.n_time)
226 }
227
228 pub fn encode_csv(&mut self, csv_path: &str) -> anyhow::Result<EmbeddingResult> {
229 let input = data::load_fmri_csv_f32(csv_path)
230 .with_context(|| format!("loading fmri csv: {csv_path}"))?;
231 self.encode_f32(input.data, input.n_rois, input.n_time)
232 }
233
234 fn encode_f32(
235 &mut self,
236 mut x: Vec<f32>, n_rois: usize,
238 n_time: usize,
239 ) -> anyhow::Result<EmbeddingResult> {
240 x = data::preprocess_fmri_f32(
242 x,
243 n_rois,
244 n_time,
245 self.data_cfg.crop_size.1,
246 self.data_cfg.downsample,
247 )?;
248
249 let t = Instant::now();
250 let slots = self.compiled.run_slots(&[&x]);
251 let embeddings = if let Some(&(out_off, out_len)) = slots.first() {
252 read_output_f32(&self.compiled, out_off, out_len)?
253 } else {
254 self.compiled
256 .run(&[("x", &x)])
257 .into_iter()
258 .next()
259 .ok_or_else(|| anyhow::anyhow!("encoder graph produced no output"))?
260 };
261 let ms_encode = t.elapsed().as_secs_f64() * 1000.0;
262
263 Ok(EmbeddingResult {
264 embeddings,
265 shape: vec![self.n_rois * self.n_time_patches, self.model_cfg.embed_dim],
266 n_rois: self.n_rois,
267 n_time_patches: self.n_time_patches,
268 ms_encode,
269 })
270 }
271}