use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn one_hot(tensor: &Tensor, num_classes: i32) -> TorshResult<Tensor> {
if tensor.shape().ndim() != 1 {
return Err(TorshError::InvalidArgument(
"one_hot expects 1D tensor input".to_string(),
));
}
let data = tensor.data()?;
let n = data.len();
let nc = if num_classes < 0 {
let max_val = data.iter().fold(0.0_f32, |a, &b| a.max(b));
(max_val as usize) + 1
} else {
num_classes as usize
};
let mut one_hot_data = vec![0.0_f32; n * nc];
for (i, &class_idx) in data.iter().enumerate() {
let class = class_idx as usize;
if class < nc {
one_hot_data[i * nc + class] = 1.0;
}
}
Tensor::from_data(one_hot_data, vec![n, nc], tensor.device())
}
pub fn linear(input: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> TorshResult<Tensor> {
if weight.shape().ndim() != 2 {
return Err(TorshError::InvalidArgument(
"Weight must be a 2D tensor".to_string(),
));
}
let output = input.matmul(&weight.transpose(-1, -2)?)?;
if let Some(b) = bias {
if b.shape().ndim() != 1 {
return Err(TorshError::InvalidArgument(
"Bias must be a 1D tensor".to_string(),
));
}
output.add(b)
} else {
Ok(output)
}
}
pub fn pairwise_distance(x1: &Tensor, x2: &Tensor, p: f32, eps: f32) -> TorshResult<Tensor> {
if x1.shape().dims() != x2.shape().dims() {
return Err(TorshError::InvalidArgument(
"Input tensors must have the same shape".to_string(),
));
}
let diff = x1.sub(x2)?;
if (p - 2.0).abs() < eps {
let squared = diff.pow_scalar(2.0)?;
let sum = squared.sum_dim(&[-1], false)?;
sum.sqrt()
} else if (p - 1.0).abs() < eps {
let abs_diff = diff.abs()?;
abs_diff.sum_dim(&[-1], false)
} else {
let abs_diff = diff.abs()?;
let powered = abs_diff.pow_scalar(p)?;
let sum = powered.sum_dim(&[-1], false)?;
sum.pow_scalar(1.0 / p)
}
}
pub fn cosine_similarity(x1: &Tensor, x2: &Tensor, dim: i32, eps: f32) -> TorshResult<Tensor> {
if x1.shape().dims() != x2.shape().dims() {
return Err(TorshError::InvalidArgument(
"Input tensors must have the same shape".to_string(),
));
}
let dot_product = x1.mul(x2)?.sum_dim(&[dim], false)?;
let x1_norm = x1.pow_scalar(2.0)?.sum_dim(&[dim], false)?.sqrt()?;
let x2_norm = x2.pow_scalar(2.0)?.sum_dim(&[dim], false)?.sqrt()?;
let denominator = x1_norm.mul(&x2_norm)?.clamp(eps, f32::MAX)?;
dot_product.div(&denominator)
}
pub fn embedding(weight: &Tensor, indices: &Tensor) -> TorshResult<Tensor> {
if weight.shape().ndim() != 2 {
return Err(TorshError::InvalidArgument(
"Weight must be a 2D tensor [num_embeddings, embedding_dim]".to_string(),
));
}
let num_embeddings = weight.shape().dims()[0];
let embedding_dim = weight.shape().dims()[1];
let indices_data = indices.data()?;
let indices_shape_binding = indices.shape();
let indices_shape = indices_shape_binding.dims();
let mut output_shape = indices_shape.to_vec();
output_shape.push(embedding_dim);
let weight_data = weight.data()?;
let mut output_data = Vec::with_capacity(indices_data.len() * embedding_dim);
for &idx in indices_data.iter() {
let idx_usize = idx as usize;
if idx_usize >= num_embeddings {
return Err(TorshError::InvalidArgument(format!(
"Index {} out of bounds for embedding with {} entries",
idx_usize, num_embeddings
)));
}
let start = idx_usize * embedding_dim;
let end = start + embedding_dim;
output_data.extend_from_slice(&weight_data[start..end]);
}
Tensor::from_data(output_data, output_shape, weight.device())
}
pub fn pixel_shuffle(input: &Tensor, upscale_factor: usize) -> TorshResult<Tensor> {
if input.shape().ndim() != 4 {
return Err(TorshError::InvalidArgument(
"pixel_shuffle expects 4D input [B, C, H, W]".to_string(),
));
}
let shape_binding = input.shape();
let shape = shape_binding.dims();
let batch_size = shape[0];
let channels = shape[1];
let height = shape[2];
let width = shape[3];
let r = upscale_factor;
let r_squared = r * r;
if channels % r_squared != 0 {
return Err(TorshError::InvalidArgument(format!(
"Channels {} must be divisible by upscale_factor^2 = {}",
channels, r_squared
)));
}
let output_channels = channels / r_squared;
let output_height = height * r;
let output_width = width * r;
let data = input.data()?;
let mut output_data =
vec![0.0_f32; batch_size * output_channels * output_height * output_width];
for b in 0..batch_size {
for c in 0..output_channels {
for h in 0..height {
for w in 0..width {
for r_h in 0..r {
for r_w in 0..r {
let input_c = c * r_squared + r_h * r + r_w;
let input_idx = ((b * channels + input_c) * height + h) * width + w;
let output_h = h * r + r_h;
let output_w = w * r + r_w;
let output_idx = ((b * output_channels + c) * output_height + output_h)
* output_width
+ output_w;
output_data[output_idx] = data[input_idx];
}
}
}
}
}
}
Tensor::from_data(
output_data,
vec![batch_size, output_channels, output_height, output_width],
input.device(),
)
}
pub fn pixel_unshuffle(input: &Tensor, downscale_factor: usize) -> TorshResult<Tensor> {
if input.shape().ndim() != 4 {
return Err(TorshError::InvalidArgument(
"pixel_unshuffle expects 4D input [B, C, H, W]".to_string(),
));
}
let shape_binding = input.shape();
let shape = shape_binding.dims();
let batch_size = shape[0];
let channels = shape[1];
let height = shape[2];
let width = shape[3];
let r = downscale_factor;
if height % r != 0 || width % r != 0 {
return Err(TorshError::InvalidArgument(format!(
"Height {} and width {} must be divisible by downscale_factor {}",
height, width, r
)));
}
let output_channels = channels * r * r;
let output_height = height / r;
let output_width = width / r;
let data = input.data()?;
let mut output_data =
vec![0.0_f32; batch_size * output_channels * output_height * output_width];
for b in 0..batch_size {
for c in 0..channels {
for h in 0..output_height {
for w in 0..output_width {
for r_h in 0..r {
for r_w in 0..r {
let input_h = h * r + r_h;
let input_w = w * r + r_w;
let input_idx =
((b * channels + c) * height + input_h) * width + input_w;
let output_c = c * r * r + r_h * r + r_w;
let output_idx = ((b * output_channels + output_c) * output_height + h)
* output_width
+ w;
output_data[output_idx] = data[input_idx];
}
}
}
}
}
}
Tensor::from_data(
output_data,
vec![batch_size, output_channels, output_height, output_width],
input.device(),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_one_hot_basic() -> TorshResult<()> {
let indices = Tensor::from_vec(vec![0.0, 1.0, 2.0, 0.0], &[4])?;
let encoded = one_hot(&indices, 3)?;
assert_eq!(encoded.shape().dims(), &[4, 3]);
let data = encoded.data()?;
assert_eq!(data[0], 1.0);
assert_eq!(data[1], 0.0);
assert_eq!(data[2], 0.0);
assert_eq!(data[3], 0.0);
assert_eq!(data[4], 1.0);
assert_eq!(data[5], 0.0);
Ok(())
}
#[test]
fn test_linear_without_bias() -> TorshResult<()> {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3])?;
let weight = Tensor::from_vec(vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0], &[2, 3])?;
let output = linear(&input, &weight, None)?;
assert_eq!(output.shape().dims(), &[1, 2]);
let data = output.data()?;
assert!((data[0] - 6.0).abs() < 1e-5);
assert!((data[1] - 12.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_pairwise_distance_euclidean() -> TorshResult<()> {
let x1 = Tensor::from_vec(vec![0.0, 0.0, 3.0, 4.0], &[2, 2])?;
let x2 = Tensor::from_vec(vec![0.0, 0.0, 0.0, 0.0], &[2, 2])?;
let dist = pairwise_distance(&x1, &x2, 2.0, 1e-6)?;
assert_eq!(dist.shape().dims(), &[2]);
let data = dist.data()?;
assert!((data[0] - 0.0).abs() < 1e-5);
assert!((data[1] - 5.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_cosine_similarity_basic() -> TorshResult<()> {
let x1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3])?;
let x2 = Tensor::from_vec(vec![2.0, 4.0, 6.0], &[1, 3])?;
let sim = cosine_similarity(&x1, &x2, 1, 1e-8)?;
let data = sim.data()?;
assert!((data[0] - 1.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_embedding_basic() -> TorshResult<()> {
let weight = Tensor::from_vec(
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ],
&[4, 3],
)?;
let indices = Tensor::from_vec(vec![0.0, 2.0, 1.0], &[3])?;
let embedded = embedding(&weight, &indices)?;
assert_eq!(embedded.shape().dims(), &[3, 3]);
let data = embedded.data()?;
assert_eq!(data[0], 1.0);
assert_eq!(data[1], 2.0);
assert_eq!(data[2], 3.0);
assert_eq!(data[3], 7.0);
assert_eq!(data[4], 8.0);
assert_eq!(data[5], 9.0);
Ok(())
}
}