#![allow(dead_code)]
use scirs2_core::ndarray::{s, Array2, ArrayD, Axis, IxDyn}; use std::io::Read;
use trustformers_core::{
device::Device,
errors::{Result, TrustformersError},
layers::{Embedding, Linear},
tensor::Tensor,
traits::{Config, Layer, Model, TokenizedInput, WeightReader},
};
use super::config::T5Config;
#[derive(Clone)]
pub struct T5Model {
config: T5Config,
shared: Embedding,
encoder: T5Stack,
decoder: T5Stack,
device: Device,
}
impl T5Model {
pub fn new(config: T5Config) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: T5Config, device: Device) -> Result<Self> {
config.validate()?;
let shared = Embedding::new_with_device(config.vocab_size, config.d_model, None, device)?;
let encoder_config = config.clone();
let decoder_config = config.clone();
Ok(Self {
config,
shared,
encoder: T5Stack::new_with_device(encoder_config, true, device)?,
decoder: T5Stack::new_with_device(decoder_config, false, device)?,
device,
})
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn load_weights_from_reader(&mut self, reader: &mut dyn WeightReader) -> Result<()> {
self.shared.set_weight(reader.read_tensor("shared.weight")?)?;
self.encoder.load_weights(reader, "encoder")?;
self.decoder.load_weights(reader, "decoder")?;
Ok(())
}
}
impl Model for T5Model {
type Config = T5Config;
type Input = T5Input;
type Output = T5Output;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let encoder_hidden_states = if let Some(encoder_outputs) = input.encoder_outputs {
encoder_outputs
} else {
let input_embeds = self.shared.forward(input.input_ids.input_ids)?;
self.encoder.forward(input_embeds, None)?
};
let decoder_outputs = if let Some(decoder_input) = input.decoder_input_ids {
let decoder_embeds = self.shared.forward(decoder_input.input_ids)?;
let decoder_hidden =
self.decoder.forward(decoder_embeds, Some(&encoder_hidden_states))?;
Some(decoder_hidden)
} else {
None
};
Ok(T5Output {
last_hidden_state: decoder_outputs.unwrap_or(encoder_hidden_states.clone()),
encoder_last_hidden_state: Some(encoder_hidden_states),
past_key_values: None,
})
}
fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
Err(TrustformersError::model_error(
"Use load_weights_from_reader instead".to_string(),
))
}
fn get_config(&self) -> &Self::Config {
&self.config
}
fn num_parameters(&self) -> usize {
self.shared.parameter_count()
+ self.encoder.parameter_count()
+ self.decoder.parameter_count()
}
}
#[derive(Clone)]
#[allow(dead_code)]
pub struct T5ForConditionalGeneration {
transformer: T5Model,
lm_head: Linear,
device: Device,
}
impl T5ForConditionalGeneration {
pub fn new(config: T5Config) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: T5Config, device: Device) -> Result<Self> {
let transformer = T5Model::new_with_device(config.clone(), device)?;
let lm_head = Linear::new_with_device(config.d_model, config.vocab_size, false, device);
Ok(Self {
transformer,
lm_head,
device,
})
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn load_weights_from_reader(&mut self, reader: &mut dyn WeightReader) -> Result<()> {
self.transformer.load_weights_from_reader(reader)?;
let shared_weight = reader.read_tensor("shared.weight")?;
self.lm_head.set_weight(shared_weight)?;
Ok(())
}
pub fn generate(
&self,
input_ids: Vec<u32>,
max_length: usize,
num_beams: usize,
) -> Result<Vec<u32>> {
if input_ids.is_empty() {
return Err(TrustformersError::model_error(
"Empty input_ids provided".to_string(),
));
}
let encoder_input = T5Input {
input_ids: TokenizedInput {
input_ids: input_ids.clone(),
attention_mask: vec![1; input_ids.len()],
token_type_ids: None,
special_tokens_mask: None,
offset_mapping: None,
overflowing_tokens: None,
},
decoder_input_ids: None,
encoder_outputs: None,
};
let encoder_output = self.transformer.forward(encoder_input)?;
let encoder_hidden_states = encoder_output.encoder_last_hidden_state.ok_or_else(|| {
TrustformersError::model_error("No encoder outputs available".to_string())
})?;
if num_beams > 1 {
self.beam_search_generate(encoder_hidden_states, max_length, num_beams)
} else {
self.greedy_generate(encoder_hidden_states, max_length)
}
}
fn greedy_generate(
&self,
encoder_hidden_states: Tensor,
max_length: usize,
) -> Result<Vec<u32>> {
let mut generated_ids = vec![0];
for _ in 0..max_length {
let decoder_input = T5Input {
input_ids: TokenizedInput {
input_ids: vec![],
attention_mask: vec![],
token_type_ids: None,
special_tokens_mask: None,
offset_mapping: None,
overflowing_tokens: None,
}, decoder_input_ids: Some(TokenizedInput {
input_ids: generated_ids.clone(),
attention_mask: vec![1; generated_ids.len()],
token_type_ids: None,
special_tokens_mask: None,
offset_mapping: None,
overflowing_tokens: None,
}),
encoder_outputs: Some(encoder_hidden_states.clone()),
};
let output = self.forward(decoder_input)?;
let logits = match &output.logits {
Tensor::F32(arr) => {
let shape = arr.shape();
let seq_len = shape[shape.len() - 2];
let _vocab_size = shape[shape.len() - 1];
let last_token_slice = if shape.len() == 3 {
arr.slice(s![0, seq_len - 1, ..])
} else {
arr.slice(s![seq_len - 1, ..])
};
last_token_slice.to_owned()
},
_ => {
return Err(TrustformersError::tensor_op_error(
"Logits must be F32 tensor",
"tensor_operation",
))
},
};
let next_token_id = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx as u32)
.ok_or_else(|| {
TrustformersError::model_error("Failed to find next token".to_string())
})?;
if next_token_id == 1 {
break;
}
generated_ids.push(next_token_id);
}
if !generated_ids.is_empty() {
generated_ids.remove(0);
}
Ok(generated_ids)
}
fn beam_search_generate(
&self,
encoder_hidden_states: Tensor,
max_length: usize,
num_beams: usize,
) -> Result<Vec<u32>> {
#[derive(Clone)]
struct Beam {
tokens: Vec<u32>,
score: f32,
}
let mut beams = vec![Beam {
tokens: vec![0],
score: 0.0,
}]; let mut finished_beams = Vec::new();
for _step in 0..max_length {
let mut all_candidates = Vec::new();
for beam in &beams {
let decoder_input = T5Input {
input_ids: TokenizedInput {
input_ids: vec![],
attention_mask: vec![],
token_type_ids: None,
special_tokens_mask: None,
offset_mapping: None,
overflowing_tokens: None,
},
decoder_input_ids: Some(TokenizedInput {
input_ids: beam.tokens.clone(),
attention_mask: vec![1; beam.tokens.len()],
token_type_ids: None,
special_tokens_mask: None,
offset_mapping: None,
overflowing_tokens: None,
}),
encoder_outputs: Some(encoder_hidden_states.clone()),
};
let output = self.forward(decoder_input)?;
let logits = match &output.logits {
Tensor::F32(arr) => {
let shape = arr.shape();
let seq_len = shape[shape.len() - 2];
let last_token_slice = if shape.len() == 3 {
arr.slice(s![0, seq_len - 1, ..])
} else {
arr.slice(s![seq_len - 1, ..])
};
last_token_slice.to_owned()
},
_ => {
return Err(TrustformersError::tensor_op_error(
"Logits must be F32 tensor",
"tensor_operation",
))
},
};
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
let sum_exp: f32 = exp_logits.iter().sum();
let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum_exp).collect();
let mut token_scores: Vec<(usize, f32)> =
probs.iter().enumerate().map(|(idx, &prob)| (idx, prob.ln())).collect();
token_scores
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for &(token_id, log_prob) in token_scores.iter().take(num_beams) {
let mut new_tokens = beam.tokens.clone();
new_tokens.push(token_id as u32);
let new_score = beam.score + log_prob;
if token_id == 1 {
finished_beams.push(Beam {
tokens: new_tokens,
score: new_score,
});
} else {
all_candidates.push(Beam {
tokens: new_tokens,
score: new_score,
});
}
}
}
all_candidates
.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
beams = all_candidates.into_iter().take(num_beams).collect();
if beams.is_empty() {
break;
}
}
finished_beams.extend(beams);
if finished_beams.is_empty() {
return Err(TrustformersError::model_error(
"No valid sequences generated".to_string(),
));
}
finished_beams
.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
let best_beam = &finished_beams[0];
let mut result = best_beam.tokens.clone();
if !result.is_empty() {
result.remove(0);
}
if let Some(&1) = result.last() {
result.pop();
}
Ok(result)
}
}
impl Model for T5ForConditionalGeneration {
type Config = T5Config;
type Input = T5Input;
type Output = T5LMOutput;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let base_output = self.transformer.forward(input)?;
let logits = self.lm_head.forward(base_output.last_hidden_state)?;
Ok(T5LMOutput {
logits,
past_key_values: base_output.past_key_values,
encoder_last_hidden_state: base_output.encoder_last_hidden_state,
})
}
fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
self.transformer.load_pretrained(reader)
}
fn get_config(&self) -> &Self::Config {
self.transformer.get_config()
}
fn num_parameters(&self) -> usize {
self.transformer.num_parameters() + self.lm_head.parameter_count()
}
}
#[derive(Clone)]
#[allow(dead_code)]
struct T5Stack {
#[allow(dead_code)]
config: T5Config,
is_encoder: bool,
embed_tokens: Option<Embedding>, block: Vec<T5Block>,
final_layer_norm: T5LayerNorm,
dropout: f32,
device: Device,
}
impl T5Stack {
fn new(config: T5Config, is_encoder: bool) -> Result<Self> {
Self::new_with_device(config, is_encoder, Device::CPU)
}
fn new_with_device(config: T5Config, is_encoder: bool, device: Device) -> Result<Self> {
let num_layers = if is_encoder {
config.num_layers
} else {
config.num_decoder_layers.unwrap_or(config.num_layers)
};
let mut block = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
block.push(T5Block::new_with_device(&config, is_encoder, device)?);
}
Ok(Self {
config: config.clone(),
is_encoder,
embed_tokens: None,
block,
final_layer_norm: T5LayerNorm::new_with_device(
config.d_model,
config.layer_norm_epsilon,
device,
),
dropout: config.dropout_rate,
device,
})
}
fn device(&self) -> &Device {
&self.device
}
fn load_weights(&mut self, reader: &mut dyn WeightReader, prefix: &str) -> Result<()> {
for (i, block) in self.block.iter_mut().enumerate() {
block.load_weights(reader, &format!("{}.block.{}", prefix, i))?;
}
self.final_layer_norm
.load_weights(reader, &format!("{}.final_layer_norm", prefix))?;
Ok(())
}
fn forward(
&self,
hidden_states: Tensor,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let mut hidden_states = hidden_states;
let attention_mask = create_attention_mask(&hidden_states)?;
for block in &self.block {
hidden_states = block.forward(
hidden_states,
Some(&attention_mask),
encoder_hidden_states,
None, )?;
}
self.final_layer_norm.forward(hidden_states)
}
fn parameter_count(&self) -> usize {
let mut total = 0;
if let Some(ref embed) = self.embed_tokens {
total += embed.parameter_count();
}
for block in &self.block {
total += block.parameter_count();
}
total += self.final_layer_norm.parameter_count();
total
}
}
#[derive(Clone)]
struct T5Block {
is_encoder: bool,
self_attention: T5Attention,
cross_attention: Option<T5Attention>,
feed_forward: T5DenseReluDense,
device: Device,
}
impl T5Block {
fn new(config: &T5Config, is_encoder: bool) -> Result<Self> {
Self::new_with_device(config, is_encoder, Device::CPU)
}
fn new_with_device(config: &T5Config, is_encoder: bool, device: Device) -> Result<Self> {
let cross_attention = if !is_encoder {
Some(T5Attention::new_with_device(config, true, device)?)
} else {
None
};
Ok(Self {
is_encoder,
self_attention: T5Attention::new_with_device(config, false, device)?,
cross_attention,
feed_forward: T5DenseReluDense::new_with_device(config, device)?,
device,
})
}
fn device(&self) -> &Device {
&self.device
}
fn load_weights(&mut self, reader: &mut dyn WeightReader, prefix: &str) -> Result<()> {
self.self_attention.load_weights(reader, &format!("{}.layer.0", prefix))?;
if let Some(ref mut cross_attn) = self.cross_attention {
cross_attn.load_weights(reader, &format!("{}.layer.1", prefix))?;
}
let ff_idx = if self.is_encoder { 1 } else { 2 };
self.feed_forward
.load_weights(reader, &format!("{}.layer.{}", prefix, ff_idx))?;
Ok(())
}
fn forward(
&self,
hidden_states: Tensor,
attention_mask: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
cross_attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let mut hidden_states = self.self_attention.forward(
hidden_states.clone(),
None, attention_mask,
)?;
if let Some(ref cross_attn) = self.cross_attention {
if let Some(encoder_hidden) = encoder_hidden_states {
hidden_states = cross_attn.forward(
hidden_states,
Some(encoder_hidden),
cross_attention_mask,
)?;
}
}
self.feed_forward.forward(hidden_states)
}
fn parameter_count(&self) -> usize {
let mut total = self.self_attention.parameter_count() + self.feed_forward.parameter_count();
if let Some(ref cross_attn) = self.cross_attention {
total += cross_attn.parameter_count();
}
total
}
}
#[derive(Clone)]
struct T5Attention {
is_cross_attention: bool,
layer_norm: T5LayerNorm,
q: Linear,
k: Linear,
v: Linear,
o: Linear,
n_heads: usize,
d_kv: usize,
#[allow(dead_code)]
dropout: f32,
has_relative_attention_bias: bool,
relative_attention_num_buckets: usize,
relative_attention_max_distance: usize,
relative_attention_bias: Option<Embedding>, device: Device,
}
impl T5Attention {
fn new(config: &T5Config, is_cross_attention: bool) -> Result<Self> {
Self::new_with_device(config, is_cross_attention, Device::CPU)
}
fn new_with_device(
config: &T5Config,
is_cross_attention: bool,
device: Device,
) -> Result<Self> {
let has_relative_bias = !is_cross_attention;
let relative_attention_bias = if has_relative_bias {
Some(Embedding::new_with_device(
config.relative_attention_num_buckets,
config.num_heads,
None,
device,
)?)
} else {
None
};
Ok(Self {
is_cross_attention,
layer_norm: T5LayerNorm::new_with_device(
config.d_model,
config.layer_norm_epsilon,
device,
),
q: Linear::new_with_device(
config.d_model,
config.num_heads * config.d_kv,
false,
device,
),
k: Linear::new_with_device(
config.d_model,
config.num_heads * config.d_kv,
false,
device,
),
v: Linear::new_with_device(
config.d_model,
config.num_heads * config.d_kv,
false,
device,
),
o: Linear::new_with_device(
config.num_heads * config.d_kv,
config.d_model,
false,
device,
),
n_heads: config.num_heads,
d_kv: config.d_kv,
dropout: config.dropout_rate,
has_relative_attention_bias: has_relative_bias,
relative_attention_num_buckets: config.relative_attention_num_buckets,
relative_attention_max_distance: config.relative_attention_max_distance,
relative_attention_bias,
device,
})
}
fn device(&self) -> &Device {
&self.device
}
fn load_weights(&mut self, reader: &mut dyn WeightReader, prefix: &str) -> Result<()> {
self.layer_norm.load_weights(reader, &format!("{}.layer_norm", prefix))?;
let attn_name = if self.is_cross_attention { "EncDecAttention" } else { "SelfAttention" };
let attn_prefix = format!("{}.{}", prefix, attn_name);
self.q.set_weight(reader.read_tensor(&format!("{}.q.weight", attn_prefix))?)?;
self.k.set_weight(reader.read_tensor(&format!("{}.k.weight", attn_prefix))?)?;
self.v.set_weight(reader.read_tensor(&format!("{}.v.weight", attn_prefix))?)?;
self.o.set_weight(reader.read_tensor(&format!("{}.o.weight", attn_prefix))?)?;
if let Some(ref mut bias) = self.relative_attention_bias {
if let Ok(bias_weight) =
reader.read_tensor(&format!("{}.relative_attention_bias.weight", attn_prefix))
{
bias.set_weight(bias_weight)?;
}
}
Ok(())
}
fn forward(
&self,
hidden_states: Tensor,
key_value_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let normed_hidden = self.layer_norm.forward(hidden_states.clone())?;
let shape = normed_hidden.shape();
let batch_size = if shape.len() == 3 { shape[0] } else { 1 };
let _seq_len = if shape.len() == 3 { shape[1] } else { shape[0] };
let normed_hidden = match &normed_hidden {
Tensor::F32(arr) => {
if arr.ndim() == 2 {
Tensor::F32(arr.clone().insert_axis(Axis(0)).to_owned())
} else {
normed_hidden
}
},
_ => normed_hidden,
};
let query = self.q.forward(normed_hidden.clone())?;
let key = if let Some(kv_states) = key_value_states {
self.k.forward(kv_states.clone())?
} else {
self.k.forward(normed_hidden.clone())?
};
let value = if let Some(kv_states) = key_value_states {
self.v.forward(kv_states.clone())?
} else {
self.v.forward(normed_hidden)?
};
let attention_output = match (&query, &key, &value) {
(Tensor::F32(q_arr), Tensor::F32(k_arr), Tensor::F32(v_arr)) => {
let q_shape = q_arr.shape();
let k_shape = k_arr.shape();
let v_shape = v_arr.shape();
let q_seq_len = if q_shape.len() == 3 { q_shape[1] } else { q_shape[0] };
let k_seq_len = if k_shape.len() == 3 { k_shape[1] } else { k_shape[0] };
let v_seq_len = if v_shape.len() == 3 { v_shape[1] } else { v_shape[0] };
let q_arr = if q_arr.ndim() == 2 {
q_arr.clone().insert_axis(Axis(0)).to_owned()
} else {
q_arr.clone()
};
let k_arr = if k_arr.ndim() == 2 {
k_arr.clone().insert_axis(Axis(0)).to_owned()
} else {
k_arr.clone()
};
let v_arr = if v_arr.ndim() == 2 {
v_arr.clone().insert_axis(Axis(0)).to_owned()
} else {
v_arr.clone()
};
let q = q_arr
.to_shape(IxDyn(&[batch_size, q_seq_len, self.n_heads, self.d_kv]))
.map_err(|e| {
TrustformersError::shape_error(format!(
"Failed to reshape Q from {:?} to [{}, {}, {}, {}]: {}",
q_arr.shape(),
batch_size,
q_seq_len,
self.n_heads,
self.d_kv,
e
))
})?
.to_owned();
let k = k_arr
.to_shape(IxDyn(&[batch_size, k_seq_len, self.n_heads, self.d_kv]))
.map_err(|e| {
TrustformersError::shape_error(format!(
"Failed to reshape K from {:?} to [{}, {}, {}, {}]: {}",
k_arr.shape(),
batch_size,
k_seq_len,
self.n_heads,
self.d_kv,
e
))
})?
.to_owned();
let v = v_arr
.to_shape(IxDyn(&[batch_size, v_seq_len, self.n_heads, self.d_kv]))
.map_err(|e| {
TrustformersError::shape_error(format!(
"Failed to reshape V from {:?} to [{}, {}, {}, {}]: {}",
v_arr.shape(),
batch_size,
v_seq_len,
self.n_heads,
self.d_kv,
e
))
})?
.to_owned();
let q = q.permuted_axes(vec![0, 2, 1, 3]);
let k = k.permuted_axes(vec![0, 2, 1, 3]);
let v = v.permuted_axes(vec![0, 2, 1, 3]);
let scale = 1.0 / (self.d_kv as f32).sqrt();
let key_seq_len = k.shape()[2];
let k_t = k.permuted_axes(vec![0, 1, 3, 2]);
let mut scores = ArrayD::<f32>::zeros(IxDyn(&[
batch_size,
self.n_heads,
q_seq_len,
key_seq_len,
]));
for b in 0..batch_size {
for h in 0..self.n_heads {
let q_head = q.slice(s![b, h, .., ..]);
let k_head_t = k_t.slice(s![b, h, .., ..]);
let score = q_head.dot(&k_head_t);
scores.slice_mut(s![b, h, .., ..]).assign(&score);
}
}
scores *= scale;
if self.has_relative_attention_bias {
let position_bias =
self.compute_relative_position_bias(q_seq_len, key_seq_len)?;
match position_bias {
Tensor::F32(bias_arr) => {
for b in 0..batch_size {
let mut slice = scores.slice_mut(s![b, .., .., ..]);
slice += &bias_arr;
}
},
_ => {
return Err(TrustformersError::tensor_op_error(
"Position bias must be F32",
"tensor_operation",
))
},
}
}
if let Some(mask) = attention_mask {
match mask {
Tensor::F32(mask_arr) => {
scores += mask_arr;
},
_ => {
return Err(TrustformersError::tensor_op_error(
"Attention mask must be F32",
"tensor_operation",
))
},
}
}
let mut attention_probs = scores.clone();
for b in 0..batch_size {
for h in 0..self.n_heads {
for i in 0..q_seq_len {
let mut row = attention_probs.slice_mut(s![b, h, i, ..]);
let max_val = row.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
row.mapv_inplace(|x| (x - max_val).exp());
let sum: f32 = row.iter().sum();
if sum > 0.0 {
row.mapv_inplace(|x| x / sum);
}
}
}
}
let mut output =
ArrayD::<f32>::zeros(IxDyn(&[batch_size, self.n_heads, q_seq_len, self.d_kv]));
for b in 0..batch_size {
for h in 0..self.n_heads {
let attn_probs_head = attention_probs.slice(s![b, h, .., ..]);
let v_head = v.slice(s![b, h, .., ..]);
let out = attn_probs_head.dot(&v_head);
output.slice_mut(s![b, h, .., ..]).assign(&out);
}
}
let output = output.permuted_axes(vec![0, 2, 1, 3]);
let output = output
.to_shape(IxDyn(&[batch_size, q_seq_len, self.n_heads * self.d_kv]))
.map_err(|_| {
TrustformersError::shape_error("Failed to reshape output".to_string())
})?
.to_owned();
Tensor::F32(output)
},
_ => {
return Err(TrustformersError::tensor_op_error(
"Unsupported tensor type",
"tensor_operation",
))
},
};
let output = self.o.forward(attention_output)?;
let output = if shape.len() == 2 {
match output {
Tensor::F32(arr) => Tensor::F32(arr.remove_axis(Axis(0))),
_ => output,
}
} else {
output
};
hidden_states.add(&output)
}
fn compute_relative_position_bias(&self, query_len: usize, key_len: usize) -> Result<Tensor> {
let mut relative_positions = Array2::<i32>::zeros((query_len, key_len));
for i in 0..query_len {
for j in 0..key_len {
relative_positions[[i, j]] = j as i32 - i as i32;
}
}
let buckets = self.relative_position_bucket(relative_positions);
if let Some(ref bias_embedding) = self.relative_attention_bias {
let bucket_indices: Vec<u32> = buckets.iter().cloned().map(|b| b as u32).collect();
let bias = bias_embedding.forward(bucket_indices)?;
match bias {
Tensor::F32(bias_arr) => {
let reshaped = bias_arr
.to_shape(IxDyn(&[query_len, key_len, self.n_heads]))
.map_err(|_| {
TrustformersError::shape_error("Failed to reshape bias".to_string())
})?
.to_owned();
let transposed = reshaped.permuted_axes(vec![2, 0, 1]);
Ok(Tensor::F32(transposed))
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor type",
"tensor_operation",
)),
}
} else {
let zeros = ArrayD::<f32>::zeros(IxDyn(&[self.n_heads, query_len, key_len]));
Ok(Tensor::F32(zeros))
}
}
fn relative_position_bucket(&self, relative_positions: Array2<i32>) -> Array2<i32> {
let num_buckets = self.relative_attention_num_buckets;
let max_distance = self.relative_attention_max_distance;
let mut buckets = relative_positions.mapv(|_x| 0);
let boundary = num_buckets as i32 / 2;
for ((i, j), val) in relative_positions.indexed_iter() {
let mut bucket = if *val > 0 { boundary } else { 0 };
let abs_val = val.abs();
if abs_val < boundary {
bucket += abs_val;
} else {
let max_exact = boundary;
let log_val = ((abs_val as f32 / max_exact as f32).ln()
/ (max_distance as f32 / max_exact as f32).ln()
* (boundary - max_exact) as f32) as i32;
bucket += max_exact + log_val.min(boundary - max_exact - 1);
}
if *val > 0 {
buckets[[i, j]] = bucket;
} else {
buckets[[i, j]] = bucket;
}
}
buckets
}
fn parameter_count(&self) -> usize {
let mut total = self.layer_norm.parameter_count()
+ self.q.parameter_count()
+ self.k.parameter_count()
+ self.v.parameter_count()
+ self.o.parameter_count();
if let Some(ref bias) = self.relative_attention_bias {
total += bias.parameter_count();
}
total
}
}
#[derive(Clone)]
struct T5DenseReluDense {
layer_norm: T5LayerNorm,
wi: Linear,
wo: Linear,
#[allow(dead_code)]
dropout: f32,
device: Device,
}
impl T5DenseReluDense {
fn new(config: &T5Config) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
fn new_with_device(config: &T5Config, device: Device) -> Result<Self> {
Ok(Self {
layer_norm: T5LayerNorm::new_with_device(
config.d_model,
config.layer_norm_epsilon,
device,
),
wi: Linear::new_with_device(config.d_model, config.d_ff, false, device),
wo: Linear::new_with_device(config.d_ff, config.d_model, false, device),
dropout: config.dropout_rate,
device,
})
}
fn device(&self) -> &Device {
&self.device
}
fn load_weights(&mut self, reader: &mut dyn WeightReader, prefix: &str) -> Result<()> {
self.layer_norm.load_weights(reader, &format!("{}.layer_norm", prefix))?;
let dense_prefix = format!("{}.DenseReluDense", prefix);
self.wi
.set_weight(reader.read_tensor(&format!("{}.wi.weight", dense_prefix))?)?;
self.wo
.set_weight(reader.read_tensor(&format!("{}.wo.weight", dense_prefix))?)?;
Ok(())
}
fn forward(&self, hidden_states: Tensor) -> Result<Tensor> {
let normed = self.layer_norm.forward(hidden_states.clone())?;
let ff_output = self.wi.forward(normed)?;
let ff_output = relu(ff_output)?;
let ff_output = self.wo.forward(ff_output)?;
hidden_states.add(&ff_output)
}
fn parameter_count(&self) -> usize {
self.layer_norm.parameter_count() + self.wi.parameter_count() + self.wo.parameter_count()
}
}
#[derive(Clone)]
struct T5LayerNorm {
weight: Tensor,
epsilon: f32,
device: Device,
}
impl T5LayerNorm {
fn new(hidden_size: usize, epsilon: f32) -> Self {
Self::new_with_device(hidden_size, epsilon, Device::CPU)
}
fn new_with_device(hidden_size: usize, epsilon: f32, device: Device) -> Self {
Self {
weight: Tensor::ones(&[hidden_size]).expect("operation failed"),
epsilon,
device,
}
}
fn device(&self) -> &Device {
&self.device
}
fn load_weights(&mut self, reader: &mut dyn WeightReader, prefix: &str) -> Result<()> {
self.weight = reader.read_tensor(&format!("{}.weight", prefix))?;
Ok(())
}
fn forward(&self, hidden_states: Tensor) -> Result<Tensor> {
match (&hidden_states, &self.weight) {
(Tensor::F32(x), Tensor::F32(w)) => {
let variance = x.mapv(|v| v * v).mean().expect("operation failed") + self.epsilon;
let x = x / variance.sqrt();
let result = &x * w;
Ok(Tensor::F32(result))
},
_ => Err(TrustformersError::tensor_op_error(
"T5LayerNorm only supports F32 tensors",
"tensor_operation",
)),
}
}
fn parameter_count(&self) -> usize {
self.weight.len()
}
}
fn relu(x: Tensor) -> Result<Tensor> {
match &x {
Tensor::F32(arr) => {
let result = arr.mapv(|val| val.max(0.0));
Ok(Tensor::F32(result))
},
_ => Err(TrustformersError::tensor_op_error(
"ReLU only supports F32 tensors",
"tensor_operation",
)),
}
}
fn create_attention_mask(hidden_states: &Tensor) -> Result<Tensor> {
let shape = hidden_states.shape();
let seq_len = shape[shape.len() - 2];
let mask = ArrayD::<f32>::ones(IxDyn(&[1, 1, seq_len, seq_len]));
Ok(Tensor::F32(mask))
}
pub struct T5Input {
pub input_ids: TokenizedInput,
pub decoder_input_ids: Option<TokenizedInput>,
pub encoder_outputs: Option<Tensor>,
}
pub struct T5Output {
pub last_hidden_state: Tensor,
pub encoder_last_hidden_state: Option<Tensor>,
pub past_key_values: Option<Vec<Tensor>>,
}
pub struct T5LMOutput {
pub logits: Tensor,
pub past_key_values: Option<Vec<Tensor>>,
pub encoder_last_hidden_state: Option<Tensor>,
}
#[cfg(test)]
mod tests {
use super::*;
use trustformers_core::traits::Config;
fn small_t5_config() -> T5Config {
T5Config {
vocab_size: 100,
d_model: 32,
d_kv: 8,
d_ff: 64,
num_layers: 2,
num_decoder_layers: None,
num_heads: 4,
relative_attention_num_buckets: 16,
relative_attention_max_distance: 64,
dropout_rate: 0.0,
layer_norm_epsilon: 1e-6,
initializer_factor: 1.0,
feed_forward_proj: "relu".to_string(),
is_encoder_decoder: true,
use_cache: false,
pad_token_id: 0,
eos_token_id: 1,
model_type: "t5".to_string(),
}
}
#[test]
fn test_t5_config_default() {
let config = T5Config::default();
assert_eq!(config.vocab_size, 32128);
assert_eq!(config.d_model, 512);
assert_eq!(config.num_layers, 6);
assert_eq!(config.num_heads, 8);
assert!(config.is_encoder_decoder);
}
#[test]
fn test_t5_config_validate() {
let config = small_t5_config();
assert!(config.validate().is_ok());
}
#[test]
fn test_t5_model_creation() {
let config = small_t5_config();
let result = T5Model::new(config);
assert!(result.is_ok());
}
#[test]
fn test_t5_model_with_device() {
let config = small_t5_config();
let result = T5Model::new_with_device(config, Device::CPU);
assert!(result.is_ok());
let model = result.expect("model creation should succeed");
assert!(matches!(model.device(), &Device::CPU));
}
#[test]
fn test_t5_model_config() {
let config = small_t5_config();
let model = T5Model::new(config.clone()).expect("model creation should succeed");
let mc = model.get_config();
assert_eq!(mc.vocab_size, config.vocab_size);
assert_eq!(mc.d_model, config.d_model);
}
#[test]
fn test_t5_model_num_parameters() {
let config = small_t5_config();
let model = T5Model::new(config).expect("model creation should succeed");
assert!(model.num_parameters() > 0);
}
#[test]
fn test_t5_model_forward_encoder_only() {
let config = small_t5_config();
let model = T5Model::new(config).expect("model creation should succeed");
let input = T5Input {
input_ids: TokenizedInput::new(vec![1, 2, 3], vec![1, 1, 1]),
decoder_input_ids: None,
encoder_outputs: None,
};
let result = model.forward(input);
assert!(result.is_ok());
let output = result.expect("forward should succeed");
assert!(output.encoder_last_hidden_state.is_some());
}
#[test]
fn test_t5_model_forward_with_decoder() {
let config = small_t5_config();
let model = T5Model::new(config).expect("model creation should succeed");
let input = T5Input {
input_ids: TokenizedInput::new(vec![1, 2, 3], vec![1, 1, 1]),
decoder_input_ids: Some(TokenizedInput::new(vec![1, 5], vec![1, 1])),
encoder_outputs: None,
};
let result = model.forward(input);
assert!(result.is_ok());
}
#[test]
fn test_t5_conditional_generation_creation() {
let config = small_t5_config();
let result = T5ForConditionalGeneration::new(config);
assert!(result.is_ok());
}
#[test]
fn test_t5_conditional_generation_num_parameters() {
let config = small_t5_config();
let model = T5ForConditionalGeneration::new(config).expect("model creation should succeed");
assert!(model.num_parameters() > 0);
}
#[test]
fn test_t5_config_custom_decoder_layers() {
let mut config = small_t5_config();
config.num_decoder_layers = Some(1);
let result = T5Model::new(config);
assert!(result.is_ok());
}
#[test]
fn test_relu_function() {
let arr = ArrayD::from_shape_vec(IxDyn(&[4]), vec![-2.0f32, -1.0, 0.0, 1.0])
.expect("array creation should succeed");
let tensor = Tensor::F32(arr);
let result = relu(tensor);
assert!(result.is_ok());
let out = result.expect("relu should succeed");
let v0 = out.get_scalar(&[0]).expect("get scalar should succeed");
let v1 = out.get_scalar(&[1]).expect("get scalar should succeed");
let v2 = out.get_scalar(&[2]).expect("get scalar should succeed");
let v3 = out.get_scalar(&[3]).expect("get scalar should succeed");
assert!((v0 - 0.0).abs() < f32::EPSILON);
assert!((v1 - 0.0).abs() < f32::EPSILON);
assert!((v2 - 0.0).abs() < f32::EPSILON);
assert!((v3 - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_create_attention_mask() {
let arr = ArrayD::from_shape_vec(IxDyn(&[1, 4, 32]), vec![0.0f32; 128])
.expect("array creation should succeed");
let hidden = Tensor::F32(arr);
let result = create_attention_mask(&hidden);
assert!(result.is_ok());
let mask = result.expect("mask should succeed");
assert_eq!(mask.shape().len(), 4);
}
#[test]
fn test_t5_input_output_types() {
let input = T5Input {
input_ids: TokenizedInput::new(vec![1], vec![1]),
decoder_input_ids: None,
encoder_outputs: None,
};
assert!(input.decoder_input_ids.is_none());
assert!(input.encoder_outputs.is_none());
}
#[test]
fn test_t5_model_load_pretrained_error() {
let config = small_t5_config();
let mut model = T5Model::new(config).expect("model creation should succeed");
let mut reader: &[u8] = &[];
let result = model.load_pretrained(&mut reader);
assert!(result.is_err());
}
#[test]
fn test_t5_conditional_generation_with_device() {
let config = small_t5_config();
let result = T5ForConditionalGeneration::new_with_device(config, Device::CPU);
assert!(result.is_ok());
}
#[test]
fn test_t5_model_clone() {
let config = small_t5_config();
let model = T5Model::new(config).expect("model creation should succeed");
let cloned = model.clone();
assert_eq!(cloned.num_parameters(), model.num_parameters());
}
#[test]
fn test_t5_config_feed_forward_proj() {
let mut config = small_t5_config();
config.feed_forward_proj = "gated-gelu".to_string();
let result = T5Model::new(config);
assert!(result.is_ok());
}
#[test]
fn test_t5_config_model_type() {
let config = small_t5_config();
assert_eq!(config.model_type, "t5");
}
#[test]
fn test_t5_output_types() {
let output = T5Output {
last_hidden_state: Tensor::zeros(&[1, 3, 32]).expect("tensor creation should succeed"),
encoder_last_hidden_state: None,
past_key_values: None,
};
assert!(output.encoder_last_hidden_state.is_none());
assert!(output.past_key_values.is_none());
}
#[test]
fn test_t5_lm_output_types() {
let output = T5LMOutput {
logits: Tensor::zeros(&[1, 3, 100]).expect("tensor creation should succeed"),
past_key_values: None,
encoder_last_hidden_state: None,
};
assert!(output.past_key_values.is_none());
}
#[test]
fn test_t5_conditional_generation_config() {
let config = small_t5_config();
let model =
T5ForConditionalGeneration::new(config.clone()).expect("model creation should succeed");
let mc = model.get_config();
assert_eq!(mc.vocab_size, config.vocab_size);
}
#[test]
fn test_t5_config_use_cache() {
let mut config = small_t5_config();
config.use_cache = true;
let result = T5Model::new(config);
assert!(result.is_ok());
}
#[test]
fn test_t5_conditional_generation_params_gt_base() {
let config = small_t5_config();
let base = T5Model::new(config.clone()).expect("base model creation should succeed");
let cond =
T5ForConditionalGeneration::new(config).expect("cond gen creation should succeed");
assert!(cond.num_parameters() >= base.num_parameters());
}
}