use std::any::Any;
use axonml_autograd::no_grad::is_grad_enabled;
use axonml_autograd::{GradFn, GradientFunction, Variable};
use axonml_nn::{Embedding, Linear, Module, Parameter};
use axonml_tensor::Tensor;
use rand::Rng;
use crate::llama::RMSNorm;
use crate::ssm::{SSMBlock, SSMConfig};
#[derive(Debug, Clone)]
pub struct HydraConfig {
pub vocab_size: usize,
pub d_model: usize,
pub num_layers: usize,
pub num_heads: usize,
pub d_state: usize,
pub d_conv: usize,
pub ssm_expansion: usize,
pub intermediate_size: usize,
pub max_seq_len: usize,
pub window_size: usize,
pub rms_norm_eps: f32,
pub dropout: f32,
}
impl HydraConfig {
pub fn base() -> Self {
Self {
vocab_size: 32000,
d_model: 768,
num_layers: 24,
num_heads: 12,
d_state: 16,
d_conv: 4,
ssm_expansion: 2,
intermediate_size: 768 * 4,
max_seq_len: 8192,
window_size: 256,
rms_norm_eps: 1e-5,
dropout: 0.0,
}
}
pub fn small() -> Self {
Self {
vocab_size: 32000,
d_model: 256,
num_layers: 8,
num_heads: 4,
d_state: 16,
d_conv: 4,
ssm_expansion: 2,
intermediate_size: 256 * 4,
max_seq_len: 2048,
window_size: 128,
rms_norm_eps: 1e-5,
dropout: 0.0,
}
}
pub fn tiny() -> Self {
Self {
vocab_size: 1000,
d_model: 64,
num_layers: 4,
num_heads: 4,
d_state: 8,
d_conv: 4,
ssm_expansion: 2,
intermediate_size: 256,
max_seq_len: 512,
window_size: 64,
rms_norm_eps: 1e-5,
dropout: 0.0,
}
}
pub fn head_dim(&self) -> usize {
self.d_model / self.num_heads
}
pub fn d_inner(&self) -> usize {
self.d_model * self.ssm_expansion
}
}
#[derive(Debug)]
pub struct WindowedAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_heads: usize,
head_dim: usize,
window_size: usize,
attn_drop_p: f32,
}
impl WindowedAttention {
pub fn new(config: &HydraConfig) -> Self {
let head_dim = config.head_dim();
Self {
q_proj: Linear::new(config.d_model, config.d_model),
k_proj: Linear::new(config.d_model, config.d_model),
v_proj: Linear::new(config.d_model, config.d_model),
o_proj: Linear::new(config.d_model, config.d_model),
num_heads: config.num_heads,
head_dim,
window_size: config.window_size,
attn_drop_p: config.dropout,
}
}
pub fn forward(&self, x: &Variable) -> Variable {
let x_data = x.data();
let shape = x_data.shape();
let batch_size = shape[0];
let seq_len = shape[1];
let d_model = shape[2];
let q = self.q_proj.forward(x);
let k = self.k_proj.forward(x);
let v = self.v_proj.forward(x);
let q = q
.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim])
.transpose(1, 2);
let k = k
.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim])
.transpose(1, 2);
let v = v
.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim])
.transpose(1, 2);
let scale = 1.0 / (self.head_dim as f32).sqrt();
let q_data = q.data();
let k_data = k.data();
let v_data = v.data();
let q_vec = q_data.to_vec();
let k_vec = k_data.to_vec();
let v_vec = v_data.to_vec();
let w = self.window_size;
let mut output = vec![0.0f32; batch_size * self.num_heads * seq_len * self.head_dim];
for b in 0..batch_size {
for h in 0..self.num_heads {
for i in 0..seq_len {
let win_start = if i >= w { i - w + 1 } else { 0 };
let win_end = i + 1; let win_len = win_end - win_start;
let mut scores = vec![0.0f32; win_len];
let q_off = ((b * self.num_heads + h) * seq_len + i) * self.head_dim;
for (wi, j) in (win_start..win_end).enumerate() {
let k_off = ((b * self.num_heads + h) * seq_len + j) * self.head_dim;
let mut dot = 0.0f32;
for d in 0..self.head_dim {
dot += q_vec[q_off + d] * k_vec[k_off + d];
}
scores[wi] = dot * scale;
}
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut exp_sum = 0.0f32;
for s in &mut scores {
*s = (*s - max_score).exp();
exp_sum += *s;
}
if exp_sum > 0.0 {
for s in &mut scores {
*s /= exp_sum;
}
}
if self.attn_drop_p > 0.0 && is_grad_enabled() {
let mut rng = rand::thread_rng();
let keep_scale = 1.0 / (1.0 - self.attn_drop_p);
for s in &mut scores {
if rng.r#gen::<f32>() < self.attn_drop_p {
*s = 0.0;
} else {
*s *= keep_scale;
}
}
}
let o_off = ((b * self.num_heads + h) * seq_len + i) * self.head_dim;
for (wi, j) in (win_start..win_end).enumerate() {
let v_off = ((b * self.num_heads + h) * seq_len + j) * self.head_dim;
for d in 0..self.head_dim {
output[o_off + d] += scores[wi] * v_vec[v_off + d];
}
}
}
}
}
let attn_out = Tensor::from_vec(
output,
&[batch_size, self.num_heads, seq_len, self.head_dim],
)
.unwrap();
let requires_grad = x.requires_grad() && is_grad_enabled();
let attn_var = if requires_grad {
let grad_fn = GradFn::new(WindowedAttnBackward {
next_fns: vec![x.grad_fn().cloned()],
saved_q: q_data.clone(),
saved_k: k_data.clone(),
saved_v: v_data.clone(),
num_heads: self.num_heads,
head_dim: self.head_dim,
window_size: self.window_size,
scale,
});
Variable::from_operation(attn_out, grad_fn, true)
} else {
Variable::new(attn_out, false)
};
let reshaped = attn_var
.transpose(1, 2)
.reshape(&[batch_size, seq_len, d_model]);
self.o_proj.forward(&reshaped)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.q_proj.parameters());
params.extend(self.k_proj.parameters());
params.extend(self.v_proj.parameters());
params.extend(self.o_proj.parameters());
params
}
}
#[derive(Debug)]
struct WindowedAttnBackward {
next_fns: Vec<Option<GradFn>>,
saved_q: Tensor<f32>,
saved_k: Tensor<f32>,
saved_v: Tensor<f32>,
num_heads: usize,
head_dim: usize,
window_size: usize,
scale: f32,
}
impl GradientFunction for WindowedAttnBackward {
fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
let shape = grad_output.shape();
let batch_size = shape[0];
let num_heads = self.num_heads;
let seq_len = shape[2];
let head_dim = self.head_dim;
let w = self.window_size;
let scale = self.scale;
let q_vec = self.saved_q.to_vec();
let k_vec = self.saved_k.to_vec();
let v_vec = self.saved_v.to_vec();
let go_vec = grad_output.to_vec();
let total = batch_size * num_heads * seq_len * head_dim;
let mut grad_q = vec![0.0f32; total];
let mut grad_k = vec![0.0f32; total];
let mut grad_v = vec![0.0f32; total];
for b in 0..batch_size {
for h in 0..num_heads {
for i in 0..seq_len {
let win_start = if i >= w { i - w + 1 } else { 0 };
let win_end = i + 1;
let win_len = win_end - win_start;
let q_off = ((b * num_heads + h) * seq_len + i) * head_dim;
let mut scores = vec![0.0f32; win_len];
for (wi, j) in (win_start..win_end).enumerate() {
let k_off = ((b * num_heads + h) * seq_len + j) * head_dim;
let mut dot = 0.0f32;
for d in 0..head_dim {
dot += q_vec[q_off + d] * k_vec[k_off + d];
}
scores[wi] = dot * scale;
}
let max_s = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut attn_w = vec![0.0f32; win_len];
let mut exp_sum = 0.0f32;
for (wi, &s) in scores.iter().enumerate() {
attn_w[wi] = (s - max_s).exp();
exp_sum += attn_w[wi];
}
if exp_sum > 0.0 {
for a in &mut attn_w {
*a /= exp_sum;
}
}
let go_off = ((b * num_heads + h) * seq_len + i) * head_dim;
for (wi, j) in (win_start..win_end).enumerate() {
let v_off = ((b * num_heads + h) * seq_len + j) * head_dim;
for d in 0..head_dim {
grad_v[v_off + d] += attn_w[wi] * go_vec[go_off + d];
}
}
let mut d_attn = vec![0.0f32; win_len];
for (wi, j) in (win_start..win_end).enumerate() {
let v_off = ((b * num_heads + h) * seq_len + j) * head_dim;
for d in 0..head_dim {
d_attn[wi] += go_vec[go_off + d] * v_vec[v_off + d];
}
}
let sum_da_aw: f32 = d_attn
.iter()
.zip(attn_w.iter())
.map(|(da, aw)| da * aw)
.sum();
let mut d_scores = vec![0.0f32; win_len];
for wi in 0..win_len {
d_scores[wi] = attn_w[wi] * (d_attn[wi] - sum_da_aw);
}
for (wi, j) in (win_start..win_end).enumerate() {
let k_off = ((b * num_heads + h) * seq_len + j) * head_dim;
for d in 0..head_dim {
grad_q[q_off + d] += d_scores[wi] * k_vec[k_off + d] * scale;
}
}
for (wi, j) in (win_start..win_end).enumerate() {
let k_off = ((b * num_heads + h) * seq_len + j) * head_dim;
for d in 0..head_dim {
grad_k[k_off + d] += d_scores[wi] * q_vec[q_off + d] * scale;
}
}
}
}
}
let mut grad_input = vec![0.0f32; total];
for i in 0..total {
grad_input[i] = grad_q[i] + grad_k[i] + grad_v[i];
}
let gi = Tensor::from_vec(grad_input, shape).unwrap();
vec![Some(gi)]
}
fn name(&self) -> &'static str {
"WindowedAttnBackward"
}
fn next_functions(&self) -> &[Option<GradFn>] {
&self.next_fns
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Debug)]
pub struct HydraMLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
}
impl HydraMLP {
pub fn new(d_model: usize, intermediate_size: usize) -> Self {
Self {
gate_proj: Linear::new(d_model, intermediate_size),
up_proj: Linear::new(d_model, intermediate_size),
down_proj: Linear::new(intermediate_size, d_model),
}
}
pub fn forward(&self, x: &Variable) -> Variable {
let gate = self.gate_proj.forward(x).silu();
let up = self.up_proj.forward(x);
let hidden = gate.mul(&up);
self.down_proj.forward(&hidden)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.gate_proj.parameters());
params.extend(self.up_proj.parameters());
params.extend(self.down_proj.parameters());
params
}
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub enum HydraBlock {
SSM {
norm: RMSNorm,
ssm: SSMBlock,
mlp_norm: RMSNorm,
mlp: HydraMLP,
},
Attention {
norm: RMSNorm,
attn: WindowedAttention,
mlp_norm: RMSNorm,
mlp: HydraMLP,
},
}
impl HydraBlock {
pub fn new_ssm(config: &HydraConfig) -> Self {
let ssm_config = SSMConfig {
vocab_size: config.vocab_size,
num_layers: 1,
rms_norm_eps: config.rms_norm_eps,
d_model: config.d_model,
d_state: config.d_state,
d_inner: config.d_inner(),
d_conv: config.d_conv,
dt_rank: config.d_model.div_ceil(16),
};
HydraBlock::SSM {
norm: RMSNorm::new(config.d_model, config.rms_norm_eps),
ssm: SSMBlock::new(&ssm_config),
mlp_norm: RMSNorm::new(config.d_model, config.rms_norm_eps),
mlp: HydraMLP::new(config.d_model, config.intermediate_size),
}
}
pub fn new_attention(config: &HydraConfig) -> Self {
HydraBlock::Attention {
norm: RMSNorm::new(config.d_model, config.rms_norm_eps),
attn: WindowedAttention::new(config),
mlp_norm: RMSNorm::new(config.d_model, config.rms_norm_eps),
mlp: HydraMLP::new(config.d_model, config.intermediate_size),
}
}
pub fn forward(&self, x: &Variable) -> Variable {
match self {
HydraBlock::SSM {
norm,
ssm,
mlp_norm,
mlp,
} => {
let residual = x.clone();
let h = norm.forward(x);
let h = ssm.forward(&h);
let h = residual.add(&h);
let residual = h.clone();
let h = mlp_norm.forward(&h);
let h = mlp.forward(&h);
residual.add(&h)
}
HydraBlock::Attention {
norm,
attn,
mlp_norm,
mlp,
} => {
let residual = x.clone();
let h = norm.forward(x);
let h = attn.forward(&h);
let h = residual.add(&h);
let residual = h.clone();
let h = mlp_norm.forward(&h);
let h = mlp.forward(&h);
residual.add(&h)
}
}
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
match self {
HydraBlock::SSM {
norm,
ssm,
mlp_norm,
mlp,
} => {
params.extend(norm.parameters());
params.extend(ssm.parameters());
params.extend(mlp_norm.parameters());
params.extend(mlp.parameters());
}
HydraBlock::Attention {
norm,
attn,
mlp_norm,
mlp,
} => {
params.extend(norm.parameters());
params.extend(attn.parameters());
params.extend(mlp_norm.parameters());
params.extend(mlp.parameters());
}
}
params
}
}
#[derive(Debug)]
pub struct HydraModel {
embed_tokens: Embedding,
blocks: Vec<HydraBlock>,
final_norm: RMSNorm,
lm_head: Linear,
config: HydraConfig,
}
impl HydraModel {
pub fn new(config: &HydraConfig) -> Self {
let mut blocks = Vec::with_capacity(config.num_layers);
for i in 0..config.num_layers {
if i % 2 == 0 {
blocks.push(HydraBlock::new_ssm(config));
} else {
blocks.push(HydraBlock::new_attention(config));
}
}
Self {
embed_tokens: Embedding::new(config.vocab_size, config.d_model),
blocks,
final_norm: RMSNorm::new(config.d_model, config.rms_norm_eps),
lm_head: Linear::new(config.d_model, config.vocab_size),
config: config.clone(),
}
}
pub fn forward_ids(&self, input_ids: &Tensor<u32>) -> Variable {
let ids_f32: Vec<f32> = input_ids.to_vec().iter().map(|&x| x as f32).collect();
let ids_var = Variable::new(Tensor::from_vec(ids_f32, input_ids.shape()).unwrap(), false);
let mut hidden = self.embed_tokens.forward(&ids_var);
for block in &self.blocks {
hidden = block.forward(&hidden);
}
let hidden = self.final_norm.forward(&hidden);
self.lm_head.forward(&hidden)
}
pub fn forward_with_loss(
&self,
input_ids: &Tensor<u32>,
labels: &Tensor<u32>,
) -> (Variable, Variable) {
let logits = self.forward_ids(input_ids);
let logits_data = logits.data();
let shape = logits_data.shape().to_vec();
let batch_size = shape[0];
let seq_len = shape[1];
let _vocab_size = shape[2];
drop(logits_data);
if seq_len > 1 {
let shift_logits = logits.narrow(1, 0, seq_len - 1);
let labels_vec = labels.to_vec();
let mut shift_labels_data = Vec::with_capacity(batch_size * (seq_len - 1));
for b in 0..batch_size {
for s in 1..seq_len {
shift_labels_data.push(labels_vec[b * seq_len + s]);
}
}
let shift_labels =
Tensor::from_vec(shift_labels_data, &[batch_size, seq_len - 1]).unwrap();
let loss = Self::cross_entropy_loss(&shift_logits, &shift_labels);
(logits, loss)
} else {
let zero = Variable::new(Tensor::from_vec(vec![0.0f32], &[1]).unwrap(), false);
(logits, zero)
}
}
fn cross_entropy_loss(logits: &Variable, labels: &Tensor<u32>) -> Variable {
let logits_data = logits.data();
let shape = logits_data.shape().to_vec();
drop(logits_data);
let batch_size = shape[0];
let seq_len = shape[1];
let vocab_size = shape[2];
let logits_flat = logits.reshape(&[batch_size * seq_len, vocab_size]);
let labels_vec = labels.to_vec();
let valid_labels: Vec<f32> = labels_vec
.iter()
.map(|&l| {
if (l as usize) < vocab_size {
l as f32
} else {
0.0f32
}
})
.collect();
let target_var = Variable::new(
Tensor::from_vec(valid_labels, &[batch_size * seq_len]).unwrap(),
false,
);
use axonml_nn::loss::CrossEntropyLoss;
CrossEntropyLoss::new().compute(&logits_flat, &target_var)
}
pub fn param_count(&self) -> usize {
self.parameters().iter().map(|p| p.data().numel()).sum()
}
pub fn config(&self) -> &HydraConfig {
&self.config
}
pub fn train(&mut self) {
}
pub fn eval(&mut self) {
}
}
impl Module for HydraModel {
fn forward(&self, input: &Variable) -> Variable {
let mut hidden = input.clone();
for block in &self.blocks {
hidden = block.forward(&hidden);
}
let hidden = self.final_norm.forward(&hidden);
self.lm_head.forward(&hidden)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.embed_tokens.parameters());
for block in &self.blocks {
params.extend(block.parameters());
}
params.extend(self.final_norm.parameters());
params.extend(self.lm_head.parameters());
params
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hydra_tiny_forward() {
let config = HydraConfig::tiny();
let model = HydraModel::new(&config);
let input_ids = Tensor::<u32>::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8], &[2, 4]).unwrap();
let logits = model.forward_ids(&input_ids);
assert_eq!(logits.data().shape()[0], 2);
assert_eq!(logits.data().shape()[1], 4);
assert_eq!(logits.data().shape()[2], config.vocab_size);
}
#[test]
fn test_hydra_tiny_with_loss() {
let config = HydraConfig::tiny();
let model = HydraModel::new(&config);
let input_ids = Tensor::<u32>::from_vec(vec![1, 2, 3, 4], &[1, 4]).unwrap();
let labels = Tensor::<u32>::from_vec(vec![2, 3, 4, 5], &[1, 4]).unwrap();
let (_logits, loss) = model.forward_with_loss(&input_ids, &labels);
let loss_val = loss.data().to_vec()[0];
assert!(loss_val > 0.0, "Cross-entropy loss should be positive");
}
#[test]
fn test_hydra_param_count() {
let config = HydraConfig::tiny();
let model = HydraModel::new(&config);
let count = model.param_count();
assert!(count > 0, "Model should have parameters");
println!("Hydra tiny params: {count}");
}
#[test]
fn test_hydra_alternating_blocks() {
let config = HydraConfig::tiny();
let model = HydraModel::new(&config);
for (i, block) in model.blocks.iter().enumerate() {
match block {
HydraBlock::SSM { .. } => assert_eq!(i % 2, 0, "SSM at even index"),
HydraBlock::Attention { .. } => assert_eq!(i % 2, 1, "Attn at odd index"),
}
}
}
#[test]
fn test_windowed_attention_shape() {
let config = HydraConfig::tiny();
let attn = WindowedAttention::new(&config);
let x = Variable::new(
Tensor::from_vec(vec![0.1f32; 2 * 8 * 64], &[2, 8, 64]).unwrap(),
true,
);
let y = attn.forward(&x);
assert_eq!(y.data().shape(), &[2, 8, 64]);
}
}