use core::f32;
use burn::{
module::Module,
nn::{
self,
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
},
tensor::{backend::Backend, Tensor},
};
use super::{helpers::Sigmoid, INPUT_SPACE_SIZE, NUM_CLASSES};
#[cfg(feature = "ml_train")]
use crate::ml::train::{data::KordBatch, helpers::KordClassificationOutput};
#[derive(Module, Debug)]
pub struct KordModel<B: Backend> {
mha: MultiHeadAttention<B>,
output: nn::Linear<B>,
sigmoid: Sigmoid<B>,
}
impl<B: Backend> KordModel<B> {
pub fn new(mha_heads: usize, mha_dropout: f64, sigmoid_strength: f32) -> Self {
let mha = MultiHeadAttentionConfig::new(INPUT_SPACE_SIZE, mha_heads).with_dropout(mha_dropout).init::<B>();
let output = nn::LinearConfig::new(INPUT_SPACE_SIZE, NUM_CLASSES).init::<B>();
let sigmoid = Sigmoid::new(sigmoid_strength);
Self { mha, output, sigmoid }
}
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = input;
let [batch_size, input_size] = x.dims();
let attn_input = x.clone().reshape([batch_size, 1, input_size]);
let attn = self.mha.forward(MhaInput::new(attn_input.clone(), attn_input.clone(), attn_input));
let mut x = attn.context.reshape([batch_size, input_size]);
x = self.output.forward(x);
x = self.sigmoid.forward(x);
x
}
#[cfg(feature = "ml_train")]
pub fn forward_classification(&self, item: KordBatch<B>) -> KordClassificationOutput<B> {
use burn::nn::loss::MSELoss;
let targets = item.targets;
let output = self.forward(item.samples);
let loss = MSELoss::default();
let loss = loss.forward(output.clone(), targets.clone(), nn::loss::Reduction::Sum);
KordClassificationOutput { loss, output, targets }
}
}