use super::super::Tensor;
use crate::errors::{Result, TrustformersError};
use scirs2_core::ndarray::{ArrayD, Axis, IxDyn, Zip};
impl Tensor {
pub fn less(&self, other: &Tensor) -> Result<Tensor> {
match (self, other) {
(Tensor::F32(a), Tensor::F32(b)) => {
if a.shape() != b.shape() {
return Err(TrustformersError::shape_error(
"Tensors must have the same shape for less comparison".to_string(),
));
}
let result =
Zip::from(a).and(b).map_collect(|&x, &y| if x < y { 1.0f32 } else { 0.0f32 });
Ok(Tensor::F32(result))
},
(Tensor::F64(a), Tensor::F64(b)) => {
if a.shape() != b.shape() {
return Err(TrustformersError::shape_error(
"Tensors must have the same shape for less comparison".to_string(),
));
}
let result =
Zip::from(a).and(b).map_collect(|&x, &y| if x < y { 1.0f64 } else { 0.0f64 });
Ok(Tensor::F64(result))
},
(Tensor::I64(a), Tensor::I64(b)) => {
if a.shape() != b.shape() {
return Err(TrustformersError::shape_error(
"Tensors must have the same shape for less comparison".to_string(),
));
}
let result =
Zip::from(a).and(b).map_collect(|&x, &y| if x < y { 1i64 } else { 0i64 });
Ok(Tensor::I64(result))
},
_ => Err(TrustformersError::tensor_op_error(
"Less operation not implemented for this tensor type combination",
"less",
)),
}
}
pub fn equal(&self, other: &Tensor) -> Result<Tensor> {
match (self, other) {
(Tensor::F32(a), Tensor::F32(b)) => {
if a.shape() != b.shape() {
return Err(TrustformersError::shape_error(
"Tensors must have the same shape for equal comparison".to_string(),
));
}
let result = Zip::from(a).and(b).map_collect(|&x, &y| {
if (x - y).abs() < f32::EPSILON {
1.0f32
} else {
0.0f32
}
});
Ok(Tensor::F32(result))
},
(Tensor::F64(a), Tensor::F64(b)) => {
if a.shape() != b.shape() {
return Err(TrustformersError::shape_error(
"Tensors must have the same shape for equal comparison".to_string(),
));
}
let result = Zip::from(a).and(b).map_collect(|&x, &y| {
if (x - y).abs() < f64::EPSILON {
1.0f64
} else {
0.0f64
}
});
Ok(Tensor::F64(result))
},
(Tensor::I64(a), Tensor::I64(b)) => {
if a.shape() != b.shape() {
return Err(TrustformersError::shape_error(
"Tensors must have the same shape for equal comparison".to_string(),
));
}
let result =
Zip::from(a).and(b).map_collect(|&x, &y| if x == y { 1i64 } else { 0i64 });
Ok(Tensor::I64(result))
},
_ => Err(TrustformersError::tensor_op_error(
"Equal operation not implemented for this tensor type combination",
"equal",
)),
}
}
pub fn where_cond(&self, condition: &Tensor, other: &Tensor) -> Result<Tensor> {
match (self, condition, other) {
(Tensor::F32(a), Tensor::F32(cond), Tensor::F32(b)) => {
if a.shape() != cond.shape() || a.shape() != b.shape() {
return Err(TrustformersError::shape_error(
"All tensors must have the same shape for where operation".to_string(),
));
}
let result = Zip::from(cond)
.and(a)
.and(b)
.map_collect(|&c, &x, &y| if c > 0.5 { x } else { y });
Ok(Tensor::F32(result))
},
(Tensor::F64(a), Tensor::F64(cond), Tensor::F64(b)) => {
if a.shape() != cond.shape() || a.shape() != b.shape() {
return Err(TrustformersError::shape_error(
"All tensors must have the same shape for where operation".to_string(),
));
}
let result = Zip::from(cond)
.and(a)
.and(b)
.map_collect(|&c, &x, &y| if c > 0.5 { x } else { y });
Ok(Tensor::F64(result))
},
(Tensor::I64(a), Tensor::I64(cond), Tensor::I64(b)) => {
if a.shape() != cond.shape() || a.shape() != b.shape() {
return Err(TrustformersError::shape_error(
"All tensors must have the same shape for where operation".to_string(),
));
}
let result = Zip::from(cond)
.and(a)
.and(b)
.map_collect(|&c, &x, &y| if c > 0 { x } else { y });
Ok(Tensor::I64(result))
},
_ => Err(TrustformersError::tensor_op_error(
"Where operation not implemented for this tensor type combination",
"where_cond",
)),
}
}
pub fn layer_norm(&self, axis: i32, epsilon: f32) -> Result<Tensor> {
match self {
Tensor::F32(a) => {
let ndim = a.ndim();
let axis = if axis < 0 { (ndim as i32 + axis) as usize } else { axis as usize };
if axis >= ndim {
return Err(TrustformersError::tensor_op_error(
&format!(
"Axis {} is out of bounds for tensor with {} dimensions",
axis, ndim
),
"layer_norm",
));
}
let _mean = a.mean_axis(Axis(axis)).expect("axis must be valid for mean operation");
let last_dim = a.ndim() - 1;
if axis != last_dim {
return Err(TrustformersError::tensor_op_error(
"Layer norm currently only supports last dimension normalization",
"layer_norm",
));
}
let mean = a.mean_axis(Axis(axis)).expect("axis must be valid for mean operation");
let var = a.map_axis(Axis(axis), |lane| {
let lane_mean = lane.mean().expect("Mean calculation failed");
lane.mapv(|x| (x - lane_mean).powi(2)).mean().expect("Mean calculation failed")
});
let mut result = a.clone();
for (i, mut lane) in result.axis_iter_mut(Axis(axis)).enumerate() {
let m = mean[i];
let v = var[i];
lane.mapv_inplace(|x| (x - m) / (v + epsilon).sqrt());
}
Ok(Tensor::F32(result))
},
_ => Err(TrustformersError::tensor_op_error(
"Layer norm not supported for this tensor type",
"layer_norm",
)),
}
}
pub fn cross_entropy(&self, targets: &Tensor, reduction: &str) -> Result<Tensor> {
match (self, targets) {
(Tensor::F32(predictions), Tensor::F32(targets)) => {
let log_preds = predictions.mapv(|x| (x + 1e-8).ln()); let losses = Zip::from(&log_preds)
.and(targets)
.map_collect(|&log_pred, &target| -target * log_pred);
match reduction {
"mean" => {
let mean_loss = losses.mean().expect("Mean calculation failed");
Ok(Tensor::F32(ArrayD::from_elem(IxDyn(&[]), mean_loss)))
},
"sum" => {
let sum_loss = losses.sum();
Ok(Tensor::F32(ArrayD::from_elem(IxDyn(&[]), sum_loss)))
},
"none" => Ok(Tensor::F32(losses)),
_ => Err(TrustformersError::tensor_op_error(
"Invalid reduction. Use 'mean', 'sum', or 'none'",
"cross_entropy",
)),
}
},
_ => Err(TrustformersError::tensor_op_error(
"Cross entropy not supported for these tensor types",
"cross_entropy",
)),
}
}
pub fn cosine_similarity(&self, other: &Tensor, dim: i32, eps: f32) -> Result<Tensor> {
match (self, other) {
(Tensor::F32(a), Tensor::F32(b)) => {
let ndim = a.ndim();
let axis = if dim < 0 { (ndim as i32 + dim) as usize } else { dim as usize };
let dot_product =
Zip::from(a).and(b).map_collect(|&x, &y| x * y).sum_axis(Axis(axis));
let norm_a = a.mapv(|x| x * x).sum_axis(Axis(axis)).mapv(|x| (x + eps).sqrt());
let norm_b = b.mapv(|x| x * x).sum_axis(Axis(axis)).mapv(|x| (x + eps).sqrt());
let similarity = Zip::from(&dot_product)
.and(&norm_a)
.and(&norm_b)
.map_collect(|&dot, &norm_a, &norm_b| dot / (norm_a * norm_b));
Ok(Tensor::F32(similarity))
},
_ => Err(TrustformersError::tensor_op_error(
"Cosine similarity not supported for these tensor types",
"cosine_similarity",
)),
}
}
pub fn log_softmax(&self, dim: i32) -> Result<Tensor> {
match self {
Tensor::F32(a) => {
let ndim = a.ndim();
let axis = if dim < 0 { (ndim as i32 + dim) as usize } else { dim as usize };
if axis >= ndim {
return Err(TrustformersError::tensor_op_error(
&format!(
"Axis {} is out of bounds for tensor with {} dimensions",
axis, ndim
),
"log_softmax",
));
}
let max_vals = a.map_axis(Axis(axis), |lane| {
lane.iter().fold(f32::NEG_INFINITY, |acc, &x| acc.max(x))
});
let mut max_shape = a.shape().to_vec();
max_shape[axis] = 1;
let max_expanded = max_vals
.into_shape_with_order(max_shape.clone())
.expect("reshape must be valid for max values");
let shifted = a - &max_expanded
.broadcast(a.raw_dim())
.expect("broadcast must succeed with compatible shapes");
let exp_shifted = shifted.mapv(|x| x.exp());
let sum_exp = exp_shifted.sum_axis(Axis(axis));
let log_sum_exp = sum_exp.mapv(|x| x.ln());
let log_sum_exp_expanded = log_sum_exp
.into_shape_with_order(max_shape)
.expect("reshape must be valid for log_sum_exp");
let result = shifted
- log_sum_exp_expanded
.broadcast(a.raw_dim())
.expect("broadcast must succeed with compatible shapes");
Ok(Tensor::F32(result))
},
_ => Err(TrustformersError::tensor_op_error(
"Log softmax not supported for this tensor type",
"log_softmax",
)),
}
}
}