burn_dragon_core 0.5.0

burn dragon core model and utilities
Documentation
mod coefficients;
mod config;
mod reference;

pub use coefficients::{
    ManifoldHyperConnectionCoefficients, ManifoldHyperConnectionStreamCoefficients,
    ManifoldHyperConnectionStreamOutput, ManifoldHyperConnectionWidthOutput,
};
pub use config::{ManifoldHyperConnectionCoefficientPolicy, ManifoldHyperConnectionsConfig};
pub use reference::ManifoldHyperConnections;

use burn::tensor::Tensor;
use burn::tensor::backend::Backend;

pub fn mhc_split<B: Backend>(
    mhc: Option<&ManifoldHyperConnections<B>>,
    residuals: Tensor<B, 4>,
) -> (Tensor<B, 4>, Tensor<B, 4>, Option<Tensor<B, 2>>) {
    mhc_split_with_coefficients(mhc, residuals, None)
}

pub fn mhc_split_with_coefficients<B: Backend>(
    mhc: Option<&ManifoldHyperConnections<B>>,
    residuals: Tensor<B, 4>,
    coefficients: Option<&ManifoldHyperConnectionCoefficients<B>>,
) -> (Tensor<B, 4>, Tensor<B, 4>, Option<Tensor<B, 2>>) {
    if let Some(mhc) = mhc {
        if let Some(coefficients) = coefficients {
            mhc.width_connection_with_coefficients(residuals, coefficients)
                .into_legacy()
        } else {
            mhc.width_connection(residuals).into_legacy()
        }
    } else {
        (residuals.clone(), residuals, None)
    }
}

pub fn mhc_merge<B: Backend>(
    mhc: Option<&ManifoldHyperConnections<B>>,
    branch_output: Tensor<B, 4>,
    residuals: Tensor<B, 4>,
    beta: Option<Tensor<B, 2>>,
) -> Tensor<B, 4> {
    mhc_merge_with_coefficients(mhc, branch_output, residuals, None, beta)
}

pub fn mhc_merge_with_coefficients<B: Backend>(
    mhc: Option<&ManifoldHyperConnections<B>>,
    branch_output: Tensor<B, 4>,
    residuals: Tensor<B, 4>,
    coefficients: Option<&ManifoldHyperConnectionCoefficients<B>>,
    beta: Option<Tensor<B, 2>>,
) -> Tensor<B, 4> {
    if let Some(mhc) = mhc {
        if let Some(coefficients) = coefficients {
            mhc.depth_connection_with_coefficients(branch_output, residuals, coefficients)
        } else {
            mhc.depth_connection(branch_output, residuals, beta)
        }
    } else {
        branch_output
    }
}

pub fn mhc_passthrough<B: Backend>(
    mhc: Option<&ManifoldHyperConnections<B>>,
    residuals: Tensor<B, 4>,
) -> Tensor<B, 4> {
    mhc_passthrough_with_coefficients(mhc, residuals, None)
}

pub fn mhc_passthrough_with_coefficients<B: Backend>(
    mhc: Option<&ManifoldHyperConnections<B>>,
    residuals: Tensor<B, 4>,
    coefficients: Option<&ManifoldHyperConnectionCoefficients<B>>,
) -> Tensor<B, 4> {
    if let Some(mhc) = mhc {
        if let Some(coefficients) = coefficients {
            let output = mhc.width_connection_with_coefficients(residuals, coefficients);
            mhc.depth_connection_with_coefficients(
                output.branch_input,
                output.residuals_out,
                coefficients,
            )
        } else {
            mhc.passthrough(residuals)
        }
    } else {
        residuals
    }
}

#[cfg(test)]
mod tests;