use burn::tensor::{backend::Backend, Tensor};
use crate::types::Representation;
pub trait Predictor<B: Backend> {
fn predict(
&self,
context: &Representation<B>,
target_positions: &Tensor<B, 2>,
latent: Option<&Tensor<B, 2>>,
) -> Representation<B>;
}
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::Tensor;
use burn_ndarray::NdArray;
type TestBackend = NdArray<f32>;
struct ZeroPredictor {
embed_dim: usize,
}
impl Predictor<TestBackend> for ZeroPredictor {
fn predict(
&self,
_context: &Representation<TestBackend>,
target_positions: &Tensor<TestBackend, 2>,
_latent: Option<&Tensor<TestBackend, 2>>,
) -> Representation<TestBackend> {
let [batch, num_targets] = target_positions.dims();
let device = target_positions.device();
Representation::new(Tensor::zeros([batch, num_targets, self.embed_dim], &device))
}
}
#[test]
fn test_predictor_trait_is_implementable() {
let predictor = ZeroPredictor { embed_dim: 64 };
let device = burn_ndarray::NdArrayDevice::Cpu;
let context = Representation::new(Tensor::zeros([2, 8, 64], &device));
let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([2, 4], &device);
let predicted = predictor.predict(&context, &target_pos, None);
assert_eq!(predicted.batch_size(), 2);
assert_eq!(predicted.seq_len(), 4);
assert_eq!(predicted.embed_dim(), 64);
}
#[test]
fn test_predictor_with_latent() {
let predictor = ZeroPredictor { embed_dim: 64 };
let device = burn_ndarray::NdArrayDevice::Cpu;
let context = Representation::new(Tensor::zeros([2, 8, 64], &device));
let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([2, 4], &device);
let latent: Tensor<TestBackend, 2> = Tensor::zeros([2, 16], &device);
let predicted = predictor.predict(&context, &target_pos, Some(&latent));
assert_eq!(predicted.batch_size(), 2);
assert_eq!(predicted.seq_len(), 4);
}
}