use burn::prelude::*;
use nn::{Linear, LinearConfig, Relu};
use serde::{Deserialize, Serialize};
#[derive(Module, Debug)]
pub struct UMAPModel<B: Backend> {
layers: Vec<Linear<B>>, activation: Relu, }
impl<B: Backend> UMAPModel<B> {
pub fn new(config: &UMAPModelConfig, device: &Device<B>) -> Self {
let mut layers = Vec::new();
let mut input_size = config.input_size;
for &hidden_size in &config.hidden_sizes {
layers.push(
LinearConfig::new(input_size, hidden_size)
.with_bias(true)
.init(device),
);
input_size = hidden_size; }
layers.push(
LinearConfig::new(input_size, config.output_size)
.with_bias(true)
.init(device),
);
let activation = Relu::new();
UMAPModel { layers, activation }
}
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let mut x = input;
for (i, layer) in self.layers.iter().enumerate() {
x = layer.forward(x);
if i < self.layers.len() - 1 {
x = self.activation.forward(x); }
}
x
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UMAPModelConfig {
pub input_size: usize, pub hidden_sizes: Vec<usize>, pub output_size: usize, }
impl UMAPModelConfig {
pub fn builder() -> UMAPModelConfigBuilder {
UMAPModelConfigBuilder::default()
}
}
#[derive(Debug, Clone)]
pub struct UMAPModelConfigBuilder {
input_size: Option<usize>,
hidden_sizes: Option<Vec<usize>>,
output_size: Option<usize>,
}
impl Default for UMAPModelConfigBuilder {
fn default() -> Self {
UMAPModelConfigBuilder {
input_size: Some(100),
hidden_sizes: Some(vec![100, 100, 100]), output_size: Some(2),
}
}
}
impl UMAPModelConfigBuilder {
pub fn input_size(mut self, input_size: usize) -> Self {
self.input_size = Some(input_size);
self
}
pub fn hidden_sizes(mut self, hidden_sizes: Vec<usize>) -> Self {
self.hidden_sizes = Some(hidden_sizes);
self
}
pub fn output_size(mut self, output_size: usize) -> Self {
self.output_size = Some(output_size);
self
}
pub fn build(self) -> Result<UMAPModelConfig, String> {
Ok(UMAPModelConfig {
input_size: self
.input_size
.ok_or_else(|| "Input size must be set".to_string())?,
hidden_sizes: self
.hidden_sizes
.ok_or_else(|| "Hidden sizes must be set".to_string())?,
output_size: self
.output_size
.ok_or_else(|| "Output size must be set".to_string())?,
})
}
}