use crate::gpt_neox::config::GPTNeoXConfig;
use crate::llama::model::RotaryEmbedding; use std::io::Read;
use trustformers_core::{
device::Device,
errors::{tensor_op_error, Result},
layers::{Embedding, LayerNorm, Linear},
ops::activations::gelu,
tensor::Tensor,
traits::{Config, Layer, Model},
};
pub struct GPTNeoXMLP {
pub dense_h_to_4h: Linear,
pub dense_4h_to_h: Linear,
}
impl GPTNeoXMLP {
pub fn new(config: &GPTNeoXConfig) -> Result<Self> {
Ok(Self {
dense_h_to_4h: Linear::new(config.hidden_size, config.intermediate_size, true),
dense_4h_to_h: Linear::new(config.intermediate_size, config.hidden_size, true),
})
}
pub fn new_with_device(config: &GPTNeoXConfig, device: Device) -> Result<Self> {
Ok(Self {
dense_h_to_4h: Linear::new_with_device(
config.hidden_size,
config.intermediate_size,
true,
device,
),
dense_4h_to_h: Linear::new_with_device(
config.intermediate_size,
config.hidden_size,
true,
device,
),
})
}
pub fn parameter_count(&self) -> usize {
self.dense_h_to_4h.parameter_count() + self.dense_4h_to_h.parameter_count()
}
#[cfg(all(target_os = "macos", feature = "metal"))]
pub fn weights_to_gpu(
&mut self,
device: &trustformers_core::device::Device,
) -> trustformers_core::errors::Result<()> {
self.dense_h_to_4h.weights_to_gpu(device)?;
self.dense_4h_to_h.weights_to_gpu(device)?;
Ok(())
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
pub fn weights_to_gpu_cuda(
&mut self,
device: &trustformers_core::device::Device,
) -> trustformers_core::errors::Result<()> {
self.dense_h_to_4h.weights_to_gpu_cuda(device)?;
self.dense_4h_to_h.weights_to_gpu_cuda(device)?;
Ok(())
}
}
impl Layer for GPTNeoXMLP {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let hidden = self.dense_h_to_4h.forward(input)?;
let activated = gelu(&hidden)?;
self.dense_4h_to_h.forward(activated)
}
}
pub struct GPTNeoXAttention {
pub query_key_value: Linear, pub dense: Linear, pub _rotary_emb: RotaryEmbedding, pub _num_heads: usize,
pub _head_dim: usize,
pub _rotary_ndims: usize,
}
impl GPTNeoXAttention {
pub fn new(config: &GPTNeoXConfig) -> Result<Self> {
let head_dim = config.hidden_size / config.num_attention_heads;
let rotary_ndims = (head_dim as f32 * config.rotary_pct) as usize;
Ok(Self {
query_key_value: Linear::new(config.hidden_size, config.hidden_size * 3, true),
dense: Linear::new(config.hidden_size, config.hidden_size, true),
_rotary_emb: RotaryEmbedding::new(
rotary_ndims,
config.max_position_embeddings,
config.rotary_emb_base,
),
_num_heads: config.num_attention_heads,
_head_dim: head_dim,
_rotary_ndims: rotary_ndims,
})
}
pub fn new_with_device(config: &GPTNeoXConfig, device: Device) -> Result<Self> {
let head_dim = config.hidden_size / config.num_attention_heads;
let rotary_ndims = (head_dim as f32 * config.rotary_pct) as usize;
Ok(Self {
query_key_value: Linear::new_with_device(
config.hidden_size,
config.hidden_size * 3,
true,
device,
),
dense: Linear::new_with_device(config.hidden_size, config.hidden_size, true, device),
_rotary_emb: RotaryEmbedding::new(
rotary_ndims,
config.max_position_embeddings,
config.rotary_emb_base,
),
_num_heads: config.num_attention_heads,
_head_dim: head_dim,
_rotary_ndims: rotary_ndims,
})
}
pub fn parameter_count(&self) -> usize {
self.query_key_value.parameter_count() + self.dense.parameter_count()
}
#[cfg(all(target_os = "macos", feature = "metal"))]
pub fn weights_to_gpu(
&mut self,
device: &trustformers_core::device::Device,
) -> trustformers_core::errors::Result<()> {
self.query_key_value.weights_to_gpu(device)?;
self.dense.weights_to_gpu(device)?;
Ok(())
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
pub fn weights_to_gpu_cuda(
&mut self,
device: &trustformers_core::device::Device,
) -> trustformers_core::errors::Result<()> {
self.query_key_value.weights_to_gpu_cuda(device)?;
self.dense.weights_to_gpu_cuda(device)?;
Ok(())
}
}
impl Layer for GPTNeoXAttention {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
use scirs2_core::ndarray::{s, Array2};
#[cfg(all(target_os = "macos", feature = "metal"))]
let input = match &input {
Tensor::Metal(_) => input.to_device_enum(&trustformers_core::device::Device::CPU)?,
_ => input,
};
#[cfg(feature = "cuda")]
let input = match &input {
Tensor::CUDA(_) => input.to_device_enum(&trustformers_core::device::Device::CPU)?,
_ => input,
};
#[cfg(not(any(feature = "metal", feature = "cuda")))]
let input = input;
let shape = input.shape();
let seq_len = if shape.len() == 2 { shape[0] } else { shape[1] };
let qkv = self.query_key_value.forward(input.clone())?;
match qkv {
Tensor::F32(arr) => {
let shape = arr.shape();
if shape.len() != 2 {
return Err(tensor_op_error(
"GPTNeoXAttention::forward",
format!("Expected 2D tensor, got shape: {:?}", shape),
));
}
let num_heads = self._num_heads;
let head_dim = self._head_dim;
let qkv_reshaped = arr
.to_shape((seq_len, num_heads, 3 * head_dim))
.map_err(|_| tensor_op_error("GPTNeoXAttention", "QKV reshape failed"))?;
let qkv_transposed = qkv_reshaped.permuted_axes([1, 0, 2]);
let q = qkv_transposed.slice(s![.., .., 0..head_dim]).to_owned();
let k = qkv_transposed.slice(s![.., .., head_dim..2 * head_dim]).to_owned();
let v = qkv_transposed.slice(s![.., .., 2 * head_dim..3 * head_dim]).to_owned();
let rotary_ndims = self._rotary_ndims;
let q_transposed = q.permuted_axes([1, 0, 2]);
let k_transposed = k.permuted_axes([1, 0, 2]);
let mut q_rope = q_transposed.to_owned();
let mut k_rope = k_transposed.to_owned();
let half_rotary_ndims = rotary_ndims / 2;
for pos in 0..seq_len {
for h in 0..num_heads {
for i in 0..half_rotary_ndims {
let j = i + half_rotary_ndims;
let freq = 1.0
/ self._rotary_emb.base.powf(2.0 * i as f32 / rotary_ndims as f32);
let angle = pos as f32 * freq;
let cos_val = angle.cos();
let sin_val = angle.sin();
let q_i = q_rope[[pos, h, i]];
let q_j = q_rope[[pos, h, j]];
q_rope[[pos, h, i]] = q_i * cos_val - q_j * sin_val;
q_rope[[pos, h, j]] = q_i * sin_val + q_j * cos_val;
let k_i = k_rope[[pos, h, i]];
let k_j = k_rope[[pos, h, j]];
k_rope[[pos, h, i]] = k_i * cos_val - k_j * sin_val;
k_rope[[pos, h, j]] = k_i * sin_val + k_j * cos_val;
}
}
}
let v_reshaped = v.permuted_axes([1, 0, 2]);
let q_reshaped = q_rope;
let k_reshaped = k_rope;
let scale = (head_dim as f32).sqrt();
let hidden_size = num_heads * head_dim;
let mut attn_output = Array2::<f32>::zeros((seq_len, hidden_size));
#[cfg(all(target_os = "macos", feature = "metal"))]
let try_metal_attention = |q_head: &Array2<f32>,
k_head: &Array2<f32>,
v_head: &Array2<f32>,
_h: usize|
-> Option<Array2<f32>> {
use trustformers_core::gpu_ops::metal::get_metal_backend;
let backend = get_metal_backend().ok()?;
let q_vec: Vec<f32> = q_head.iter().copied().collect();
let k_t = k_head.t();
let k_t_vec: Vec<f32> = k_t.iter().copied().collect();
let v_vec: Vec<f32> = v_head.iter().copied().collect();
let scores_vec =
backend.matmul_f32(&q_vec, &k_t_vec, seq_len, head_dim, seq_len).ok()?;
let scores_scaled: Vec<f32> = scores_vec.iter().map(|&x| x / scale).collect();
let attn_weights_vec =
backend.softmax_causal_f32(&scores_scaled, seq_len).ok()?;
let output_vec = backend
.matmul_f32(&attn_weights_vec, &v_vec, seq_len, seq_len, head_dim)
.ok()?;
Array2::from_shape_vec((seq_len, head_dim), output_vec).ok()
};
let cpu_attention = |q_head: Array2<f32>,
k_head: Array2<f32>,
v_head: Array2<f32>|
-> Array2<f32> {
let k_t = k_head.t();
let mut scores = q_head.dot(&k_t) / scale;
for i in 0..seq_len {
for j in (i + 1)..seq_len {
scores[[i, j]] = f32::NEG_INFINITY;
}
}
let mut attn_weights = Array2::<f32>::zeros(scores.dim());
for i in 0..seq_len {
let row = scores.row(i);
let max_val = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_row: Vec<f32> = row.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f32 = exp_row.iter().sum();
for (j, &val) in exp_row.iter().enumerate() {
attn_weights[[i, j]] = val / sum;
}
}
attn_weights.dot(&v_head)
};
for h in 0..num_heads {
let q_head = q_reshaped.slice(s![.., h, ..]).to_owned();
let k_head = k_reshaped.slice(s![.., h, ..]).to_owned();
let v_head = v_reshaped.slice(s![.., h, ..]).to_owned();
let head_output: Array2<f32> = {
#[cfg(all(target_os = "macos", feature = "metal"))]
{
try_metal_attention(&q_head, &k_head, &v_head, h).unwrap_or_else(|| {
cpu_attention(q_head.clone(), k_head.clone(), v_head.clone())
})
}
#[cfg(not(all(target_os = "macos", feature = "metal")))]
{
cpu_attention(q_head, k_head, v_head)
}
};
let start_idx = h * head_dim;
let end_idx = start_idx + head_dim;
attn_output.slice_mut(s![.., start_idx..end_idx]).assign(&head_output);
}
self.dense.forward(Tensor::F32(attn_output.into_dyn()))
},
_ => Err(tensor_op_error(
"GPTNeoXAttention::forward",
"Unsupported tensor type",
)),
}
}
}
pub struct GPTNeoXLayer {
pub input_layernorm: LayerNorm,
pub post_attention_layernorm: LayerNorm,
pub attention: GPTNeoXAttention,
pub mlp: GPTNeoXMLP,
pub use_parallel_residual: bool,
}
impl GPTNeoXLayer {
pub fn new(config: &GPTNeoXConfig) -> Result<Self> {
Ok(Self {
input_layernorm: LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps)?,
post_attention_layernorm: LayerNorm::new(
vec![config.hidden_size],
config.layer_norm_eps,
)?,
attention: GPTNeoXAttention::new(config)?,
mlp: GPTNeoXMLP::new(config)?,
use_parallel_residual: config.use_parallel_residual,
})
}
pub fn new_with_device(config: &GPTNeoXConfig, device: Device) -> Result<Self> {
Ok(Self {
input_layernorm: LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps)?,
post_attention_layernorm: LayerNorm::new(
vec![config.hidden_size],
config.layer_norm_eps,
)?,
attention: GPTNeoXAttention::new_with_device(config, device)?,
mlp: GPTNeoXMLP::new_with_device(config, device)?,
use_parallel_residual: config.use_parallel_residual,
})
}
pub fn parameter_count(&self) -> usize {
self.attention.parameter_count() + self.mlp.parameter_count()
}
#[cfg(all(target_os = "macos", feature = "metal"))]
pub fn weights_to_gpu(
&mut self,
device: &trustformers_core::device::Device,
) -> trustformers_core::errors::Result<()> {
self.input_layernorm.weights_to_gpu(device)?;
self.attention.weights_to_gpu(device)?;
self.post_attention_layernorm.weights_to_gpu(device)?;
self.mlp.weights_to_gpu(device)?;
Ok(())
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
pub fn weights_to_gpu_cuda(
&mut self,
device: &trustformers_core::device::Device,
) -> trustformers_core::errors::Result<()> {
self.input_layernorm.weights_to_gpu_cuda(device)?;
self.attention.weights_to_gpu_cuda(device)?;
self.post_attention_layernorm.weights_to_gpu_cuda(device)?;
self.mlp.weights_to_gpu_cuda(device)?;
Ok(())
}
}
impl Layer for GPTNeoXLayer {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
if self.use_parallel_residual {
let ln1_out = self.input_layernorm.forward(input.clone())?;
let mut attn_out = self.attention.forward(ln1_out)?;
#[cfg(feature = "cuda")]
if matches!(input, Tensor::CUDA(_)) {
attn_out = attn_out.to_device_enum(&self.input_layernorm.device())?;
}
let ln2_out = self.post_attention_layernorm.forward(input.clone())?;
let mut mlp_out = self.mlp.forward(ln2_out)?;
#[cfg(feature = "cuda")]
if matches!(input, Tensor::CUDA(_)) {
mlp_out = mlp_out.to_device_enum(&self.input_layernorm.device())?;
}
let temp = input.add(&attn_out)?;
temp.add(&mlp_out)
} else {
let ln1_out = self.input_layernorm.forward(input.clone())?;
let mut attn_out = self.attention.forward(ln1_out)?;
#[cfg(feature = "cuda")]
if matches!(input, Tensor::CUDA(_)) {
attn_out = attn_out.to_device_enum(&self.input_layernorm.device())?;
}
let residual = input.add(&attn_out)?;
let ln2_out = self.post_attention_layernorm.forward(residual.clone())?;
let mut mlp_out = self.mlp.forward(ln2_out)?;
#[cfg(feature = "cuda")]
if matches!(input, Tensor::CUDA(_)) {
mlp_out = mlp_out.to_device_enum(&self.input_layernorm.device())?;
}
residual.add(&mlp_out)
}
}
}
pub struct GPTNeoXModel {
pub embed_in: Embedding,
pub layers: Vec<GPTNeoXLayer>,
pub final_layer_norm: LayerNorm,
config: GPTNeoXConfig,
}
impl GPTNeoXModel {
pub fn new(config: GPTNeoXConfig) -> Result<Self> {
config.validate()?;
let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(GPTNeoXLayer::new(&config)?);
}
Ok(Self {
embed_in: Embedding::new(config.vocab_size, config.hidden_size, None)?,
layers,
final_layer_norm: LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps)?,
config,
})
}
pub fn new_with_device(config: GPTNeoXConfig, device: Device) -> Result<Self> {
config.validate()?;
let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(GPTNeoXLayer::new_with_device(&config, device)?);
}
Ok(Self {
embed_in: Embedding::new_with_device(
config.vocab_size,
config.hidden_size,
None,
device,
)?,
layers,
final_layer_norm: LayerNorm::new_with_device(
vec![config.hidden_size],
config.layer_norm_eps,
device,
)?,
config,
})
}
#[cfg(all(target_os = "macos", feature = "metal"))]
pub fn weights_to_gpu(
&mut self,
device: &trustformers_core::device::Device,
) -> trustformers_core::errors::Result<()> {
self.embed_in.weights_to_gpu(device)?;
for layer in &mut self.layers {
layer.weights_to_gpu(device)?;
}
self.final_layer_norm.weights_to_gpu(device)?;
Ok(())
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
pub fn weights_to_gpu_cuda(
&mut self,
device: &trustformers_core::device::Device,
) -> trustformers_core::errors::Result<()> {
for layer in &mut self.layers {
layer.weights_to_gpu_cuda(device)?;
}
self.final_layer_norm.weights_to_gpu_cuda(device)?;
println!("✓ GPTNeoXModel: All layer weights cached on CUDA GPU");
Ok(())
}
pub fn load_from_path(&mut self, model_path: impl AsRef<std::path::Path>) -> Result<()> {
use crate::weight_loading::{auto_create_loader, WeightLoadingConfig};
let model_path = model_path.as_ref();
let config = WeightLoadingConfig {
lazy_loading: true,
memory_mapped: false,
..Default::default()
};
let mut loader = auto_create_loader(model_path, Some(config.clone()))?;
if let Ok(embed_weights) = loader.load_tensor("gpt_neox.embed_in.weight") {
self.embed_in.set_weight(embed_weights)?;
}
match loader.load_tensor("gpt_neox.final_layer_norm.weight") {
Ok(final_ln_weight) => {
if let Tensor::F32(ref arr) = final_ln_weight {
use scirs2_core::ndarray::s;
let first_10 = arr.slice(s![0..10]);
eprintln!(
"[DEBUG] final_layer_norm.weight first 10: {:?}",
first_10.iter().take(10).collect::<Vec<_>>()
);
let mean = arr.mean().unwrap_or(0.0);
eprintln!(
"[DEBUG] final_layer_norm.weight mean: {:.3} (expected: 6.688)",
mean
);
}
self.final_layer_norm.set_weight(final_ln_weight)?;
},
Err(e) => {
eprintln!("[ERROR] Failed to load final_layer_norm.weight: {:?}", e);
},
}
match loader.load_tensor("gpt_neox.final_layer_norm.bias") {
Ok(final_ln_bias) => {
self.final_layer_norm.set_bias(final_ln_bias)?;
},
Err(e) => {
eprintln!("[ERROR] Failed to load final_layer_norm.bias: {:?}", e);
},
}
for (i, layer) in self.layers.iter_mut().enumerate() {
let prefix = format!("gpt_neox.layers.{}", i);
if let Ok(qkv_weights) =
loader.load_tensor(&format!("{}.attention.query_key_value.weight", prefix))
{
layer.attention.query_key_value.set_weight(qkv_weights)?;
}
if let Ok(qkv_bias) =
loader.load_tensor(&format!("{}.attention.query_key_value.bias", prefix))
{
layer.attention.query_key_value.set_bias(qkv_bias)?;
}
if let Ok(dense_weights) =
loader.load_tensor(&format!("{}.attention.dense.weight", prefix))
{
layer.attention.dense.set_weight(dense_weights)?;
}
if let Ok(dense_bias) = loader.load_tensor(&format!("{}.attention.dense.bias", prefix))
{
layer.attention.dense.set_bias(dense_bias)?;
}
if let Ok(mlp_up_weights) =
loader.load_tensor(&format!("{}.mlp.dense_h_to_4h.weight", prefix))
{
layer.mlp.dense_h_to_4h.set_weight(mlp_up_weights)?;
}
if let Ok(mlp_up_bias) =
loader.load_tensor(&format!("{}.mlp.dense_h_to_4h.bias", prefix))
{
layer.mlp.dense_h_to_4h.set_bias(mlp_up_bias)?;
}
if let Ok(mlp_down_weights) =
loader.load_tensor(&format!("{}.mlp.dense_4h_to_h.weight", prefix))
{
layer.mlp.dense_4h_to_h.set_weight(mlp_down_weights)?;
}
if let Ok(mlp_down_bias) =
loader.load_tensor(&format!("{}.mlp.dense_4h_to_h.bias", prefix))
{
layer.mlp.dense_4h_to_h.set_bias(mlp_down_bias)?;
}
if let Ok(ln1_weight) =
loader.load_tensor(&format!("{}.input_layernorm.weight", prefix))
{
layer.input_layernorm.set_weight(ln1_weight)?;
}
if let Ok(ln1_bias) = loader.load_tensor(&format!("{}.input_layernorm.bias", prefix)) {
layer.input_layernorm.set_bias(ln1_bias)?;
}
if let Ok(ln2_weight) =
loader.load_tensor(&format!("{}.post_attention_layernorm.weight", prefix))
{
layer.post_attention_layernorm.set_weight(ln2_weight)?;
}
if let Ok(ln2_bias) =
loader.load_tensor(&format!("{}.post_attention_layernorm.bias", prefix))
{
layer.post_attention_layernorm.set_bias(ln2_bias)?;
}
}
loader.close()?;
Ok(())
}
}
impl Model for GPTNeoXModel {
type Config = GPTNeoXConfig;
type Input = Vec<u32>;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let mut hidden_states = self.embed_in.forward(input)?;
#[cfg(feature = "cuda")]
if !self.layers.is_empty() {
if let Some(first_layer) = self.layers.first() {
if matches!(
first_layer.attention.query_key_value.device(),
Device::CUDA(_)
) {
let device = first_layer.attention.query_key_value.device();
hidden_states = hidden_states.to_device_enum(&device)?;
}
}
}
for layer in &self.layers {
hidden_states = layer.forward(hidden_states)?;
}
self.final_layer_norm.forward(hidden_states)
}
fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
Err(
trustformers_core::errors::TrustformersError::not_implemented(
"Use load_from_path for GPT-NeoX weight loading".to_string(),
),
)
}
fn get_config(&self) -> &Self::Config {
&self.config
}
fn num_parameters(&self) -> usize {
let embed_params = self.embed_in.parameter_count();
let layer_params: usize = self.layers.iter().map(|l| l.parameter_count()).sum();
embed_params + layer_params
}
}
pub struct GPTNeoXForCausalLM {
pub gpt_neox: GPTNeoXModel,
pub embed_out: Linear,
}
impl GPTNeoXForCausalLM {
pub fn new(config: GPTNeoXConfig) -> Result<Self> {
let gpt_neox = GPTNeoXModel::new(config.clone())?;
let embed_out = Linear::new(config.hidden_size, config.vocab_size, false);
Ok(Self {
gpt_neox,
embed_out,
})
}
pub fn new_with_device(config: GPTNeoXConfig, device: Device) -> Result<Self> {
let gpt_neox = GPTNeoXModel::new_with_device(config.clone(), device)?;
let embed_out =
Linear::new_with_device(config.hidden_size, config.vocab_size, false, device);
Ok(Self {
gpt_neox,
embed_out,
})
}
#[cfg(all(target_os = "macos", feature = "metal"))]
pub fn weights_to_gpu(
&mut self,
device: &trustformers_core::device::Device,
) -> trustformers_core::errors::Result<()> {
self.gpt_neox.weights_to_gpu(device)?;
self.embed_out.weights_to_gpu(device)?;
println!("✓ All model weights uploaded to GPU");
Ok(())
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
pub fn weights_to_gpu_cuda(
&mut self,
device: &trustformers_core::device::Device,
) -> trustformers_core::errors::Result<()> {
self.gpt_neox.weights_to_gpu_cuda(device)?;
self.embed_out.weights_to_gpu_cuda(device)?;
println!("✓ All model weights uploaded to CUDA GPU");
Ok(())
}
pub fn load_from_path(&mut self, model_path: impl AsRef<std::path::Path>) -> Result<()> {
use crate::weight_loading::{auto_create_loader, WeightLoadingConfig};
self.gpt_neox.load_from_path(model_path.as_ref())?;
let config = WeightLoadingConfig {
lazy_loading: true,
memory_mapped: false,
..Default::default()
};
let mut loader = auto_create_loader(model_path, Some(config))?;
eprintln!("[DEBUG] Loading embed_out.weight...");
match loader.load_tensor("embed_out.weight") {
Ok(embed_out_weights) => {
eprintln!("[DEBUG] ✓ embed_out.weight loaded successfully");
if let Tensor::F32(ref arr) = embed_out_weights {
use scirs2_core::ndarray::s;
eprintln!("[DEBUG] embed_out.weight shape: {:?}", arr.shape());
let first_5 = arr.slice(s![0, 0..5]);
eprintln!(
"[DEBUG] embed_out.weight[0, 0..5]: {:?}",
first_5.iter().take(5).collect::<Vec<_>>()
);
}
self.embed_out.set_weight(embed_out_weights)?;
},
Err(e) => {
eprintln!("[ERROR] Failed to load embed_out.weight: {:?}", e);
eprintln!("[WARNING] LM head will use uninitialized/default weights!");
},
}
loader.close()?;
Ok(())
}
}
impl Model for GPTNeoXForCausalLM {
type Config = GPTNeoXConfig;
type Input = Vec<u32>;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let hidden_states = self.gpt_neox.forward(input)?;
let logits = self.embed_out.forward(hidden_states)?;
Ok(logits)
}
fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
self.gpt_neox.load_pretrained(reader)
}
fn get_config(&self) -> &Self::Config {
self.gpt_neox.get_config()
}
fn num_parameters(&self) -> usize {
self.gpt_neox.num_parameters() + self.embed_out.parameter_count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use trustformers_core::traits::{Config, Model};
fn tiny_config() -> GPTNeoXConfig {
GPTNeoXConfig {
vocab_size: 64,
hidden_size: 16,
num_hidden_layers: 1,
num_attention_heads: 2,
intermediate_size: 32,
max_position_embeddings: 32,
layer_norm_eps: 1e-5,
hidden_act: "gelu".to_string(),
rotary_emb_base: 10000.0,
rotary_pct: 0.25,
use_parallel_residual: false,
tie_word_embeddings: false,
initializer_range: 0.02,
bos_token_id: Some(0),
eos_token_id: Some(2),
}
}
#[test]
fn test_gpt_neox_config_default() {
let cfg = GPTNeoXConfig::default();
assert_eq!(cfg.hidden_size, 2048);
assert_eq!(cfg.num_hidden_layers, 16);
assert_eq!(cfg.num_attention_heads, 16);
assert!((cfg.rotary_pct - 1.0).abs() < 1e-6);
}
#[test]
fn test_gpt_neox_config_validate_ok() {
let cfg = GPTNeoXConfig::default();
assert!(cfg.validate().is_ok(), "default config should validate");
}
#[test]
fn test_gpt_neox_config_validate_bad_heads() {
let cfg = GPTNeoXConfig {
hidden_size: 17,
num_attention_heads: 3,
..GPTNeoXConfig::default()
};
assert!(
cfg.validate().is_err(),
"hidden not divisible by heads should fail"
);
}
#[test]
fn test_gpt_neox_config_rotary_pct_invalid() {
let cfg = GPTNeoXConfig {
rotary_pct: 1.5, ..GPTNeoXConfig::default()
};
assert!(cfg.validate().is_err(), "rotary_pct > 1.0 should fail");
}
#[test]
fn test_gpt_neox_rotary_ndims_25pct() {
let cfg = GPTNeoXConfig {
hidden_size: 6144,
num_attention_heads: 64,
rotary_pct: 0.25,
..GPTNeoXConfig::default()
};
let head_dim = cfg.hidden_size / cfg.num_attention_heads; let rotary_ndims = (head_dim as f32 * cfg.rotary_pct) as usize; assert_eq!(head_dim, 96);
assert_eq!(
rotary_ndims, 24,
"25% of head_dim=96 should give 24 rotary dims"
);
}
#[test]
fn test_gpt_neox_pythia_160m_config() {
let cfg = GPTNeoXConfig::pythia_160m();
assert_eq!(cfg.hidden_size, 768);
assert_eq!(cfg.num_hidden_layers, 12);
assert!(
(cfg.rotary_pct - 0.25).abs() < 1e-6,
"Pythia uses 25% rotary"
);
assert!(cfg.use_parallel_residual, "Pythia uses parallel residual");
}
#[test]
fn test_gpt_neox_tie_embeddings_false() {
let cfg = GPTNeoXConfig::default();
assert!(!cfg.tie_word_embeddings, "tie_embeddings should be false");
}
#[test]
fn test_gpt_neox_architecture_name() {
let cfg = GPTNeoXConfig::default();
assert_eq!(cfg.architecture(), "gpt_neox");
}
#[test]
fn test_gpt_neox_parallel_residual_flag() {
let cfg_parallel = GPTNeoXConfig {
use_parallel_residual: true,
..tiny_config()
};
assert!(cfg_parallel.use_parallel_residual);
let cfg_sequential = GPTNeoXConfig {
use_parallel_residual: false,
..tiny_config()
};
assert!(!cfg_sequential.use_parallel_residual);
}
#[test]
fn test_gpt_neox_mlp_construction() {
let cfg = tiny_config();
let mlp = GPTNeoXMLP::new(&cfg);
assert!(mlp.is_ok(), "GPTNeoXMLP should construct");
}
#[test]
fn test_gpt_neox_mlp_parameter_count() {
let cfg = tiny_config();
let mlp = GPTNeoXMLP::new(&cfg).expect("GPTNeoXMLP should construct");
assert!(
mlp.parameter_count() > 0,
"MLP should have positive parameter count"
);
}
#[test]
fn test_gpt_neox_attention_construction() {
let cfg = tiny_config();
let attn = GPTNeoXAttention::new(&cfg);
assert!(attn.is_ok(), "GPTNeoXAttention should construct");
}
#[test]
fn test_gpt_neox_attention_rotary_ndims() {
let cfg = tiny_config();
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
let rotary_ndims = (head_dim as f32 * cfg.rotary_pct) as usize;
assert_eq!(rotary_ndims, 2, "tiny config: 8 * 0.25 = 2 rotary dims");
}
#[test]
fn test_gpt_neox_model_construction() {
let cfg = tiny_config();
let model = GPTNeoXModel::new(cfg);
assert!(model.is_ok(), "GPTNeoXModel should construct");
}
#[test]
fn test_gpt_neox_model_num_parameters() {
let cfg = tiny_config();
let model = GPTNeoXModel::new(cfg).expect("GPTNeoXModel should construct");
assert!(
model.num_parameters() > 0,
"model should have positive params"
);
}
#[test]
fn test_gpt_neox_model_layers_count() {
let cfg = tiny_config();
let model = GPTNeoXModel::new(cfg.clone()).expect("GPTNeoXModel should construct");
assert_eq!(model.layers.len(), cfg.num_hidden_layers);
}
#[test]
fn test_gpt_neox_model_forward_output_shape() {
let cfg = tiny_config();
let model = GPTNeoXModel::new(cfg.clone()).expect("GPTNeoXModel should construct");
let input_ids: Vec<u32> = vec![1, 2, 3];
let output = model.forward(input_ids).expect("GPTNeoXModel forward should succeed");
let shape = output.shape();
let last_dim = *shape.last().expect("output should have dimensions");
assert_eq!(
last_dim, cfg.hidden_size,
"output last dim should be hidden_size"
);
}
#[test]
fn test_gpt_neox_causal_lm_construction() {
let cfg = tiny_config();
let model = GPTNeoXForCausalLM::new(cfg);
assert!(model.is_ok(), "GPTNeoXForCausalLM should construct");
}
#[test]
fn test_gpt_neox_causal_lm_num_params_larger_than_base() {
let cfg = tiny_config();
let base = GPTNeoXModel::new(cfg.clone()).expect("GPTNeoXModel should construct");
let lm = GPTNeoXForCausalLM::new(cfg).expect("GPTNeoXForCausalLM should construct");
assert!(
lm.num_parameters() > base.num_parameters(),
"LM model should have more params than base (extra embed_out)"
);
}
#[test]
fn test_gpt_neox_causal_lm_forward_vocab_size() {
let cfg = tiny_config();
let model =
GPTNeoXForCausalLM::new(cfg.clone()).expect("GPTNeoXForCausalLM should construct");
let input_ids: Vec<u32> = vec![1, 2];
let output = model.forward(input_ids).expect("forward should succeed");
let shape = output.shape();
let last_dim = *shape.last().expect("output should have dimensions");
assert_eq!(last_dim, cfg.vocab_size, "last dim must be vocab_size");
}
}