tensor-rs 0.5.9

A typeless tensor library
Documentation
use crate::tensor_impl::gen_tensor::GenTensor;
use crate::tensor_trait::index_slicing::IndexSlicing;
use crate::tensor::PaddingMode;
#[cfg(feature = "use-blas-lapack")]
use super::blas_api::BlasAPI;

#[cfg(feature = "use-blas-lapack")]
macro_rules! blas_conv {
    ($a:ty, $b: ident) => {
        pub fn $b(
            data: &GenTensor<$a>,
            filter: &GenTensor<$a>,
            stride: &[usize],
            padding: &[usize],
            dilation: &[usize],
            padding_mode: PaddingMode
        ) -> GenTensor<$a> {
            let self_dim = data.size();
            let filter_dim = filter.size();
        
            let out_channels = filter_dim[0];
            let in_channels = filter_dim[1];
            let sample_size = self_dim[0];
        
            // prepare the padded input
            let mut padded_dim = Vec::new();
            for i in 2..self_dim.len() {
                padded_dim.push(self_dim[i] + padding[i-2]*2);
            }
            //println!("padded_dim: {:?}", padded_dim);
        
            // find the coordinate of
            // start center point in a filter in padded dimension
            // in case filter_dim[i] is even, start_point will be the half.
            // in case filter_dim[i] is odd, start_point will be the center.
            let mut start_point = Vec::new();
            for i in 0..stride.len() {
                let half = filter_dim[2+i]/2;
                let dilated = half*dilation[i];
                start_point.push(dilated);
            }
            //println!("start_point: {:?}", start_point);
        
            let mut output_size = Vec::new();
            //println!("{:?}, {:?}", padded_dim, stride);
            for i in 0..stride.len() {
                let output_dim = (padded_dim[i] - dilation[i]*(filter_dim[2+i]-1)-1)/stride[i] + 1;
                output_size.push(output_dim);
            }
            let mut output_tensor_size = vec![sample_size, out_channels];
            output_tensor_size.append(&mut output_size.clone()); // output_size moved.
            let output_inner_size = output_size.iter().product::<usize>();
            //println!("output_size: {:?}", output_size);
            //println!("{:?}", output_inner_size);
            //println!("{:?}", output_tensor_size);
                
            let conv_size = filter_dim.iter().product::<usize>()/out_channels; // this is Cin xd1xd2xd3...
	    //let mut data_block = vec![0.; conv_size];
	    //let mut filter_block = vec![0.; conv_size];
        
            //println!("sample_size*output_inner_size*conv_size: {:?}", sample_size*output_inner_size*conv_size);
            let mut columned_data = Vec::<$a>::with_capacity(sample_size*output_inner_size*conv_size);
            //let columned_filter = Vec::<f32>::with_capacity(out_channels*conv_size);
        
            let mut left_upper = vec![0; stride.len()];
            let mut current_data_elem = left_upper.to_vec();
            let mut push_value: $a;
            let mut in_margin: bool;
        
            for i in 0..sample_size {
                left_upper.iter_mut().map(|x| *x = 0).count();
        
                for _k in 0..output_inner_size { // every possible data bl
                    // get_data_block
                    //let mut current_data_elem = left_upper.to_vec();
                    current_data_elem.clone_from_slice(&left_upper);
                    for in_channel_index in 0..in_channels {
                        for _inner_index in 0..conv_size/in_channels {
                    
                            // assign single scale to the tmp tensor.
                            push_value = 0.;
                            in_margin = false;
                            for i in 0..current_data_elem.len() {
                                 if current_data_elem[i] < padding[i]
                                    || current_data_elem[i] >= (padding[i] + self_dim[i+2]) {
                                    match padding_mode {
                                        PaddingMode::Zeros => {
                                            push_value = 0.;
                                            in_margin = true;
                                            break;
                                        },
                                        _ => {unimplemented!();}
                                    }
                                }
                            }
                            if ! in_margin {
                                let real_data_elem = current_data_elem.iter()
                                    .zip(padding.iter())
                                    .map(|(x, y)| x - y)
                                    .collect::<Vec::<usize>>();
                                let mut real_data_elem2 = vec![i, in_channel_index];
                                real_data_elem2.append(&mut real_data_elem.clone());
                                push_value = data.get(&real_data_elem2);
                            }
                    
                            //data_block[in_channel_index*(conv_size/in_channels) + inner_index] = push_value;
                            columned_data.push(push_value);
                    
                            // update to the next position.
                            let mut current_pos = current_data_elem.len()-1;
                            loop {
                                current_data_elem[current_pos] += dilation[current_pos];
                                if current_data_elem[current_pos] >= dilation[current_pos]*filter_dim[current_pos+2] + left_upper[current_pos] {
                                    current_data_elem[current_pos] = left_upper[current_pos];
                                    if current_pos > 0 {
                                        current_pos -= 1;
                                    } else {
                                        break;
                                    }
                                } else {
                                    break;
                                }
                            };
                        }
                    };
        
                    // update for next prodsum position
                    let mut current_pos = left_upper.len()-1;
                    loop {
                        left_upper[current_pos] += stride[current_pos];
                        let mut compare_pos = padded_dim[current_pos] - start_point[current_pos]*2;
                        if filter_dim[current_pos+2] % 2 == 0 {
                            compare_pos += 1;
                        }
                        if left_upper[current_pos] >= compare_pos {
                            left_upper[current_pos] = 0;
                            if current_pos > 0 {
                                current_pos -= 1;
                            } else {
                                break;
                            }
                        } else {
                            break;
                        }
                    };
                }
            }
        
            //println!("columned_data: {:?}", columned_data);
            //println!("filter: {:?}", filter.get_data());
            //println!("sample_size*out_channels*output_inner_size: {:?}", sample_size*out_channels*output_inner_size);
            //println!("{:?}, {:?}, {:?}", sample_size*output_inner_size, out_channels, conv_size);
        
            let mut columned_result = vec![0.; sample_size*out_channels*output_inner_size];
            BlasAPI::<$a>::gemm(true, false, sample_size*output_inner_size, out_channels, conv_size,
                                 1., &columned_data, conv_size,
                                 filter.get_data(), conv_size,
                                 1., &mut columned_result, sample_size*output_inner_size
            );
        
            //println!("columned_result: {:?}", columned_result);
        
            let mut result_dim = output_tensor_size.to_vec();
	    result_dim.swap(0, 1);
            let mut result = GenTensor::<$a>::new_move(columned_result.to_vec(),
                                                        result_dim);
            let mut permute_dim: Vec<usize> = (0..output_tensor_size.len()).collect();
            permute_dim[0] = 1;
            permute_dim[1] = 0;
            result = result.permute(&permute_dim);
            result
            
        }
    }
}

