use axonml_autograd::Variable;
use axonml_nn::{Dropout, Embedding, Module, Parameter};
use axonml_tensor::Tensor;
use axonml_tensor::creation::{ones, zeros};
#[derive(Debug)]
pub struct TokenEmbedding {
pub embedding: Embedding,
}
impl TokenEmbedding {
pub fn new(vocab_size: usize, embed_dim: usize) -> Self {
Self {
embedding: Embedding::new(vocab_size, embed_dim),
}
}
pub fn forward_ids(&self, input_ids: &Tensor<u32>) -> Variable {
let ids_f32: Vec<f32> = input_ids.to_vec().iter().map(|&x| x as f32).collect();
let ids_var = Variable::new(Tensor::from_vec(ids_f32, input_ids.shape()).unwrap(), false);
self.embedding.forward(&ids_var)
}
}
impl Module for TokenEmbedding {
fn forward(&self, input: &Variable) -> Variable {
self.embedding.forward(input)
}
fn parameters(&self) -> Vec<Parameter> {
self.embedding.parameters()
}
}
#[derive(Debug)]
pub struct PositionalEmbedding {
pub embedding: Embedding,
pub max_len: usize,
}
impl PositionalEmbedding {
pub fn new(max_len: usize, embed_dim: usize) -> Self {
Self {
embedding: Embedding::new(max_len, embed_dim),
max_len,
}
}
pub fn forward_positions(&self, seq_len: usize, batch_size: usize) -> Variable {
let embed_dim = self.embedding.embedding_dim();
let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
let position_tensor = Tensor::from_vec(positions.clone(), &[1, seq_len]).unwrap();
let position_var = Variable::new(position_tensor, false);
let pos_embeds = self.embedding.forward(&position_var);
if batch_size > 1 {
pos_embeds.expand(&[batch_size, seq_len, embed_dim])
} else {
pos_embeds
}
}
}
impl Module for PositionalEmbedding {
fn forward(&self, input: &Variable) -> Variable {
self.embedding.forward(input)
}
fn parameters(&self) -> Vec<Parameter> {
self.embedding.parameters()
}
}
#[derive(Debug)]
pub struct SinusoidalPositionalEncoding {
pub encodings: Tensor<f32>,
pub max_len: usize,
pub embed_dim: usize,
}
impl SinusoidalPositionalEncoding {
pub fn new(max_len: usize, embed_dim: usize) -> Self {
let mut encodings = vec![0.0f32; max_len * embed_dim];
for pos in 0..max_len {
for i in 0..embed_dim / 2 {
let div_term = (10000.0f32).powf(2.0 * i as f32 / embed_dim as f32);
let angle = pos as f32 / div_term;
encodings[pos * embed_dim + 2 * i] = angle.sin();
encodings[pos * embed_dim + 2 * i + 1] = angle.cos();
}
}
Self {
encodings: Tensor::from_vec(encodings, &[max_len, embed_dim]).unwrap(),
max_len,
embed_dim,
}
}
pub fn forward_seq(&self, seq_len: usize) -> Variable {
if seq_len >= self.max_len {
Variable::new(self.encodings.clone(), false)
} else {
let sliced = self.encodings.narrow(0, 0, seq_len).unwrap();
Variable::new(sliced, false)
}
}
}
#[derive(Debug)]
pub struct BertEmbedding {
pub word_embeddings: Embedding,
pub position_embeddings: Embedding,
pub token_type_embeddings: Embedding,
pub layer_norm: LayerNorm,
pub dropout: Dropout,
pub embed_dim: usize,
}
#[derive(Debug)]
pub struct LayerNorm {
weight: Parameter,
bias: Parameter,
eps: f32,
}
impl LayerNorm {
fn new(dim: usize, eps: f32) -> Self {
let weight = Parameter::new(ones::<f32>(&[dim]), true);
let bias = Parameter::new(zeros::<f32>(&[dim]), true);
Self { weight, bias, eps }
}
fn forward(&self, x: &Variable) -> Variable {
let mean = x.mean_dim(-1, true);
let variance = x.var_dim(-1, true);
let x_normalized = x.sub(&mean).div(&variance.add_scalar(self.eps).sqrt());
let weight_var = Variable::from_tensor_with_grad(
self.weight.data().clone(),
self.weight.requires_grad(),
);
let bias_var =
Variable::from_tensor_with_grad(self.bias.data().clone(), self.bias.requires_grad());
x_normalized.mul(&weight_var).add(&bias_var)
}
fn parameters(&self) -> Vec<Parameter> {
vec![self.weight.clone(), self.bias.clone()]
}
}
impl BertEmbedding {
pub fn new(
vocab_size: usize,
max_position_embeddings: usize,
type_vocab_size: usize,
hidden_size: usize,
layer_norm_eps: f32,
dropout_prob: f32,
) -> Self {
Self {
word_embeddings: Embedding::new(vocab_size, hidden_size),
position_embeddings: Embedding::new(max_position_embeddings, hidden_size),
token_type_embeddings: Embedding::new(type_vocab_size, hidden_size),
layer_norm: LayerNorm::new(hidden_size, layer_norm_eps),
dropout: Dropout::new(dropout_prob),
embed_dim: hidden_size,
}
}
pub fn forward_with_ids(
&self,
input_ids: &Tensor<u32>,
token_type_ids: Option<&Tensor<u32>>,
position_ids: Option<&Tensor<u32>>,
) -> Variable {
let batch_size = input_ids.shape()[0];
let seq_len = input_ids.shape()[1];
let input_ids_f32 = Self::u32_to_f32_tensor(input_ids);
let word_embeds = self
.word_embeddings
.forward(&Variable::new(input_ids_f32, false));
let pos_ids = if let Some(ids) = position_ids {
Self::u32_to_f32_tensor(ids)
} else {
let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
let pos_data: Vec<f32> = (0..batch_size)
.flat_map(|_| positions.iter().cloned())
.collect();
Tensor::from_vec(pos_data, &[batch_size, seq_len]).unwrap()
};
let position_embeds = self
.position_embeddings
.forward(&Variable::new(pos_ids, false));
let type_ids = if let Some(ids) = token_type_ids {
Self::u32_to_f32_tensor(ids)
} else {
zeros::<f32>(&[batch_size, seq_len])
};
let token_type_embeds = self
.token_type_embeddings
.forward(&Variable::new(type_ids, false));
let embeddings = word_embeds.add(&position_embeds).add(&token_type_embeds);
let embeddings = self.layer_norm.forward(&embeddings);
self.dropout.forward(&embeddings)
}
fn u32_to_f32_tensor(t: &Tensor<u32>) -> Tensor<f32> {
let data: Vec<f32> = t.to_vec().iter().map(|&x| x as f32).collect();
Tensor::from_vec(data, t.shape()).unwrap()
}
}
impl Module for BertEmbedding {
fn forward(&self, input: &Variable) -> Variable {
let input_data = input.data();
let shape = input_data.shape();
let batch_size = shape[0];
let seq_len = shape[1];
let word_embeds = self.word_embeddings.forward(input);
let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
let pos_data: Vec<f32> = (0..batch_size)
.flat_map(|_| positions.iter().cloned())
.collect();
let pos_tensor = Tensor::from_vec(pos_data, &[batch_size, seq_len]).unwrap();
let position_embeds = self
.position_embeddings
.forward(&Variable::new(pos_tensor, false));
let type_tensor = zeros::<f32>(&[batch_size, seq_len]);
let token_type_embeds = self
.token_type_embeddings
.forward(&Variable::new(type_tensor, false));
let embeddings = word_embeds.add(&position_embeds).add(&token_type_embeds);
let embeddings = self.layer_norm.forward(&embeddings);
self.dropout.forward(&embeddings)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.word_embeddings.parameters());
params.extend(self.position_embeddings.parameters());
params.extend(self.token_type_embeddings.parameters());
params.extend(self.layer_norm.parameters());
params
}
fn train(&mut self) {
self.dropout.train();
}
fn eval(&mut self) {
self.dropout.eval();
}
}
#[derive(Debug)]
pub struct GPT2Embedding {
pub wte: Embedding,
pub wpe: Embedding,
pub dropout: Dropout,
pub n_embd: usize,
}
impl GPT2Embedding {
pub fn new(vocab_size: usize, n_ctx: usize, n_embd: usize, dropout: f32) -> Self {
Self {
wte: Embedding::new(vocab_size, n_embd),
wpe: Embedding::new(n_ctx, n_embd),
dropout: Dropout::new(dropout),
n_embd,
}
}
pub fn forward_ids(&self, input_ids: &Tensor<u32>) -> Variable {
let batch_size = input_ids.shape()[0];
let seq_len = input_ids.shape()[1];
let input_ids_f32 = Self::u32_to_f32_tensor(input_ids);
let token_embeds = self.wte.forward(&Variable::new(input_ids_f32, false));
let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
let pos_data: Vec<f32> = (0..batch_size)
.flat_map(|_| positions.iter().cloned())
.collect();
let pos_tensor = Tensor::from_vec(pos_data, &[batch_size, seq_len]).unwrap();
let position_embeds = self.wpe.forward(&Variable::new(pos_tensor, false));
let embeddings = token_embeds.add(&position_embeds);
self.dropout.forward(&embeddings)
}
fn u32_to_f32_tensor(t: &Tensor<u32>) -> Tensor<f32> {
let data: Vec<f32> = t.to_vec().iter().map(|&x| x as f32).collect();
Tensor::from_vec(data, t.shape()).unwrap()
}
}
impl Module for GPT2Embedding {
fn forward(&self, input: &Variable) -> Variable {
let input_data = input.data();
let shape = input_data.shape();
let batch_size = shape[0];
let seq_len = shape[1];
let token_embeds = self.wte.forward(input);
let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
let pos_data: Vec<f32> = (0..batch_size)
.flat_map(|_| positions.iter().cloned())
.collect();
let pos_tensor = Tensor::from_vec(pos_data, &[batch_size, seq_len]).unwrap();
let position_embeds = self.wpe.forward(&Variable::new(pos_tensor, false));
let embeddings = token_embeds.add(&position_embeds);
self.dropout.forward(&embeddings)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.wte.parameters());
params.extend(self.wpe.parameters());
params
}
fn train(&mut self) {
self.dropout.train();
}
fn eval(&mut self) {
self.dropout.eval();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_embedding() {
let embed = TokenEmbedding::new(1000, 64);
let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4], &[2, 2]).unwrap();
let output = embed.forward_ids(&input_ids);
assert_eq!(output.data().shape(), &[2, 2, 64]);
}
#[test]
fn test_positional_embedding() {
let embed = PositionalEmbedding::new(128, 64);
let output = embed.forward_positions(16, 2);
assert_eq!(output.data().shape(), &[2, 16, 64]);
}
#[test]
fn test_sinusoidal_encoding() {
let encoding = SinusoidalPositionalEncoding::new(100, 64);
let output = encoding.forward_seq(16);
assert_eq!(output.data().shape(), &[16, 64]);
}
#[test]
fn test_gpt2_embedding() {
let embed = GPT2Embedding::new(1000, 128, 64, 0.0);
let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4], &[2, 2]).unwrap();
let output = embed.forward_ids(&input_ids);
assert_eq!(output.data().shape(), &[2, 2, 64]);
}
}