Skip to main content

jepa_core/
encoder.rs

1//! Encoder trait for JEPA.
2//!
3//! Implements RFC-002 (Encoder Module).
4//!
5//! An encoder maps raw input (images, video, or already-embedded tokens)
6//! into a [`Representation`] in embedding space.
7//!
8//! In a JEPA training loop **two** encoder instances exist:
9//!
10//! | Role | Gradients? | Weight update |
11//! |------|-----------|---------------|
12//! | **Context encoder** (θ) | Yes | Backpropagation |
13//! | **Target encoder** (ξ) | No | EMA of θ (see [`crate::ema::Ema`]) |
14//!
15//! Both share the same architecture and implement this trait. The
16//! asymmetric update (EMA on the target) is what prevents collapse.
17//!
18//! See [`crate::collapse`] for the regularizers that complement EMA.
19
20use burn::tensor::backend::Backend;
21
22use crate::types::Representation;
23
24/// Trait for JEPA encoders.
25///
26/// An encoder maps raw input to a [`Representation`] with shape
27/// `[batch, seq_len, embed_dim]`. Concrete implementations include:
28///
29/// - [`jepa_vision::VitEncoder`](../../jepa_vision/vit/struct.VitEncoder.html) — Vision Transformer for images
30/// - [`jepa_vision::VitVideoEncoder`](../../jepa_vision/video/struct.VitVideoEncoder.html) — Vision Transformer for video
31///
32/// # Type parameters
33///
34/// - `B` — burn backend (e.g. `NdArray`, `Wgpu`, `Tch`)
35///
36/// # Associated types
37///
38/// - `Input` — the raw input type this encoder accepts. For vision
39///   encoders this is typically a `Tensor<B, 4>` (images) or
40///   `Tensor<B, 5>` (video). Higher-level wrappers may accept
41///   [`Representation<B>`] so that levels in a hierarchy can chain.
42pub trait Encoder<B: Backend> {
43    /// The type of input this encoder accepts.
44    type Input;
45
46    /// Encode input into a representation.
47    ///
48    /// # Arguments
49    /// * `input` - The raw input to encode
50    ///
51    /// # Returns
52    /// A [`Representation`] with shape `[batch, seq_len, embed_dim]`
53    fn encode(&self, input: &Self::Input) -> Representation<B>;
54
55    /// Get the output embedding dimension.
56    fn embed_dim(&self) -> usize;
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62    use burn::tensor::Tensor;
63    use burn_ndarray::NdArray;
64
65    type TestBackend = NdArray<f32>;
66
67    /// A trivial encoder for testing the trait definition.
68    struct IdentityEncoder {
69        dim: usize,
70    }
71
72    impl Encoder<TestBackend> for IdentityEncoder {
73        type Input = Tensor<TestBackend, 3>;
74
75        fn encode(&self, input: &Self::Input) -> Representation<TestBackend> {
76            Representation::new(input.clone())
77        }
78
79        fn embed_dim(&self) -> usize {
80            self.dim
81        }
82    }
83
84    #[test]
85    fn test_encoder_trait_is_implementable() {
86        let encoder = IdentityEncoder { dim: 64 };
87        let device = burn_ndarray::NdArrayDevice::Cpu;
88        let input: Tensor<TestBackend, 3> = Tensor::zeros([2, 8, 64], &device);
89        let repr = encoder.encode(&input);
90        assert_eq!(repr.batch_size(), 2);
91        assert_eq!(repr.seq_len(), 8);
92        assert_eq!(repr.embed_dim(), 64);
93        assert_eq!(encoder.embed_dim(), 64);
94    }
95}