use crate::utils::*;
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn example_linear(
input: &Tensor,
weight: &Tensor,
bias: Option<&Tensor>,
) -> TorshResult<Tensor> {
let context = function_context("example_linear");
validate_non_empty(input, &context)?;
validate_non_empty(weight, &context)?;
validate_tensor_dims(weight, 2, &context)?;
let input_shape_binding = input.shape();
let input_shape = input_shape_binding.dims();
let weight_shape_binding = weight.shape();
let weight_shape = weight_shape_binding.dims();
let input_features = input_shape[input_shape.len() - 1];
let weight_features = weight_shape[1];
if input_features != weight_features {
return Err(invalid_argument_error(
&format!(
"Input features ({}) don't match weight features ({})",
input_features, weight_features
),
&context,
));
}
if let Some(bias_tensor) = bias {
validate_non_empty(bias_tensor, &context)?;
validate_tensor_dims(bias_tensor, 1, &context)?;
let bias_shape_binding = bias_tensor.shape();
let bias_shape = bias_shape_binding.dims();
if bias_shape[0] != weight_shape[0] {
return Err(invalid_argument_error(
&format!(
"Bias size ({}) doesn't match output features ({})",
bias_shape[0], weight_shape[0]
),
&context,
));
}
}
let weight_t = weight.transpose(0, 1).map_err(|e| {
TorshError::InvalidOperation(format!("Weight transpose failed: {}", e))
.with_context(&context)
})?;
let output = input.matmul(&weight_t).map_err(|e| {
TorshError::InvalidOperation(format!("Matrix multiplication failed: {}", e))
.with_context(&context)
})?;
if let Some(bias_tensor) = bias {
output.add(bias_tensor).map_err(|e| {
TorshError::InvalidOperation(format!("Bias addition failed: {}", e))
.with_context(&context)
})
} else {
Ok(output)
}
}
pub fn example_activation<T>(input: &Tensor, alpha: f32, inplace: bool) -> TorshResult<Tensor>
where
T: Copy + PartialOrd + From<f32>,
{
let context = function_context("example_activation");
validate_non_empty(input, &context)?;
validate_positive(alpha, "alpha", &context)?;
handle_inplace_operation(
input,
inplace,
|tensor| {
tensor.relu() },
&context,
)
}
pub fn example_pool2d(
input: &Tensor,
kernel_size: (usize, usize),
stride: Option<(usize, usize)>,
padding: (usize, usize),
_dilation: (usize, usize),
) -> TorshResult<Tensor> {
let context = function_context("example_pool2d");
validate_tensor_dims(input, 4, &context)?;
let kernel_slice = [kernel_size.0, kernel_size.1];
let stride_slice = stride.unwrap_or(kernel_size);
let stride_slice = [stride_slice.0, stride_slice.1];
let padding_slice = [padding.0, padding.1];
validate_pooling_params(
input,
&kernel_slice,
&stride_slice,
&padding_slice,
&context,
)?;
Ok(input.clone())
}
pub fn example_loss(
input: &Tensor,
target: &Tensor,
reduction: crate::loss::ReductionType,
weight: Option<&Tensor>,
) -> TorshResult<Tensor> {
let context = function_context("example_loss");
validate_non_empty(input, &context)?;
validate_non_empty(target, &context)?;
validate_broadcastable_shapes(input, target, &context)?;
if let Some(weight_tensor) = weight {
validate_non_empty(weight_tensor, &context)?;
}
let loss = input.sub(target)?;
reduction.apply(loss)
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
#[test]
fn test_example_linear() -> TorshResult<()> {
let input = Tensor::from_data(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![2, 3],
DeviceType::Cpu,
)?;
let weight = Tensor::from_data(
vec![0.1f32, 0.2, 0.3, 0.4, 0.5, 0.6],
vec![2, 3], DeviceType::Cpu,
)?;
let bias = Tensor::from_data(vec![0.1f32, 0.2], vec![2], DeviceType::Cpu)?;
let output = example_linear(&input, &weight, Some(&bias))?;
assert_eq!(output.shape().dims(), &[2, 2]);
Ok(())
}
#[test]
fn test_validation_patterns() -> TorshResult<()> {
let empty = Tensor::from_data(vec![0.0f32; 0], vec![0], DeviceType::Cpu)?;
let valid = Tensor::from_data(vec![1.0f32], vec![1], DeviceType::Cpu)?;
let result = example_linear(&empty, &valid, None);
assert!(result.is_err());
Ok(())
}
}