rstsr_core/tensor/
tensor2_impl.rs

1use crate::prelude_dev::*;
2
3impl<S> TensorBase<S, Ix2> {
4    pub fn nrow(&self) -> usize {
5        self.shape()[0]
6    }
7
8    pub fn ncol(&self) -> usize {
9        self.shape()[1]
10    }
11
12    /// Leading dimension in row-major order case.
13    ///
14    /// This function will not return any value if the layout is not row-major.
15    pub fn ld_row(&self) -> Option<usize> {
16        if !self.c_prefer() {
17            // leading dimension is only defined if not c-prefer
18            return None;
19        } else if self.shape()[0] == 1 {
20            // col-vector, leading dimension must be larger than dimension of col
21            return Some(self.shape()[1]);
22        } else {
23            // usual definition that leading dimension is stride of row
24            return Some(self.stride()[0] as usize);
25        }
26    }
27
28    /// Leading dimension in column-major order case.
29    ///
30    /// This function will not return any value if the layout is not
31    /// column-major.
32    pub fn ld_col(&self) -> Option<usize> {
33        if !self.f_prefer() {
34            // leading dimension is only defined if not f-prefer
35            return None;
36        } else if self.shape()[1] == 1 {
37            // row-vector, leading dimension must be larger than dimension of row
38            return Some(self.shape()[0]);
39        } else {
40            // usual definition that leading dimension is stride of col
41            return Some(self.stride()[1] as usize);
42        }
43    }
44
45    /// Leading dimension by order.
46    pub fn ld(&self, order: FlagOrder) -> Option<usize> {
47        match order {
48            ColMajor => self.ld_col(),
49            RowMajor => self.ld_row(),
50        }
51    }
52}
53
54#[cfg(test)]
55mod test {
56    use super::*;
57
58    #[test]
59    fn playground() {
60        let l = Layout::new([6, 1], [100, 10], 0).unwrap();
61        println!("{:?}", l.f_prefer());
62        println!("{:?}", l.c_prefer());
63        println!("{l:?}");
64    }
65}