use burn_core::module::{Module, Param};
use burn_core::prelude::*;
use burn_core::tensor::activation;
use burn_nn as nn;
#[derive(Module, Debug)]
pub struct LoraLinear<B: Backend> {
base: nn::Linear<B>,
lora_a: nn::Linear<B>,
lora_b: nn::Linear<B>,
#[module(skip)]
scaling: f32,
#[module(skip)]
active: bool,
}
#[derive(Config, Debug)]
pub struct LoraLinearConfig {
pub in_features: usize,
pub out_features: usize,
#[config(default = "16")]
pub rank: usize,
#[config(default = "32.0")]
pub alpha: f32,
}
impl LoraLinearConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> LoraLinear<B> {
let base = nn::LinearConfig::new(self.in_features, self.out_features)
.with_bias(false)
.init(device);
let lora_a = nn::LinearConfig::new(self.in_features, self.rank)
.with_bias(false)
.init(device);
let lora_b_config = nn::LinearConfig::new(self.rank, self.out_features).with_bias(false);
let mut lora_b = lora_b_config.init(device);
lora_b.weight = lora_b.weight.map(|w| w.zeros_like());
LoraLinear {
base,
lora_a,
lora_b,
scaling: self.alpha / self.rank as f32,
active: true,
}
}
pub fn init_with_base_weights<B: Backend>(
&self,
base_weight: Tensor<B, 2>,
device: &B::Device,
) -> LoraLinear<B> {
let base = nn::Linear {
weight: Param::from_tensor(base_weight),
bias: None,
};
let lora_a = nn::LinearConfig::new(self.in_features, self.rank)
.with_bias(false)
.init(device);
let lora_b_config = nn::LinearConfig::new(self.rank, self.out_features).with_bias(false);
let mut lora_b = lora_b_config.init(device);
lora_b.weight = lora_b.weight.map(|w| w.zeros_like());
LoraLinear {
base,
lora_a,
lora_b,
scaling: self.alpha / self.rank as f32,
active: true,
}
}
}
impl<B: Backend> LoraLinear<B> {
pub fn lora_a_weight(&self) -> Tensor<B, 2> {
self.lora_a.weight.val()
}
pub fn lora_b_weight(&self) -> Tensor<B, 2> {
self.lora_b.weight.val()
}
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let base_out = self.base.forward(input.clone());
if !self.active {
return base_out;
}
let lora_out = self.lora_b.forward(self.lora_a.forward(input));
let lora_scaled = lora_out.mul_scalar(self.scaling);
base_out + lora_scaled
}
pub fn set_active(&mut self, active: bool) {
self.active = active;
}
pub fn trainable_param_count(&self) -> usize {
let a_shape = self.lora_a.weight.val().shape();
let a_params = a_shape.dims[0] * a_shape.dims[1];
let b_shape = self.lora_b.weight.val().shape();
let b_params = b_shape.dims[0] * b_shape.dims[1];
a_params + b_params
}
}
#[derive(Module, Debug)]
pub struct RmsNorm<B: Backend> {
weight: Param<Tensor<B, 1>>,
#[module(skip)]
eps: f64,
}
#[derive(Config, Debug)]
pub struct RmsNormConfig {
pub hidden_size: usize,
#[config(default = "1e-5")]
pub eps: f64,
}
impl RmsNormConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNorm<B> {
let weight = Tensor::ones([self.hidden_size], device);
RmsNorm {
weight: Param::from_tensor(weight),
eps: self.eps,
}
}
}
impl<B: Backend> RmsNorm<B> {
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let variance = x.clone().powf_scalar(2.0).mean_dim(1);
let rms = (variance + self.eps).sqrt();
let normed = x / rms;
normed * self.weight.val().unsqueeze_dim(0)
}
}
#[derive(Module, Debug)]
pub struct SwiGluFfn<B: Backend> {
gate_proj: nn::Linear<B>,
up_proj: nn::Linear<B>,
down_proj: nn::Linear<B>,
}
#[derive(Config, Debug)]
pub struct SwiGluFfnConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
}
impl SwiGluFfnConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> SwiGluFfn<B> {
SwiGluFfn {
gate_proj: nn::LinearConfig::new(self.hidden_size, self.intermediate_size)
.with_bias(false)
.init(device),
up_proj: nn::LinearConfig::new(self.hidden_size, self.intermediate_size)
.with_bias(false)
.init(device),
down_proj: nn::LinearConfig::new(self.intermediate_size, self.hidden_size)
.with_bias(false)
.init(device),
}
}
}
impl<B: Backend> SwiGluFfn<B> {
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let gate = activation::silu(self.gate_proj.forward(x.clone()));
let up = self.up_proj.forward(x);
self.down_proj.forward(gate * up)
}
}
pub fn cross_entropy_loss<B: Backend>(
logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>, ) -> Tensor<B, 1> {
let log_softmax = activation::log_softmax(logits, 1);
let batch_size = targets.dims()[0];
let targets_2d = targets.reshape([batch_size, 1]);
let gathered = log_softmax.gather(1, targets_2d);
gathered.neg().mean()
}
#[derive(Debug)]
pub struct TrainStepOutput<B: Backend> {
pub loss: Tensor<B, 1>,
pub num_tokens: usize,
}
#[derive(Module, Debug)]
pub struct DoraLinear<B: Backend> {
base: nn::Linear<B>,
lora_a: nn::Linear<B>,
lora_b: nn::Linear<B>,
magnitude: Param<Tensor<B, 1>>,
#[module(skip)]
scaling: f32,
}
#[derive(Config, Debug)]
pub struct DoraLinearConfig {
pub in_features: usize,
pub out_features: usize,
#[config(default = "16")]
pub rank: usize,
#[config(default = "32.0")]
pub alpha: f32,
}
impl DoraLinearConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> DoraLinear<B> {
let base = nn::LinearConfig::new(self.in_features, self.out_features)
.with_bias(false)
.init(device);
let lora_a = nn::LinearConfig::new(self.in_features, self.rank)
.with_bias(false)
.init(device);
let lora_b_config = nn::LinearConfig::new(self.rank, self.out_features).with_bias(false);
let mut lora_b = lora_b_config.init(device);
lora_b.weight = lora_b.weight.map(|w| w.zeros_like());
let base_w = base.weight.val();
let col_norms = base_w
.clone()
.powf_scalar(2.0)
.sum_dim(0)
.sqrt()
.squeeze::<1>();
let magnitude = Param::from_tensor(col_norms);
DoraLinear {
base,
lora_a,
lora_b,
magnitude,
scaling: self.alpha / self.rank as f32,
}
}
pub fn init_with_base_weights<B: Backend>(
&self,
base_weight: Tensor<B, 2>,
device: &B::Device,
) -> DoraLinear<B> {
let col_norms = base_weight
.clone()
.powf_scalar(2.0)
.sum_dim(0)
.sqrt()
.squeeze::<1>();
let magnitude = Param::from_tensor(col_norms);
let base = nn::Linear {
weight: Param::from_tensor(base_weight),
bias: None,
};
let lora_a = nn::LinearConfig::new(self.in_features, self.rank)
.with_bias(false)
.init(device);
let lora_b_config = nn::LinearConfig::new(self.rank, self.out_features).with_bias(false);
let mut lora_b = lora_b_config.init(device);
lora_b.weight = lora_b.weight.map(|w| w.zeros_like());
DoraLinear {
base,
lora_a,
lora_b,
magnitude,
scaling: self.alpha / self.rank as f32,
}
}
}
impl<B: Backend> DoraLinear<B> {
pub fn lora_a_weight(&self) -> Tensor<B, 2> {
self.lora_a.weight.val()
}
pub fn lora_b_weight(&self) -> Tensor<B, 2> {
self.lora_b.weight.val()
}
pub fn magnitude_data(&self) -> Tensor<B, 1> {
self.magnitude.val()
}
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let lora_a_w = self.lora_a.weight.val(); let lora_b_w = self.lora_b.weight.val(); let lora_update = lora_a_w.matmul(lora_b_w).mul_scalar(self.scaling); let updated_w = self.base.weight.val() + lora_update;
let col_norms = updated_w.clone().powf_scalar(2.0).sum_dim(0).sqrt(); let eps: f32 = 1e-8;
let col_norms_safe = col_norms + eps;
let direction = updated_w / col_norms_safe;
let m = self.magnitude.val().unsqueeze_dim(0); let final_w = direction * m;
input.matmul(final_w)
}
pub fn trainable_param_count(&self) -> usize {
let a_shape = self.lora_a.weight.val().shape();
let b_shape = self.lora_b.weight.val().shape();
let m_shape = self.magnitude.val().shape();
a_shape.dims[0] * a_shape.dims[1] + b_shape.dims[0] * b_shape.dims[1] + m_shape.dims[0]
}
}
#[derive(Module, Debug)]
pub struct QLoraLinear<B: Backend> {
base: nn::Linear<B>,
lora_a: nn::Linear<B>,
lora_b: nn::Linear<B>,
#[module(skip)]
scaling: f32,
#[module(skip)]
active: bool,
}
#[derive(Config, Debug)]
pub struct QLoraLinearConfig {
pub in_features: usize,
pub out_features: usize,
#[config(default = "16")]
pub rank: usize,
#[config(default = "32.0")]
pub alpha: f32,
#[config(default = "4")]
pub bits: u8,
}
impl QLoraLinearConfig {
pub fn init_quantized<B: Backend>(
&self,
base_weight_f32: &[f32],
device: &B::Device,
) -> QLoraLinear<B> {
let weight_tensor = Tensor::<B, 1>::from_floats(
burn_core::tensor::TensorData::new(base_weight_f32.to_vec(), [base_weight_f32.len()]),
device,
)
.reshape([self.in_features, self.out_features]);
let base = nn::Linear {
weight: Param::from_tensor(weight_tensor),
bias: None,
};
let lora_a = nn::LinearConfig::new(self.in_features, self.rank)
.with_bias(false)
.init(device);
let lora_b_config = nn::LinearConfig::new(self.rank, self.out_features).with_bias(false);
let mut lora_b = lora_b_config.init(device);
lora_b.weight = lora_b.weight.map(|w| w.zeros_like());
QLoraLinear {
base,
lora_a,
lora_b,
scaling: self.alpha / self.rank as f32,
active: true,
}
}
pub fn init<B: Backend>(&self, device: &B::Device) -> QLoraLinear<B> {
let base = nn::LinearConfig::new(self.in_features, self.out_features)
.with_bias(false)
.init(device);
let lora_a = nn::LinearConfig::new(self.in_features, self.rank)
.with_bias(false)
.init(device);
let lora_b_config = nn::LinearConfig::new(self.rank, self.out_features).with_bias(false);
let mut lora_b = lora_b_config.init(device);
lora_b.weight = lora_b.weight.map(|w| w.zeros_like());
QLoraLinear {
base,
lora_a,
lora_b,
scaling: self.alpha / self.rank as f32,
active: true,
}
}
}
impl<B: Backend> QLoraLinear<B> {
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let base_out = self.base.forward(input.clone());
if !self.active {
return base_out;
}
let lora_out = self.lora_b.forward(self.lora_a.forward(input));
let lora_scaled = lora_out.mul_scalar(self.scaling);
base_out + lora_scaled
}
pub fn lora_a_weight(&self) -> Tensor<B, 2> {
self.lora_a.weight.val()
}
pub fn lora_b_weight(&self) -> Tensor<B, 2> {
self.lora_b.weight.val()
}
pub fn trainable_param_count(&self) -> usize {
let a_shape = self.lora_a.weight.val().shape();
let b_shape = self.lora_b.weight.val().shape();
a_shape.dims[0] * a_shape.dims[1] + b_shape.dims[0] * b_shape.dims[1]
}
pub fn set_active(&mut self, active: bool) {
self.active = active;
}
}
pub fn dpo_loss<B: Backend>(
chosen_logps: Tensor<B, 1>, rejected_logps: Tensor<B, 1>, ref_chosen_logps: Tensor<B, 1>, ref_rejected_logps: Tensor<B, 1>, beta: f32,
) -> Tensor<B, 1> {
let chosen_rewards = (chosen_logps - ref_chosen_logps).mul_scalar(beta);
let rejected_rewards = (rejected_logps - ref_rejected_logps).mul_scalar(beta);
let logits = chosen_rewards - rejected_rewards;
let neg_logits = logits.neg();
let loss = (neg_logits.clone().exp() + 1.0).log();
loss.mean()
}
pub fn orpo_alignment_loss<B: Backend>(
chosen_probs: Tensor<B, 1>, rejected_probs: Tensor<B, 1>, ) -> Tensor<B, 1> {
let eps: f32 = 1e-10;
let one_minus_eps: f32 = 1.0 - eps;
let chosen_clamped = chosen_probs.clamp(eps, one_minus_eps);
let rejected_clamped = rejected_probs.clamp(eps, one_minus_eps);
let ones = Tensor::ones_like(&chosen_clamped);
let chosen_odds = chosen_clamped.clone() / (ones.clone() - chosen_clamped);
let rejected_odds = rejected_clamped.clone() / (ones - rejected_clamped);
let log_odds_ratio = (chosen_odds / rejected_odds).log();
let neg_lor = log_odds_ratio.neg();
let loss = (neg_lor.exp() + 1.0).log();
loss.mean()
}
pub fn orpo_loss<B: Backend>(
sft_loss: Tensor<B, 1>,
chosen_probs: Tensor<B, 1>,
rejected_probs: Tensor<B, 1>,
lambda: f32,
) -> Tensor<B, 1> {
let align = orpo_alignment_loss(chosen_probs, rejected_probs);
sft_loss + align.mul_scalar(lambda)
}
#[derive(Module, Debug)]
pub struct BurnTransformerBlock<B: Backend> {
pre_norm: RmsNorm<B>,
q_proj: nn::Linear<B>,
k_proj: nn::Linear<B>,
v_proj: nn::Linear<B>,
o_proj: nn::Linear<B>,
post_norm: RmsNorm<B>,
ffn: SwiGluFfn<B>,
#[module(skip)]
num_heads: usize,
#[module(skip)]
head_dim: usize,
}
#[derive(Config, Debug)]
pub struct BurnTransformerBlockConfig {
pub hidden_size: usize,
pub num_heads: usize,
pub intermediate_size: usize,
}
impl BurnTransformerBlockConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> BurnTransformerBlock<B> {
let head_dim = self.hidden_size / self.num_heads;
BurnTransformerBlock {
pre_norm: RmsNormConfig::new(self.hidden_size).init(device),
q_proj: nn::LinearConfig::new(self.hidden_size, self.hidden_size)
.with_bias(false)
.init(device),
k_proj: nn::LinearConfig::new(self.hidden_size, self.hidden_size)
.with_bias(false)
.init(device),
v_proj: nn::LinearConfig::new(self.hidden_size, self.hidden_size)
.with_bias(false)
.init(device),
o_proj: nn::LinearConfig::new(self.hidden_size, self.hidden_size)
.with_bias(false)
.init(device),
post_norm: RmsNormConfig::new(self.hidden_size).init(device),
ffn: SwiGluFfnConfig::new(self.hidden_size, self.intermediate_size).init(device),
num_heads: self.num_heads,
head_dim,
}
}
}
impl<B: Backend> BurnTransformerBlock<B> {
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let normed = self.pre_norm.forward(x.clone());
let q = self.q_proj.forward(normed.clone());
let k = self.k_proj.forward(normed.clone());
let v = self.v_proj.forward(normed);
let scale = (self.head_dim as f32).sqrt();
let attn_weights = q.matmul(k.transpose()).div_scalar(scale);
let attn_weights = activation::softmax(attn_weights, 1);
let attn_out = attn_weights.matmul(v);
let attn_proj = self.o_proj.forward(attn_out);
let h = x + attn_proj;
let normed2 = self.post_norm.forward(h.clone());
let ffn_out = self.ffn.forward(normed2);
h + ffn_out }
}
#[cfg(test)]
mod tests {
use super::*;
use burn_ndarray::NdArray;
type TestBackend = NdArray;
#[test]
fn test_lora_linear_forward() {
let device = Default::default();
let config = LoraLinearConfig::new(64, 128);
let layer = config.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 2>::random(
[4, 64],
burn_core::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = layer.forward(input);
assert_eq!(output.dims(), [4, 128]);
}
#[test]
fn test_lora_linear_zero_init() {
let device = Default::default();
let config = LoraLinearConfig::new(32, 32);
let layer = config.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 2>::random(
[2, 32],
burn_core::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let base_out = layer.base.forward(input.clone());
let full_out = layer.forward(input);
let diff = (full_out - base_out).abs().sum().into_scalar();
assert!(
diff < 1e-5,
"LoRA should contribute zero initially, diff={}",
diff
);
}
#[test]
fn test_lora_inactive() {
let device = Default::default();
let config = LoraLinearConfig::new(32, 32);
let mut layer = config.init::<TestBackend>(&device);
layer.set_active(false);
let input = Tensor::<TestBackend, 2>::random(
[2, 32],
burn_core::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let base_out = layer.base.forward(input.clone());
let full_out = layer.forward(input);
let diff = (full_out - base_out).abs().sum().into_scalar();
assert!(diff < 1e-6, "Inactive LoRA should not contribute");
}
#[test]
fn test_rms_norm() {
let device = Default::default();
let norm = RmsNormConfig::new(64).init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 2>::random(
[4, 64],
burn_core::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = norm.forward(input);
assert_eq!(output.dims(), [4, 64]);
}
#[test]
fn test_swiglu_ffn() {
let device = Default::default();
let ffn = SwiGluFfnConfig::new(64, 128).init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 2>::random(
[4, 64],
burn_core::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = ffn.forward(input);
assert_eq!(output.dims(), [4, 64]);
}
#[test]
fn test_trainable_params() {
let device = Default::default();
let config = LoraLinearConfig::new(4096, 4096).with_rank(16);
let layer = config.init::<TestBackend>(&device);
let params = layer.trainable_param_count();
assert_eq!(params, 16 * 4096 + 4096 * 16); }
#[test]
fn test_dora_forward() {
let device = Default::default();
let config = DoraLinearConfig::new(64, 128).with_rank(8);
let layer = config.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 2>::random(
[4, 64],
burn_core::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = layer.forward(input);
assert_eq!(output.dims(), [4, 128]);
}
#[test]
fn test_dora_trainable_params() {
let device = Default::default();
let config = DoraLinearConfig::new(256, 256).with_rank(16);
let layer = config.init::<TestBackend>(&device);
let params = layer.trainable_param_count();
assert_eq!(params, 16 * 256 + 256 * 16 + 256);
}
#[test]
fn test_dpo_loss_tensor() {
let device = Default::default();
let chosen = Tensor::<TestBackend, 1>::from_floats([-1.0, -0.5, -0.8], &device);
let rejected = Tensor::<TestBackend, 1>::from_floats([-3.0, -2.5, -2.8], &device);
let ref_chosen = Tensor::<TestBackend, 1>::from_floats([-1.5, -1.0, -1.2], &device);
let ref_rejected = Tensor::<TestBackend, 1>::from_floats([-1.5, -1.0, -1.2], &device);
let loss = dpo_loss(chosen, rejected, ref_chosen, ref_rejected, 0.1);
let val: f32 = loss.into_scalar();
assert!(val > 0.0, "DPO loss should be positive, got {}", val);
assert!(val < 5.0, "DPO loss should be reasonable, got {}", val);
}
#[test]
fn test_dpo_loss_equal_logps() {
let device = Default::default();
let logps = Tensor::<TestBackend, 1>::from_floats([-2.0], &device);
let loss = dpo_loss(logps.clone(), logps.clone(), logps.clone(), logps, 0.1);
let val: f32 = loss.into_scalar();
assert!(
(val - (2.0f32).ln()).abs() < 0.01,
"Expected ~ln(2), got {}",
val
);
}
#[test]
fn test_orpo_alignment_loss() {
let device = Default::default();
let chosen = Tensor::<TestBackend, 1>::from_floats([0.8, 0.7], &device);
let rejected = Tensor::<TestBackend, 1>::from_floats([0.3, 0.2], &device);
let loss = orpo_alignment_loss(chosen, rejected);
let val: f32 = loss.into_scalar();
assert!(val > 0.0, "ORPO alignment loss should be positive");
}
#[test]
fn test_orpo_full_loss() {
let device = Default::default();
let sft = Tensor::<TestBackend, 1>::from_floats([2.0], &device);
let chosen = Tensor::<TestBackend, 1>::from_floats([0.7], &device);
let rejected = Tensor::<TestBackend, 1>::from_floats([0.3], &device);
let total = orpo_loss(sft, chosen, rejected, 0.5);
let val: f32 = total.into_scalar();
assert!(val > 2.0, "Total should be > SFT loss, got {}", val);
}
#[test]
fn test_transformer_block() {
let device = Default::default();
let config = BurnTransformerBlockConfig::new(64, 4, 128);
let block = config.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 2>::random(
[8, 64],
burn_core::tensor::Distribution::Normal(0.0, 0.1),
&device,
);
let output = block.forward(input);
assert_eq!(
output.dims(),
[8, 64],
"Transformer block should preserve shape"
);
}
#[test]
fn test_lora_init_with_base_weights() {
let device = Default::default();
let config = LoraLinearConfig::new(32, 64).with_rank(8);
let base_weight = Tensor::<TestBackend, 2>::random(
[32, 64],
burn_core::tensor::Distribution::Normal(0.0, 0.1),
&device,
);
let layer = config.init_with_base_weights::<TestBackend>(base_weight.clone(), &device);
let input = Tensor::<TestBackend, 2>::random(
[4, 32],
burn_core::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = layer.forward(input.clone());
assert_eq!(output.dims(), [4, 64]);
let expected = input.matmul(base_weight);
let diff = (output - expected).abs().sum().into_scalar();
assert!(
diff < 1e-4,
"LoRA with base weights should match base output initially, diff={}",
diff
);
}
#[test]
fn test_dora_init_with_base_weights() {
let device = Default::default();
let config = DoraLinearConfig::new(32, 64).with_rank(8);
let base_weight = Tensor::<TestBackend, 2>::random(
[32, 64],
burn_core::tensor::Distribution::Normal(0.0, 0.1),
&device,
);
let layer = config.init_with_base_weights::<TestBackend>(base_weight, &device);
let input = Tensor::<TestBackend, 2>::random(
[4, 32],
burn_core::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = layer.forward(input);
assert_eq!(output.dims(), [4, 64]);
}
#[test]
fn test_qlora_forward() {
let device = Default::default();
let config = QLoraLinearConfig::new(64, 128).with_rank(8);
let layer = config.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 2>::random(
[4, 64],
burn_core::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = layer.forward(input);
assert_eq!(output.dims(), [4, 128]);
}
#[test]
fn test_qlora_init_quantized() {
let device = Default::default();
let config = QLoraLinearConfig::new(16, 32).with_rank(4).with_bits(4);
let base_weights: Vec<f32> = (0..16 * 32).map(|i| (i as f32) * 0.001).collect();
let layer = config.init_quantized::<TestBackend>(&base_weights, &device);
let input = Tensor::<TestBackend, 2>::random(
[2, 16],
burn_core::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = layer.forward(input);
assert_eq!(output.dims(), [2, 32]);
assert_eq!(layer.trainable_param_count(), 4 * 16 + 32 * 4);
}
}