1use anyhow::Result;
19use rlx_flow::BuiltModel;
20
21use super::config::Vjepa2Config;
22use super::predictor::Vjepa2PredictorLayout;
23use super::weights::{Vjepa2EncoderWeights, Vjepa2PoolerWeights, Vjepa2PredictorWeights};
24use rlx_core::flow_util::built_from_hir;
25
26#[derive(Clone)]
27pub struct Vjepa2EncoderFlow<'a> {
28 cfg: &'a Vjepa2Config,
29 encoder: &'a Vjepa2EncoderWeights,
30 batch: usize,
31}
32
33impl<'a> Vjepa2EncoderFlow<'a> {
34 pub fn new(cfg: &'a Vjepa2Config, encoder: &'a Vjepa2EncoderWeights, batch: usize) -> Self {
35 Self {
36 cfg,
37 encoder,
38 batch,
39 }
40 }
41
42 pub fn build(self) -> Result<Vjepa2EncoderBuilt> {
43 let (hir, params, preprocess) =
44 super::builder::build_vjepa2_encoder_hir_sized(self.cfg, self.encoder, self.batch)?;
45 Ok(Vjepa2EncoderBuilt {
46 model: built_from_hir(hir, params)?,
47 preprocess,
48 })
49 }
50}
51
52pub struct Vjepa2EncoderBuilt {
53 pub model: BuiltModel,
54 pub preprocess: super::builder::Vjepa2GraphPreprocess,
55}
56
57#[derive(Clone)]
58pub struct Vjepa2PredictorFlow<'a> {
59 cfg: &'a Vjepa2Config,
60 predictor: &'a Vjepa2PredictorWeights,
61 layout: &'a Vjepa2PredictorLayout,
62 mask_rows: &'a [f32],
63 batch: usize,
64}
65
66impl<'a> Vjepa2PredictorFlow<'a> {
67 pub fn new(
68 cfg: &'a Vjepa2Config,
69 predictor: &'a Vjepa2PredictorWeights,
70 layout: &'a Vjepa2PredictorLayout,
71 mask_rows: &'a [f32],
72 batch: usize,
73 ) -> Self {
74 Self {
75 cfg,
76 predictor,
77 layout,
78 mask_rows,
79 batch,
80 }
81 }
82
83 pub fn build(self) -> Result<BuiltModel> {
84 let (hir, params) = super::builder::build_vjepa2_predictor_hir_sized(
85 self.cfg,
86 self.predictor,
87 self.layout,
88 self.mask_rows,
89 self.batch,
90 )?;
91 built_from_hir(hir, params.f32)
92 }
93}
94
95#[derive(Clone)]
96pub struct Vjepa2PoolerFlow<'a> {
97 cfg: &'a Vjepa2Config,
98 pooler: &'a Vjepa2PoolerWeights,
99 batch: usize,
100}
101
102impl<'a> Vjepa2PoolerFlow<'a> {
103 pub fn new(cfg: &'a Vjepa2Config, pooler: &'a Vjepa2PoolerWeights, batch: usize) -> Self {
104 Self { cfg, pooler, batch }
105 }
106
107 pub fn build(self) -> Result<BuiltModel> {
108 let (hir, params) =
109 super::builder::build_vjepa2_pooler_hir_sized(self.cfg, self.pooler, self.batch)?;
110 built_from_hir(hir, params.f32)
111 }
112}