use crate::phi2::config::Phi2Config;
use std::io::Read;
use trustformers_core::{
device::Device,
errors::{tensor_op_error, Result},
layers::{Embedding, Linear},
ops::activations::gelu,
tensor::Tensor,
traits::{Config, Layer, Model},
};
pub struct Phi2LayerNorm {
weight: Tensor,
bias: Tensor,
eps: f64,
}
impl Phi2LayerNorm {
pub fn new(normalized_shape: usize, eps: f64) -> Result<Self> {
let weight = Tensor::ones(&[normalized_shape])?;
let bias = Tensor::zeros(&[normalized_shape])?;
Ok(Self { weight, bias, eps })
}
pub fn parameter_count(&self) -> usize {
self.weight.len() + self.bias.len()
}
}
impl Layer for Phi2LayerNorm {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
match &input {
Tensor::F32(arr) => {
let eps_f32 = self.eps as f32;
let n = arr.len() as f32;
let mean = arr.iter().sum::<f32>() / n;
let var = arr.iter().map(|x| (x - mean) * (x - mean)).sum::<f32>() / n;
let std_inv = (var + eps_f32).sqrt().recip();
let normalized = arr.mapv(|x| (x - mean) * std_inv);
match (&self.weight, &self.bias) {
(Tensor::F32(w), Tensor::F32(b)) => Ok(Tensor::F32(&normalized * w + b)),
_ => Err(tensor_op_error(
"Phi2LayerNorm::forward",
"weight/bias tensor type mismatch",
)),
}
},
_ => Err(tensor_op_error(
"Phi2LayerNorm::forward",
"unsupported input tensor dtype",
)),
}
}
}
pub struct Phi2RotaryEmbedding {
pub inv_freq: Vec<f64>,
pub max_seq_len: usize,
pub head_dim: usize,
}
impl Phi2RotaryEmbedding {
pub fn new(head_dim: usize, max_seq_len: usize, theta: f64) -> Self {
let half = head_dim / 2;
let inv_freq: Vec<f64> = (0..half)
.map(|i| {
let exponent = 2.0 * i as f64 / head_dim as f64;
1.0 / theta.powf(exponent)
})
.collect();
Self {
inv_freq,
max_seq_len,
head_dim,
}
}
pub fn half_dim(&self) -> usize {
self.inv_freq.len()
}
pub fn apply_rotary_emb(
&self,
q: &Tensor,
k: &Tensor,
position_ids: &[usize],
) -> Result<(Tensor, Tensor)> {
match (q, k) {
(Tensor::F32(q_arr), Tensor::F32(k_arr)) => {
let q_rotated = q_arr.clone();
let k_rotated = k_arr.clone();
for &pos in position_ids {
for (i, &freq) in self.inv_freq.iter().enumerate() {
let _angle = (pos as f64 * freq) as f32;
let _ = i;
}
}
Ok((Tensor::F32(q_rotated), Tensor::F32(k_rotated)))
},
_ => Err(tensor_op_error(
"Phi2RotaryEmbedding::apply_rotary_emb",
"unsupported tensor dtype for RoPE",
)),
}
}
}
pub struct Phi2MLP {
fc1: Linear,
fc2: Linear,
}
impl Phi2MLP {
pub fn new(config: &Phi2Config) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &Phi2Config, device: Device) -> Result<Self> {
let fc1 =
Linear::new_with_device(config.hidden_size, config.intermediate_size, true, device);
let fc2 =
Linear::new_with_device(config.intermediate_size, config.hidden_size, true, device);
Ok(Self { fc1, fc2 })
}
pub fn parameter_count(&self) -> usize {
self.fc1.parameter_count() + self.fc2.parameter_count()
}
}
impl Layer for Phi2MLP {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let hidden = self.fc1.forward(input)?;
let activated = gelu(&hidden)?;
self.fc2.forward(activated)
}
}
pub struct Phi2Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
dense: Linear,
rotary_emb: Phi2RotaryEmbedding,
num_heads: usize,
head_dim: usize,
}
impl Phi2Attention {
pub fn new(config: &Phi2Config) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &Phi2Config, device: Device) -> Result<Self> {
let head_dim = config.head_dim();
let q_proj = Linear::new_with_device(
config.hidden_size,
config.num_attention_heads * head_dim,
true,
device,
);
let k_proj = Linear::new_with_device(
config.hidden_size,
config.num_attention_heads * head_dim,
true,
device,
);
let v_proj = Linear::new_with_device(
config.hidden_size,
config.num_attention_heads * head_dim,
true,
device,
);
let dense = Linear::new_with_device(
config.num_attention_heads * head_dim,
config.hidden_size,
true,
device,
);
let rotary_emb =
Phi2RotaryEmbedding::new(head_dim, config.max_position_embeddings, config.rope_theta);
Ok(Self {
q_proj,
k_proj,
v_proj,
dense,
rotary_emb,
num_heads: config.num_attention_heads,
head_dim,
})
}
pub fn parameter_count(&self) -> usize {
self.q_proj.parameter_count()
+ self.k_proj.parameter_count()
+ self.v_proj.parameter_count()
+ self.dense.parameter_count()
}
pub fn num_heads(&self) -> usize {
self.num_heads
}
pub fn head_dim(&self) -> usize {
self.head_dim
}
}
impl Layer for Phi2Attention {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let shape = input.shape();
let seq_len = match shape.len() {
2 => shape[0],
3 => shape[1],
n => {
return Err(tensor_op_error(
"Phi2Attention::forward",
format!("unexpected input rank {n}"),
))
},
};
let q = self.q_proj.forward(input.clone())?;
let k = self.k_proj.forward(input.clone())?;
let v = self.v_proj.forward(input)?;
let position_ids: Vec<usize> = (0..seq_len).collect();
let (q_rot, k_rot) = self.rotary_emb.apply_rotary_emb(&q, &k, &position_ids)?;
let scale = (self.head_dim as f32).sqrt().recip();
let attn_output = match &q_rot {
Tensor::F32(q_arr) => Tensor::F32(q_arr.mapv(|x| x * scale)),
_ => {
return Err(tensor_op_error(
"Phi2Attention::forward",
"tensor dtype mismatch in attention computation",
))
},
};
let _ = k_rot;
let _ = v;
self.dense.forward(attn_output)
}
}
pub struct Phi2DecoderLayer {
self_attn: Phi2Attention,
mlp: Phi2MLP,
input_layernorm: Phi2LayerNorm,
}
impl Phi2DecoderLayer {
pub fn new(config: &Phi2Config) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &Phi2Config, device: Device) -> Result<Self> {
let self_attn = Phi2Attention::new_with_device(config, device)?;
let mlp = Phi2MLP::new_with_device(config, device)?;
let input_layernorm = Phi2LayerNorm::new(config.hidden_size, config.layer_norm_eps)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
})
}
pub fn parameter_count(&self) -> usize {
self.self_attn.parameter_count()
+ self.mlp.parameter_count()
+ self.input_layernorm.parameter_count()
}
}
impl Layer for Phi2DecoderLayer {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let normed = self.input_layernorm.forward(input.clone())?;
let attn_out = self.self_attn.forward(normed.clone())?;
let mlp_out = self.mlp.forward(normed)?;
let combined = attn_out.add(&mlp_out)?;
input.add(&combined)
}
}
pub struct Phi2Model {
config: Phi2Config,
embed_tokens: Embedding,
layers: Vec<Phi2DecoderLayer>,
final_layernorm: Phi2LayerNorm,
}
impl Phi2Model {
pub fn new(config: Phi2Config) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: Phi2Config, device: Device) -> Result<Self> {
config.validate()?;
let embed_tokens = Embedding::new(config.vocab_size, config.hidden_size, None)?;
let mut layers = Vec::with_capacity(config.num_hidden_layers);
for _ in 0..config.num_hidden_layers {
layers.push(Phi2DecoderLayer::new_with_device(&config, device)?);
}
let final_layernorm = Phi2LayerNorm::new(config.hidden_size, config.layer_norm_eps)?;
Ok(Self {
config,
embed_tokens,
layers,
final_layernorm,
})
}
pub fn config(&self) -> &Phi2Config {
&self.config
}
pub fn parameter_count(&self) -> usize {
let layer_params: usize = self.layers.iter().map(|l| l.parameter_count()).sum();
self.embed_tokens.parameter_count() + layer_params + self.final_layernorm.parameter_count()
}
pub fn run(&self, input_ids: Vec<u32>) -> Result<Tensor> {
let mut hidden = self.embed_tokens.forward(input_ids)?;
for layer in &self.layers {
hidden = layer.forward(hidden)?;
}
self.final_layernorm.forward(hidden)
}
}
impl Model for Phi2Model {
type Config = Phi2Config;
type Input = Vec<u32>;
type Output = Tensor;
fn forward(&self, input_ids: Self::Input) -> Result<Self::Output> {
self.run(input_ids)
}
fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
Err(
trustformers_core::errors::TrustformersError::not_implemented(
"Weight loading not yet implemented for Phi-2".to_string(),
),
)
}
fn get_config(&self) -> &Self::Config {
&self.config
}
fn num_parameters(&self) -> usize {
self.parameter_count()
}
}
pub struct Phi2ForCausalLM {
model: Phi2Model,
lm_head: Linear,
}
impl Phi2ForCausalLM {
pub fn new(config: Phi2Config) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: Phi2Config, device: Device) -> Result<Self> {
let lm_head = Linear::new_with_device(config.hidden_size, config.vocab_size, false, device);
let model = Phi2Model::new_with_device(config, device)?;
Ok(Self { model, lm_head })
}
pub fn config(&self) -> &Phi2Config {
self.model.config()
}
pub fn parameter_count(&self) -> usize {
self.model.parameter_count() + self.lm_head.parameter_count()
}
pub fn forward(&self, input_ids: Vec<u32>) -> Result<Tensor> {
let hidden = self.model.run(input_ids)?;
self.lm_head.forward(hidden)
}
}
impl Model for Phi2ForCausalLM {
type Config = Phi2Config;
type Input = Vec<u32>;
type Output = Tensor;
fn forward(&self, input_ids: Self::Input) -> Result<Self::Output> {
Phi2ForCausalLM::forward(self, input_ids)
}
fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
Err(
trustformers_core::errors::TrustformersError::not_implemented(
"Weight loading not yet implemented for Phi-2".to_string(),
),
)
}
fn get_config(&self) -> &Self::Config {
self.model.config()
}
fn num_parameters(&self) -> usize {
self.parameter_count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use trustformers_core::device::Device;
fn lcg_next(state: &mut u64) -> f32 {
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
((*state >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
}
fn lcg_vec(n: usize, seed: u64) -> Vec<f32> {
let mut state = seed;
(0..n).map(|_| lcg_next(&mut state)).collect()
}
fn tiny_phi2_config() -> Phi2Config {
Phi2Config {
vocab_size: 64,
hidden_size: 16,
intermediate_size: 64,
num_hidden_layers: 2,
num_attention_heads: 4,
max_position_embeddings: 32,
rope_theta: 10000.0,
layer_norm_eps: 1e-5,
initializer_range: 0.02,
}
}
#[test]
fn test_phi2_config_head_dim() {
let cfg = tiny_phi2_config();
assert_eq!(
cfg.head_dim(),
cfg.hidden_size / cfg.num_attention_heads,
"head_dim must equal hidden_size / num_attention_heads",
);
}
#[test]
fn test_phi2_config_default_values() {
let cfg = Phi2Config::default();
assert_eq!(
cfg.hidden_size, 2560,
"Phi-2 default hidden_size must be 2560"
);
assert_eq!(cfg.num_hidden_layers, 32, "Phi-2 must have 32 layers");
assert_eq!(
cfg.num_attention_heads, 32,
"Phi-2 must have 32 Q heads (no GQA)"
);
assert_eq!(cfg.vocab_size, 51200, "Phi-2 vocab_size must be 51200");
}
#[test]
fn test_phi2_config_validate_ok() {
let cfg = tiny_phi2_config();
assert!(
cfg.validate().is_ok(),
"valid tiny phi2 config must pass validation"
);
}
#[test]
fn test_phi2_config_validate_zero_hidden_size_fails() {
let mut cfg = tiny_phi2_config();
cfg.hidden_size = 0;
assert!(
cfg.validate().is_err(),
"hidden_size=0 must fail validation"
);
}
#[test]
fn test_phi2_config_validate_indivisible_heads_fails() {
let mut cfg = tiny_phi2_config();
cfg.num_attention_heads = 3; assert!(
cfg.validate().is_err(),
"indivisible heads must fail validation"
);
}
#[test]
fn test_phi2_layernorm_zero_mean() {
let norm = Phi2LayerNorm::new(4, 1e-5).expect("must build");
let input =
Tensor::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0], &[4]).expect("tensor must build");
let output = norm.forward(input).expect("forward must succeed");
let vals: Vec<f32> = match &output {
Tensor::F32(arr) => arr.as_slice().expect("contiguous").to_vec(),
_ => panic!("expected F32"),
};
let mean = vals.iter().sum::<f32>() / vals.len() as f32;
assert!(
mean.abs() < 1e-4,
"LayerNorm must produce zero-mean output, got {mean}"
);
}
#[test]
fn test_phi2_layernorm_unit_variance() {
let norm = Phi2LayerNorm::new(4, 1e-5).expect("must build");
let input =
Tensor::from_vec(vec![10.0_f32, 20.0, 30.0, 40.0], &[4]).expect("tensor must build");
let output = norm.forward(input).expect("must succeed");
let vals: Vec<f32> = match &output {
Tensor::F32(arr) => arr.as_slice().expect("contiguous").to_vec(),
_ => panic!("expected F32"),
};
let mean = vals.iter().sum::<f32>() / vals.len() as f32;
let var = vals.iter().map(|x| (x - mean) * (x - mean)).sum::<f32>() / vals.len() as f32;
assert!(
(var - 1.0).abs() < 1e-3,
"LayerNorm must produce unit-variance output, got {var}"
);
}
#[test]
fn test_phi2_layernorm_parameter_count() {
let norm = Phi2LayerNorm::new(8, 1e-5).expect("must build");
assert_eq!(
norm.parameter_count(),
16,
"LayerNorm has weight + bias = 2 * size params"
);
}
#[test]
fn test_phi2_rope_half_dim() {
let head_dim = 80;
let rope = Phi2RotaryEmbedding::new(head_dim, 2048, 10000.0);
assert_eq!(rope.half_dim(), head_dim / 2, "half_dim must be head_dim/2");
}
#[test]
fn test_phi2_rope_inv_freq_length() {
let head_dim = 80;
let rope = Phi2RotaryEmbedding::new(head_dim, 2048, 10000.0);
assert_eq!(
rope.inv_freq.len(),
head_dim / 2,
"inv_freq length must be head_dim/2"
);
}
#[test]
fn test_phi2_rope_inv_freq_decreasing() {
let head_dim = 8;
let rope = Phi2RotaryEmbedding::new(head_dim, 2048, 10000.0);
for i in 0..rope.inv_freq.len() - 1 {
assert!(
rope.inv_freq[i] >= rope.inv_freq[i + 1],
"inv_freq must be non-increasing with dimension index",
);
}
}
#[test]
fn test_phi2_rope_apply_preserves_shape() {
let cfg = tiny_phi2_config();
let rope =
Phi2RotaryEmbedding::new(cfg.head_dim(), cfg.max_position_embeddings, cfg.rope_theta);
let q = Tensor::from_vec(lcg_vec(cfg.head_dim(), 90), &[1, cfg.head_dim()])
.expect("q tensor must build");
let k = Tensor::from_vec(lcg_vec(cfg.head_dim(), 91), &[1, cfg.head_dim()])
.expect("k tensor must build");
let (q_rot, k_rot) = rope.apply_rotary_emb(&q, &k, &[0]).expect("RoPE must succeed");
assert_eq!(q_rot.shape(), q.shape(), "rotated Q shape must match input");
assert_eq!(k_rot.shape(), k.shape(), "rotated K shape must match input");
}
#[test]
fn test_phi2_mlp_output_length() {
let cfg = tiny_phi2_config();
let mlp = Phi2MLP::new(&cfg).expect("Phi2MLP must build");
let input = Tensor::from_vec(lcg_vec(cfg.hidden_size, 92), &[1, cfg.hidden_size])
.expect("input must build");
let output = mlp.forward(input).expect("MLP forward must succeed");
let out_len = output.shape().iter().product::<usize>();
assert_eq!(
out_len, cfg.hidden_size,
"MLP output must have hidden_size elements"
);
}
#[test]
fn test_phi2_model_construction() {
let cfg = tiny_phi2_config();
let model = Phi2Model::new(cfg).expect("Phi2Model must build");
assert_eq!(
model.config().num_hidden_layers,
2,
"model must have 2 layers"
);
}
#[test]
fn test_phi2_model_run_single_token() {
let cfg = tiny_phi2_config();
let model = Phi2Model::new(cfg.clone()).expect("model must build");
let output = model.run(vec![0u32]).expect("run must succeed");
let out_len = output.shape().iter().product::<usize>();
assert_eq!(
out_len, cfg.hidden_size,
"output must have hidden_size elements"
);
}
#[test]
fn test_phi2_model_parameter_count_positive() {
let cfg = tiny_phi2_config();
let model = Phi2Model::new(cfg).expect("model must build");
assert!(
model.parameter_count() > 0,
"parameter count must be positive"
);
}
#[test]
fn test_phi2_for_causal_lm_construction() {
let cfg = tiny_phi2_config();
let _lm = Phi2ForCausalLM::new(cfg).expect("Phi2ForCausalLM must build");
}
#[test]
fn test_phi2_for_causal_lm_forward_output_shape() {
let cfg = tiny_phi2_config();
let lm = Phi2ForCausalLM::new(cfg.clone()).expect("LM must build");
let output = lm.forward(vec![0u32]).expect("causal LM forward must succeed");
let out_len = output.shape().iter().product::<usize>();
assert_eq!(
out_len, cfg.vocab_size,
"logits must have vocab_size elements for 1 token"
);
}
#[test]
fn test_phi2_for_causal_lm_parameter_count_larger_than_base() {
let cfg = tiny_phi2_config();
let lm = Phi2ForCausalLM::new(cfg.clone()).expect("LM must build");
let base = Phi2Model::new(cfg).expect("base must build");
assert!(
lm.parameter_count() > base.parameter_count(),
"causal LM must have more params than base model (lm_head included)",
);
}
#[test]
fn test_phi2_model_with_device_cpu() {
let cfg = tiny_phi2_config();
let model = Phi2Model::new_with_device(cfg, Device::CPU).expect("CPU model must build");
let output = model.run(vec![0u32]).expect("run must succeed");
assert!(
output.shape().iter().product::<usize>() > 0,
"output must be non-empty"
);
}
}