tensor-rs 0.5.9

A typeless tensor library
Documentation
use std::collections::BTreeMap;
use crate::tensor::PaddingMode;
use super::GenTensor;
use crate::tensor_trait::convolution::Convolution;

impl<T> Convolution for GenTensor<T> where T: num_traits::Float {

    // conv2d ops
    fn conv2d(&self, filter: &GenTensor<T>,
                  stride: (usize, usize),
                  padding: (usize, usize),
                  dilation: (usize, usize),
                  padding_mode: PaddingMode
    ) -> Self {
        self.conv_gen(filter,
                      &[stride.0, stride.1],
                      &[padding.0, padding.1],
                      &[dilation.0, dilation.1],
                      padding_mode)
    }
    fn conv2d_grad(&self, filter: &GenTensor<T>,
                       stride: (usize, usize),
                       padding: (usize, usize),
                       dilation: (usize, usize),
                       padding_mode: PaddingMode,
                       output_grad: &GenTensor<T>
    ) -> (Self, Self){
            self.conv_grad_gen(filter,
                           &[stride.0, stride.1],
                           &[padding.0, padding.1],
                           &[dilation.0, dilation.1],
                           padding_mode,
                           output_grad)
    }

    // gneral convolutional operator, should work for 2d and 3d cases.
    fn conv_gen(&self, filter: &GenTensor<T>,
                    stride: &[usize],
                    padding: &[usize],
                    dilation: &[usize],
                    padding_mode: PaddingMode
    ) -> GenTensor<T> {
        let self_dim = self.size();
        let filter_dim = filter.size();
        if self_dim.len() != filter_dim.len() {
            panic!("covn2d expects input and filter has the same dims, get {:?}, {:?}", self_dim, filter_dim);
        }
        if stride.len() != padding.len() || stride.len() != dilation.len() || stride.len() != (self_dim.len() - 2) {
            panic!("stride, padding, stride should have the same # of dims, {:?}, {:?}, {:?}", stride, padding, dilation);
        }
        if stride.iter().any(|x| *x < 1) {
            panic!("stride should be at least 1, get {:?}", stride);
        }
        if dilation.iter().any(|x| *x < 1) {
            panic!("dilation should be at least 1, get {:?}", dilation);
        }

        let out_channels = filter_dim[0];
        let in_channels = filter_dim[1];
        let sample_size = self_dim[0];
        let data_channels = self_dim[1];
        if in_channels != data_channels {
            panic!("covn2d expects input data channel size matches depth in filter {:?}, {:?}", self_dim, filter_dim);
        }
        
        // prepare the padded input
        let mut padded_dim = Vec::new();
        for i in 2..self_dim.len() {
            padded_dim.push(self_dim[i] + padding[i-2]*2);
        }
        //println!("padded_dim: {:?}", padded_dim);

        // find the coordinate of
        // start center point in a filter in padded dimension
        // in case filter_dim[i] is even, start_point will be the half.
        // in case filter_dim[i] is odd, start_point will be the center.
        let mut start_point = Vec::new();
        for i in 0..stride.len() {
            let half = filter_dim[2+i]/2;
            let dilated = half*dilation[i];
            start_point.push(dilated);
        }
        //println!("start_point: {:?}", start_point);

        let mut output_size = Vec::new();
        //println!("{:?}, {:?}", padded_dim, stride);
        for i in 0..stride.len() {
            let output_dim = (padded_dim[i] - dilation[i]*(filter_dim[2+i]-1)-1)/stride[i] + 1;
            output_size.push(output_dim);
        }
        let mut output_tensor_size = vec![sample_size, out_channels];
        output_tensor_size.append(&mut output_size.clone()); // output_size moved.
        let output_inner_size = output_size.iter().product::<usize>();
        //println!("output_size: {:?}", output_size);
        //println!("{:?}", output_inner_size);
        //println!("{:?}", output_tensor_size);
        
        let mut ret = GenTensor::<T>::zeros(&output_tensor_size);

        let conv_size = filter_dim.iter().product::<usize>()/out_channels; // this is Cin xd1xd2xd3...
        let mut data_block = vec![T::zero(); conv_size];
        let mut filter_block = vec![T::zero(); conv_size];

        let inner_steps = output_inner_size*out_channels;
        let filter_step = conv_size;
        
        for i in 0..sample_size {
            for j in 0..out_channels {
                filter_block.copy_from_slice(&filter.get_data()[(j)*filter_step..(j+1)*filter_step]);

                let mut left_upper = vec![0; stride.len()];
                for k in 0..output_inner_size {
                    //println!("left_upper: {:?}", left_upper);

                    // get_data_block
                    let mut current_data_elem = left_upper.to_vec();
                    for in_channel_index in 0..in_channels {
                        for inner_index in 0..conv_size/in_channels {

                            // assign single scale to the tmp tensor.
                            let mut push_value = T::zero();
                            let mut in_margin = false;
                            for i in 0..current_data_elem.len() {
                                if current_data_elem[i] < padding[i] || current_data_elem[i] >= (padding[i] + self_dim[i+2]){
                                    match padding_mode {
                                        PaddingMode::Zeros => {
                                            push_value = T::zero();
                                            in_margin = true;
                                            break;
                                        },
                                        _ => {unimplemented!();}
                                    }
                                }
                            }
                            if ! in_margin {
                                let real_data_elem = current_data_elem.iter().zip(padding.iter()).map(|(x, y)| x - y).collect::<Vec::<usize>>();
                                let mut real_data_elem2 = vec![i, in_channel_index];
                                real_data_elem2.append(&mut real_data_elem.clone());
                                push_value = self.get(&real_data_elem2);
                            }

                            data_block[in_channel_index*(conv_size/in_channels) + inner_index] = push_value;


                            // update to the next position.
                            let mut current_pos = current_data_elem.len()-1;
                            loop {
                                current_data_elem[current_pos] += dilation[current_pos];
                                if current_data_elem[current_pos] >= dilation[current_pos]*filter_dim[current_pos+2] + left_upper[current_pos] {
                                    current_data_elem[current_pos] = left_upper[current_pos];
                                    if current_pos > 0 {
                                        current_pos -= 1;
                                    } else {
                                        break;
                                    }
                                } else {
                                    break;
                                }
                            };
                        }
                    };
                
                    //let value = data_block.iter().zip(&filter_block).map(|(x, y)|
                    //                                                     (*x)*(*y)
                    //).sum::<T>();
                    let mut value = T::zero();
                    for (x, y) in data_block.iter().zip(&filter_block) {
                        value = value + (*x)*(*y);
                    }
                    //println!("index: {}, {}, {}", i, j, k);
                    //println!("raw index: {}", i*inner_steps + j*output_inner_size + k);
                    //ret.d[i*inner_steps + j*output_inner_size + k] = value;
                    ret.set_1d(i*inner_steps + j*output_inner_size + k, value);

                    // update for next prodsum position
                    let mut current_pos = left_upper.len()-1;
                    loop {
                        left_upper[current_pos] += stride[current_pos];
                        let mut compare_pos = padded_dim[current_pos] - start_point[current_pos]*2;
                        if filter_dim[current_pos+2] % 2 == 0 {
                            compare_pos += 1;
                        }
                        if left_upper[current_pos] >= compare_pos {
                            left_upper[current_pos] = 0;
                            if current_pos > 0 {
                                current_pos -= 1;
                            } else {
                                break;
                            }
                        } else {
                            break;
                        }
                    };

                }
            }
        }
        
        ret
    }

