pub mod attention;
pub mod block;
pub mod decoder;
pub mod encoder;
pub mod feedforward;
pub mod flash_attn;
pub mod fused_attn;
pub mod norm;
pub mod patch_embed;
pub mod pos_embed;
use burn::module::{Param, ParamId};
use burn::nn::Linear;
use burn::prelude::*;
pub fn linear_zeros<B: Backend>(
d_input: usize,
d_output: usize,
bias: bool,
device: &B::Device,
) -> Linear<B> {
let weight = Param::initialized(
ParamId::new(),
Tensor::zeros([d_input, d_output], device),
);
let bias = if bias {
Some(Param::initialized(
ParamId::new(),
Tensor::zeros([d_output], device),
))
} else {
None
};
Linear { weight, bias }
}