Skip to main content

brainjepa/rlx/
inference.rs

1//! RLX-backed Brain-JEPA encoder inference.
2
3use std::path::Path;
4use std::time::Instant;
5
6use anyhow::Context;
7
8use crate::config::{DataConfig, ModelConfig};
9use crate::data::{self, GradientData};
10use crate::error::BrainJepaError;
11
12use super::attn_layout::resolve_attn_layout;
13use super::device::ensure_device;
14use super::graph::{build_encoder_graph, EncoderSpec};
15use super::pos_embed_cpu::build_pos_embed;
16use super::weights::{apply_params, build_encoder_params, load_safetensors, ParamMap};
17
18/// Encoder embedding output.
19pub struct EmbeddingResult {
20    /// Latent embeddings: row-major f32, shape [n_patches, embed_dim]
21    pub embeddings: Vec<f32>,
22    /// Shape: [n_patches, embed_dim]
23    pub shape: Vec<usize>,
24    /// Number of ROI patches
25    pub n_rois: usize,
26    /// Number of temporal patches
27    pub n_time_patches: usize,
28    /// Encoding time in milliseconds
29    pub ms_encode: f64,
30}
31
32impl EmbeddingResult {
33    pub fn n_patches(&self) -> usize {
34        self.n_rois * self.n_time_patches
35    }
36    pub fn embed_dim(&self) -> usize {
37        self.shape.get(1).copied().unwrap_or(0)
38    }
39
40    pub fn save_safetensors(&self, path: &str) -> anyhow::Result<()> {
41        use safetensors::{Dtype, View};
42        use std::borrow::Cow;
43
44        struct RawTensor {
45            data: Vec<u8>,
46            shape: Vec<usize>,
47        }
48        impl View for RawTensor {
49            fn dtype(&self) -> Dtype {
50                Dtype::F32
51            }
52            fn shape(&self) -> &[usize] {
53                &self.shape
54            }
55            fn data(&self) -> Cow<'_, [u8]> {
56                Cow::Borrowed(&self.data)
57            }
58            fn data_len(&self) -> usize {
59                self.data.len()
60            }
61        }
62
63        let bytes: Vec<u8> = self
64            .embeddings
65            .iter()
66            .flat_map(|f| f.to_le_bytes())
67            .collect();
68        let tensor = RawTensor {
69            data: bytes,
70            shape: self.shape.clone(),
71        };
72        let pairs: Vec<(&str, RawTensor)> = vec![("embeddings", tensor)];
73        let out = safetensors::serialize(pairs, None)?;
74        std::fs::write(path, out)?;
75        Ok(())
76    }
77}
78
79/// One forward pass for warmup (wgpu has no `run_slots` yet).
80fn warmup_run(compiled: &mut rlx::CompiledGraph, x: &[f32]) {
81    if compiled.run_slots(&[x]).is_empty() {
82        let _ = compiled.run(&[("x", x)]);
83    }
84}
85
86/// Copy the sole graph output from the arena after `run_slots`.
87fn read_output_f32(
88    compiled: &rlx::CompiledGraph,
89    off: usize,
90    len: usize,
91) -> anyhow::Result<Vec<f32>> {
92    let base = compiled.arena_ptr();
93    anyhow::ensure!(len > 0, "encoder output is empty");
94    let out = unsafe { std::slice::from_raw_parts(base.add(off) as *const f32, len) };
95    Ok(out.to_vec())
96}
97
98pub struct BrainJepaEncoder {
99    pub model_cfg: ModelConfig,
100    pub data_cfg: DataConfig,
101    pub device: rlx::Device,
102
103    #[allow(dead_code)]
104    params: ParamMap,
105    compiled: rlx::CompiledGraph,
106
107    n_rois: usize,
108    #[allow(dead_code)]
109    n_time: usize,
110    n_time_patches: usize,
111}
112
113impl BrainJepaEncoder {
114    pub fn from_weights(
115        weights_path: &str,
116        gradient_csv_path: &str,
117        model_cfg: &ModelConfig,
118        data_cfg: &DataConfig,
119        device: &rlx::Device,
120    ) -> anyhow::Result<(Self, f64)> {
121        ensure_device(*device)?;
122
123        if !Path::new(weights_path).exists() {
124            return Err(BrainJepaError::FileNotFound {
125                kind: "weights",
126                path: weights_path.into(),
127            }
128            .into());
129        }
130
131        let grad = GradientData::from_csv(gradient_csv_path)?;
132        let expected_rois = data_cfg.crop_size.0;
133        if grad.n_rois != expected_rois {
134            return Err(BrainJepaError::GradientRoiMismatch {
135                expected: expected_rois,
136                got: grad.n_rois,
137            }
138            .into());
139        }
140
141        let t = Instant::now();
142        let mut raw = load_safetensors(weights_path)?;
143        let (params, grad_proj) = build_encoder_params(&mut raw, model_cfg)?;
144        let ms_weights = t.elapsed().as_secs_f64() * 1000.0;
145
146        let n_rois = data_cfg.crop_size.0;
147        let n_time = data_cfg.crop_size.1;
148        let patch = model_cfg.patch_size;
149        let n_time_patches = n_time / patch;
150        let n = n_rois * n_time_patches;
151
152        // CPU build positional embedding once.
153        let (grad_w, grad_b, grad_dim) = grad_proj
154            .map(|(w, b, gd)| (Some(w), Some(b), gd))
155            .unwrap_or((None, None, grad.grad_dim));
156
157        let pos = build_pos_embed(
158            &model_cfg.pos_mode,
159            n_rois,
160            n_time_patches,
161            model_cfg.embed_dim,
162            &grad.values,
163            grad_dim,
164            grad_w.as_deref(),
165            grad_b.as_deref(),
166        )?;
167
168        let spec = EncoderSpec {
169            b: 1,
170            h: n_rois,
171            w: n_time,
172            patch,
173            w_p: n_time_patches,
174            n,
175            dim: model_cfg.embed_dim,
176            depth: model_cfg.depth,
177            num_heads: model_cfg.num_heads,
178            head_dim: model_cfg.embed_dim / model_cfg.num_heads,
179            hidden_dim: (model_cfg.embed_dim as f64 * model_cfg.mlp_ratio) as usize,
180            norm_eps: model_cfg.norm_eps as f32,
181        };
182
183        let attn_layout = resolve_attn_layout(*device)?;
184        let graph = build_encoder_graph(&spec, attn_layout);
185        let session = rlx::Session::new(*device);
186        let mut compiled = session.compile(graph);
187        apply_params(&mut compiled, &params);
188        compiled.set_param("pos_embed", &pos);
189
190        // Warm up GPU backends (MPSGraph first-run specialization, kernel cache).
191        if !matches!(*device, rlx::Device::Cpu) {
192            let x_warm = vec![0.0f32; 1 * 1 * n_rois * n_time];
193            warmup_run(&mut compiled, &x_warm);
194        }
195
196        Ok((
197            Self {
198                model_cfg: model_cfg.clone(),
199                data_cfg: data_cfg.clone(),
200                device: *device,
201                params,
202                compiled,
203                n_rois,
204                n_time,
205                n_time_patches,
206            },
207            ms_weights,
208        ))
209    }
210
211    pub fn describe(&self) -> String {
212        format!(
213            "Brain-JEPA encoder (RLX, {})  embed_dim={}  depth={}  heads={}  patch={}",
214            super::device::display_name(self.device),
215            self.model_cfg.embed_dim,
216            self.model_cfg.depth,
217            self.model_cfg.num_heads,
218            self.model_cfg.patch_size
219        )
220    }
221
222    pub fn encode_safetensors(&mut self, fmri_path: &str) -> anyhow::Result<EmbeddingResult> {
223        let input = data::load_fmri_safetensors_f32(fmri_path)
224            .with_context(|| format!("loading fmri safetensors: {fmri_path}"))?;
225        self.encode_f32(input.data, input.n_rois, input.n_time)
226    }
227
228    pub fn encode_csv(&mut self, csv_path: &str) -> anyhow::Result<EmbeddingResult> {
229        let input = data::load_fmri_csv_f32(csv_path)
230            .with_context(|| format!("loading fmri csv: {csv_path}"))?;
231        self.encode_f32(input.data, input.n_rois, input.n_time)
232    }
233
234    fn encode_f32(
235        &mut self,
236        mut x: Vec<f32>, // [1, 1, H, W] row-major
237        n_rois: usize,
238        n_time: usize,
239    ) -> anyhow::Result<EmbeddingResult> {
240        // Optional temporal downsampling (CPU)
241        x = data::preprocess_fmri_f32(
242            x,
243            n_rois,
244            n_time,
245            self.data_cfg.crop_size.1,
246            self.data_cfg.downsample,
247        )?;
248
249        let t = Instant::now();
250        let slots = self.compiled.run_slots(&[&x]);
251        let embeddings = if let Some(&(out_off, out_len)) = slots.first() {
252            read_output_f32(&self.compiled, out_off, out_len)?
253        } else {
254            // rlx-wgpu (and some other backends) do not implement run_slots yet.
255            self.compiled
256                .run(&[("x", &x)])
257                .into_iter()
258                .next()
259                .ok_or_else(|| anyhow::anyhow!("encoder graph produced no output"))?
260        };
261        let ms_encode = t.elapsed().as_secs_f64() * 1000.0;
262
263        Ok(EmbeddingResult {
264            embeddings,
265            shape: vec![self.n_rois * self.n_time_patches, self.model_cfg.embed_dim],
266            n_rois: self.n_rois,
267            n_time_patches: self.n_time_patches,
268            ms_encode,
269        })
270    }
271}