    // the 1st return is the gradient for w
    // the 2nd return is the gradient for the input, given the output_grad
    fn conv_grad_gen(&self, filter: &GenTensor<T>,
                         stride: &[usize],
                         padding: &[usize],
                         dilation: &[usize],
                         padding_mode: PaddingMode,
                         output_grad: &GenTensor<T>,
    ) -> (GenTensor<T>, GenTensor<T>) {
        let self_dim = self.size();
        let filter_dim = filter.size();
        let output_grad_dim = output_grad.size();
        if self_dim.len() <= 2 {
            panic!("input data for conv has not enough dim {:?}.", self_dim);
        }
        if filter_dim.len() <= 2 {
            panic!("filter for conv has not enough dim {:?}.", filter_dim);
        }
        if output_grad_dim.len() <= 2 {
            panic!("output gradient for conv has not enough dim {:?}.", filter_dim);
        }
        if self_dim.len() != filter_dim.len() || self_dim.len() != output_grad_dim.len() {
            panic!("covn2d expects input, output gradient and filter has the same dims, get {:?}, {:?}, {:?}", self_dim, filter_dim, output_grad_dim);
        }
        if filter_dim[1] != self_dim[1] {
            panic!("covn2d expects input data channel size matches depth in filter {:?}, {:?}", self_dim, filter_dim);
        }
        if self_dim[0] != output_grad_dim[0] {
            panic!("conv2d expects input and output has the same N: {:?}, {:?}", self_dim, output_grad_dim);
        }
        if filter_dim[0] != output_grad_dim[1] {
            panic!("conv2d expects filter and output has the same Cout: {:?}, {:?}", filter_dim, output_grad_dim);
        }
        if stride.len() != padding.len() || stride.len() != dilation.len() {
            panic!("stride, padding, stride should have the same # of dims, {:?}, {:?}, {:?}", stride, padding, dilation);
        }
        if stride.len()+2 != filter_dim.len() {
            panic!("expect the same inner size, {:?}, {:?}", stride, filter_dim);
        }
        
        let filter_size = filter.size();
        let n_c_out = filter_size[0];
        let n_c_in = filter_size[1];
        let n_n = self_dim[0];
        //let n_d_dd = self_dim.iter().product::<usize>()/n_n/n_c_in;
        let n_f_dd = filter_dim.iter().product::<usize>()/n_c_out/n_c_in;
        let d_inner = self_dim.len() - 2;

        let output_dd = output_grad_dim.iter().product::<usize>()/n_n/n_c_out;

        // save all the record
        let mut w_grad: BTreeMap<usize, Vec<T>> = BTreeMap::new();
        let mut x_grad: BTreeMap<usize, Vec<T>> = BTreeMap::new();

        for i in 0..n_n {
            for j in 0..n_c_out {
                // left_upper in padded dimension.
                let mut left_upper = vec![0; d_inner];

                let mut output_index = 0;
                
                loop {
                    //println!("left_upper: {:?}", left_upper);

                    // get the current output_gradient
                    let output_real_index = j*output_dd + i*n_c_out*output_dd + output_index;
                    //println!("output_real_index: {:?}", output_real_index);
                    let output_dimpos = output_grad.index2dimpos(output_real_index);
                    //println!("output_dimpos: {:?}", output_dimpos);
                    let output_gradient_value = output_grad.get(&output_dimpos);
                    //println!("output_gradient_value: {:?}", output_gradient_value.to_f32());

                    // remember where to get data.
                    // let mut data_loc = BTreeMap::<Vec::<usize>, >::new();

                    for cin_index in 0..n_c_in {
                        for dd_index in 0..n_f_dd {

                            // get current position for filter elements.
                            let mut filter_elem = Vec::new();
                            let mut reminder = dd_index;
                            for dim_pos in 0..d_inner {
                                let left_product = filter_size[dim_pos+3..filter_size.len()]
                                    .iter()
                                    .product::<usize>();
                                filter_elem.push(reminder / left_product);
                                reminder %= left_product;
                            }
                            //println!("filter_elem: {:?}", filter_elem);

                            
                            // get current position for data elements in padded dimension
                            let mut data_elem = left_upper.to_vec();
                            for dim_pos in 0..d_inner {
                                data_elem[dim_pos] += filter_elem[dim_pos]*dilation[dim_pos];
                            }
                            //println!("data_elem: {:?}", data_elem);


                            // find real current position from filter
                            let mut full_filter_elem = vec![j, cin_index];
                            full_filter_elem.append(&mut filter_elem.clone());
                            // println!("filter_value: {}", filter_value.to_f32().expect(""));
                            // println!("full_filter_elem: {:?}", full_filter_elem);

                            // find real current position from data
                            let mut zero_padded_flag = false;
                            let mut unpadded_elem = data_elem.clone();
                            //println!("data_elem: {:?}", data_elem);
                            for dim_pos in 0..d_inner {
                                if data_elem[dim_pos] < padding[dim_pos] {
                                    match padding_mode {
                                        PaddingMode::Zeros => {
                                            zero_padded_flag = true;
                                        },
                                        PaddingMode::Reflect => {
                                            unpadded_elem[dim_pos] = padding[dim_pos] - data_elem[dim_pos] - 1;
                                        },
                                        PaddingMode::Replicate => {
                                            unpadded_elem[dim_pos] = 0;
                                        },
                                        PaddingMode::Circular => {
                                            unpadded_elem[dim_pos] = self_dim[dim_pos+2] - (padding[dim_pos] - data_elem[dim_pos]);
                                        },
                                    }
                                } else if data_elem[dim_pos] >= self_dim[dim_pos + 2] + padding[dim_pos] {
                                    match padding_mode {
                                        PaddingMode::Zeros => {
                                            zero_padded_flag = true;
                                        },
                                        PaddingMode::Reflect => {
                                            unpadded_elem[dim_pos] = self_dim[dim_pos+2] - (data_elem[dim_pos] - (self_dim[dim_pos + 2] + padding[dim_pos]) + 1);
                                        },
                                        PaddingMode::Replicate => {
                                            unpadded_elem[dim_pos] = self_dim[dim_pos + 2]-1;
                                        },
                                        PaddingMode::Circular => {
                                            unpadded_elem[dim_pos] = data_elem[dim_pos] - (self_dim[dim_pos + 2] + padding[dim_pos]);
                                        },
                                    }
                                } else {
                                    unpadded_elem[dim_pos] -= padding[dim_pos];
                                }
                            }

                            if zero_padded_flag {
                                continue;
                            } else {
                                //println!("unpadded_elem: {:?}", unpadded_elem);
                                let mut full_data_elem = vec![i, cin_index];
                                full_data_elem.append(&mut unpadded_elem.clone());
                                //println!("full_data_elem: {:?}", full_data_elem);
                                
                                let filter_value = filter.get(&full_filter_elem);
                                let data_value = self.get(&full_data_elem);
                                
                                // collect all the data.
                                let w_grad_value = output_gradient_value*data_value;
                                let x_grad_value = output_gradient_value*filter_value;
                                
                                let total_w_index = filter.dimpos2index(&full_filter_elem);
                                let total_x_index = self.dimpos2index(&full_data_elem);
                                
                                //println!("full_data_elem: {:?}, total_x_index: {:?}, data_value: {:?}",
                                //         full_data_elem,
                                //         total_x_index,
                                //         data_value.to_f32());
                                //println!("full_filter_elem: {:?}, total_w_index: {:?}, filter_value: {:?}, w_grad_value: {:?}, output_gradient_value: {:?}, data_vluae: {:?}",
                                //         full_filter_elem,
                                //         total_w_index,
                                //         filter_value.to_f32(),
                                //         w_grad_value.to_f32(),
                                //         output_gradient_value.to_f32(),
                                //         data_value.to_f32());
                                
                                if let std::collections::btree_map::Entry::Vacant(e) = w_grad.entry(total_w_index) {
                                    e.insert(vec![w_grad_value]);
                                } else {
                                    w_grad.get_mut(&total_w_index).expect("").push(w_grad_value);
                                }
                                
                                if let std::collections::btree_map::Entry::Vacant(e) = x_grad.entry(total_x_index) {
                                     e.insert(vec![x_grad_value]);
                                 } else {
                                     x_grad.get_mut(&total_x_index).expect("").push(x_grad_value);
                                 }    
                            }
                            
                        }
                    }

                    // update left_upper to the next position.
                    for current_pos in 0..d_inner {
                        let real_pos = d_inner - current_pos - 1;
                        left_upper[real_pos] += stride[real_pos];
                        
                        let compare_pos = self_dim[real_pos+2]
                            + padding[real_pos]*2
                            - ((filter_dim[real_pos + 2]-1)*dilation[real_pos] + 1);
                        
                        if left_upper[real_pos] > compare_pos {
                            left_upper[real_pos] = 0;
                        } else {
                            break;
                        }
                    }
                    if left_upper.iter().sum::<usize>() == 0 {
                        break;
                    }
                    output_index += 1;
                };
            }
        }

        let mut ret_w_grad = GenTensor::zeros(filter.size());
        let mut ret_x_grad = GenTensor::zeros(self.size());

        for i in w_grad.keys() {
            //println!("i: {:?}", i);
            let mut sum = T::zero();
            for w_value in w_grad.get(i).expect("") {
                sum = sum + *w_value;
                //println!("w_value: {}", w_value.to_f32().expect("") );
            }
            //ret_w_grad.d[*i] = sum/T::from(w_grad.get(i).expect("").len()).expect("");
            //ret_w_grad.d[*i] = sum;
            ret_w_grad.set_1d(*i, sum);
        }
        for i in x_grad.keys() {
            //println!("i: {:?}", i);
            let mut sum = T::zero();
            for x_value in x_grad.get(i).expect("") {
                sum = sum + *x_value;
                //println!("x_value: {}", x_value.to_f32().expect("") );
            }
            //ret_x_grad.d[*i] = sum/T::from(x_grad.get(i).expect("").len()).expect("");
            //ret_x_grad.d[*i] = sum;
            ret_x_grad.set_1d(*i, sum);
        }
        
        (ret_w_grad, ret_x_grad)
    }
}

