use anyhow::Result;
use rlx_flow::BuiltModel;
use super::config::Vjepa2Config;
use super::predictor::Vjepa2PredictorLayout;
use super::weights::{Vjepa2EncoderWeights, Vjepa2PoolerWeights, Vjepa2PredictorWeights};
use rlx_core::flow_util::built_from_hir;
#[derive(Clone)]
pub struct Vjepa2EncoderFlow<'a> {
cfg: &'a Vjepa2Config,
encoder: &'a Vjepa2EncoderWeights,
batch: usize,
}
impl<'a> Vjepa2EncoderFlow<'a> {
pub fn new(cfg: &'a Vjepa2Config, encoder: &'a Vjepa2EncoderWeights, batch: usize) -> Self {
Self {
cfg,
encoder,
batch,
}
}
pub fn build(self) -> Result<Vjepa2EncoderBuilt> {
let (hir, params, preprocess) =
super::builder::build_vjepa2_encoder_hir_sized(self.cfg, self.encoder, self.batch)?;
Ok(Vjepa2EncoderBuilt {
model: built_from_hir(hir, params)?,
preprocess,
})
}
}
pub struct Vjepa2EncoderBuilt {
pub model: BuiltModel,
pub preprocess: super::builder::Vjepa2GraphPreprocess,
}
#[derive(Clone)]
pub struct Vjepa2PredictorFlow<'a> {
cfg: &'a Vjepa2Config,
predictor: &'a Vjepa2PredictorWeights,
layout: &'a Vjepa2PredictorLayout,
mask_rows: &'a [f32],
batch: usize,
}
impl<'a> Vjepa2PredictorFlow<'a> {
pub fn new(
cfg: &'a Vjepa2Config,
predictor: &'a Vjepa2PredictorWeights,
layout: &'a Vjepa2PredictorLayout,
mask_rows: &'a [f32],
batch: usize,
) -> Self {
Self {
cfg,
predictor,
layout,
mask_rows,
batch,
}
}
pub fn build(self) -> Result<BuiltModel> {
let (hir, params) = super::builder::build_vjepa2_predictor_hir_sized(
self.cfg,
self.predictor,
self.layout,
self.mask_rows,
self.batch,
)?;
built_from_hir(hir, params.f32)
}
}
#[derive(Clone)]
pub struct Vjepa2PoolerFlow<'a> {
cfg: &'a Vjepa2Config,
pooler: &'a Vjepa2PoolerWeights,
batch: usize,
}
impl<'a> Vjepa2PoolerFlow<'a> {
pub fn new(cfg: &'a Vjepa2Config, pooler: &'a Vjepa2PoolerWeights, batch: usize) -> Self {
Self { cfg, pooler, batch }
}
pub fn build(self) -> Result<BuiltModel> {
let (hir, params) =
super::builder::build_vjepa2_pooler_hir_sized(self.cfg, self.pooler, self.batch)?;
built_from_hir(hir, params.f32)
}
}