use crate::error::{Error, Result};
use numr::dtype::DType;
use numr::ops::{
BinaryOps, MatmulOps, ReduceOps, ScalarOps, ShapeOps, TensorOps, UnaryOps, UtilityOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
#[allow(non_snake_case)]
pub fn ssd_chunk_scan_impl<R, C>(
client: &C,
x: &Tensor<R>,
states: &Tensor<R>,
c: &Tensor<R>,
dA_cumsum: &Tensor<R>,
d: Option<&Tensor<R>>,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ShapeOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ MatmulOps<R>
+ UtilityOps<R>
+ TensorOps<R>,
{
let x_shape = x.shape();
let s_shape = states.shape();
let c_shape = c.shape();
let da_shape = dA_cumsum.shape();
let batch = x_shape[0];
let seqlen = x_shape[1];
let nheads = x_shape[2];
let headdim = x_shape[3];
let nchunks = s_shape[1];
let dstate = s_shape[4];
let ngroups = c_shape[2];
let chunk_size = da_shape[3];
let heads_per_group = nheads / ngroups;
let device = x.device();
let dtype = x.dtype();
let padded_len = nchunks * chunk_size;
let c_padded = if padded_len > seqlen {
let pad = Tensor::<R>::zeros(
&[batch, padded_len - seqlen, ngroups, dstate],
dtype,
device,
);
client.cat(&[c, &pad], 1).map_err(Error::Numr)?
} else {
c.clone()
};
let dA_clamped = client
.clamp(dA_cumsum, f64::NEG_INFINITY, 0.0)
.map_err(Error::Numr)?;
let decay = client.exp(&dA_clamped).map_err(Error::Numr)?;
let c_chunks = c_padded
.reshape(&[batch, nchunks, chunk_size, ngroups, dstate])
.map_err(Error::Numr)?;
let c_for_heads = if ngroups == nheads {
c_chunks.permute(&[0, 1, 2, 3, 4]).map_err(Error::Numr)?
} else if heads_per_group == 1 {
c_chunks
} else {
c_chunks
.reshape(&[batch, nchunks, chunk_size, ngroups, 1, dstate])
.map_err(Error::Numr)?
.broadcast_to(&[batch, nchunks, chunk_size, ngroups, heads_per_group, dstate])
.map_err(Error::Numr)?
.contiguous()
.reshape(&[batch, nchunks, chunk_size, nheads, dstate])
.map_err(Error::Numr)?
};
let c_t = c_for_heads
.permute(&[0, 1, 3, 2, 4])
.map_err(Error::Numr)?
.contiguous()
.reshape(&[batch * nchunks * nheads, chunk_size, dstate])
.map_err(Error::Numr)?;
let states_t = states
.reshape(&[batch * nchunks * nheads, headdim, dstate])
.map_err(Error::Numr)?
.permute(&[0, 2, 1])
.map_err(Error::Numr)?
.contiguous();
let y_flat = client.matmul(&c_t, &states_t).map_err(Error::Numr)?;
let y_chunked = y_flat
.reshape(&[batch, nchunks, nheads, chunk_size, headdim])
.map_err(Error::Numr)?;
let decay_t = decay
.permute(&[0, 2, 1, 3])
.map_err(Error::Numr)?
.contiguous()
.reshape(&[batch, nchunks, nheads, chunk_size, 1])
.map_err(Error::Numr)?;
let y_decayed = client.mul(&y_chunked, &decay_t).map_err(Error::Numr)?;
let y_full = y_decayed
.permute(&[0, 1, 3, 2, 4])
.map_err(Error::Numr)?
.contiguous()
.reshape(&[batch, padded_len, nheads, headdim])
.map_err(Error::Numr)?;
let y_trimmed = if padded_len > seqlen {
y_full.narrow(1, 0, seqlen).map_err(Error::Numr)?
} else {
y_full
};
let output = if let Some(d_param) = d {
let d_broad = d_param.reshape(&[1, 1, nheads, 1]).map_err(Error::Numr)?;
let d_x = client.mul(&d_broad, x).map_err(Error::Numr)?;
client.add(&y_trimmed, &d_x).map_err(Error::Numr)?
} else {
y_trimmed
};
Ok(output)
}