burn_dragon_core 0.5.0

burn dragon core model and utilities
Documentation
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;

#[derive(Debug, Clone)]
pub struct ManifoldHyperConnectionCoefficients<B: Backend> {
    pub residual_weights: Tensor<B, 2>,
    pub branch_input_weights: Tensor<B, 2>,
    pub branch_output_weights: Option<Tensor<B, 2>>,
}

#[derive(Debug, Clone)]
pub struct ManifoldHyperConnectionWidthOutput<B: Backend> {
    pub branch_input: Tensor<B, 4>,
    pub residuals_out: Tensor<B, 4>,
    pub coefficients: ManifoldHyperConnectionCoefficients<B>,
}

impl<B: Backend> ManifoldHyperConnectionWidthOutput<B> {
    pub fn into_legacy(self) -> (Tensor<B, 4>, Tensor<B, 4>, Option<Tensor<B, 2>>) {
        (
            self.branch_input,
            self.residuals_out,
            self.coefficients.branch_output_weights,
        )
    }
}

#[derive(Debug, Clone)]
pub struct ManifoldHyperConnectionStreamCoefficients<B: Backend> {
    pub residual_weights: Tensor<B, 4>,
    pub branch_input_weights: Tensor<B, 3>,
    pub branch_output_weights: Option<Tensor<B, 3>>,
}

#[derive(Debug, Clone)]
pub struct ManifoldHyperConnectionStreamOutput<B: Backend> {
    pub branch_input: Tensor<B, 4>,
    pub residuals_out: Tensor<B, 4>,
    pub coefficients: ManifoldHyperConnectionStreamCoefficients<B>,
}