use burn::prelude::*;
use burn::tensor::backend::AutodiffBackend;
use tsai_core::TSClassificationModel;
use crate::cnn::{FCN, InceptionTimePlus, OmniScaleCNN, ResNetPlus, XCMPlus, XceptionTime};
use crate::rnn::{RNNAttention, RNNPlus};
use crate::rocket::HydraPlus;
use crate::transformer::{PatchTST, TSPerceiver, TSiTPlus, TSTPlus};
impl<B: AutodiffBackend> TSClassificationModel<B> for InceptionTimePlus<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
self.forward(x)
}
}
impl<B: AutodiffBackend> TSClassificationModel<B> for XceptionTime<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
self.forward(x)
}
}
impl<B: AutodiffBackend> TSClassificationModel<B> for FCN<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
self.forward(x)
}
}
impl<B: AutodiffBackend> TSClassificationModel<B> for ResNetPlus<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
self.forward(x)
}
}
impl<B: AutodiffBackend> TSClassificationModel<B> for OmniScaleCNN<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
self.forward(x)
}
}
impl<B: AutodiffBackend> TSClassificationModel<B> for XCMPlus<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
self.forward(x)
}
}
impl<B: AutodiffBackend> TSClassificationModel<B> for TSTPlus<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
self.forward(x)
}
}
impl<B: AutodiffBackend> TSClassificationModel<B> for TSiTPlus<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
self.forward(x)
}
}
impl<B: AutodiffBackend> TSClassificationModel<B> for TSPerceiver<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
self.forward(x)
}
}
impl<B: AutodiffBackend> TSClassificationModel<B> for PatchTST<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
self.forward(x)
}
}
impl<B: AutodiffBackend> TSClassificationModel<B> for HydraPlus<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
self.forward(x)
}
}
impl<B: AutodiffBackend> TSClassificationModel<B> for RNNPlus<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
self.forward(x)
}
}
impl<B: AutodiffBackend> TSClassificationModel<B> for RNNAttention<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
self.forward(x)
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_traits_compile() {
}
}