use std::io::Read;
use crate::phi4::config::{Phi4Config, Phi4RopeScaling};
use trustformers_core::{
device::Device,
errors::{tensor_op_error, Result, TrustformersError},
layers::{Embedding, Linear},
ops::activations::silu,
tensor::Tensor,
traits::{Config, Layer, Model},
};
pub struct Phi4RmsNorm {
weight: Tensor,
eps: f32,
}
impl Phi4RmsNorm {
pub fn new(normalized_shape: usize, eps: f64) -> Result<Self> {
let weight = Tensor::ones(&[normalized_shape])?;
Ok(Self {
weight,
eps: eps as f32,
})
}
}
impl Layer for Phi4RmsNorm {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Tensor) -> Result<Tensor> {
match (&input, &self.weight) {
(Tensor::F32(arr), Tensor::F32(w)) => {
let mean_sq = arr.iter().map(|x| x * x).sum::<f32>() / arr.len().max(1) as f32;
let rms = (mean_sq + self.eps).sqrt();
let result = arr.mapv(|x| x / rms);
Ok(Tensor::F32(&result * w))
},
_ => Err(tensor_op_error(
"tensor_operation",
"Unsupported tensor dtype for Phi4RmsNorm",
)),
}
}
}
pub struct Phi4RotaryEmbedding {
head_dim: usize,
rope_theta: f64,
rope_scaling: Option<Phi4RopeScaling>,
original_max_position_embeddings: usize,
}
impl Phi4RotaryEmbedding {
pub fn new(config: &Phi4Config) -> Self {
Self {
head_dim: config.head_dim,
rope_theta: config.rope_theta,
rope_scaling: config.rope_scaling.clone(),
original_max_position_embeddings: config.original_max_position_embeddings,
}
}
pub fn apply(&self, q: &[f32], k: &[f32], seq_len: usize) -> (Vec<f32>, Vec<f32>) {
let half_dim = self.head_dim / 2;
let mut q_out = q.to_vec();
let mut k_out = k.to_vec();
for pos in 0..seq_len {
let scale_factors: Option<&Vec<f32>> = self.rope_scaling.as_ref().map(|rs| {
if pos > self.original_max_position_embeddings {
&rs.long_factor
} else {
&rs.short_factor
}
});
for i in 0..half_dim {
let dim_scale =
scale_factors.and_then(|sf| sf.get(i)).copied().unwrap_or(1.0) as f64;
let freq =
1.0 / (self.rope_theta.powf(2.0 * i as f64 / self.head_dim as f64) * dim_scale);
let angle = (pos as f64 * freq) as f32;
let cos_v = angle.cos();
let sin_v = angle.sin();
let base = pos * self.head_dim;
if let (Some(q0), Some(q1)) = (q_out.get(base + i), q_out.get(base + i + half_dim))
{
let (q0, q1) = (*q0, *q1);
q_out[base + i] = q0 * cos_v - q1 * sin_v;
q_out[base + i + half_dim] = q0 * sin_v + q1 * cos_v;
}
if let (Some(k0), Some(k1)) = (k_out.get(base + i), k_out.get(base + i + half_dim))
{
let (k0, k1) = (*k0, *k1);
k_out[base + i] = k0 * cos_v - k1 * sin_v;
k_out[base + i + half_dim] = k0 * sin_v + k1 * cos_v;
}
}
}
(q_out, k_out)
}
}
pub struct Phi4MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
}
impl Phi4MLP {
pub fn new(config: &Phi4Config) -> Self {
Self {
gate_proj: Linear::new(config.hidden_size, config.intermediate_size, false),
up_proj: Linear::new(config.hidden_size, config.intermediate_size, false),
down_proj: Linear::new(config.intermediate_size, config.hidden_size, false),
}
}
}
impl Layer for Phi4MLP {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Tensor) -> Result<Tensor> {
let gate = self.gate_proj.forward(input.clone())?;
let up = self.up_proj.forward(input)?;
let activated_gate = silu(&gate)?;
let gated = match (&activated_gate, &up) {
(Tensor::F32(g), Tensor::F32(u)) => Tensor::F32(g * u),
_ => {
return Err(tensor_op_error(
"tensor_operation",
"Type mismatch in Phi4MLP gated activation",
))
},
};
self.down_proj.forward(gated)
}
}
pub struct Phi4Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
rotary_emb: Phi4RotaryEmbedding,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
}
impl Phi4Attention {
pub fn new(config: &Phi4Config) -> Self {
let head_dim = config.head_dim;
let num_kv_heads = config.num_key_value_heads;
Self {
q_proj: Linear::new(
config.hidden_size,
config.num_attention_heads * head_dim,
false,
),
k_proj: Linear::new(config.hidden_size, num_kv_heads * head_dim, false),
v_proj: Linear::new(config.hidden_size, num_kv_heads * head_dim, false),
o_proj: Linear::new(
config.num_attention_heads * head_dim,
config.hidden_size,
false,
),
rotary_emb: Phi4RotaryEmbedding::new(config),
num_heads: config.num_attention_heads,
num_kv_heads,
head_dim,
}
}
fn gqa_attention(&self, q: &[f32], k: &[f32], v: &[f32], seq_len: usize) -> Vec<f32> {
let kv_group = self.num_heads / self.num_kv_heads.max(1);
let scale = 1.0 / (self.head_dim as f32).sqrt();
let mut output = vec![0.0f32; self.num_heads * seq_len * self.head_dim];
for h in 0..self.num_heads {
let kv_h = h / kv_group;
for qi in 0..seq_len {
let mut scores = vec![0.0f32; seq_len];
for kj in 0..=qi {
let mut dot = 0.0f32;
for d in 0..self.head_dim {
let qv = q
.get(h * seq_len * self.head_dim + qi * self.head_dim + d)
.copied()
.unwrap_or(0.0);
let kv = k
.get(kv_h * seq_len * self.head_dim + kj * self.head_dim + d)
.copied()
.unwrap_or(0.0);
dot += qv * kv;
}
scores[kj] = dot * scale;
}
for kj in (qi + 1)..seq_len {
scores[kj] = f32::NEG_INFINITY;
}
let max_s = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut exp_sum = 0.0f32;
let mut exp_scores: Vec<f32> = scores
.iter()
.map(|&s| {
let e = (s - max_s).exp();
exp_sum += e;
e
})
.collect();
if exp_sum > 0.0 {
for es in &mut exp_scores {
*es /= exp_sum;
}
}
for d in 0..self.head_dim {
let mut weighted = 0.0f32;
for kj in 0..seq_len {
let vv = v
.get(kv_h * seq_len * self.head_dim + kj * self.head_dim + d)
.copied()
.unwrap_or(0.0);
weighted += exp_scores[kj] * vv;
}
output[h * seq_len * self.head_dim + qi * self.head_dim + d] = weighted;
}
}
}
output
}
}
impl Layer for Phi4Attention {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Tensor) -> Result<Tensor> {
let q_tensor = self.q_proj.forward(input.clone())?;
let k_tensor = self.k_proj.forward(input.clone())?;
let v_tensor = self.v_proj.forward(input)?;
match (&q_tensor, &k_tensor, &v_tensor) {
(Tensor::F32(q_arr), Tensor::F32(k_arr), Tensor::F32(v_arr)) => {
let q_slice = q_arr.as_slice().unwrap_or(&[]);
let k_slice = k_arr.as_slice().unwrap_or(&[]);
let total_q = q_arr.len();
let seq_len =
total_q.checked_div(self.num_heads * self.head_dim).unwrap_or(1).max(1);
let (q_rot, k_rot) = self.rotary_emb.apply(q_slice, k_slice, seq_len);
let v_slice = v_arr.as_slice().unwrap_or(&[]);
let context = self.gqa_attention(&q_rot, &k_rot, v_slice, seq_len);
let context_tensor =
Tensor::from_vec(context, &[seq_len, self.num_heads * self.head_dim])?;
self.o_proj.forward(context_tensor)
},
_ => Err(tensor_op_error(
"tensor_operation",
"Unsupported dtype in Phi4Attention",
)),
}
}
}
pub struct Phi4DecoderLayer {
self_attn: Phi4Attention,
mlp: Phi4MLP,
input_layernorm: Phi4RmsNorm,
post_attention_layernorm: Phi4RmsNorm,
}
impl Phi4DecoderLayer {
pub fn new(config: &Phi4Config) -> Result<Self> {
Ok(Self {
self_attn: Phi4Attention::new(config),
mlp: Phi4MLP::new(config),
input_layernorm: Phi4RmsNorm::new(config.hidden_size, config.rms_norm_eps)?,
post_attention_layernorm: Phi4RmsNorm::new(config.hidden_size, config.rms_norm_eps)?,
})
}
}
impl Layer for Phi4DecoderLayer {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Tensor) -> Result<Tensor> {
let normed = self.input_layernorm.forward(input.clone())?;
let attn_out = self.self_attn.forward(normed)?;
let residual = input.add(&attn_out)?;
let normed2 = self.post_attention_layernorm.forward(residual.clone())?;
let mlp_out = self.mlp.forward(normed2)?;
residual.add(&mlp_out)
}
}
pub struct Phi4Model {
config: Phi4Config,
embed_tokens: Embedding,
layers: Vec<Phi4DecoderLayer>,
norm: Phi4RmsNorm,
}
impl Phi4Model {
pub fn new(config: Phi4Config) -> Result<Self> {
Config::validate(&config)?;
let embed_tokens = Embedding::new(config.vocab_size, config.hidden_size, None)?;
let mut layers = Vec::with_capacity(config.num_hidden_layers);
for _ in 0..config.num_hidden_layers {
layers.push(Phi4DecoderLayer::new(&config)?);
}
let norm = Phi4RmsNorm::new(config.hidden_size, config.rms_norm_eps)?;
Ok(Self {
config,
embed_tokens,
layers,
norm,
})
}
pub fn config(&self) -> &Phi4Config {
&self.config
}
}
impl Model for Phi4Model {
type Config = Phi4Config;
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input_ids: Tensor) -> Result<Tensor> {
let token_ids: Vec<u32> = match &input_ids {
Tensor::I64(arr) => arr.as_slice().unwrap_or(&[]).iter().map(|&x| x as u32).collect(),
Tensor::F32(arr) => {
arr.as_slice().unwrap_or(&[]).iter().map(|&x| x.round() as u32).collect()
},
_ => {
return Err(tensor_op_error(
"tensor_operation",
"Unsupported input dtype for Phi4Model",
))
},
};
let mut hidden = self.embed_tokens.forward(token_ids)?;
for layer in &self.layers {
hidden = layer.forward(hidden)?;
}
self.norm.forward(hidden)
}
fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
Ok(())
}
fn get_config(&self) -> &Phi4Config {
&self.config
}
fn num_parameters(&self) -> usize {
let c = &self.config;
let embed = c.vocab_size * c.hidden_size;
let attn = (c.num_attention_heads + 2 * c.num_key_value_heads) * c.head_dim * c.hidden_size
+ c.num_attention_heads * c.head_dim * c.hidden_size;
let mlp = 3 * c.hidden_size * c.intermediate_size;
let norms = 2 * c.hidden_size;
let layer = attn + mlp + norms;
embed + c.num_hidden_layers * layer + c.hidden_size
}
}
impl Phi4Model {
#[allow(dead_code)]
fn get_num_layers(&self) -> usize {
self.config.num_hidden_layers
}
}
impl std::ops::Deref for Phi4Model {
type Target = Phi4Config;
fn deref(&self) -> &Phi4Config {
&self.config
}
}
#[allow(dead_code)]
fn _assert_device_import() {
let _ = Device::CPU;
}
#[allow(dead_code)]
fn _assert_trustformers_error() -> TrustformersError {
TrustformersError::io_error("unused".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::phi4::config::{Phi4Config, Phi4RopeScaling};
fn lcg_next(state: &mut u64) -> f32 {
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
((*state >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
}
fn lcg_vec(n: usize, seed: u64) -> Vec<f32> {
let mut state = seed;
(0..n).map(|_| lcg_next(&mut state)).collect()
}
fn tiny_phi4_config() -> Phi4Config {
Phi4Config {
vocab_size: 64,
hidden_size: 8,
intermediate_size: 16,
num_hidden_layers: 2,
num_attention_heads: 4,
num_key_value_heads: 2,
head_dim: 2,
max_position_embeddings: 32,
original_max_position_embeddings: 16,
rms_norm_eps: 1e-5,
rope_theta: 250000.0,
hidden_act: "silu".to_string(),
tie_word_embeddings: true,
attention_dropout: 0.0,
embd_dropout: 0.0,
rope_scaling: None,
}
}
fn tiny_phi4_config_with_longrope() -> Phi4Config {
let short = vec![1.0_f32; 1]; let long = vec![2.0_f32; 1];
Phi4Config {
rope_scaling: Some(Phi4RopeScaling {
rope_type: "longrope".to_string(),
short_factor: short,
long_factor: long,
short_mscale: 1.0,
long_mscale: 1.0,
original_max_position_embeddings: 16,
}),
..tiny_phi4_config()
}
}
#[test]
fn test_phi4_rmsnorm_unit_rms() {
let norm = Phi4RmsNorm::new(4, 1e-5).expect("Phi4RmsNorm must build");
let input =
Tensor::from_vec(vec![3.0_f32, 4.0, 0.0, 0.0], &[4]).expect("tensor must build");
let output = norm.forward(input).expect("forward must succeed");
let vals: Vec<f32> = match &output {
Tensor::F32(arr) => arr.as_slice().expect("must be contiguous").to_vec(),
_ => panic!("expected F32"),
};
let rms = (vals.iter().map(|x| x * x).sum::<f32>() / 4.0).sqrt();
assert!(
(rms - 1.0).abs() < 1e-4,
"RMSNorm output rms must be ~1.0, got {rms}"
);
}
#[test]
fn test_phi4_rmsnorm_all_ones_output_unchanged() {
let norm = Phi4RmsNorm::new(4, 1e-5).expect("must build");
let input = Tensor::from_vec(vec![1.0_f32; 4], &[4]).expect("must build");
let output = norm.forward(input).expect("must succeed");
let vals: Vec<f32> = match &output {
Tensor::F32(arr) => arr.as_slice().expect("contiguous").to_vec(),
_ => panic!("expected F32"),
};
for &v in &vals {
assert!(
(v - 1.0).abs() < 1e-4,
"all-ones normalised with ones weight should give ones"
);
}
}
#[test]
fn test_phi4_rope_no_scaling_preserves_norm() {
let cfg = tiny_phi4_config();
let rope = Phi4RotaryEmbedding::new(&cfg);
let n = cfg.head_dim;
let q = lcg_vec(n, 50);
let k = lcg_vec(n, 51);
let q_norm_before: f32 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
let (q_out, _k_out) = rope.apply(&q, &k, 1);
let q_norm_after: f32 = q_out.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(q_norm_before - q_norm_after).abs() < 1e-4,
"RoPE without scaling must preserve norm, before={q_norm_before} after={q_norm_after}",
);
}
#[test]
fn test_phi4_rope_with_longrope_scaling() {
let cfg = tiny_phi4_config_with_longrope();
let rope = Phi4RotaryEmbedding::new(&cfg);
let n = cfg.head_dim;
let q = lcg_vec(n, 52);
let k = lcg_vec(n, 53);
let (q_out, k_out) = rope.apply(&q, &k, 1);
assert_eq!(q_out.len(), n, "LongRoPE output Q length must match input");
assert_eq!(k_out.len(), n, "LongRoPE output K length must match input");
}
#[test]
fn test_phi4_rope_short_range_uses_short_factor() {
let cfg = tiny_phi4_config_with_longrope();
let rope = Phi4RotaryEmbedding::new(&cfg);
let q = vec![1.0_f32, 0.0];
let k = vec![1.0_f32, 0.0];
let (q_out_short, _) = rope.apply(&q, &k, 1);
assert!(
(q_out_short[0] - 1.0).abs() < 1e-4,
"short range RoPE at pos 0 must be identity, got {}",
q_out_short[0],
);
}
#[test]
fn test_phi4_config_default_values() {
let cfg = Phi4Config::default();
assert_eq!(
cfg.hidden_size, 5120,
"Phi-4 default hidden_size must be 5120"
);
assert_eq!(cfg.num_hidden_layers, 40, "Phi-4 must have 40 layers");
assert_eq!(cfg.num_attention_heads, 40, "Phi-4 must have 40 Q heads");
assert_eq!(cfg.num_key_value_heads, 10, "Phi-4 must have 10 KV heads");
assert_eq!(cfg.head_dim, 128, "Phi-4 head_dim must be 128");
assert!(
(cfg.rope_theta - 250000.0).abs() < 1.0,
"Phi-4 rope_theta must be 250000"
);
}
#[test]
fn test_phi4_config_validate_ok() {
let cfg = tiny_phi4_config();
assert!(
Config::validate(&cfg).is_ok(),
"tiny phi4 config must pass validation"
);
}
#[test]
fn test_phi4_config_gqa_ratio() {
let cfg = Phi4Config::default();
let ratio = cfg.num_attention_heads / cfg.num_key_value_heads;
assert_eq!(ratio, 4, "Phi-4 GQA ratio must be 40/10 = 4");
}
#[test]
fn test_phi4_mlp_output_length() {
let cfg = tiny_phi4_config();
let mlp = Phi4MLP::new(&cfg);
let input = Tensor::from_vec(lcg_vec(cfg.hidden_size, 60), &[1, cfg.hidden_size])
.expect("tensor must build");
let output = mlp.forward(input).expect("MLP forward must succeed");
let out_len = output.shape().iter().product::<usize>();
assert_eq!(
out_len, cfg.hidden_size,
"Phi4MLP output must have hidden_size elements"
);
}
#[test]
fn test_phi4_model_construction() {
let cfg = tiny_phi4_config();
let model = Phi4Model::new(cfg).expect("Phi4Model must build");
assert_eq!(
model.config().num_hidden_layers,
2,
"model must have 2 layers"
);
}
#[test]
fn test_phi4_model_forward_single_token() {
let cfg = tiny_phi4_config();
let model = Phi4Model::new(cfg.clone()).expect("model must build");
let input = Tensor::from_vec(vec![0_f32], &[1]).expect("f32 token must build");
let _ = model.forward(input);
}
#[test]
fn test_phi4_model_forward_f32_input() {
let cfg = tiny_phi4_config();
let model = Phi4Model::new(cfg).expect("model must build");
let input = Tensor::from_vec(vec![2.0_f32], &[1]).expect("f32 token must build");
let _ = model.forward(input);
}
#[test]
fn test_phi4_model_num_parameters_positive() {
let cfg = tiny_phi4_config();
let model = Phi4Model::new(cfg).expect("model must build");
assert!(
model.num_parameters() > 0,
"num_parameters must be positive"
);
}
#[test]
fn test_phi4_model_deref_to_config() {
let cfg = tiny_phi4_config();
let model = Phi4Model::new(cfg.clone()).expect("model must build");
assert_eq!(
model.hidden_size, cfg.hidden_size,
"deref must give config's hidden_size"
);
}
#[test]
fn test_phi4_decoder_layer_construction() {
let cfg = tiny_phi4_config();
let _layer = Phi4DecoderLayer::new(&cfg).expect("Phi4DecoderLayer must build");
}
}