manifolds-rs 0.3.3

Embedding methods implemented in Rust: (parametric) UMAP, tSNE, PHATE, Diffusion Map and PacMAP.
Documentation
//! Model parameters for the neural network that is being trained during
//! parametric UMAP.

use burn::tensor::Element;
use burn::{nn::LeakyReluConfig, prelude::*};
use faer::MatRef;
use nn::{LeakyRelu, Linear, LinearConfig};
use num_traits::{Float, FromPrimitive};
use std::marker::PhantomData;

//////////////////
// Model config //
//////////////////

/// Configuration structure for creating a `UMAPModel`.
#[derive(Config, Debug)]
pub struct UmapMlpConfig {
    /// Number of input features
    pub input_size: usize,
    /// Vector of sizes for the hidden layers.
    pub hidden_sizes: Vec<usize>,
    /// Number of output features.
    pub output_size: usize,
}

/// MLP model that can have several layers
///
/// ### Fields
///
/// * `layers` - Vector of linear layers
/// * `activation` - Activation function
#[derive(Module, Debug)]
pub struct UmapMlp<B: Backend> {
    layers: Vec<Linear<B>>,
    activation: LeakyRelu,
}

///////////
// Model //
///////////

impl<B: Backend> UmapMlp<B> {
    /// Generate a new model based on a UmapMlpConfig
    ///
    /// ### Params
    ///
    /// * `config` - The configuration with the model specifications
    /// * `device` - The device on which to put the model
    ///
    /// ### Return
    ///
    /// Initialised UmapMlp model.
    pub fn new(config: &UmapMlpConfig, device: &Device<B>) -> UmapMlp<B> {
        let mut layer_vec = Vec::with_capacity(config.hidden_sizes.len());
        let mut input_size = config.input_size;

        for &hidden_size in &config.hidden_sizes {
            let layer: Linear<B> = LinearConfig::new(input_size, hidden_size)
                .with_bias(true)
                .init(device);
            layer_vec.push(layer);
            input_size = hidden_size;
        }

        // last layer
        layer_vec.push(
            LinearConfig::new(input_size, config.output_size)
                .with_bias(true)
                .init(device),
        );

        let activation = LeakyReluConfig::init(&LeakyReluConfig {
            negative_slope: 0.1,
        });

        Self {
            layers: layer_vec,
            activation,
        }
    }

    /// Forward pass of the model
    ///
    /// ### Params
    ///
    /// * `input` - Tensor of [batch_size, features]
    ///
    /// ### Returns
    ///
    /// Tensor of [batch_size, embedding]
    pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
        let mut x = input;

        // iterate through the layers
        for layer in &self.layers[..self.layers.len() - 1] {
            x = layer.forward(x);
            x = self.activation.forward(x);
        }

        // last one with the activation
        self.layers.last().unwrap().forward(x)
    }
}

/////////////
// Builder //
/////////////

impl UmapMlpConfig {
    /// Initialise the model
    ///
    /// ### Params
    ///
    /// * `device` - The device on which to run the model
    ///
    /// ### Returns
    ///
    /// Initialised model
    pub fn init<B: Backend>(&self, device: &B::Device) -> UmapMlp<B> {
        UmapMlp::new(self, device)
    }

    /// Generate a new configuration based on parameters
    ///
    /// ### Params
    ///
    /// * `input_size` - Number of input features.
    /// * `hidden_sizes` - Vector of sizes for the hidden layers.
    /// * `output_size` - Number of output features.
    ///
    /// ### Returns
    ///
    /// Initialised UmapMlpConfig
    pub fn from_params(input_size: usize, hidden_sizes: Vec<usize>, output_size: usize) -> Self {
        Self {
            input_size,
            hidden_sizes,
            output_size,
        }
    }
}

///////////////////
// Trained model //
///////////////////

/// TrainedUmapModel
pub struct TrainedUmapModel<B: Backend, T> {
    /// The trained model
    pub model: UmapMlp<B>,
    /// Device on which the tensor reside
    device: B::Device,
    /// Phantomdata for types for compiling
    _phantom: PhantomData<T>,
}

impl<B: Backend, T> TrainedUmapModel<B, T>
where
    T: Element + Float + FromPrimitive,
{
    /// Returns an initialised TrainedUmapModel
    ///
    /// ### Params
    ///
    /// * `model` - The trained model
    ///
    /// ### Returns
    ///
    /// Initialised self
    pub fn new(model: UmapMlp<B>, device: B::Device) -> Self {
        Self {
            model,
            device,
            _phantom: PhantomData,
        }
    }

    /// Predict on new data
    ///
    /// ### Params
    ///
    /// * `data` - Data which to use for the predictions
    ///
    /// ### Returns
    ///
    /// The embeddings generated by the trained model
    pub fn predict(&self, data: MatRef<T>) -> Vec<Vec<T>> {
        let n_samples = data.nrows();
        let n_features = data.ncols();

        let data_flat: Vec<T> = (0..n_samples)
            .flat_map(|i| (0..n_features).map(move |j| data[(i, j)]))
            .collect();

        let input = Tensor::<B, 2>::from_data(
            TensorData::new(data_flat, [n_samples, n_features]),
            &self.device,
        );

        let embeddings = self.model.forward(input);

        let n_components = embeddings.dims()[1];
        let embedding_data: Vec<T> = embeddings.into_data().to_vec().unwrap();

        let mut result = vec![vec![T::zero(); n_samples]; n_components];
        for i in 0..n_samples {
            for j in 0..n_components {
                result[j][i] = embedding_data[i * n_components + j];
            }
        }

        result
    }
}