#[cfg(feature = "use-blas-lapack")]
blas_conv!(f32, gemm_conv_f32);

#[cfg(feature = "use-blas-lapack")]
blas_conv!(f64, gemm_conv_f64);


#[cfg(test)]
mod tests {
    use crate::tensor_impl::gen_tensor::GenTensor;
    use super::*;


    // gemm_conv
    #[test]
    #[cfg(feature = "use-blas-lapack")]
    fn test_gemm_conv() {
        {
            let data = GenTensor::<f32>::arange(30).reshape(&vec![2, 3, 5]);
            let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 3, 3]);
            let stride = vec![1];
            let padding = vec![0];
            let dilation = vec![1];
            let padding_mode = PaddingMode::Zeros;
            let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
            println!("output size: {:?}", result.size());
            println!("output size: {:?}", result.get_data());
            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![312.0, 348.0, 384.0, 798.0, 915.0, 1032.0, 852.0, 888.0, 924.0, 2553.0, 2670.0, 2787.0], &vec![2, 2, 3]));
        }

        {
            let mut raw_data = Vec::new();
            for i in 0..75 {
                raw_data.push(i as f32);
            }
            let data = GenTensor::<f32>::new_raw(&raw_data, &vec![1, 3, 5, 5]);
            let mut raw_data = Vec::new();
            for i in 0..54 {
                raw_data.push(i as f32);
            }
            let filter = GenTensor::<f32>::new_raw(&raw_data, &vec![2, 3, 3, 3]);
            
            let stride = vec![1, 1];
            let padding = vec![0, 0];
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            
            let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
            
            println!("output size: {:?}", result.size());
            println!("output size: {:?}", result.get_data());
            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![15219.0, 15570.0, 15921.0, 16974.0, 17325.0, 17676.0, 18729.0, 19080.0, 19431.0, 37818.0, 38898.0, 39978.0, 43218.0, 44298.0, 45378.0, 48618.0, 49698.0, 50778.0], &vec![1, 2, 3, 3]));    
        }

        {
            let mut raw_data = Vec::new();
            for i in 0..60 {
                raw_data.push(i as f32);
            }
            let data = GenTensor::<f32>::new_raw(&raw_data, &vec![1, 3, 5, 4]);
            let mut raw_data = Vec::new();
            for i in 0..36 {
                raw_data.push(i as f32);
            }
            let filter = GenTensor::<f32>::new_raw(&raw_data, &vec![2, 3, 3, 2]);
            
            let stride = vec![1, 1];
            let padding = vec![0, 0];
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            
            let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
            
            println!("output size: {:?}", result.size());
            println!("output size: {:?}", result.get_data());
            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![5289.0, 5442.0, 5595.0, 5901.0, 6054.0, 6207.0, 6513.0, 6666.0, 6819.0, 13227.0, 13704.0, 14181.0, 15135.0, 15612.0, 16089.0, 17043.0, 17520.0, 17997.0], &vec![1, 2, 3, 3]));    
        }

        {
            let data = GenTensor::<f32>::arange(375).reshape(&vec![1, 3, 5, 5, 5]);
            let filter = GenTensor::<f32>::arange(162).reshape(&vec![2, 3, 3, 3, 3]);
            let stride = vec![1, 1, 1];
            let padding = vec![0, 0, 0];
            let dilation = vec![1, 1, 1];
            let padding_mode = PaddingMode::Zeros;
            let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
            println!("output size: {:?}", result.size());
            println!("output size: {:?}", result.get_data());
            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![700704.0, 703944.0, 707184.0, 716904.0, 720144.0, 723384.0, 733104.0, 736344.0, 739584.0, 781704.0, 784944.0, 788184.0, 797904.0, 801144.0, 804384.0, 814104.0, 817344.0, 820584.0, 862704.0, 865944.0, 869184.0, 878904.0, 882144.0, 885384.0, 895104.0, 898344.0, 901584.0, 1724220.0, 1734021.0, 1743822.0, 1773225.0, 1783026.0, 1792827.0, 1822230.0, 1832031.0, 1841832.0, 1969245.0, 1979046.0, 1988847.0, 2018250.0, 2028051.0, 2037852.0, 2067255.0, 2077056.0, 2086857.0, 2214270.0, 2224071.0, 2233872.0, 2263275.0, 2273076.0, 2282877.0, 2312280.0, 2322081.0, 2331882.0], &vec![1, 2, 3, 3, 3]));
        }

        {
            let data = GenTensor::<f32>::arange(16).reshape(&vec![1, 1, 4, 4]);
            let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
            let stride = vec![1, 1];
            let padding = vec![1, 1];
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
            println!("final output size: {:?}", result.size());
            println!("final output: {:?}", result.get_data());
            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![73.0, 121.0, 154.0, 103.0, 171.0, 258.0, 294.0, 186.0, 279.0, 402.0, 438.0, 270.0, 139.0, 187.0, 202.0, 113.0, 163.0, 283.0, 370.0, 265.0, 414.0, 663.0, 780.0, 537.0, 738.0, 1131.0, 1248.0, 837.0, 517.0, 781.0, 850.0, 563.0], &vec![1, 2, 4, 4]));
        }

        {
            let data = GenTensor::<f32>::arange(49).reshape(&vec![1, 1, 7, 7]);
            let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
            let stride = vec![2, 2];
            let padding = vec![0, 0];
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
            println!("final output size: {:?}", result.size());
            println!("final output: {:?}", result.get_data());
            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![420.0, 492.0, 564.0, 924.0, 996.0, 1068.0, 1428.0, 1500.0, 1572.0, 1068.0, 1302.0, 1536.0, 2706.0, 2940.0, 3174.0, 4344.0, 4578.0, 4812.0], &vec![1, 2, 3, 3]));
        }
        
        {
            
            let data = GenTensor::<f32>::arange(49).reshape(&vec![1, 1, 7, 7]);
            let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
            let stride = vec![2, 2];
            let padding = vec![0, 0];
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
            //println!("final output size: {:?}", result.size());
            //println!("final output: {:?}", result.get_data());
            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![420.0, 492.0, 564.0, 924.0, 996.0, 1068.0, 1428.0, 1500.0, 1572.0, 1068.0, 1302.0, 1536.0, 2706.0, 2940.0, 3174.0, 4344.0, 4578.0, 4812.0], &vec![1, 2, 3, 3]));
        }

        {
            
            let data = GenTensor::<f64>::arange(49).reshape(&vec![1, 1, 7, 7]);
            let filter = GenTensor::<f64>::arange(18).reshape(&vec![2, 1, 3, 3]);
            let stride = vec![2, 2];
            let padding = vec![0, 0];
            let dilation = vec![1, 1];
            let padding_mode = PaddingMode::Zeros;
            let result = gemm_conv_f64(&data, &filter, &stride, &padding, &dilation, padding_mode);
            //println!("final output size: {:?}", result.size());
            //println!("final output: {:?}", result.get_data());
            assert_eq!(result, GenTensor::<f64>::new_raw(&vec![420.0, 492.0, 564.0, 924.0, 996.0, 1068.0, 1428.0, 1500.0, 1572.0, 1068.0, 1302.0, 1536.0, 2706.0, 2940.0, 3174.0, 4344.0, 4578.0, 4812.0], &vec![1, 2, 3, 3]));
        }
    }
}