jepa_core/encoder.rs
1//! Encoder trait for JEPA.
2//!
3//! Implements RFC-002 (Encoder Module).
4//!
5//! An encoder maps raw input (images, video, or already-embedded tokens)
6//! into a [`Representation`] in embedding space.
7//!
8//! In a JEPA training loop **two** encoder instances exist:
9//!
10//! | Role | Gradients? | Weight update |
11//! |------|-----------|---------------|
12//! | **Context encoder** (θ) | Yes | Backpropagation |
13//! | **Target encoder** (ξ) | No | EMA of θ (see [`crate::ema::Ema`]) |
14//!
15//! Both share the same architecture and implement this trait. The
16//! asymmetric update (EMA on the target) is what prevents collapse.
17//!
18//! See [`crate::collapse`] for the regularizers that complement EMA.
19
20use burn::tensor::backend::Backend;
21
22use crate::types::Representation;
23
24/// Trait for JEPA encoders.
25///
26/// An encoder maps raw input to a [`Representation`] with shape
27/// `[batch, seq_len, embed_dim]`. Concrete implementations include:
28///
29/// - [`jepa_vision::VitEncoder`](../../jepa_vision/vit/struct.VitEncoder.html) — Vision Transformer for images
30/// - [`jepa_vision::VitVideoEncoder`](../../jepa_vision/video/struct.VitVideoEncoder.html) — Vision Transformer for video
31///
32/// # Type parameters
33///
34/// - `B` — burn backend (e.g. `NdArray`, `Wgpu`, `Tch`)
35///
36/// # Associated types
37///
38/// - `Input` — the raw input type this encoder accepts. For vision
39/// encoders this is typically a `Tensor<B, 4>` (images) or
40/// `Tensor<B, 5>` (video). Higher-level wrappers may accept
41/// [`Representation<B>`] so that levels in a hierarchy can chain.
42pub trait Encoder<B: Backend> {
43 /// The type of input this encoder accepts.
44 type Input;
45
46 /// Encode input into a representation.
47 ///
48 /// # Arguments
49 /// * `input` - The raw input to encode
50 ///
51 /// # Returns
52 /// A [`Representation`] with shape `[batch, seq_len, embed_dim]`
53 fn encode(&self, input: &Self::Input) -> Representation<B>;
54
55 /// Get the output embedding dimension.
56 fn embed_dim(&self) -> usize;
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62 use burn::tensor::Tensor;
63 use burn_ndarray::NdArray;
64
65 type TestBackend = NdArray<f32>;
66
67 /// A trivial encoder for testing the trait definition.
68 struct IdentityEncoder {
69 dim: usize,
70 }
71
72 impl Encoder<TestBackend> for IdentityEncoder {
73 type Input = Tensor<TestBackend, 3>;
74
75 fn encode(&self, input: &Self::Input) -> Representation<TestBackend> {
76 Representation::new(input.clone())
77 }
78
79 fn embed_dim(&self) -> usize {
80 self.dim
81 }
82 }
83
84 #[test]
85 fn test_encoder_trait_is_implementable() {
86 let encoder = IdentityEncoder { dim: 64 };
87 let device = burn_ndarray::NdArrayDevice::Cpu;
88 let input: Tensor<TestBackend, 3> = Tensor::zeros([2, 8, 64], &device);
89 let repr = encoder.encode(&input);
90 assert_eq!(repr.batch_size(), 2);
91 assert_eq!(repr.seq_len(), 8);
92 assert_eq!(repr.embed_dim(), 64);
93 assert_eq!(encoder.embed_dim(), 64);
94 }
95}