#[cfg(test)]
mod tests {
    use crate::tensor_impl::gen_tensor::GenTensor;
    use crate::tensor_trait::index_slicing::IndexSlicing;
    use super::*;

    #[test]
    fn conv_gen() {

        {
            let data = GenTensor::<f32>::arange(30).reshape(&vec![2, 3, 5]);
            let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 3, 3]);
            let stride = vec![1];
            let padding = vec![0];
            let dilation = vec![1];
            let padding_mode = PaddingMode::Zeros;
            let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
            println!("output size: {:?}", result.size());
            println!("output size: {:?}", result.get_data());
            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![312.0, 348.0, 384.0, 798.0, 915.0, 1032.0, 852.0, 888.0, 924.0, 2553.0, 2670.0, 2787.0], &vec![2, 2, 3]));
        }

        {
            let mut raw_data = Vec::new();
            for i in 0..75 {
                raw_data.push(i as f32);
            }
            let data = GenTensor::<f32>::new_raw(&raw_data, &vec![1, 3, 5, 5]);
            let mut raw_data = Vec::new();
            for i in 0..54 {
                raw_data.push(i as f32);
            }
            let filter = GenTensor::<f32>::new_raw(&raw_data, &vec![2, 3, 3, 3]);
            
            let stride = vec![1, 1];
            let padding = vec![0, 0];
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            
            let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
            
            println!("output size: {:?}", result.size());
            println!("output size: {:?}", result.get_data());
            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![15219.0, 15570.0, 15921.0, 16974.0, 17325.0, 17676.0, 18729.0, 19080.0, 19431.0, 37818.0, 38898.0, 39978.0, 43218.0, 44298.0, 45378.0, 48618.0, 49698.0, 50778.0], &vec![1, 2, 3, 3]));    
        }
        
        {
            let mut raw_data = Vec::new();
            for i in 0..60 {
                raw_data.push(i as f32);
            }
            let data = GenTensor::<f32>::new_raw(&raw_data, &vec![1, 3, 5, 4]);
            let mut raw_data = Vec::new();
            for i in 0..36 {
                raw_data.push(i as f32);
            }
            let filter = GenTensor::<f32>::new_raw(&raw_data, &vec![2, 3, 3, 2]);
            
            let stride = vec![1, 1];
            let padding = vec![0, 0];
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            
            let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
            
            println!("output size: {:?}", result.size());
            println!("output size: {:?}", result.get_data());
            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![5289.0, 5442.0, 5595.0, 5901.0, 6054.0, 6207.0, 6513.0, 6666.0, 6819.0, 13227.0, 13704.0, 14181.0, 15135.0, 15612.0, 16089.0, 17043.0, 17520.0, 17997.0], &vec![1, 2, 3, 3]));    
        }

        {
            let data = GenTensor::<f32>::arange(375).reshape(&vec![1, 3, 5, 5, 5]);
            let filter = GenTensor::<f32>::arange(162).reshape(&vec![2, 3, 3, 3, 3]);
            let stride = vec![1, 1, 1];
            let padding = vec![0, 0, 0];
            let dilation = vec![1, 1, 1];
            let padding_mode = PaddingMode::Zeros;
            let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
            println!("output size: {:?}", result.size());
            println!("output size: {:?}", result.get_data());
            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![700704.0, 703944.0, 707184.0, 716904.0, 720144.0, 723384.0, 733104.0, 736344.0, 739584.0, 781704.0, 784944.0, 788184.0, 797904.0, 801144.0, 804384.0, 814104.0, 817344.0, 820584.0, 862704.0, 865944.0, 869184.0, 878904.0, 882144.0, 885384.0, 895104.0, 898344.0, 901584.0, 1724220.0, 1734021.0, 1743822.0, 1773225.0, 1783026.0, 1792827.0, 1822230.0, 1832031.0, 1841832.0, 1969245.0, 1979046.0, 1988847.0, 2018250.0, 2028051.0, 2037852.0, 2067255.0, 2077056.0, 2086857.0, 2214270.0, 2224071.0, 2233872.0, 2263275.0, 2273076.0, 2282877.0, 2312280.0, 2322081.0, 2331882.0], &vec![1, 2, 3, 3, 3]));
        }

        {
            let data = GenTensor::<f32>::arange(16).reshape(&vec![1, 1, 4, 4]);
            let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
            let stride = vec![1, 1];
            let padding = vec![1, 1];
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
            println!("final output size: {:?}", result.size());
            println!("final output: {:?}", result.get_data());
            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![73.0, 121.0, 154.0, 103.0, 171.0, 258.0, 294.0, 186.0, 279.0, 402.0, 438.0, 270.0, 139.0, 187.0, 202.0, 113.0, 163.0, 283.0, 370.0, 265.0, 414.0, 663.0, 780.0, 537.0, 738.0, 1131.0, 1248.0, 837.0, 517.0, 781.0, 850.0, 563.0], &vec![1, 2, 4, 4]));
        }

        {
            let data = GenTensor::<f32>::arange(49).reshape(&vec![1, 1, 7, 7]);
            let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
            let stride = vec![2, 2];
            let padding = vec![0, 0];
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
            println!("final output size: {:?}", result.size());
            println!("final output: {:?}", result.get_data());
            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![420.0, 492.0, 564.0, 924.0, 996.0, 1068.0, 1428.0, 1500.0, 1572.0, 1068.0, 1302.0, 1536.0, 2706.0, 2940.0, 3174.0, 4344.0, 4578.0, 4812.0], &vec![1, 2, 3, 3]));
        }
    }

    #[test]
    fn conv_grad_gen() {

        {
            let data = GenTensor::<f32>::arange(75).reshape(&vec![1, 3, 5, 5]);
            let filter = GenTensor::<f32>::arange(54).reshape(&vec![2, 3, 3, 3]);
            let output_grad = GenTensor::<f32>::arange(18).reshape(&vec![1, 2, 3, 3]);
            
            let stride = vec![1, 1];
            let padding = vec![0, 0];
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            
            let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
            println!("w_grad: {:?}", w_grad);
            println!("x_grad: {:?}", x_grad);
        
            assert_eq!(w_grad, GenTensor::new_raw(&vec![312.0, 348.0, 384.0, 492.0, 528.0, 564.0, 672.0, 708.0, 744.0, 1212.0, 1248.0, 1284.0, 1392.0, 1428.0, 1464.0, 1572.0, 1608.0, 1644.0, 2112.0, 2148.0, 2184.0, 2292.0, 2328.0, 2364.0, 2472.0, 2508.0, 2544.0, 798.0, 915.0, 1032.0, 1383.0, 1500.0, 1617.0, 1968.0, 2085.0, 2202.0, 3723.0, 3840.0, 3957.0, 4308.0, 4425.0, 4542.0, 4893.0, 5010.0, 5127.0, 6648.0, 6765.0, 6882.0, 7233.0, 7350.0, 7467.0, 7818.0, 7935.0, 8052.0], &vec![2, 3, 3, 3]));
        }

        {
        
            let data = GenTensor::<f32>::arange(60).reshape(&vec![1, 3, 5, 4]);
            let filter = GenTensor::<f32>::arange(36).reshape(&vec![2, 3, 3, 2]);
            let output_grad = GenTensor::<f32>::arange(18).reshape(&vec![1, 2, 3, 3]);
            //println!("output_grad: {:?}", output_grad);
            
            let stride = vec![1, 1];
            let padding = vec![0, 0];
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            
            let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
            println!("{:?}, {:?}, {:?}", w_grad, x_grad, output_grad);
            //println!("w_grad: {:?}", w_grad);
            assert_eq!(w_grad, GenTensor::new_raw(&vec![258.0, 294.0, 402.0, 438.0, 546.0, 582.0, 978.0, 1014.0, 1122.0, 1158.0, 1266.0, 1302.0, 1698.0, 1734.0, 1842.0, 1878.0, 1986.0, 2022.0, 663.0, 780.0, 1131.0, 1248.0, 1599.0, 1716.0, 3003.0, 3120.0, 3471.0, 3588.0, 3939.0, 4056.0, 5343.0, 5460.0, 5811.0, 5928.0, 6279.0, 6396.0], &vec![2, 3, 3, 2]));
        
        }


        {
            let data = GenTensor::<f32>::arange(75).reshape(&vec![1, 3, 5, 5]);
            let filter = GenTensor::<f32>::arange(54).reshape(&vec![2, 3, 3, 3]);
            let output_grad = GenTensor::<f32>::arange(50).reshape(&vec![1, 2, 5, 5]);
            
            let stride = vec![1, 1];
            let padding = vec![1, 1]; // <- THIS IS THE CHANGE
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            
            let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
            println!("w_grad: {:?}", w_grad);
            println!("x_grad: {:?}", x_grad);
        
            assert_eq!(w_grad, GenTensor::new_raw(&vec![2680.0, 3420.0, 2760.0, 3900.0, 4900.0, 3900.0, 2760.0, 3420.0, 2680.0, 8680.0, 10670.0, 8360.0, 10150.0, 12400.0, 9650.0, 6760.0, 8170.0, 6280.0, 14680.0, 17920.0, 13960.0, 16400.0, 19900.0, 15400.0, 10760.0, 12920.0, 9880.0, 6280.0, 8170.0, 6760.0, 9650.0, 12400.0, 10150.0, 8360.0, 10670.0, 8680.0, 22280.0, 27920.0, 22360.0, 28400.0, 35525.0, 28400.0, 22360.0, 27920.0, 22280.0, 38280.0, 47670.0, 37960.0, 47150.0, 58650.0, 46650.0, 36360.0, 45170.0, 35880.0], &vec![2, 3, 3, 3]));
        }

        {
            let data = GenTensor::<f32>::arange(75).reshape(&vec![1, 3, 5, 5]);
            let filter = GenTensor::<f32>::arange(150).reshape(&vec![2, 3, 5, 5]);
            let output_grad = GenTensor::<f32>::arange(50).reshape(&vec![1, 2, 5, 5]);
            
            let stride = vec![1, 1];
            let padding = vec![2, 2]; // <- THIS IS THE CHANGE
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            
            let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
            println!("w_grad: {:?}", w_grad);
            println!("x_grad: {:?}", x_grad);
        
            assert_eq!(w_grad, GenTensor::new_raw(&vec![1128.0, 1580.0, 2065.0, 1700.0, 1308.0, 1964.0, 2680.0, 3420.0, 2760.0, 2084.0, 2905.0, 3900.0, 4900.0, 3900.0, 2905.0, 2084.0, 2760.0, 3420.0, 2680.0, 1964.0, 1308.0, 1700.0, 2065.0, 1580.0, 1128.0, 5178.0, 6830.0, 8440.0, 6650.0, 4908.0, 6614.0, 8680.0, 10670.0, 8360.0, 6134.0, 7780.0, 10150.0, 12400.0, 9650.0, 7030.0, 5234.0, 6760.0, 8170.0, 6280.0, 4514.0, 3108.0, 3950.0, 4690.0, 3530.0, 2478.0, 9228.0, 12080.0, 14815.0, 11600.0, 8508.0, 11264.0, 14680.0, 17920.0, 13960.0, 10184.0, 12655.0, 16400.0, 19900.0, 15400.0, 11155.0, 8384.0, 10760.0, 12920.0, 9880.0, 7064.0, 4908.0, 6200.0, 7315.0, 5480.0, 3828.0, 2478.0, 3530.0, 4690.0, 3950.0, 3108.0, 4514.0, 6280.0, 8170.0, 6760.0, 5234.0, 7030.0, 9650.0, 12400.0, 10150.0, 7780.0, 6134.0, 8360.0, 10670.0, 8680.0, 6614.0, 4908.0, 6650.0, 8440.0, 6830.0, 5178.0, 12153.0, 16280.0, 20440.0, 16400.0, 12333.0, 16664.0, 22280.0, 27920.0, 22360.0, 16784.0, 21280.0, 28400.0, 35525.0, 28400.0, 21280.0, 16784.0, 22360.0, 27920.0, 22280.0, 16664.0, 12333.0, 16400.0, 20440.0, 16280.0, 12153.0, 21828.0, 29030.0, 36190.0, 28850.0, 21558.0, 28814.0, 38280.0, 47670.0, 37960.0, 28334.0, 35530.0, 47150.0, 58650.0, 46650.0, 34780.0, 27434.0, 36360.0, 45170.0, 35880.0, 26714.0, 19758.0, 26150.0, 32440.0, 25730.0, 19128.0], &vec![2, 3, 5, 5]));
        }

        {
            let data = GenTensor::<f32>::arange(75).reshape(&vec![1, 3, 5, 5]);
            let filter = GenTensor::<f32>::arange(150).reshape(&vec![2, 3, 5, 5]);
            let output_grad = GenTensor::<f32>::arange(18).reshape(&vec![1, 2, 3, 3]);
            
            let stride = vec![2, 2]; // <- THIS IS THE CHANGE
            let padding = vec![2, 2]; 
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            
            let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
            println!("w_grad: {:?}", w_grad);
            println!("x_grad: {:?}", x_grad);
        
            assert_eq!(w_grad, GenTensor::new_raw(&vec![176.0, 200.0, 284.0, 172.0, 192.0, 296.0, 320.0, 449.0, 272.0, 292.0, 420.0, 447.0, 624.0, 375.0, 396.0, 164.0, 176.0, 233.0, 128.0, 136.0, 224.0, 236.0, 308.0, 168.0, 176.0, 776.0, 800.0, 1109.0, 672.0, 692.0, 896.0, 920.0, 1274.0, 772.0, 792.0, 1095.0, 1122.0, 1524.0, 900.0, 921.0, 464.0, 476.0, 608.0, 328.0, 336.0, 524.0, 536.0, 683.0, 368.0, 376.0, 1376.0, 1400.0, 1934.0, 1172.0, 1192.0, 1496.0, 1520.0, 2099.0, 1272.0, 1292.0, 1770.0, 1797.0, 2424.0, 1425.0, 1446.0, 764.0, 776.0, 983.0, 528.0, 536.0, 824.0, 836.0, 1058.0, 568.0, 576.0, 392.0, 452.0, 662.0, 424.0, 480.0, 692.0, 752.0, 1097.0, 704.0, 760.0, 1014.0, 1095.0, 1596.0, 1023.0, 1098.0, 560.0, 608.0, 881.0, 560.0, 604.0, 800.0, 848.0, 1226.0, 780.0, 824.0, 1892.0, 1952.0, 2837.0, 1824.0, 1880.0, 2192.0, 2252.0, 3272.0, 2104.0, 2160.0, 3039.0, 3120.0, 4521.0, 2898.0, 2973.0, 1760.0, 1808.0, 2606.0, 1660.0, 1704.0, 2000.0, 2048.0, 2951.0, 1880.0, 1924.0, 3392.0, 3452.0, 5012.0, 3224.0, 3280.0, 3692.0, 3752.0, 5447.0, 3504.0, 3560.0, 5064.0, 5145.0, 7446.0, 4773.0, 4848.0, 2960.0, 3008.0, 4331.0, 2760.0, 2804.0, 3200.0, 3248.0, 4676.0, 2980.0, 3024.0], &vec![2, 3, 5, 5]));
        }
    }
    
}