use axonml_autograd::no_grad::is_grad_enabled;
use axonml_autograd::{GradFn, Variable};
use axonml_nn::layers::ternary::TernaryLinear;
use axonml_nn::{Embedding, Linear, Module, Parameter};
use axonml_tensor::Tensor;
use crate::llama::{RMSNorm, RepeatKVBackward, RotaryEmbedding};
#[derive(Debug, Clone)]
pub struct TridentConfig {
pub vocab_size: usize,
pub d_model: usize,
pub num_layers: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub intermediate_size: usize,
pub max_seq_len: usize,
pub rms_norm_eps: f32,
pub use_rope: bool,
pub rope_theta: f32,
pub use_squared_relu: bool,
pub use_sub_ln: bool,
}
impl TridentConfig {
pub fn default_150m() -> Self {
Self {
vocab_size: 32000,
d_model: 512,
num_layers: 12,
num_heads: 8,
num_kv_heads: 8,
intermediate_size: 2048,
max_seq_len: 2048,
rms_norm_eps: 1e-6,
use_rope: false,
rope_theta: 10_000.0,
use_squared_relu: false,
use_sub_ln: false,
}
}
pub fn tiny() -> Self {
Self {
vocab_size: 1000,
d_model: 64,
num_layers: 2,
num_heads: 4,
num_kv_heads: 4,
intermediate_size: 256,
max_seq_len: 128,
rms_norm_eps: 1e-6,
use_rope: false,
rope_theta: 10_000.0,
use_squared_relu: false,
use_sub_ln: false,
}
}
pub fn medium() -> Self {
Self {
vocab_size: 32000,
d_model: 768,
num_layers: 16,
num_heads: 12,
num_kv_heads: 12,
intermediate_size: 3072,
max_seq_len: 2048,
rms_norm_eps: 1e-6,
use_rope: false,
rope_theta: 10_000.0,
use_squared_relu: false,
use_sub_ln: false,
}
}
pub fn trident_1b(vocab_size: usize) -> Self {
Self {
vocab_size,
d_model: 2048,
num_layers: 24,
num_heads: 16,
num_kv_heads: 4,
intermediate_size: 5504,
max_seq_len: 4096,
rms_norm_eps: 1e-5,
use_rope: true,
rope_theta: 500_000.0,
use_squared_relu: true,
use_sub_ln: true,
}
}
pub fn trident_3b(vocab_size: usize) -> Self {
Self {
vocab_size,
d_model: 3200,
num_layers: 26,
num_heads: 32,
num_kv_heads: 8,
intermediate_size: 8640,
max_seq_len: 4096,
rms_norm_eps: 1e-5,
use_rope: true,
rope_theta: 500_000.0,
use_squared_relu: true,
use_sub_ln: true,
}
}
pub fn smoke(vocab_size: usize) -> Self {
Self {
vocab_size,
d_model: 256,
num_layers: 4,
num_heads: 8,
num_kv_heads: 2,
intermediate_size: 688,
max_seq_len: 512,
rms_norm_eps: 1e-5,
use_rope: true,
rope_theta: 500_000.0,
use_squared_relu: true,
use_sub_ln: true,
}
}
pub fn head_dim(&self) -> usize {
self.d_model / self.num_heads
}
pub fn estimated_params(&self) -> usize {
let embedding = self.vocab_size * self.d_model;
let lm_head = self.d_model * self.vocab_size;
let head_dim = self.head_dim();
let kv_hidden = self.num_kv_heads * head_dim;
let per_layer = {
let attn = 2 * self.d_model * self.d_model + 2 * self.d_model * kv_hidden;
let mlp_linears = if self.use_squared_relu { 3 } else { 2 };
let mlp = mlp_linears * self.d_model * self.intermediate_size;
let mut norms = 2 * self.d_model;
if self.use_sub_ln {
norms += self.d_model + self.intermediate_size;
}
attn + mlp + norms
};
embedding + lm_head + self.num_layers * per_layer
}
pub fn ternary_storage_bytes(&self) -> usize {
let head_dim = self.head_dim();
let kv_hidden = self.num_kv_heads * head_dim;
let per_layer = {
let attn = 2 * self.d_model * self.d_model + 2 * self.d_model * kv_hidden;
let mlp_linears = if self.use_squared_relu { 3 } else { 2 };
let mlp = mlp_linears * self.d_model * self.intermediate_size;
attn + mlp
};
let total_ternary_weights = self.num_layers * per_layer;
let packed_bytes = total_ternary_weights.div_ceil(4);
let linears_per_block = 4 + if self.use_squared_relu { 3 } else { 2 };
let scale_bytes = self.num_layers * linears_per_block * 4;
let mut fp32_bytes = (self.vocab_size * self.d_model
+ self.d_model * self.vocab_size
+ self.num_layers * 2 * self.d_model)
* 4;
if self.use_sub_ln {
fp32_bytes += self.num_layers * (self.d_model + self.intermediate_size) * 4;
}
packed_bytes + scale_bytes + fp32_bytes
}
pub fn fp32_storage_bytes(&self) -> usize {
self.estimated_params() * 4
}
}
#[derive(Debug)]
struct TridentAttention {
q_proj: TernaryLinear,
k_proj: TernaryLinear,
v_proj: TernaryLinear,
o_proj: TernaryLinear,
rotary_emb: Option<RotaryEmbedding>,
attn_sub_norm: Option<RMSNorm>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
hidden_size: usize,
}
impl TridentAttention {
fn new(config: &TridentConfig) -> Self {
let head_dim = config.head_dim();
let kv_hidden = config.num_kv_heads * head_dim;
let rotary_emb = if config.use_rope {
Some(RotaryEmbedding::new(
head_dim,
config.max_seq_len,
config.rope_theta,
))
} else {
None
};
let attn_sub_norm = if config.use_sub_ln {
Some(RMSNorm::new(config.d_model, config.rms_norm_eps))
} else {
None
};
Self {
q_proj: TernaryLinear::with_bias(config.d_model, config.d_model, false),
k_proj: TernaryLinear::with_bias(config.d_model, kv_hidden, false),
v_proj: TernaryLinear::with_bias(config.d_model, kv_hidden, false),
o_proj: TernaryLinear::with_bias(config.d_model, config.d_model, false),
rotary_emb,
attn_sub_norm,
num_heads: config.num_heads,
num_kv_heads: config.num_kv_heads,
head_dim,
hidden_size: config.d_model,
}
}
fn forward(&self, hidden_states: &Variable) -> Variable {
let data = hidden_states.data();
let shape = data.shape();
let batch_size = shape[0];
let seq_len = shape[1];
let q = self.q_proj.forward(hidden_states);
let k = self.k_proj.forward(hidden_states);
let v = self.v_proj.forward(hidden_states);
let q = q
.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim])
.transpose(1, 2);
let k = k
.reshape(&[batch_size, seq_len, self.num_kv_heads, self.head_dim])
.transpose(1, 2);
let v = v
.reshape(&[batch_size, seq_len, self.num_kv_heads, self.head_dim])
.transpose(1, 2);
let (q, k) = if let Some(rope) = &self.rotary_emb {
rope.apply(&q, &k, 0)
} else {
(q, k)
};
let (k, v) = if self.num_kv_heads != self.num_heads {
let n_rep = self.num_heads / self.num_kv_heads;
(self.repeat_kv(&k, n_rep), self.repeat_kv(&v, n_rep))
} else {
(k, v)
};
let scale = 1.0 / (self.head_dim as f32).sqrt();
let attn_weights = q.matmul(&k.transpose(2, 3)).mul_scalar(scale);
let mask = self.create_causal_mask(seq_len);
let attn_weights = attn_weights.add(&Variable::new(mask, false));
let attn_weights = attn_weights.softmax(-1);
let attn_output = attn_weights.matmul(&v);
let attn_output =
attn_output
.transpose(1, 2)
.reshape(&[batch_size, seq_len, self.hidden_size]);
let attn_output = if let Some(norm) = &self.attn_sub_norm {
norm.forward(&attn_output)
} else {
attn_output
};
self.o_proj.forward(&attn_output)
}
fn repeat_kv(&self, x: &Variable, n_rep: usize) -> Variable {
if n_rep == 1 {
return x.clone();
}
let data = x.data();
let shape = data.shape();
let batch = shape[0];
let num_kv_heads = shape[1];
let seq_len = shape[2];
let head_dim = shape[3];
let data_vec = data.to_vec();
let mut output = Vec::with_capacity(data_vec.len() * n_rep);
for b in 0..batch {
for h in 0..num_kv_heads {
for _ in 0..n_rep {
for s in 0..seq_len {
let offset = ((b * num_kv_heads + h) * seq_len + s) * head_dim;
output.extend_from_slice(&data_vec[offset..offset + head_dim]);
}
}
}
}
let output_tensor =
Tensor::from_vec(output, &[batch, num_kv_heads * n_rep, seq_len, head_dim]).unwrap();
if x.requires_grad() && is_grad_enabled() {
let grad_fn = GradFn::new(RepeatKVBackward {
next_fns: vec![x.grad_fn().cloned()],
num_kv_heads,
n_rep,
});
Variable::from_operation(output_tensor, grad_fn, true)
} else {
Variable::new(output_tensor, false)
}
}
fn create_causal_mask(&self, seq_len: usize) -> Tensor<f32> {
let mut mask_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in (i + 1)..seq_len {
mask_data[i * seq_len + j] = f32::NEG_INFINITY;
}
}
Tensor::from_vec(mask_data, &[1, 1, seq_len, seq_len]).unwrap()
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.q_proj.parameters());
params.extend(self.k_proj.parameters());
params.extend(self.v_proj.parameters());
params.extend(self.o_proj.parameters());
if let Some(n) = &self.attn_sub_norm {
params.extend(n.parameters());
}
params
}
fn quantize_for_inference(&mut self) {
self.q_proj.quantize_for_inference();
self.k_proj.quantize_for_inference();
self.v_proj.quantize_for_inference();
self.o_proj.quantize_for_inference();
}
}
#[derive(Debug)]
struct TridentMLP {
up_proj: TernaryLinear,
gate_proj: Option<TernaryLinear>,
down_proj: TernaryLinear,
ffn_sub_norm: Option<RMSNorm>,
use_squared_relu: bool,
}
impl TridentMLP {
fn new(config: &TridentConfig) -> Self {
let gate_proj = if config.use_squared_relu {
Some(TernaryLinear::with_bias(
config.d_model,
config.intermediate_size,
false,
))
} else {
None
};
let ffn_sub_norm = if config.use_sub_ln {
Some(RMSNorm::new(config.intermediate_size, config.rms_norm_eps))
} else {
None
};
Self {
up_proj: TernaryLinear::with_bias(config.d_model, config.intermediate_size, false),
gate_proj,
down_proj: TernaryLinear::with_bias(config.intermediate_size, config.d_model, false),
ffn_sub_norm,
use_squared_relu: config.use_squared_relu,
}
}
fn forward(&self, x: &Variable) -> Variable {
if self.use_squared_relu {
let gate = self
.gate_proj
.as_ref()
.expect("gate_proj set when use_squared_relu")
.forward(x);
let up = self.up_proj.forward(x);
let gate_act = gate.relu().pow(2.0);
let gated = gate_act.mul(&up);
let normed = if let Some(n) = &self.ffn_sub_norm {
n.forward(&gated)
} else {
gated
};
self.down_proj.forward(&normed)
} else {
let up = self.up_proj.forward(x).silu();
self.down_proj.forward(&up)
}
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.up_proj.parameters());
if let Some(g) = &self.gate_proj {
params.extend(g.parameters());
}
params.extend(self.down_proj.parameters());
if let Some(n) = &self.ffn_sub_norm {
params.extend(n.parameters());
}
params
}
fn quantize_for_inference(&mut self) {
self.up_proj.quantize_for_inference();
if let Some(g) = &mut self.gate_proj {
g.quantize_for_inference();
}
self.down_proj.quantize_for_inference();
}
}
#[derive(Debug)]
struct TridentBlock {
attn_norm: RMSNorm,
attention: TridentAttention,
mlp_norm: RMSNorm,
mlp: TridentMLP,
}
impl TridentBlock {
fn new(config: &TridentConfig) -> Self {
Self {
attn_norm: RMSNorm::new(config.d_model, config.rms_norm_eps),
attention: TridentAttention::new(config),
mlp_norm: RMSNorm::new(config.d_model, config.rms_norm_eps),
mlp: TridentMLP::new(config),
}
}
fn forward(&self, hidden_states: &Variable) -> Variable {
let residual = hidden_states.clone();
let normed = self.attn_norm.forward(hidden_states);
let attn_out = self.attention.forward(&normed);
let hidden_states = residual.add(&attn_out);
let residual = hidden_states.clone();
let normed = self.mlp_norm.forward(&hidden_states);
let mlp_out = self.mlp.forward(&normed);
residual.add(&mlp_out)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.attn_norm.parameters());
params.extend(self.attention.parameters());
params.extend(self.mlp_norm.parameters());
params.extend(self.mlp.parameters());
params
}
fn quantize_for_inference(&mut self) {
self.attention.quantize_for_inference();
self.mlp.quantize_for_inference();
}
}
pub struct TridentModel {
embed_tokens: Embedding,
blocks: Vec<TridentBlock>,
final_norm: RMSNorm,
lm_head: Linear,
config: TridentConfig,
}
impl TridentModel {
pub fn new(config: &TridentConfig) -> Self {
let blocks = (0..config.num_layers)
.map(|_| TridentBlock::new(config))
.collect();
Self {
embed_tokens: Embedding::new(config.vocab_size, config.d_model),
blocks,
final_norm: RMSNorm::new(config.d_model, config.rms_norm_eps),
lm_head: Linear::with_bias(config.d_model, config.vocab_size, false),
config: config.clone(),
}
}
pub fn forward_ids(&self, input_ids: &Tensor<u32>) -> Variable {
let ids_f32: Vec<f32> = input_ids.to_vec().iter().map(|&x| x as f32).collect();
let ids_var = Variable::new(Tensor::from_vec(ids_f32, input_ids.shape()).unwrap(), false);
let mut hidden = self.embed_tokens.forward(&ids_var);
for block in &self.blocks {
hidden = block.forward(&hidden);
}
hidden = self.final_norm.forward(&hidden);
self.lm_head.forward(&hidden)
}
pub fn forward_with_loss(
&self,
input_ids: &Tensor<u32>,
labels: &Tensor<u32>,
) -> (Variable, Variable) {
let logits = self.forward_ids(input_ids);
let logits_data = logits.data();
let shape = logits_data.shape();
let batch_size = shape[0];
let seq_len = shape[1];
let _vocab_size = shape[2];
if seq_len > 1 {
let shift_logits = logits.narrow(1, 0, seq_len - 1);
let labels_vec = labels.to_vec();
let mut shift_labels_data = Vec::with_capacity(batch_size * (seq_len - 1));
for b in 0..batch_size {
for s in 1..seq_len {
shift_labels_data.push(labels_vec[b * seq_len + s]);
}
}
let shift_labels =
Tensor::from_vec(shift_labels_data, &[batch_size, seq_len - 1]).unwrap();
let loss = Self::cross_entropy_loss(&shift_logits, &shift_labels);
(logits, loss)
} else {
let zero_loss = Variable::new(Tensor::from_vec(vec![0.0f32], &[1]).unwrap(), false);
(logits, zero_loss)
}
}
fn cross_entropy_loss(logits: &Variable, labels: &Tensor<u32>) -> Variable {
let logits_data = logits.data();
let shape = logits_data.shape();
let batch_size = shape[0];
let seq_len = shape[1];
let vocab_size = shape[2];
let logits_flat = logits.reshape(&[batch_size * seq_len, vocab_size]);
let labels_vec = labels.to_vec();
let valid_labels: Vec<f32> = labels_vec
.iter()
.map(|&l| {
let label = l as usize;
if label < vocab_size { l as f32 } else { 0.0f32 }
})
.collect();
let target_var = Variable::new(
Tensor::from_vec(valid_labels, &[batch_size * seq_len]).unwrap(),
false,
);
use axonml_nn::loss::CrossEntropyLoss;
CrossEntropyLoss::new().compute(&logits_flat, &target_var)
}
pub fn quantize_for_inference(&mut self) {
for block in &mut self.blocks {
block.quantize_for_inference();
}
}
pub fn average_sparsity(&self) -> f32 {
let mut total_sparsity = 0.0f32;
let mut count = 0usize;
for block in &self.blocks {
total_sparsity += block.attention.q_proj.weight_sparsity();
total_sparsity += block.attention.k_proj.weight_sparsity();
total_sparsity += block.attention.v_proj.weight_sparsity();
total_sparsity += block.attention.o_proj.weight_sparsity();
count += 4;
total_sparsity += block.mlp.up_proj.weight_sparsity();
total_sparsity += block.mlp.down_proj.weight_sparsity();
count += 2;
if let Some(g) = &block.mlp.gate_proj {
total_sparsity += g.weight_sparsity();
count += 1;
}
}
if count > 0 {
total_sparsity / count as f32
} else {
0.0
}
}
pub fn config(&self) -> &TridentConfig {
&self.config
}
pub fn report(&self) {
let param_count: usize = self.parameters().iter().map(|p| p.data().numel()).sum();
let fp32_mb = self.config.fp32_storage_bytes() as f32 / (1024.0 * 1024.0);
let ternary_mb = self.config.ternary_storage_bytes() as f32 / (1024.0 * 1024.0);
let compression = fp32_mb / ternary_mb;
println!("Trident Model Report");
println!("====================");
println!("Layers : {}", self.config.num_layers);
println!("d_model : {}", self.config.d_model);
println!("Heads : {}", self.config.num_heads);
println!("Vocab : {}", self.config.vocab_size);
println!("Parameters : {}", param_count);
println!("FP32 size : {:.1} MB", fp32_mb);
println!("Ternary size : {:.1} MB", ternary_mb);
println!("Compression : {:.1}x", compression);
println!("Sparsity : {:.1}%", self.average_sparsity() * 100.0);
}
}
impl Module for TridentModel {
fn forward(&self, input: &Variable) -> Variable {
let mut hidden = self.embed_tokens.forward(input);
for block in &self.blocks {
hidden = block.forward(&hidden);
}
hidden = self.final_norm.forward(&hidden);
self.lm_head.forward(&hidden)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.embed_tokens.parameters());
for block in &self.blocks {
params.extend(block.parameters());
}
params.extend(self.final_norm.parameters());
params.extend(self.lm_head.parameters());
params
}
fn name(&self) -> &'static str {
"TridentModel"
}
}
impl std::fmt::Debug for TridentModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TridentModel")
.field("vocab_size", &self.config.vocab_size)
.field("d_model", &self.config.d_model)
.field("num_layers", &self.config.num_layers)
.field("num_heads", &self.config.num_heads)
.field("intermediate_size", &self.config.intermediate_size)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trident_config() {
let config = TridentConfig::default_150m();
assert_eq!(config.d_model, 512);
assert_eq!(config.num_layers, 12);
assert_eq!(config.num_heads, 8);
assert_eq!(config.head_dim(), 64);
}
#[test]
fn test_trident_config_storage() {
let config = TridentConfig::default_150m();
let fp32 = config.fp32_storage_bytes();
let ternary = config.ternary_storage_bytes();
assert!(ternary < fp32);
println!(
"FP32: {} bytes, Ternary: {} bytes, Ratio: {:.1}x",
fp32,
ternary,
fp32 as f32 / ternary as f32
);
}
#[test]
fn test_trident_tiny_forward() {
let config = TridentConfig::tiny();
let model = TridentModel::new(&config);
let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4], &[2, 2]).unwrap();
let logits = model.forward_ids(&input_ids);
assert_eq!(logits.data().shape(), &[2, 2, config.vocab_size]);
}
#[test]
fn test_trident_parameters() {
let config = TridentConfig::tiny();
let model = TridentModel::new(&config);
let params = model.parameters();
assert!(!params.is_empty());
let total: usize = params.iter().map(|p| p.data().numel()).sum();
println!("Tiny Trident params: {}", total);
}
#[test]
fn test_trident_forward_with_loss() {
let config = TridentConfig::tiny();
let model = TridentModel::new(&config);
let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4, 5, 6], &[2, 3]).unwrap();
let labels = Tensor::from_vec(vec![2u32, 3, 4, 5, 6, 7], &[2, 3]).unwrap();
let (logits, loss) = model.forward_with_loss(&input_ids, &labels);
assert_eq!(logits.data().shape(), &[2, 3, config.vocab_size]);
assert_eq!(loss.data().numel(), 1);
let loss_val = loss.data().to_vec()[0];
assert!(loss_val > 0.0, "Loss should be positive, got {}", loss_val);
}
#[test]
fn test_trident_sparsity() {
let config = TridentConfig::tiny();
let model = TridentModel::new(&config);
let sparsity = model.average_sparsity();
assert!((0.0..=1.0).contains(&sparsity));
}
#[test]
fn test_trident_1b_shape_converges() {
use axonml_optim::{Adam, Optimizer};
let config = TridentConfig {
vocab_size: 64,
d_model: 48,
num_layers: 2,
num_heads: 4,
num_kv_heads: 4,
intermediate_size: 128,
max_seq_len: 32,
rms_norm_eps: 1e-5,
use_rope: true,
rope_theta: 10_000.0,
use_squared_relu: true,
use_sub_ln: true,
};
let model = TridentModel::new(&config);
let mut optimizer = Adam::new(model.parameters(), 3e-3);
let seq_len = 8usize;
let batch_size = 4usize;
let patterns: Vec<Vec<u32>> = vec![
vec![1, 2, 3, 4, 5, 6, 7, 8],
vec![9, 10, 11, 12, 13, 14, 15, 16],
vec![17, 18, 19, 20, 21, 22, 23, 24],
];
let mut flat_batch = Vec::with_capacity(batch_size * seq_len);
for b in 0..batch_size {
flat_batch.extend_from_slice(&patterns[b % patterns.len()]);
}
let ids = Tensor::from_vec(flat_batch.clone(), &[batch_size, seq_len]).unwrap();
let labels = Tensor::from_vec(flat_batch, &[batch_size, seq_len]).unwrap();
let (_, loss0) = model.forward_with_loss(&ids, &labels);
let start_loss = loss0.data().to_vec()[0];
let mut last_loss = start_loss;
for _ in 0..100 {
optimizer.zero_grad();
let (_, loss) = model.forward_with_loss(&ids, &labels);
last_loss = loss.data().to_vec()[0];
loss.backward();
optimizer.step();
}
println!("[trident convergence] start_loss={start_loss:.4} end_loss={last_loss:.4}");
assert!(
last_loss < start_loss * 0.7,
"Trident did not converge: start={start_loss:.4}, end={last_loss:.4}"
);
assert!(last_loss.is_finite(), "Loss went NaN/Inf: {last_loss}");
}
#[test]
fn test_trident_1b_config_shapes() {
let cfg = TridentConfig::trident_1b(32_000);
assert_eq!(cfg.d_model, 2048);
assert_eq!(cfg.num_layers, 24);
assert_eq!(cfg.num_heads, 16);
assert_eq!(cfg.num_kv_heads, 4);
assert_eq!(cfg.intermediate_size, 5504);
assert_eq!(cfg.max_seq_len, 4096);
assert!((cfg.rope_theta - 500_000.0).abs() < 1e-3);
assert!(cfg.use_rope);
assert!(cfg.use_squared_relu);
assert!(cfg.use_sub_ln);
let n = cfg.estimated_params();
assert!(
(800_000_000..1_500_000_000).contains(&n),
"1B config estimated_params={n} outside [0.8B, 1.5B]"
);
}
#[test]
fn test_trident_smoke_forward() {
let cfg = TridentConfig::smoke(64);
let model = TridentModel::new(&cfg);
let ids = Tensor::from_vec(vec![1u32, 2, 3, 4], &[1, 4]).unwrap();
let logits = model.forward_ids(&ids);
assert_eq!(logits.data().shape(), &[1, 4, 64]);
}
#[test]
fn test_trident_quantize_inference() {
let config = TridentConfig::tiny();
let mut model = TridentModel::new(&config);
let input_ids = Tensor::from_vec(vec![1u32, 2, 3], &[1, 3]).unwrap();
let logits_train = model.forward_ids(&input_ids);
model.quantize_for_inference();
let logits_infer = model.forward_ids(&input_ids);
assert_eq!(logits_train.data().shape(), logits_infer.data().shape());
let train_vec = logits_train.data().to_vec();
let infer_vec = logits_infer.data().to_vec();
for (a, b) in train_vec.iter().zip(infer_vec.iter()) {
assert!((a - b).abs() < 1e-4, "Train {} vs infer {} mismatch", a, b);
}
}
}