mod convolution;
mod pooling;
use crate::TruenoError;
use super::super::Matrix;
impl Matrix<f32> {
pub fn embedding_lookup(&self, indices: &[usize]) -> Result<Matrix<f32>, TruenoError> {
contract_pre_embedding_lookup!(indices);
for (i, &idx) in indices.iter().enumerate() {
if idx >= self.rows {
return Err(TruenoError::InvalidInput(format!(
"Index {} at position {} is out of bounds for embedding table with {} rows",
idx, i, self.rows
)));
}
}
if indices.is_empty() {
return Ok(Matrix::zeros_with_backend(0, self.cols, self.backend));
}
let seq_len = indices.len();
let embed_dim = self.cols;
let total = seq_len * embed_dim;
let mut data: Vec<f32> = Vec::with_capacity(total);
unsafe {
data.set_len(total);
}
let mut result = Matrix { rows: seq_len, cols: embed_dim, data, backend: self.backend };
for (out_row, &idx) in indices.iter().enumerate() {
let src_start = idx * embed_dim;
let dst_start = out_row * embed_dim;
result.data[dst_start..dst_start + embed_dim]
.copy_from_slice(&self.data[src_start..src_start + embed_dim]);
}
contract_post_embedding_lookup!(&result.data);
Ok(result)
}
pub fn embedding_lookup_sparse(
&self,
indices: &[usize],
) -> Result<(Matrix<f32>, Vec<usize>), TruenoError> {
let embeddings = self.embedding_lookup(indices)?;
let mut unique: Vec<usize> = indices.to_vec();
unique.sort_unstable();
unique.dedup();
Ok((embeddings, unique))
}
pub fn topk(&self, k: usize) -> Result<(Vec<f32>, Vec<usize>), TruenoError> {
if k == 0 {
return Ok((vec![], vec![]));
}
let k = k.min(self.data.len());
let mut indexed: Vec<(usize, f32)> = self.data.iter().copied().enumerate().collect();
indexed.select_nth_unstable_by(k.saturating_sub(1), |a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
indexed.truncate(k);
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let values: Vec<f32> = indexed.iter().map(|(_, v)| *v).collect();
let indices: Vec<usize> = indexed.iter().map(|(i, _)| *i).collect();
Ok((values, indices))
}
pub fn gather(&self, indices: &[usize], axis: usize) -> Result<Matrix<f32>, TruenoError> {
match axis {
0 => self.gather_rows(indices),
1 => self.gather_cols(indices),
_ => Err(TruenoError::InvalidInput(format!(
"Axis {} not supported for 2D matrix (use 0 or 1)",
axis
))),
}
}
fn gather_rows(&self, indices: &[usize]) -> Result<Matrix<f32>, TruenoError> {
let mut result = Matrix::new(indices.len(), self.cols);
for (out_i, &idx) in indices.iter().enumerate() {
if idx >= self.rows {
return Err(TruenoError::InvalidInput(format!(
"Index {} out of bounds for axis 0 with size {}",
idx, self.rows
)));
}
result.data[out_i * self.cols..(out_i + 1) * self.cols]
.copy_from_slice(&self.data[idx * self.cols..(idx + 1) * self.cols]);
}
Ok(result)
}
fn gather_cols(&self, indices: &[usize]) -> Result<Matrix<f32>, TruenoError> {
let mut result = Matrix::new(self.rows, indices.len());
for i in 0..self.rows {
for (out_j, &idx) in indices.iter().enumerate() {
if idx >= self.cols {
return Err(TruenoError::InvalidInput(format!(
"Index {} out of bounds for axis 1 with size {}",
idx, self.cols
)));
}
result.data[i * indices.len() + out_j] = self.data[i * self.cols + idx];
}
}
Ok(result)
}
pub fn pad(
&self,
padding: ((usize, usize), (usize, usize)),
value: f32,
) -> Result<Matrix<f32>, TruenoError> {
let ((top, bottom), (left, right)) = padding;
let new_rows = self.rows + top + bottom;
let new_cols = self.cols + left + right;
let mut result = Matrix::from_vec(new_rows, new_cols, vec![value; new_rows * new_cols])?;
for i in 0..self.rows {
for j in 0..self.cols {
result.data[(i + top) * new_cols + (j + left)] = self.data[i * self.cols + j];
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests;