use rayon::prelude::*;
use rten_gemm::{GemmExecutor, GemmInputA, GemmInputB, GemmUninitOptions};
use rten_simd::SimdOp;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, Tensor, TensorView};
use rten_vecmath::Softmax;
use crate::buffer_pool::{AutoReturn, BufferPool};
use crate::operator::{
IntoOpResult, OpError, OpRunContext, Operator, OutputList, OutputType, OutputTypeList,
OutputTypesContext,
};
use crate::ops::{
binary_elementwise::broadcast_shapes, layout::expand_to, norm::NanHandling, resolve_axis,
};
use crate::value::Value;
const BROADCAST_ERROR: OpError = OpError::IncompatibleInputShapes("Cannot broadcast inputs");
fn add_softmax_in_place(
pool: &BufferPool,
qk: Tensor<f32>,
m: TensorView<f32>,
nan_handling: NanHandling,
) -> Result<Tensor, OpError> {
let axis = resolve_axis(qk.ndim(), -1)?;
let m = m.try_broadcast(qk.shape()).map_err(|_| BROADCAST_ERROR)?;
let mut qk = if qk.stride(axis) == 1 {
qk
} else {
qk.auto_return(pool).to_tensor_in(pool)
};
let flush_nans = match nan_handling {
NanHandling::KeepNans => false,
NanHandling::FlushToZero => true,
};
qk.lanes_mut(axis)
.into_par_iter()
.zip(m.lanes(axis).into_par_iter())
.for_each(|(mut qk_inner, m_inner)| {
let qk_inner = qk_inner.as_slice_mut().unwrap();
for (qk, m) in qk_inner.iter_mut().zip(m_inner) {
*qk += m;
}
Softmax::new_mut(qk_inner)
.flush_nans_to_zero(flush_nans)
.dispatch();
});
Ok(qk)
}
#[derive(Debug)]
pub struct AddSoftmax {
pub flush_nans_to_zero: bool,
}
impl AddSoftmax {
fn nan_handling(&self) -> NanHandling {
if self.flush_nans_to_zero {
NanHandling::FlushToZero
} else {
NanHandling::KeepNans
}
}
}
impl Operator for AddSoftmax {
fn name(&self) -> &str {
"AddSoftmax"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let x: TensorView = ctx.inputs().require_as(0)?;
let y: TensorView = ctx.inputs().require_as(1)?;
let (qk, m) = if x.len() > y.len() { (x, y) } else { (y, x) };
let out_shape = broadcast_shapes(qk.shape(), m.shape());
let qk = match out_shape.as_deref() {
Some(shape) => qk.broadcast(shape).to_tensor_in(ctx.pool()),
None => {
return Err(BROADCAST_ERROR);
}
};
add_softmax_in_place(ctx.pool(), qk, m, self.nan_handling()).into_op_result()
}
fn is_commutative(&self) -> bool {
true
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
let qk: Tensor = input.try_into()?;
let m: TensorView = ctx.inputs().require_as(0)?;
let out_shape = broadcast_shapes(qk.shape(), m.shape());
let qk = match out_shape.as_deref() {
Some(shape) if shape == qk.shape() => qk,
Some(shape) => qk.broadcast(shape).to_tensor_in(ctx.pool()),
None => {
return Err(BROADCAST_ERROR);
}
};
add_softmax_in_place(ctx.pool(), qk, m, self.nan_handling()).map(|qk| qk.into())
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
}
fn repeat_interleave<T: Copy>(
pool: &BufferPool,
mut input: TensorView<T>,
axis: usize,
repeats: usize,
) -> Result<Tensor<T>, OpError> {
if input.ndim() <= axis {
return Err(OpError::InvalidValue("Input has too few dims"));
}
input.insert_axis(axis + 1);
let mut target_shape = input.shape().to_vec();
target_shape[axis + 1] *= repeats;
let mut expanded = expand_to(pool, input, &target_shape);
target_shape.remove(axis + 1);
target_shape[axis] *= repeats;
expanded.reshape(&target_shape);
Ok(expanded)
}
#[derive(Debug)]
pub struct RepeatInterleave {
pub axis: usize,
pub repeats: usize,
}
impl Operator for RepeatInterleave {
fn name(&self) -> &str {
"RepeatInterleave"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input: TensorView<f32> = ctx.inputs().require_as(0)?;
repeat_interleave(ctx.pool(), input, self.axis, self.repeats).into_op_result()
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
}
#[derive(Debug)]
pub struct GroupedQueryAttentionMatMul {
pub repeats: usize,
pub alpha: Option<f32>,
pub transpose_rhs: bool,
}
impl Operator for GroupedQueryAttentionMatMul {
fn name(&self) -> &str {
"GroupedQueryAttentionMatMul"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let lhs: NdTensorView<f32, 4> = ctx.inputs().require_as(0)?;
let mut rhs: NdTensorView<f32, 4> = ctx.inputs().require_as(1)?;
if self.transpose_rhs {
rhs.permute([0, 1, 3, 2]);
}
let [batch, heads, seq, k] = lhs.shape();
let [rhs_batch, rhs_heads, rhs_k, rhs_n] = rhs.shape();
if batch != rhs_batch {
return Err(OpError::IncompatibleInputShapes("Batch size mismatch"));
}
if k != rhs_k {
return Err(OpError::IncompatibleInputShapes("K size mismatch"));
}
if rhs_heads * self.repeats != heads {
return Err(OpError::IncompatibleInputShapes(
"Repeated axis size mismatch",
));
}
let chunk_size = self.repeats * seq * rhs_n;
let out_size = batch * (heads / self.repeats) * chunk_size;
let mut out_data = ctx.pool().alloc(out_size);
let out_uninit = &mut out_data.spare_capacity_mut()[..out_size];
let gemm = GemmExecutor::default();
let lhs_mats = lhs.reshaped_in(
ctx.pool(),
[batch, heads / self.repeats, self.repeats * seq, k],
);
let opts = GemmUninitOptions {
alpha: self.alpha.unwrap_or(1.0),
..Default::default()
};
lhs_mats
.inner_iter::<2>()
.into_par_iter()
.zip(rhs.inner_iter::<2>())
.zip(out_uninit.par_chunks_mut(chunk_size))
.for_each(|((lhs, rhs), out)| {
gemm.gemm_uninit(
out,
GemmInputA::Unpacked(lhs),
GemmInputB::Unpacked(rhs),
opts.clone(),
)
.unwrap();
});
unsafe { out_data.set_len(out_size) };
Tensor::from_data(&[batch, heads, seq, rhs_n], out_data).into_op_result()
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
}
#[cfg(test)]
mod tests {
use rten_tensor::prelude::*;
use rten_tensor::rng::XorShiftRng;
use rten_tensor::test_util::expect_equal;
use rten_tensor::{NdTensor, Tensor, TensorView};
use rten_testing::TestCases;
use super::{AddSoftmax, BROADCAST_ERROR, GroupedQueryAttentionMatMul, RepeatInterleave};
use crate::operator::{OpError, OperatorExt};
use crate::ops::{Add, Softmax};
fn reference_add_softmax(x: TensorView, y: TensorView) -> Result<Tensor, OpError> {
let add = Add {};
let softmax = Softmax {
axis: -1,
flush_nans_to_zero: false,
};
let sum: Tensor = add.run_simple((x, y))?;
softmax.run_simple(sum.view())
}
#[test]
fn test_add_softmax() {
#[derive(Debug)]
struct Case {
qk_shape: Vec<usize>,
m_shape: Vec<usize>,
expected_err: Option<OpError>,
in_place: bool,
}
let cases = [
Case {
qk_shape: [1, 8, 32, 32].into(),
m_shape: [1, 1, 32, 32].into(),
expected_err: None,
in_place: true,
},
Case {
qk_shape: [1, 8, 32, 32].into(),
m_shape: [1, 2, 32, 32].into(),
expected_err: Some(BROADCAST_ERROR),
in_place: true,
},
Case {
qk_shape: [1, 8, 32, 32].into(),
m_shape: [1, 2, 32, 32].into(),
expected_err: Some(BROADCAST_ERROR),
in_place: false,
},
Case {
qk_shape: [1, 8, 16].into(),
m_shape: [8, 1, 16].into(),
expected_err: None,
in_place: true,
},
Case {
qk_shape: [1, 1, 32, 32].into(),
m_shape: [1, 8, 32, 32].into(),
expected_err: None,
in_place: false,
},
];
cases.test_each(|case| {
let mut rng = XorShiftRng::new(1234);
let op = AddSoftmax {
flush_nans_to_zero: false,
};
let qk = Tensor::rand(&case.qk_shape, &mut rng);
let m = Tensor::rand(&case.m_shape, &mut rng);
let result: Result<Tensor, _> = if case.in_place {
op.run_simple_in_place(qk.clone(), m.view())
} else {
op.run_simple((qk.view(), m.view()))
};
if let Some(expected_err) = &case.expected_err {
assert_eq!(result.as_ref().err().unwrap(), expected_err);
} else {
let expected = reference_add_softmax(qk.view(), m.view()).unwrap();
expect_equal(&result.unwrap(), &expected).unwrap();
}
});
}
#[test]
fn test_add_softmax_flush_nans_to_zero() {
let qk = Tensor::from([f32::NEG_INFINITY, f32::NEG_INFINITY, f32::NEG_INFINITY]);
let m = Tensor::from([0., 0., 0.]);
let op = AddSoftmax {
flush_nans_to_zero: false,
};
let result: Tensor = op.run_simple((qk.view(), m.view())).unwrap();
assert!(result.iter().all(|x| x.is_nan()));
let qk = Tensor::from([f32::NEG_INFINITY, f32::NEG_INFINITY, f32::NEG_INFINITY]);
let m = Tensor::from([0., 0., 0.]);
let op = AddSoftmax {
flush_nans_to_zero: true,
};
let result: Tensor = op.run_simple((qk.view(), m.view())).unwrap();
assert_eq!(result.to_vec(), vec![0., 0., 0.]);
}
#[test]
fn test_repeat_interleave() {
let input = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
let op = RepeatInterleave {
axis: 1,
repeats: 2,
};
let repeated: Tensor = op.run_simple(input.view()).unwrap();
assert_eq!(repeated, Tensor::from([[1., 1., 2., 2.], [3., 3., 4., 4.]]));
}
#[test]
fn test_grouped_query_attention_matmul() {
let batch = 1;
let query_heads = 8;
let kv_heads = 2;
let seq = 3;
let d_model = 8;
let query = NdTensor::<f32, 4>::zeros([batch, query_heads, seq, d_model]);
let key = NdTensor::<f32, 4>::zeros([batch, kv_heads, seq, d_model]);
let value = NdTensor::<f32, 4>::zeros([batch, kv_heads, seq, d_model]);
let op = GroupedQueryAttentionMatMul {
repeats: query_heads / kv_heads,
alpha: Some(0.5),
transpose_rhs: true,
};
let query_key: NdTensor<f32, 4> = op.run_simple((query.view(), key.view())).unwrap();
assert_eq!(query_key.shape(), [batch, query_heads, seq, seq]);
let op = GroupedQueryAttentionMatMul {
repeats: query_heads / kv_heads,
alpha: None,
transpose_rhs: false,
};
let qkv: NdTensor<f32, 4> = op.run_simple((query_key.view(), value.view())).unwrap();
assert_eq!(qkv.shape(), [batch, query_heads, seq, d_model]);
}
}