use burn::tensor::Element;
use burn::{nn::LeakyReluConfig, prelude::*};
use faer::MatRef;
use nn::{LeakyRelu, Linear, LinearConfig};
use num_traits::{Float, FromPrimitive};
use std::marker::PhantomData;
#[derive(Config, Debug)]
pub struct UmapMlpConfig {
pub input_size: usize,
pub hidden_sizes: Vec<usize>,
pub output_size: usize,
}
#[derive(Module, Debug)]
pub struct UmapMlp<B: Backend> {
layers: Vec<Linear<B>>,
activation: LeakyRelu,
}
impl<B: Backend> UmapMlp<B> {
pub fn new(config: &UmapMlpConfig, device: &Device<B>) -> UmapMlp<B> {
let mut layer_vec = Vec::with_capacity(config.hidden_sizes.len());
let mut input_size = config.input_size;
for &hidden_size in &config.hidden_sizes {
let layer: Linear<B> = LinearConfig::new(input_size, hidden_size)
.with_bias(true)
.init(device);
layer_vec.push(layer);
input_size = hidden_size;
}
layer_vec.push(
LinearConfig::new(input_size, config.output_size)
.with_bias(true)
.init(device),
);
let activation = LeakyReluConfig::init(&LeakyReluConfig {
negative_slope: 0.1,
});
Self {
layers: layer_vec,
activation,
}
}
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let mut x = input;
for layer in &self.layers[..self.layers.len() - 1] {
x = layer.forward(x);
x = self.activation.forward(x);
}
self.layers.last().unwrap().forward(x)
}
}
impl UmapMlpConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> UmapMlp<B> {
UmapMlp::new(self, device)
}
pub fn from_params(input_size: usize, hidden_sizes: Vec<usize>, output_size: usize) -> Self {
Self {
input_size,
hidden_sizes,
output_size,
}
}
}
pub struct TrainedUmapModel<B: Backend, T> {
pub model: UmapMlp<B>,
device: B::Device,
_phantom: PhantomData<T>,
}
impl<B: Backend, T> TrainedUmapModel<B, T>
where
T: Element + Float + FromPrimitive,
{
pub fn new(model: UmapMlp<B>, device: B::Device) -> Self {
Self {
model,
device,
_phantom: PhantomData,
}
}
pub fn predict(&self, data: MatRef<T>) -> Vec<Vec<T>> {
let n_samples = data.nrows();
let n_features = data.ncols();
let data_flat: Vec<T> = (0..n_samples)
.flat_map(|i| (0..n_features).map(move |j| data[(i, j)]))
.collect();
let input = Tensor::<B, 2>::from_data(
TensorData::new(data_flat, [n_samples, n_features]),
&self.device,
);
let embeddings = self.model.forward(input);
let n_components = embeddings.dims()[1];
let embedding_data: Vec<T> = embeddings.into_data().to_vec().unwrap();
let mut result = vec![vec![T::zero(); n_samples]; n_components];
for i in 0..n_samples {
for j in 0..n_components {
result[j][i] = embedding_data[i * n_components + j];
}
}
result
}
}
#[cfg(test)]
mod model_tests {
use super::*;
use burn::backend::flex::FlexDevice;
use burn::backend::Flex;
type TestBackend = Flex<f32>;
#[test]
fn test_model_forward_shape() {
let device = FlexDevice;
let config = UmapMlpConfig::from_params(10, vec![64, 32], 2);
let model: UmapMlp<TestBackend> = config.init(&device);
let batch_size = 16;
let input = Tensor::<TestBackend, 2>::random(
[batch_size, 10],
burn::tensor::Distribution::Uniform(-1.0, 1.0),
&device,
);
let output = model.forward(input);
assert_eq!(output.dims()[0], batch_size);
assert_eq!(output.dims()[1], 2);
}
#[test]
fn test_model_no_hidden_layers() {
let device = FlexDevice;
let config = UmapMlpConfig::from_params(10, vec![], 2);
let model: UmapMlp<TestBackend> = config.init(&device);
let input = Tensor::<TestBackend, 2>::random(
[5, 10],
burn::tensor::Distribution::Uniform(-1.0, 1.0),
&device,
);
let output = model.forward(input);
assert_eq!(output.dims()[0], 5);
assert_eq!(output.dims()[1], 2);
}
#[test]
fn test_model_single_hidden_layer() {
let device = FlexDevice;
let config = UmapMlpConfig::from_params(10, vec![64], 2);
let model: UmapMlp<TestBackend> = config.init(&device);
let input = Tensor::<TestBackend, 2>::random(
[8, 10],
burn::tensor::Distribution::Uniform(-1.0, 1.0),
&device,
);
let output = model.forward(input);
assert_eq!(output.dims()[0], 8);
assert_eq!(output.dims()[1], 2);
}
#[test]
fn test_model_output_is_finite() {
let device = FlexDevice;
let config = UmapMlpConfig::from_params(10, vec![64, 32], 2);
let model: UmapMlp<TestBackend> = config.init(&device);
let input = Tensor::<TestBackend, 2>::random(
[16, 10],
burn::tensor::Distribution::Uniform(-1.0, 1.0),
&device,
);
let output = model.forward(input);
let output_data: Vec<f32> = output.to_data().to_vec().unwrap();
for (i, &val) in output_data.iter().enumerate() {
assert!(
val.is_finite(),
"Output at index {} is not finite: {}",
i,
val
);
}
}
#[test]
fn test_model_batch_size_one() {
let device = FlexDevice;
let config = UmapMlpConfig::from_params(5, vec![32], 3);
let model: UmapMlp<TestBackend> = config.init(&device);
let input = Tensor::<TestBackend, 2>::random(
[1, 5],
burn::tensor::Distribution::Uniform(-1.0, 1.0),
&device,
);
let output = model.forward(input);
assert_eq!(output.dims()[0], 1);
assert_eq!(output.dims()[1], 3);
}
#[test]
fn test_model_deterministic() {
let device = FlexDevice;
let config = UmapMlpConfig::from_params(10, vec![64], 2);
let model: UmapMlp<TestBackend> = config.init(&device);
let input =
Tensor::<TestBackend, 2>::from_floats([[1.0; 10], [2.0; 10], [3.0; 10]], &device);
let output1 = model.forward(input.clone());
let output2 = model.forward(input.clone());
let data1: Vec<f32> = output1.to_data().to_vec().unwrap();
let data2: Vec<f32> = output2.to_data().to_vec().unwrap();
assert_eq!(data1, data2, "Model should be deterministic");
}
#[test]
fn test_model_different_inputs_different_outputs() {
let device = FlexDevice;
let config = UmapMlpConfig::from_params(10, vec![64], 2);
let model: UmapMlp<TestBackend> = config.init(&device);
let input1 = Tensor::<TestBackend, 2>::from_floats([[1.0; 10]], &device);
let input2 = Tensor::<TestBackend, 2>::from_floats([[2.0; 10]], &device);
let output1 = model.forward(input1);
let output2 = model.forward(input2);
let data1: Vec<f32> = output1.to_data().to_vec().unwrap();
let data2: Vec<f32> = output2.to_data().to_vec().unwrap();
assert_ne!(
data1, data2,
"Different inputs should produce different outputs"
);
}
#[test]
fn test_model_config_builder() {
let config = UmapMlpConfig::from_params(20, vec![128, 64, 32], 5);
assert_eq!(config.input_size, 20);
assert_eq!(config.hidden_sizes, vec![128, 64, 32]);
assert_eq!(config.output_size, 5);
}
}