Skip to main content

rlx_vjepa2/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Tier-0 V-JEPA2 flows — encoder, predictor, pooler.
17
18use anyhow::Result;
19use rlx_flow::BuiltModel;
20
21use super::config::Vjepa2Config;
22use super::predictor::Vjepa2PredictorLayout;
23use super::weights::{Vjepa2EncoderWeights, Vjepa2PoolerWeights, Vjepa2PredictorWeights};
24use rlx_core::flow_util::built_from_hir;
25
26#[derive(Clone)]
27pub struct Vjepa2EncoderFlow<'a> {
28    cfg: &'a Vjepa2Config,
29    encoder: &'a Vjepa2EncoderWeights,
30    batch: usize,
31}
32
33impl<'a> Vjepa2EncoderFlow<'a> {
34    pub fn new(cfg: &'a Vjepa2Config, encoder: &'a Vjepa2EncoderWeights, batch: usize) -> Self {
35        Self {
36            cfg,
37            encoder,
38            batch,
39        }
40    }
41
42    pub fn build(self) -> Result<Vjepa2EncoderBuilt> {
43        let (hir, params, preprocess) =
44            super::builder::build_vjepa2_encoder_hir_sized(self.cfg, self.encoder, self.batch)?;
45        Ok(Vjepa2EncoderBuilt {
46            model: built_from_hir(hir, params)?,
47            preprocess,
48        })
49    }
50}
51
52pub struct Vjepa2EncoderBuilt {
53    pub model: BuiltModel,
54    pub preprocess: super::builder::Vjepa2GraphPreprocess,
55}
56
57#[derive(Clone)]
58pub struct Vjepa2PredictorFlow<'a> {
59    cfg: &'a Vjepa2Config,
60    predictor: &'a Vjepa2PredictorWeights,
61    layout: &'a Vjepa2PredictorLayout,
62    mask_rows: &'a [f32],
63    batch: usize,
64}
65
66impl<'a> Vjepa2PredictorFlow<'a> {
67    pub fn new(
68        cfg: &'a Vjepa2Config,
69        predictor: &'a Vjepa2PredictorWeights,
70        layout: &'a Vjepa2PredictorLayout,
71        mask_rows: &'a [f32],
72        batch: usize,
73    ) -> Self {
74        Self {
75            cfg,
76            predictor,
77            layout,
78            mask_rows,
79            batch,
80        }
81    }
82
83    pub fn build(self) -> Result<BuiltModel> {
84        let (hir, params) = super::builder::build_vjepa2_predictor_hir_sized(
85            self.cfg,
86            self.predictor,
87            self.layout,
88            self.mask_rows,
89            self.batch,
90        )?;
91        built_from_hir(hir, params.f32)
92    }
93}
94
95#[derive(Clone)]
96pub struct Vjepa2PoolerFlow<'a> {
97    cfg: &'a Vjepa2Config,
98    pooler: &'a Vjepa2PoolerWeights,
99    batch: usize,
100}
101
102impl<'a> Vjepa2PoolerFlow<'a> {
103    pub fn new(cfg: &'a Vjepa2Config, pooler: &'a Vjepa2PoolerWeights, batch: usize) -> Self {
104        Self { cfg, pooler, batch }
105    }
106
107    pub fn build(self) -> Result<BuiltModel> {
108        let (hir, params) =
109            super::builder::build_vjepa2_pooler_hir_sized(self.cfg, self.pooler, self.batch)?;
110        built_from_hir(hir, params.f32)
111    }
112}