use burn::prelude::*;
use nn::{Linear, LinearConfig, Relu};
use serde::{Deserialize, Serialize};
#[derive(Module, Debug)]
pub struct UMAPModel<B: Backend> {
linear1: Linear<B>, linear2: Linear<B>, linear3: Linear<B>, linear4: Linear<B>, activation: Relu, }
impl<B: Backend> UMAPModel<B> {
pub fn new(config: &UMAPModelConfig, device: &Device<B>) -> Self {
let linear1 = LinearConfig::new(config.input_size, config.hidden_size)
.with_bias(true)
.init(device);
let linear2 = LinearConfig::new(config.hidden_size, config.hidden_size)
.with_bias(true)
.init(device);
let linear3 = LinearConfig::new(config.hidden_size, config.hidden_size)
.with_bias(true)
.init(device);
let linear4 = LinearConfig::new(config.hidden_size, config.output_size)
.with_bias(true)
.init(device);
let activation = Relu::new();
UMAPModel {
linear1,
linear2,
linear3,
linear4,
activation,
}
}
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.linear1.forward(input); let x = self.activation.forward(x); let x = self.linear2.forward(x); let x = self.activation.forward(x); let x = self.linear3.forward(x); let x = self.activation.forward(x); let x = self.linear4.forward(x); x
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UMAPModelConfig {
pub input_size: usize, pub hidden_size: usize, pub output_size: usize, }
impl UMAPModelConfig {
pub fn builder() -> UMAPModelConfigBuilder {
UMAPModelConfigBuilder::default()
}
}
#[derive(Debug, Clone)]
pub struct UMAPModelConfigBuilder {
input_size: Option<usize>,
hidden_size: Option<usize>,
output_size: Option<usize>,
}
impl Default for UMAPModelConfigBuilder {
fn default() -> Self {
UMAPModelConfigBuilder {
input_size: Some(100),
hidden_size: Some(100),
output_size: Some(2),
}
}
}
impl UMAPModelConfigBuilder {
pub fn input_size(mut self, input_size: usize) -> Self {
self.input_size = Some(input_size);
self
}
pub fn hidden_size(mut self, hidden_size: usize) -> Self {
self.hidden_size = Some(hidden_size);
self
}
pub fn output_size(mut self, output_size: usize) -> Self {
self.output_size = Some(output_size);
self
}
pub fn build(self) -> Result<UMAPModelConfig, String> {
Ok(UMAPModelConfig {
input_size: self
.input_size
.ok_or_else(|| "Input size must be set".to_string())?,
hidden_size: self
.hidden_size
.ok_or_else(|| "Hidden size must be set".to_string())?,
output_size: self
.output_size
.ok_or_else(|| "Output size must be set".to_string())?,
})
}
}