///////////
// Tests //
///////////

#[cfg(test)]
mod model_tests {
    use super::*;
    use burn::backend::flex::FlexDevice;
    use burn::backend::Flex;

    type TestBackend = Flex<f32>;

    #[test]
    fn test_model_forward_shape() {
        let device = FlexDevice;
        let config = UmapMlpConfig::from_params(10, vec![64, 32], 2);
        let model: UmapMlp<TestBackend> = config.init(&device);

        let batch_size = 16;
        let input = Tensor::<TestBackend, 2>::random(
            [batch_size, 10],
            burn::tensor::Distribution::Uniform(-1.0, 1.0),
            &device,
        );

        let output = model.forward(input);

        assert_eq!(output.dims()[0], batch_size);
        assert_eq!(output.dims()[1], 2);
    }

    #[test]
    fn test_model_no_hidden_layers() {
        let device = FlexDevice;
        let config = UmapMlpConfig::from_params(10, vec![], 2);
        let model: UmapMlp<TestBackend> = config.init(&device);

        let input = Tensor::<TestBackend, 2>::random(
            [5, 10],
            burn::tensor::Distribution::Uniform(-1.0, 1.0),
            &device,
        );
        let output = model.forward(input);

        assert_eq!(output.dims()[0], 5);
        assert_eq!(output.dims()[1], 2);
    }

    #[test]
    fn test_model_single_hidden_layer() {
        let device = FlexDevice;
        let config = UmapMlpConfig::from_params(10, vec![64], 2);
        let model: UmapMlp<TestBackend> = config.init(&device);

        let input = Tensor::<TestBackend, 2>::random(
            [8, 10],
            burn::tensor::Distribution::Uniform(-1.0, 1.0),
            &device,
        );
        let output = model.forward(input);

        assert_eq!(output.dims()[0], 8);
        assert_eq!(output.dims()[1], 2);
    }

    #[test]
    fn test_model_output_is_finite() {
        let device = FlexDevice;
        let config = UmapMlpConfig::from_params(10, vec![64, 32], 2);
        let model: UmapMlp<TestBackend> = config.init(&device);

        let input = Tensor::<TestBackend, 2>::random(
            [16, 10],
            burn::tensor::Distribution::Uniform(-1.0, 1.0),
            &device,
        );
        let output = model.forward(input);

        let output_data: Vec<f32> = output.to_data().to_vec().unwrap();

        for (i, &val) in output_data.iter().enumerate() {
            assert!(
                val.is_finite(),
                "Output at index {} is not finite: {}",
                i,
                val
            );
        }
    }

    #[test]
    fn test_model_batch_size_one() {
        let device = FlexDevice;
        let config = UmapMlpConfig::from_params(5, vec![32], 3);
        let model: UmapMlp<TestBackend> = config.init(&device);

        let input = Tensor::<TestBackend, 2>::random(
            [1, 5],
            burn::tensor::Distribution::Uniform(-1.0, 1.0),
            &device,
        );
        let output = model.forward(input);

        assert_eq!(output.dims()[0], 1);
        assert_eq!(output.dims()[1], 3);
    }

    #[test]
    fn test_model_deterministic() {
        let device = FlexDevice;
        let config = UmapMlpConfig::from_params(10, vec![64], 2);
        let model: UmapMlp<TestBackend> = config.init(&device);

        let input =
            Tensor::<TestBackend, 2>::from_floats([[1.0; 10], [2.0; 10], [3.0; 10]], &device);

        let output1 = model.forward(input.clone());
        let output2 = model.forward(input.clone());

        let data1: Vec<f32> = output1.to_data().to_vec().unwrap();
        let data2: Vec<f32> = output2.to_data().to_vec().unwrap();

        assert_eq!(data1, data2, "Model should be deterministic");
    }

    #[test]
    fn test_model_different_inputs_different_outputs() {
        let device = FlexDevice;
        let config = UmapMlpConfig::from_params(10, vec![64], 2);
        let model: UmapMlp<TestBackend> = config.init(&device);

        let input1 = Tensor::<TestBackend, 2>::from_floats([[1.0; 10]], &device);
        let input2 = Tensor::<TestBackend, 2>::from_floats([[2.0; 10]], &device);

        let output1 = model.forward(input1);
        let output2 = model.forward(input2);

        let data1: Vec<f32> = output1.to_data().to_vec().unwrap();
        let data2: Vec<f32> = output2.to_data().to_vec().unwrap();

        assert_ne!(
            data1, data2,
            "Different inputs should produce different outputs"
        );
    }

    #[test]
    fn test_model_config_builder() {
        let config = UmapMlpConfig::from_params(20, vec![128, 64, 32], 5);

        assert_eq!(config.input_size, 20);
        assert_eq!(config.hidden_sizes, vec![128, 64, 32]);
        assert_eq!(config.output_size, 5);
    }
}