use crate::{CooTensor, CsrTensor, SparseTensor, TorshResult};
use torsh_core::{Shape, TorshError};
use torsh_tensor::{creation::zeros, Tensor};
#[derive(Debug, Clone)]
pub struct SparseConv2d {
kernel: CsrTensor,
bias: Option<Tensor>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
}
impl SparseConv2d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: Option<(usize, usize)>,
padding: Option<(usize, usize)>,
dilation: Option<(usize, usize)>,
sparsity: f32,
use_bias: bool,
) -> TorshResult<Self> {
if !(0.0..=1.0).contains(&sparsity) {
return Err(TorshError::InvalidArgument(
"Sparsity must be between 0.0 and 1.0".to_string(),
));
}
if kernel_size.0 == 0 || kernel_size.1 == 0 {
return Err(TorshError::InvalidArgument(
"Kernel size must be greater than 0".to_string(),
));
}
let stride = stride.unwrap_or((1, 1));
let padding = padding.unwrap_or((0, 0));
let dilation = dilation.unwrap_or((1, 1));
if stride.0 == 0 || stride.1 == 0 {
return Err(TorshError::InvalidArgument(
"Stride must be greater than 0".to_string(),
));
}
if dilation.0 == 0 || dilation.1 == 0 {
return Err(TorshError::InvalidArgument(
"Dilation must be greater than 0".to_string(),
));
}
let kernel =
Self::generate_sparse_kernel(out_channels, in_channels, kernel_size, sparsity)?;
let bias = if use_bias {
Some(zeros::<f32>(&[out_channels])?)
} else {
None
};
Ok(Self {
kernel,
bias,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
})
}
pub fn forward(&self, input: &Tensor) -> TorshResult<Tensor> {
let input_shape = input.shape();
if input_shape.ndim() != 4 {
return Err(TorshError::InvalidArgument(
"Input must be 4D tensor (batch_size, in_channels, height, width)".to_string(),
));
}
let batch_size = input_shape.dims()[0];
let input_channels = input_shape.dims()[1];
let input_height = input_shape.dims()[2];
let input_width = input_shape.dims()[3];
if input_channels != self.in_channels {
return Err(TorshError::InvalidArgument(format!(
"Input channels {} don't match layer input channels {}",
input_channels, self.in_channels
)));
}
let output_height =
(input_height + 2 * self.padding.0 - self.dilation.0 * (self.kernel_size.0 - 1) - 1)
/ self.stride.0
+ 1;
let output_width =
(input_width + 2 * self.padding.1 - self.dilation.1 * (self.kernel_size.1 - 1) - 1)
/ self.stride.1
+ 1;
let mut output =
zeros::<f32>(&[batch_size, self.out_channels, output_height, output_width])?;
for b in 0..batch_size {
self.conv2d_single(
input,
&mut output,
b,
input_height,
input_width,
output_height,
output_width,
)?;
}
if let Some(ref bias) = self.bias {
for b in 0..batch_size {
for c in 0..self.out_channels {
let bias_val = bias.get(&[c])?;
for h in 0..output_height {
for w in 0..output_width {
let current = output.get(&[b, c, h, w])?;
output.set(&[b, c, h, w], current + bias_val)?;
}
}
}
}
}
Ok(output)
}
#[allow(clippy::too_many_arguments)]
fn conv2d_single(
&self,
input: &Tensor,
output: &mut Tensor,
batch_idx: usize,
input_height: usize,
input_width: usize,
output_height: usize,
output_width: usize,
) -> TorshResult<()> {
let kernel_coo = self.kernel.to_coo()?;
let kernel_triplets = kernel_coo.triplets();
for (kernel_flat_idx, _, kernel_val) in kernel_triplets {
let (out_c, in_c, kh, kw) = self.flat_to_4d_kernel(kernel_flat_idx);
for out_h in 0..output_height {
for out_w in 0..output_width {
let in_h = out_h * self.stride.0 + kh * self.dilation.0;
let in_w = out_w * self.stride.1 + kw * self.dilation.1;
if in_h >= self.padding.0 && in_w >= self.padding.1 {
let padded_in_h = in_h - self.padding.0;
let padded_in_w = in_w - self.padding.1;
if padded_in_h < input_height && padded_in_w < input_width {
let input_val =
input.get(&[batch_idx, in_c, padded_in_h, padded_in_w])?;
let current_out = output.get(&[batch_idx, out_c, out_h, out_w])?;
output.set(
&[batch_idx, out_c, out_h, out_w],
current_out + kernel_val * input_val,
)?;
}
}
}
}
}
Ok(())
}
fn flat_to_4d_kernel(&self, flat_idx: usize) -> (usize, usize, usize, usize) {
let kernel_size_total = self.kernel_size.0 * self.kernel_size.1;
let channel_size = self.in_channels * kernel_size_total;
let out_c = flat_idx / channel_size;
let remaining = flat_idx % channel_size;
let in_c = remaining / kernel_size_total;
let remaining = remaining % kernel_size_total;
let kh = remaining / self.kernel_size.1;
let kw = remaining % self.kernel_size.1;
(out_c, in_c, kh, kw)
}
#[allow(dead_code)]
fn kernel_4d_to_flat(&self, out_c: usize, in_c: usize, kh: usize, kw: usize) -> usize {
let kernel_size_total = self.kernel_size.0 * self.kernel_size.1;
let channel_size = self.in_channels * kernel_size_total;
out_c * channel_size + in_c * kernel_size_total + kh * self.kernel_size.1 + kw
}
fn generate_sparse_kernel(
out_channels: usize,
in_channels: usize,
kernel_size: (usize, usize),
sparsity: f32,
) -> TorshResult<CsrTensor> {
let total_elements = out_channels * in_channels * kernel_size.0 * kernel_size.1;
let nnz = ((total_elements as f32) * (1.0 - sparsity)) as usize;
let mut row_indices = Vec::with_capacity(nnz);
let mut col_indices = Vec::with_capacity(nnz);
let mut values = Vec::with_capacity(nnz);
let mut positions = std::collections::HashSet::new();
while positions.len() < nnz {
let mut rng = scirs2_core::random::thread_rng();
let flat_idx = rng.gen_range(0..total_elements);
positions.insert(flat_idx);
}
let kernel_size_total = kernel_size.0 * kernel_size.1;
let channel_size = in_channels * kernel_size_total;
for flat_idx in positions {
let out_c = flat_idx / channel_size;
let col_idx = flat_idx % channel_size;
row_indices.push(out_c);
col_indices.push(col_idx);
let fan_in = in_channels * kernel_size.0 * kernel_size.1;
let std_dev = (2.0 / fan_in as f32).sqrt();
let mut rng = scirs2_core::random::thread_rng();
values.push((rng.random::<f32>() * 2.0 - 1.0) * std_dev);
}
let shape = Shape::new(vec![out_channels, channel_size]);
let coo = CooTensor::new(row_indices, col_indices, values, shape)?;
CsrTensor::from_coo(&coo)
}
pub fn num_parameters(&self) -> usize {
let kernel_params = self.kernel.nnz();
let bias_params = self.bias.as_ref().map_or(0, |b| b.shape().numel());
kernel_params + bias_params
}
pub fn kernel_sparsity(&self) -> f32 {
let total_elements =
self.out_channels * self.in_channels * self.kernel_size.0 * self.kernel_size.1;
let nnz = self.kernel.nnz();
1.0 - (nnz as f32 / total_elements as f32)
}
pub fn in_channels(&self) -> usize {
self.in_channels
}
pub fn out_channels(&self) -> usize {
self.out_channels
}
pub fn kernel_size(&self) -> (usize, usize) {
self.kernel_size
}
pub fn stride(&self) -> (usize, usize) {
self.stride
}
pub fn padding(&self) -> (usize, usize) {
self.padding
}
pub fn dilation(&self) -> (usize, usize) {
self.dilation
}
}
#[derive(Debug, Clone)]
pub struct SparseConv1d {
kernel: CsrTensor,
bias: Option<Tensor>,
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
}
impl SparseConv1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: Option<usize>,
padding: Option<usize>,
dilation: Option<usize>,
sparsity: f32,
use_bias: bool,
) -> TorshResult<Self> {
if !(0.0..=1.0).contains(&sparsity) {
return Err(TorshError::InvalidArgument(
"Sparsity must be between 0.0 and 1.0".to_string(),
));
}
if kernel_size == 0 {
return Err(TorshError::InvalidArgument(
"Kernel size must be greater than 0".to_string(),
));
}
let stride = stride.unwrap_or(1);
let padding = padding.unwrap_or(0);
let dilation = dilation.unwrap_or(1);
if stride == 0 {
return Err(TorshError::InvalidArgument(
"Stride must be greater than 0".to_string(),
));
}
if dilation == 0 {
return Err(TorshError::InvalidArgument(
"Dilation must be greater than 0".to_string(),
));
}
let kernel =
Self::generate_sparse_kernel_1d(out_channels, in_channels, kernel_size, sparsity)?;
let bias = if use_bias {
Some(zeros::<f32>(&[out_channels])?)
} else {
None
};
Ok(Self {
kernel,
bias,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
})
}
pub fn forward(&self, input: &Tensor) -> TorshResult<Tensor> {
let input_shape = input.shape();
if input_shape.ndim() != 3 {
return Err(TorshError::InvalidArgument(
"Input must be 3D tensor (batch_size, in_channels, length)".to_string(),
));
}
let batch_size = input_shape.dims()[0];
let input_channels = input_shape.dims()[1];
let input_length = input_shape.dims()[2];
if input_channels != self.in_channels {
return Err(TorshError::InvalidArgument(format!(
"Input channels {} don't match layer input channels {}",
input_channels, self.in_channels
)));
}
let output_length =
(input_length + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1)
/ self.stride
+ 1;
let mut output = zeros::<f32>(&[batch_size, self.out_channels, output_length])?;
for b in 0..batch_size {
self.conv1d_single(input, &mut output, b, input_length, output_length)?;
}
if let Some(ref bias) = self.bias {
for b in 0..batch_size {
for c in 0..self.out_channels {
let bias_val = bias.get(&[c])?;
for l in 0..output_length {
let current = output.get(&[b, c, l])?;
output.set(&[b, c, l], current + bias_val)?;
}
}
}
}
Ok(output)
}
fn conv1d_single(
&self,
input: &Tensor,
output: &mut Tensor,
batch_idx: usize,
input_length: usize,
output_length: usize,
) -> TorshResult<()> {
let kernel_coo = self.kernel.to_coo()?;
let kernel_triplets = kernel_coo.triplets();
for (kernel_flat_idx, _, kernel_val) in kernel_triplets {
let (out_c, in_c, k_pos) = self.flat_to_3d_kernel(kernel_flat_idx);
for out_pos in 0..output_length {
let in_pos = out_pos * self.stride + k_pos * self.dilation;
if in_pos >= self.padding {
let padded_in_pos = in_pos - self.padding;
if padded_in_pos < input_length {
let input_val = input.get(&[batch_idx, in_c, padded_in_pos])?;
let current_out = output.get(&[batch_idx, out_c, out_pos])?;
output.set(
&[batch_idx, out_c, out_pos],
current_out + kernel_val * input_val,
)?;
}
}
}
}
Ok(())
}
fn flat_to_3d_kernel(&self, flat_idx: usize) -> (usize, usize, usize) {
let channel_size = self.in_channels * self.kernel_size;
let out_c = flat_idx / channel_size;
let remaining = flat_idx % channel_size;
let in_c = remaining / self.kernel_size;
let k_pos = remaining % self.kernel_size;
(out_c, in_c, k_pos)
}
fn generate_sparse_kernel_1d(
out_channels: usize,
in_channels: usize,
kernel_size: usize,
sparsity: f32,
) -> TorshResult<CsrTensor> {
let total_elements = out_channels * in_channels * kernel_size;
let nnz = ((total_elements as f32) * (1.0 - sparsity)) as usize;
let mut row_indices = Vec::with_capacity(nnz);
let mut col_indices = Vec::with_capacity(nnz);
let mut values = Vec::with_capacity(nnz);
let mut positions = std::collections::HashSet::new();
while positions.len() < nnz {
let mut rng = scirs2_core::random::thread_rng();
let flat_idx = rng.gen_range(0..total_elements);
positions.insert(flat_idx);
}
let channel_size = in_channels * kernel_size;
for flat_idx in positions {
let out_c = flat_idx / channel_size;
let col_idx = flat_idx % channel_size;
row_indices.push(out_c);
col_indices.push(col_idx);
let fan_in = in_channels * kernel_size;
let std_dev = (2.0 / fan_in as f32).sqrt();
let mut rng = scirs2_core::random::thread_rng();
values.push((rng.random::<f32>() * 2.0 - 1.0) * std_dev);
}
let shape = Shape::new(vec![out_channels, channel_size]);
let coo = CooTensor::new(row_indices, col_indices, values, shape)?;
CsrTensor::from_coo(&coo)
}
pub fn num_parameters(&self) -> usize {
let kernel_params = self.kernel.nnz();
let bias_params = self.bias.as_ref().map_or(0, |b| b.shape().numel());
kernel_params + bias_params
}
pub fn kernel_sparsity(&self) -> f32 {
let total_elements = self.out_channels * self.in_channels * self.kernel_size;
let nnz = self.kernel.nnz();
1.0 - (nnz as f32 / total_elements as f32)
}
pub fn in_channels(&self) -> usize {
self.in_channels
}
pub fn out_channels(&self) -> usize {
self.out_channels
}
pub fn kernel_size(&self) -> usize {
self.kernel_size
}
pub fn stride(&self) -> usize {
self.stride
}
pub fn padding(&self) -> usize {
self.padding
}
pub fn dilation(&self) -> usize {
self.dilation
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::ones;
#[test]
fn test_sparse_conv2d_creation() {
let conv = SparseConv2d::new(3, 16, (3, 3), Some((1, 1)), Some((1, 1)), None, 0.5, true)
.expect("operation should succeed");
assert_eq!(conv.in_channels(), 3);
assert_eq!(conv.out_channels(), 16);
assert_eq!(conv.kernel_size(), (3, 3));
assert!(conv.num_parameters() > 0);
}
#[test]
fn test_sparse_conv2d_forward() {
let conv = SparseConv2d::new(2, 4, (3, 3), None, Some((1, 1)), None, 0.3, false)
.expect("operation should succeed");
let input = ones::<f32>(&[1, 2, 5, 5]).expect("operation should succeed");
let output = conv.forward(&input).expect("forward pass should succeed");
assert_eq!(output.shape().dims(), &[1, 4, 5, 5]); }
#[test]
fn test_sparse_conv1d_creation() {
let conv = SparseConv1d::new(8, 16, 5, None, None, None, 0.7, true)
.expect("Sparse Conv1d should succeed");
assert_eq!(conv.in_channels(), 8);
assert_eq!(conv.out_channels(), 16);
assert_eq!(conv.kernel_size(), 5);
assert!(conv.num_parameters() > 0);
}
#[test]
fn test_sparse_conv1d_forward() {
let conv = SparseConv1d::new(4, 8, 3, None, Some(1), None, 0.4, false)
.expect("operation should succeed");
let input = ones::<f32>(&[2, 4, 10]).expect("operation should succeed");
let output = conv.forward(&input).expect("forward pass should succeed");
assert_eq!(output.shape().dims(), &[2, 8, 10]); }
#[test]
fn test_output_size_calculation() {
let conv = SparseConv2d::new(1, 1, (3, 3), Some((2, 2)), None, None, 0.0, false)
.expect("operation should succeed");
let input = ones::<f32>(&[1, 1, 8, 8]).expect("operation should succeed");
let output = conv.forward(&input).expect("forward pass should succeed");
assert_eq!(output.shape().dims(), &[1, 1, 3, 3]);
let conv1d = SparseConv1d::new(1, 1, 3, Some(2), None, None, 0.0, false)
.expect("operation should succeed");
let input1d = ones::<f32>(&[1, 1, 10]).expect("operation should succeed");
let output1d = conv1d
.forward(&input1d)
.expect("forward pass should succeed");
assert_eq!(output1d.shape().dims(), &[1, 1, 4]); }
#[test]
fn test_invalid_parameters() {
assert!(SparseConv2d::new(1, 1, (0, 3), None, None, None, 0.5, false).is_err());
assert!(SparseConv2d::new(1, 1, (3, 3), Some((0, 1)), None, None, 0.5, false).is_err());
assert!(SparseConv2d::new(1, 1, (3, 3), None, None, None, 1.5, false).is_err());
assert!(SparseConv1d::new(1, 1, 0, None, None, None, 0.5, false).is_err());
assert!(SparseConv1d::new(1, 1, 3, Some(0), None, None, 0.5, false).is_err());
assert!(SparseConv1d::new(1, 1, 3, None, None, None, -0.1, false).is_err());
}
#[test]
fn test_dimension_validation() {
let conv = SparseConv2d::new(3, 16, (3, 3), None, None, None, 0.5, false)
.expect("operation should succeed");
let wrong_input = ones::<f32>(&[1, 2, 5, 5]).expect("operation should succeed"); assert!(conv.forward(&wrong_input).is_err());
let conv1d = SparseConv1d::new(4, 8, 3, None, None, None, 0.5, false)
.expect("Sparse Conv1d should succeed");
let wrong_input1d = ones::<f32>(&[1, 3, 10]).expect("operation should succeed"); assert!(conv1d.forward(&wrong_input1d).is_err());
}
#[test]
fn test_sparsity_measurement() {
let conv = SparseConv2d::new(2, 4, (3, 3), None, None, None, 0.8, false)
.expect("operation should succeed");
let sparsity = conv.kernel_sparsity();
assert!(sparsity >= 0.7 && sparsity <= 0.9);
let conv1d = SparseConv1d::new(2, 4, 5, None, None, None, 0.6, false)
.expect("Sparse Conv1d should succeed");
let sparsity1d = conv1d.kernel_sparsity();
assert!(sparsity1d >= 0.5 && sparsity1d <= 0.7); }
#[test]
fn test_bias_addition() {
let conv = SparseConv2d::new(1, 2, (1, 1), None, None, None, 0.0, true)
.expect("operation should succeed");
let input = ones::<f32>(&[1, 1, 3, 3]).expect("operation should succeed");
let output = conv.forward(&input).expect("forward pass should succeed");
assert_eq!(output.shape().dims(), &[1, 2, 3, 3]);
let conv1d = SparseConv1d::new(1, 2, 1, None, None, None, 0.0, true)
.expect("Sparse Conv1d should succeed");
let input1d = ones::<f32>(&[1, 1, 5]).expect("operation should succeed");
let output1d = conv1d
.forward(&input1d)
.expect("forward pass should succeed");
assert_eq!(output1d.shape().dims(), &[1, 2, 5]);
}
}