use std::any::Any;
use crate::backend::hip_embedding::run_rocm_hip_embedding_fwd;
use crate::backend::hip_gelu::run_rocm_hip_gelu_fwd;
use crate::backend::hip_gelu_bw::run_rocm_hip_gelu_bwd;
use crate::backend::hip_gemm_bw::run_rocm_hip_gemm_bw_grad_b;
use crate::backend::hip_gemm_f16::run_rocm_hip_gemm_f16;
use crate::backend::hip_layernorm::{run_rocm_hip_layernorm_bwd, run_rocm_hip_layernorm_fwd};
use crate::backend::hip_softmax::{run_rocm_hip_grad_loss_wrt_logits, run_rocm_hip_softmax_fwd};
use crate::domain::DomainId;
use crate::object::{Shape, Tensor};
use crate::{Error, Result};
use super::parameter::Parameter;
use super::util::{f16_to_f32, f32_to_f16, fp16_bits_to_tensor, tensor_full, tensor_to_fp16_bits};
pub const MODEL_DOMAIN: &str = "f32_model";
fn domain() -> DomainId {
DomainId::new(MODEL_DOMAIN)
}
#[allow(dead_code)]
pub(crate) fn numel(shape: &Shape) -> usize {
let mut n = 1usize;
for d in &shape.dims {
match d {
crate::object::Dim::Static(v) => n *= v,
_ => return 0,
}
}
n
}
pub trait Layer: Send {
fn name(&self) -> &'static str;
fn parameters(&self) -> Vec<&Parameter>;
fn parameters_mut(&mut self) -> Vec<&mut Parameter>;
fn forward(&self, input: &Tensor<f32>) -> Result<(Tensor<f32>, Box<dyn Any + Send>)>;
fn backward(
&self,
grad_output: &Tensor<f32>,
cache: &dyn Any,
) -> Result<(Tensor<f32>, Vec<Tensor<f32>>)>;
}
pub struct Linear {
pub weight: Parameter, pub bias: Parameter, pub n_padded: usize, pub weight_padded: std::cell::RefCell<Vec<u16>>, pub weight_t_padded: std::cell::RefCell<Vec<u16>>, pub bias_padded: std::cell::RefCell<Vec<u16>>, }
impl Linear {
pub fn new(in_features: usize, out_features: usize, seed: u32) -> Self {
let bound = (1.0 / (in_features as f32).sqrt()) * 0.5;
let weight = Parameter::uniform(
Shape::from(vec![out_features, in_features]),
-bound,
bound,
seed,
domain(),
);
let bias = Parameter::zeros(Shape::from(vec![out_features]), domain());
let n_padded = out_features.next_multiple_of(16);
let (weight_padded, weight_t_padded, bias_padded) = build_padded_caches(
&weight.data.data,
&bias.data.data,
out_features,
in_features,
n_padded,
);
Self {
weight,
bias,
n_padded,
weight_padded: std::cell::RefCell::new(weight_padded),
weight_t_padded: std::cell::RefCell::new(weight_t_padded),
bias_padded: std::cell::RefCell::new(bias_padded),
}
}
fn refresh_padded_caches(&self) {
let n = match &self.weight.data.meta.shape.dims[0] {
crate::object::Dim::Static(v) => *v,
_ => 0,
};
let k = match &self.weight.data.meta.shape.dims[1] {
crate::object::Dim::Static(v) => *v,
_ => 0,
};
let (wp, wtp, bp) = build_padded_caches(
&self.weight.data.data,
&self.bias.data.data,
n,
k,
self.n_padded,
);
*self.weight_padded.borrow_mut() = wp;
*self.weight_t_padded.borrow_mut() = wtp;
*self.bias_padded.borrow_mut() = bp;
}
}
fn build_padded_caches(
weight: &[f32],
bias: &[f32],
n: usize,
k: usize,
n_padded: usize,
) -> (Vec<u16>, Vec<u16>, Vec<u16>) {
let mut weight_padded = vec![0u16; n_padded * k];
for i in 0..n {
for j in 0..k {
weight_padded[i * k + j] = f32_to_f16(weight[i * k + j]);
}
}
let mut weight_t_padded = vec![0u16; k * n_padded];
for i in 0..n {
for j in 0..k {
weight_t_padded[j * n_padded + i] = weight_padded[i * k + j];
}
}
let mut bias_padded = vec![f32_to_f16(0.0); n_padded];
for i in 0..n {
bias_padded[i] = f32_to_f16(bias[i]);
}
(weight_padded, weight_t_padded, bias_padded)
}
pub struct LinearCache {
pub input: Tensor<f32>, pub input_bits: Vec<u16>, pub m: usize, pub k: usize, pub n: usize, pub n_padded: usize, }
impl Layer for Linear {
fn name(&self) -> &'static str {
"Linear"
}
fn parameters(&self) -> Vec<&Parameter> {
vec![&self.weight, &self.bias]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter> {
vec![&mut self.weight, &mut self.bias]
}
fn forward(&self, input: &Tensor<f32>) -> Result<(Tensor<f32>, Box<dyn Any + Send>)> {
let shape = &input.meta.shape;
if shape.dims.len() != 2 {
return Err(Error::shape(format!(
"Linear expects 2D input [B, K], got rank {}",
shape.dims.len()
)));
}
let (m, k) = match (&shape.dims[0], &shape.dims[1]) {
(crate::object::Dim::Static(m), crate::object::Dim::Static(k)) => (*m, *k),
_ => return Err(Error::shape("Linear input dims must be static")),
};
let n = match &self.weight.data.meta.shape.dims[0] {
crate::object::Dim::Static(v) => *v,
_ => return Err(Error::shape("Linear weight dim 0 must be static")),
};
if k * n != self.weight.numel() {
return Err(Error::shape(format!(
"Linear weight numel {} != in_features*out_features={}*{}",
self.weight.numel(),
k,
n
)));
}
if m % 16 != 0 {
return Err(Error::shape(format!(
"Linear requires batch dim M={} to be a multiple of 16",
m
)));
}
if k % 16 != 0 {
return Err(Error::shape(format!(
"Linear requires in_features K={} to be a multiple of 16",
k
)));
}
if self.n_padded > 65535 || m > 65535 {
return Err(Error::shape(
"Linear M/N must fit in u16 for the fp16 GEMM kernel's stdin protocol",
));
}
let input_bits = tensor_to_fp16_bits(input);
self.refresh_padded_caches();
let weight_padded = self.weight_padded.borrow();
let bias_padded = self.bias_padded.borrow();
let report = run_rocm_hip_gemm_f16(&input_bits, &weight_padded, m, self.n_padded, k)?;
let mut out_data: Vec<f32> = Vec::with_capacity(m * n);
for row in 0..m {
for col in 0..n {
out_data
.push(report.outputs[row * self.n_padded + col] + f16_to_f32(bias_padded[col]));
}
}
let output = Tensor::dense_cpu(domain(), Shape::from(vec![m, n]), out_data);
let cache = LinearCache {
input: input.clone(),
input_bits,
m,
k,
n,
n_padded: self.n_padded,
};
Ok((output, Box::new(cache)))
}
fn backward(
&self,
grad_output: &Tensor<f32>,
cache: &dyn Any,
) -> Result<(Tensor<f32>, Vec<Tensor<f32>>)> {
let c = cache
.downcast_ref::<LinearCache>()
.ok_or_else(|| Error::backend("Linear backward cache downcast failed"))?;
let m = c.m;
let k = c.k;
let n = c.n;
let n_padded = c.n_padded;
let mut grad_output_padded = vec![f32_to_f16(0.0); m * n_padded];
for row in 0..m {
for col in 0..n {
grad_output_padded[row * n_padded + col] =
f32_to_f16(grad_output.data[row * n + col]);
}
}
let grad_w_report = run_rocm_hip_gemm_bw_grad_b(
&grad_output_padded,
&c.input_bits,
m, n_padded, k, )?;
let mut grad_w_data: Vec<f32> = Vec::with_capacity(n * k);
for n_idx in 0..n {
for k_idx in 0..k {
grad_w_data.push(f16_to_f32(grad_w_report.outputs[k_idx * n_padded + n_idx]));
}
}
let grad_w = Tensor::dense_cpu(domain(), Shape::from(vec![n, k]), grad_w_data);
let weight_padded = self.weight_padded.borrow();
let grad_x_report =
run_rocm_hip_gemm_f16(&grad_output_padded, &weight_padded, m, k, n_padded)?;
let grad_x_data: Vec<f32> = grad_x_report.outputs;
let grad_x = Tensor::dense_cpu(domain(), Shape::from(vec![m, k]), grad_x_data);
let mut grad_b_data = vec![0.0f32; n];
for row in 0..m {
for col in 0..n {
grad_b_data[col] += grad_output.data[row * n + col];
}
}
let grad_b = Tensor::dense_cpu(domain(), Shape::from(vec![n]), grad_b_data);
Ok((grad_x, vec![grad_w, grad_b]))
}
}
pub struct LayerNorm {
pub gamma: Parameter, pub beta: Parameter, pub gamma_cache: Vec<u16>,
pub beta_cache: Vec<u16>,
pub n_cols: usize,
pub eps: f32,
}
impl LayerNorm {
pub fn new(n_cols: usize, eps: f32) -> Self {
let gamma = Parameter::from_tensor(tensor_full(Shape::from(vec![n_cols]), 1.0, domain()));
let beta = Parameter::zeros(Shape::from(vec![n_cols]), domain());
let gamma_cache = tensor_to_fp16_bits(&gamma.data);
let beta_cache = tensor_to_fp16_bits(&beta.data);
Self {
gamma,
beta,
gamma_cache,
beta_cache,
n_cols,
eps,
}
}
}
pub struct LayerNormCache {
pub input: Tensor<f32>,
pub input_bits: Vec<u16>,
pub mean: Vec<f32>,
pub rstd: Vec<f32>,
pub output_bits: Vec<u16>,
pub n_rows: usize,
}
impl Layer for LayerNorm {
fn name(&self) -> &'static str {
"LayerNorm"
}
fn parameters(&self) -> Vec<&Parameter> {
vec![&self.gamma, &self.beta]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter> {
vec![&mut self.gamma, &mut self.beta]
}
fn forward(&self, input: &Tensor<f32>) -> Result<(Tensor<f32>, Box<dyn Any + Send>)> {
let shape = &input.meta.shape;
if shape.dims.len() != 2 {
return Err(Error::shape(format!(
"LayerNorm expects 2D input, got rank {}",
shape.dims.len()
)));
}
let n_rows = match &shape.dims[0] {
crate::object::Dim::Static(v) => *v,
_ => return Err(Error::shape("LayerNorm n_rows must be static")),
};
let input_bits = tensor_to_fp16_bits(input);
let report = run_rocm_hip_layernorm_fwd(
&input_bits,
&self.gamma_cache,
&self.beta_cache,
n_rows,
self.n_cols,
self.eps,
)?;
let output = Tensor::dense_cpu(
domain(),
Shape::from(vec![n_rows, self.n_cols]),
report.output.iter().copied().map(f16_to_f32).collect(),
);
let cache = LayerNormCache {
input: input.clone(),
input_bits,
mean: report.mean,
rstd: report.rstd,
output_bits: report.output,
n_rows,
};
Ok((output, Box::new(cache)))
}
fn backward(
&self,
grad_output: &Tensor<f32>,
cache: &dyn Any,
) -> Result<(Tensor<f32>, Vec<Tensor<f32>>)> {
let c = cache
.downcast_ref::<LayerNormCache>()
.ok_or_else(|| Error::shape("LayerNorm backward cache downcast failed"))?;
let grad_output_bits = tensor_to_fp16_bits(grad_output);
let report = run_rocm_hip_layernorm_bwd(
&grad_output_bits,
&c.input_bits,
&self.gamma_cache,
&c.mean,
&c.rstd,
c.n_rows,
self.n_cols,
)?;
let grad_input = Tensor::dense_cpu(
domain(),
Shape::from(vec![c.n_rows, self.n_cols]),
report.grad_input.iter().copied().map(f16_to_f32).collect(),
);
let grad_gamma = Tensor::dense_cpu(
domain(),
Shape::from(vec![self.n_cols]),
report.grad_gamma.iter().copied().map(f16_to_f32).collect(),
);
let grad_beta = Tensor::dense_cpu(
domain(),
Shape::from(vec![self.n_cols]),
report.grad_beta.iter().copied().map(f16_to_f32).collect(),
);
Ok((grad_input, vec![grad_gamma, grad_beta]))
}
}
pub struct GELU;
pub struct GELUCache {
pub input: Tensor<f32>,
pub input_bits: Vec<u16>,
}
impl Layer for GELU {
fn name(&self) -> &'static str {
"GELU"
}
fn parameters(&self) -> Vec<&Parameter> {
Vec::new()
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter> {
Vec::new()
}
fn forward(&self, input: &Tensor<f32>) -> Result<(Tensor<f32>, Box<dyn Any + Send>)> {
let n = input.data.len();
let input_bits = tensor_to_fp16_bits(input);
let output_bits = run_rocm_hip_gelu_fwd(&input_bits, n)?;
let output = fp16_bits_to_tensor(&output_bits, input.meta.shape.clone(), domain());
Ok((
output,
Box::new(GELUCache {
input: input.clone(),
input_bits,
}),
))
}
fn backward(
&self,
grad_output: &Tensor<f32>,
cache: &dyn Any,
) -> Result<(Tensor<f32>, Vec<Tensor<f32>>)> {
let c = cache
.downcast_ref::<GELUCache>()
.ok_or_else(|| Error::backend("GELU backward cache downcast failed"))?;
let n = grad_output.data.len();
let grad_output_bits = tensor_to_fp16_bits(grad_output);
let grad_input_bits = run_rocm_hip_gelu_bwd(&grad_output_bits, &c.input_bits, n)?;
let grad_input =
fp16_bits_to_tensor(&grad_input_bits, grad_output.meta.shape.clone(), domain());
Ok((grad_input, Vec::new()))
}
}
pub struct Softmax;
pub struct SoftmaxCache {
pub output_bits: Vec<u16>,
pub n_rows: usize,
pub n_cols: usize,
}
impl Layer for Softmax {
fn name(&self) -> &'static str {
"Softmax"
}
fn parameters(&self) -> Vec<&Parameter> {
Vec::new()
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter> {
Vec::new()
}
fn forward(&self, input: &Tensor<f32>) -> Result<(Tensor<f32>, Box<dyn Any + Send>)> {
let shape = &input.meta.shape;
if shape.dims.len() != 2 {
return Err(Error::shape(format!(
"Softmax expects 2D input, got rank {}",
shape.dims.len()
)));
}
let n_rows = match &shape.dims[0] {
crate::object::Dim::Static(v) => *v,
_ => return Err(Error::shape("Softmax n_rows must be static")),
};
let n_cols = match &shape.dims[1] {
crate::object::Dim::Static(v) => *v,
_ => return Err(Error::shape("Softmax n_cols must be static")),
};
let input_bits = tensor_to_fp16_bits(input);
let report = run_rocm_hip_softmax_fwd(&input_bits, n_rows, n_cols)?;
let output = Tensor::dense_cpu(
domain(),
Shape::from(vec![n_rows, n_cols]),
report.outputs.iter().copied().map(f16_to_f32).collect(),
);
Ok((
output,
Box::new(SoftmaxCache {
output_bits: report.outputs,
n_rows,
n_cols,
}),
))
}
fn backward(
&self,
grad_output: &Tensor<f32>,
cache: &dyn Any,
) -> Result<(Tensor<f32>, Vec<Tensor<f32>>)> {
let c = cache
.downcast_ref::<SoftmaxCache>()
.ok_or_else(|| Error::backend("Softmax backward cache downcast failed"))?;
let grad_output_bits = tensor_to_fp16_bits(grad_output);
let report = run_rocm_hip_grad_loss_wrt_logits(
&grad_output_bits,
&c.output_bits,
c.n_rows,
c.n_cols,
)?;
let grad_input = Tensor::dense_cpu(
domain(),
Shape::from(vec![c.n_rows, c.n_cols]),
report.outputs.iter().copied().map(f16_to_f32).collect(),
);
Ok((grad_input, Vec::new()))
}
}
pub struct Add;
pub struct AddCache {
}
impl Layer for Add {
fn name(&self) -> &'static str {
"Add"
}
fn parameters(&self) -> Vec<&Parameter> {
Vec::new()
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter> {
Vec::new()
}
fn forward(&self, input: &Tensor<f32>) -> Result<(Tensor<f32>, Box<dyn Any + Send>)> {
let shape = &input.meta.shape;
if shape.dims.len() != 2 {
return Err(Error::shape(format!(
"Add expects 2D input [B, 2*D], got rank {}",
shape.dims.len()
)));
}
let (b, two_d) = match (&shape.dims[0], &shape.dims[1]) {
(crate::object::Dim::Static(b), crate::object::Dim::Static(d)) => (*b, *d),
_ => return Err(Error::shape("Add dims must be static")),
};
if two_d % 2 != 0 {
return Err(Error::shape(format!(
"Add expects 2*D width, got {}",
two_d
)));
}
let d = two_d / 2;
let mut out_data = Vec::with_capacity(b * d);
for row in 0..b {
for col in 0..d {
out_data.push(input.data[row * two_d + col] + input.data[row * two_d + d + col]);
}
}
let output = Tensor::dense_cpu(domain(), Shape::from(vec![b, d]), out_data);
Ok((output, Box::new(AddCache {})))
}
fn backward(
&self,
grad_output: &Tensor<f32>,
_cache: &dyn Any,
) -> Result<(Tensor<f32>, Vec<Tensor<f32>>)> {
let shape = &grad_output.meta.shape;
let (b, d) = match (&shape.dims[0], &shape.dims[1]) {
(crate::object::Dim::Static(b), crate::object::Dim::Static(d)) => (*b, *d),
_ => return Err(Error::shape("Add backward dims must be static")),
};
let two_d = 2 * d;
let mut data = Vec::with_capacity(b * two_d);
for row in 0..b {
for col in 0..d {
let g = grad_output.data[row * d + col];
data.push(g);
}
for col in 0..d {
let g = grad_output.data[row * d + col];
data.push(g);
}
}
let grad_input = Tensor::dense_cpu(domain(), Shape::from(vec![b, two_d]), data);
Ok((grad_input, Vec::new()))
}
}
pub struct Router {
pub linear: Linear,
pub top_k: usize,
}
pub struct RouterCache {
pub input: Tensor<f32>,
pub linear_cache: Box<dyn Any + Send>,
pub softmax_output_bits: Vec<u16>,
pub n_rows: usize,
pub n_experts: usize,
pub top_k_indices: Vec<usize>, }
impl Router {
pub fn new(in_features: usize, n_experts: usize, top_k: usize, seed: u32) -> Self {
Self {
linear: Linear::new(in_features, n_experts, seed),
top_k,
}
}
}
impl Layer for Router {
fn name(&self) -> &'static str {
"Router"
}
fn parameters(&self) -> Vec<&Parameter> {
self.linear.parameters()
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter> {
self.linear.parameters_mut()
}
fn forward(&self, input: &Tensor<f32>) -> Result<(Tensor<f32>, Box<dyn Any + Send>)> {
let (logits, linear_cache) = self.linear.forward(input)?;
let shape = logits.meta.shape.clone();
let n_rows = match &shape.dims[0] {
crate::object::Dim::Static(v) => *v,
_ => return Err(Error::shape("Router n_rows must be static")),
};
let n_experts = match &shape.dims[1] {
crate::object::Dim::Static(v) => *v,
_ => return Err(Error::shape("Router n_experts must be static")),
};
let logits_bits = tensor_to_fp16_bits(&logits);
let sm_report = run_rocm_hip_softmax_fwd(&logits_bits, n_rows, n_experts)?;
let mut masked_data = vec![0.0f32; n_rows * n_experts];
let probs: Vec<f32> = sm_report.outputs.iter().copied().map(f16_to_f32).collect();
let mut top_k_indices: Vec<usize> = Vec::with_capacity(n_rows * self.top_k);
for r in 0..n_rows {
let row = &probs[r * n_experts..(r + 1) * n_experts];
let mut indexed: Vec<(usize, f32)> = row.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for &(idx, _) in indexed.iter().take(self.top_k) {
top_k_indices.push(idx);
masked_data[r * n_experts + idx] = row[idx];
}
}
let output = Tensor::dense_cpu(domain(), shape, masked_data);
let cache = RouterCache {
input: input.clone(),
linear_cache,
softmax_output_bits: sm_report.outputs,
n_rows,
n_experts,
top_k_indices,
};
Ok((output, Box::new(cache)))
}
fn backward(
&self,
grad_output: &Tensor<f32>,
cache: &dyn Any,
) -> Result<(Tensor<f32>, Vec<Tensor<f32>>)> {
let c = cache
.downcast_ref::<RouterCache>()
.ok_or_else(|| Error::backend("Router backward cache downcast failed"))?;
let grad_output_bits = tensor_to_fp16_bits(grad_output);
let grad_logits_report = run_rocm_hip_grad_loss_wrt_logits(
&grad_output_bits,
&c.softmax_output_bits,
c.n_rows,
c.n_experts,
)?;
let grad_logits = Tensor::dense_cpu(
domain(),
Shape::from(vec![c.n_rows, c.n_experts]),
grad_logits_report
.outputs
.iter()
.copied()
.map(f16_to_f32)
.collect(),
);
let (grad_input, mut param_grads) = self
.linear
.backward(&grad_logits, c.linear_cache.as_ref())?;
let _ = &c.top_k_indices;
Ok((grad_input, param_grads.drain(..).collect()))
}
}
pub struct Expert {
pub fc1: Linear,
pub fc2: Linear,
pub hidden_dim: usize,
}
impl Expert {
pub fn new(in_features: usize, out_features: usize, hidden_dim: usize, seed: u32) -> Self {
Self {
fc1: Linear::new(in_features, hidden_dim, seed),
fc2: Linear::new(hidden_dim, out_features, seed.wrapping_add(1)),
hidden_dim,
}
}
}
pub struct ExpertCache {
pub fc1_cache: Box<dyn Any + Send>,
pub fc2_cache: Box<dyn Any + Send>,
}
impl Layer for Expert {
fn name(&self) -> &'static str {
"Expert"
}
fn parameters(&self) -> Vec<&Parameter> {
let mut p = self.fc1.parameters();
p.extend(self.fc2.parameters());
p
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter> {
let mut p = self.fc1.parameters_mut();
p.extend(self.fc2.parameters_mut());
p
}
fn forward(&self, input: &Tensor<f32>) -> Result<(Tensor<f32>, Box<dyn Any + Send>)> {
let (hidden, fc1_cache) = self.fc1.forward(input)?;
let act = GELU;
let (hidden_act, gelu_cache) = act.forward(&hidden)?;
let (out, fc2_cache) = self.fc2.forward(&hidden_act)?;
Ok((
out,
Box::new(ExpertCache {
fc1_cache,
fc2_cache: Box::new(ExpertSubCache {
fc2: fc2_cache,
gelu: gelu_cache,
}),
}),
))
}
fn backward(
&self,
grad_output: &Tensor<f32>,
cache: &dyn Any,
) -> Result<(Tensor<f32>, Vec<Tensor<f32>>)> {
let c = cache
.downcast_ref::<ExpertCache>()
.ok_or_else(|| Error::backend("Expert backward cache downcast failed"))?;
let sub = c
.fc2_cache
.downcast_ref::<ExpertSubCache>()
.ok_or_else(|| Error::backend("Expert sub-cache downcast failed"))?;
let (grad_hidden_act, fc2_pg) = self.fc2.backward(grad_output, sub.fc2.as_ref())?;
let (grad_hidden, _gelu_pg) = GELU.backward(&grad_hidden_act, sub.gelu.as_ref())?;
let (grad_input, fc1_pg) = self.fc1.backward(&grad_hidden, c.fc1_cache.as_ref())?;
let mut all_pg = fc1_pg;
all_pg.extend(fc2_pg);
Ok((grad_input, all_pg))
}
}
pub struct ExpertSubCache {
pub fc2: Box<dyn Any + Send>,
pub gelu: Box<dyn Any + Send>,
}
pub struct Embedding {
pub weight: Parameter, pub weight_cache: Vec<u16>,
pub vocab_size: usize,
pub embedding_dim: usize,
}
impl Embedding {
pub fn new(vocab_size: usize, embedding_dim: usize, seed: u32) -> Self {
let bound = (1.0 / (embedding_dim as f32).sqrt()) * 0.5;
let weight = Parameter::uniform(
Shape::from(vec![vocab_size, embedding_dim]),
-bound,
bound,
seed,
domain(),
);
let weight_cache = tensor_to_fp16_bits(&weight.data);
Self {
weight,
weight_cache,
vocab_size,
embedding_dim,
}
}
pub fn forward(&self, indices: &Tensor<i32>) -> Result<Tensor<f32>> {
let n_queries = indices.data.len();
let output_bits = run_rocm_hip_embedding_fwd(
&indices.data,
&self.weight_cache,
n_queries,
self.embedding_dim,
self.vocab_size,
)?;
let data: Vec<f32> = output_bits.iter().copied().map(f16_to_f32).collect();
Ok(Tensor::dense_cpu(
domain(),
Shape::from(vec![n_queries, self.embedding_dim]),
data,
))
}
pub fn backward(&self, grad_output: &Tensor<f32>) -> Result<Tensor<f32>> {
let (_n_queries, embedding_dim) = match (
&grad_output.meta.shape.dims[0],
&grad_output.meta.shape.dims[1],
) {
(crate::object::Dim::Static(q), crate::object::Dim::Static(d)) => (*q, *d),
_ => return Err(Error::shape("Embedding backward expects static shape")),
};
if embedding_dim != self.embedding_dim {
return Err(Error::shape(format!(
"Embedding backward embedding_dim {} != self.embedding_dim {}",
embedding_dim, self.embedding_dim
)));
}
Err(Error::backend(
"Embedding::backward requires the indices from the forward pass; use EmbeddingContext instead",
))
}
}
pub struct LinearResidual;
pub struct LinearResidualCache;
impl Layer for LinearResidual {
fn name(&self) -> &'static str {
"LinearResidual"
}
fn parameters(&self) -> Vec<&Parameter> {
Vec::new()
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter> {
Vec::new()
}
fn forward(&self, input: &Tensor<f32>) -> Result<(Tensor<f32>, Box<dyn Any + Send>)> {
Ok((input.clone(), Box::new(LinearResidualCache)))
}
fn backward(
&self,
grad_output: &Tensor<f32>,
_cache: &dyn Any,
) -> Result<(Tensor<f32>, Vec<Tensor<f32>>)> {
Ok((grad_output.clone(), Vec::new()))
}
}