1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
//! Model definition for the Kord model.
use core::f32;
use burn::{
module::Module,
nn::{
self,
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
},
tensor::{backend::Backend, Tensor},
};
use super::{helpers::Sigmoid, INPUT_SPACE_SIZE, NUM_CLASSES};
#[cfg(feature = "ml_train")]
use crate::ml::train::{data::KordBatch, helpers::KordClassificationOutput};
/// The Kord model.
///
/// This model is a transformer model that uses multi-head attention to classify notes from a frequency space.
#[derive(Module, Debug)]
pub struct KordModel<B: Backend> {
mha: MultiHeadAttention<B>,
output: nn::Linear<B>,
sigmoid: Sigmoid<B>,
}
impl<B: Backend> KordModel<B> {
/// Create the model from the given configuration.
pub fn new(mha_heads: usize, mha_dropout: f64, sigmoid_strength: f32) -> Self {
let mha = MultiHeadAttentionConfig::new(INPUT_SPACE_SIZE, mha_heads).with_dropout(mha_dropout).init::<B>();
let output = nn::LinearConfig::new(INPUT_SPACE_SIZE, NUM_CLASSES).init::<B>();
let sigmoid = Sigmoid::new(sigmoid_strength);
Self { mha, output, sigmoid }
}
/// Applies the forward pass on the input tensor.
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = input;
// Perform the multi-head attention transformer forward pass.
let [batch_size, input_size] = x.dims();
let attn_input = x.clone().reshape([batch_size, 1, input_size]);
let attn = self.mha.forward(MhaInput::new(attn_input.clone(), attn_input.clone(), attn_input));
// Reshape the output to remove the sequence dimension.
let mut x = attn.context.reshape([batch_size, input_size]);
// Perform the final linear layer to map to the output dimensions.
x = self.output.forward(x);
// Apply the sigmoid function to the output to achieve multi-classification.
x = self.sigmoid.forward(x);
x
}
/// Applies the forward classification pass on the input tensor.
#[cfg(feature = "ml_train")]
pub fn forward_classification(&self, item: KordBatch<B>) -> KordClassificationOutput<B> {
use burn::nn::loss::MSELoss;
let targets = item.targets;
let output = self.forward(item.samples);
let loss = MSELoss::default();
let loss = loss.forward(output.clone(), targets.clone(), nn::loss::Reduction::Sum);
// let loss = MeanSquareLoss::default();
// let loss = loss.forward(output.clone(), targets.clone());
// let loss = BinaryCrossEntropyLoss::default();
// let loss = loss.forward(output.clone(), targets.clone());
// let mut loss = FocalLoss::default();
// loss.gamma = 2.0;
// let loss = loss.forward(output.clone(), targets.clone());
//let loss = loss + l1_regularization(self, 1e-4);
// let harmonic_penalty_tensor = get_harmonic_penalty_tensor().to_device(&output.device());
// let harmonic_loss = output.clone().matmul(harmonic_penalty_tensor).sum_dim(0).mean().mul_scalar(0.0001);
// let loss = loss + harmonic_loss;
KordClassificationOutput { loss, output, targets }
}
}