use burn::tensor::backend::Backend;
use crate::types::Representation;
pub trait Encoder<B: Backend> {
type Input;
fn encode(&self, input: &Self::Input) -> Representation<B>;
fn embed_dim(&self) -> usize;
}
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::Tensor;
use burn_ndarray::NdArray;
type TestBackend = NdArray<f32>;
struct IdentityEncoder {
dim: usize,
}
impl Encoder<TestBackend> for IdentityEncoder {
type Input = Tensor<TestBackend, 3>;
fn encode(&self, input: &Self::Input) -> Representation<TestBackend> {
Representation::new(input.clone())
}
fn embed_dim(&self) -> usize {
self.dim
}
}
#[test]
fn test_encoder_trait_is_implementable() {
let encoder = IdentityEncoder { dim: 64 };
let device = burn_ndarray::NdArrayDevice::Cpu;
let input: Tensor<TestBackend, 3> = Tensor::zeros([2, 8, 64], &device);
let repr = encoder.encode(&input);
assert_eq!(repr.batch_size(), 2);
assert_eq!(repr.seq_len(), 8);
assert_eq!(repr.embed_dim(), 64);
assert_eq!(encoder.embed_dim(), 64);
}
}