Skip to main content

lumen_core/tensor/
boolean.rs

1use crate::{AutogradMetaT, Error, Layout, Result, Storage, TensorOrScalar, WithDType};
2use super::Tensor;
3
4impl Tensor<bool> {
5    pub fn if_else<T: WithDType>(&self, true_val: impl Into<TensorOrScalar<T>>, false_val: impl Into<TensorOrScalar<T>>) -> Result<Tensor<T>> {
6        let true_val = true_val.into();
7        let false_val = false_val.into();
8
9        if let TensorOrScalar::Tensor(tensor) = &true_val && tensor.shape() != self.shape() {
10            Err(Error::ShapeMismatchSelect { mask: self.shape().clone(), who: "true_val", })?
11        }
12        if let TensorOrScalar::Tensor(tensor) = &false_val && tensor.shape() != self.shape() {
13            Err(Error::ShapeMismatchSelect { mask: self.shape().clone(), who: "false_val", })?
14        }
15
16        let (mut new_storage, tv) = match &true_val {
17            TensorOrScalar::Tensor(tensor) => (tensor.storage_read()?.copy(self.layout()), Some(tensor)),
18            TensorOrScalar::Scalar(v) => (Storage::full(*v, self.shape()), None),
19        };
20        let layout = Layout::contiguous(self.shape());
21
22        let fv = match &false_val {
23            TensorOrScalar::Tensor(false_val) => {
24                for ((result_index, condition), fv) in layout.storage_indices().zip(self.iter()?).zip(false_val.iter()?) {
25                    if !condition {
26                        new_storage.set_unchecked(result_index, fv);
27                    }
28                }
29                Some(false_val)
30            }
31            TensorOrScalar::Scalar(fv) => {
32                for (result_index, condition) in layout.storage_indices().zip(self.iter()?) {
33                    if !condition {
34                        new_storage.set_unchecked(result_index, *fv);
35                    }
36                }
37                None
38            }
39        };
40
41        let meta = T::AutogradMeta::on_ifelse_op(self, tv, fv);
42        Ok(Tensor::from_storage(new_storage, layout, meta))
43    }
44
45    pub fn true_count(&self) -> crate::Result<usize> {
46        self.iter().map(|i| i.filter(|v| *v).count())
47    }
48
49    pub fn false_count(&self) -> crate::Result<usize> {
50        self.iter().map(|i| i.filter(|v| !*v).count())
51    }
52}
53
54impl<T: WithDType> Tensor<T> {
55    pub fn masked_fill(&self, mask: &Tensor<bool>, value: impl Into<TensorOrScalar<T>>) -> Result<Tensor<T>> {
56        mask.if_else(value, self)
57    }
58}
59
60#[cfg(test)]
61mod test {
62    use crate::Tensor;
63
64    #[test]
65    fn test_if_else_scalar_values() {
66        let mask = Tensor::new(&[true, false, true, false]).unwrap();
67        let result = Tensor::if_else(&mask, 1, 0).unwrap();
68        assert_eq!(result.to_vec().unwrap(), [1, 0, 1, 0]);
69    }
70    
71    #[test]
72    fn test_if_else_array_values() {
73        let mask = Tensor::new(&[true, false, true, false]).unwrap();
74        
75        let true_vals = Tensor::new(&[10, 20, 30, 40]).unwrap();
76        let false_vals = Tensor::new(&[100, 200, 300, 400]).unwrap();
77        
78        let result = mask.if_else(&true_vals, &false_vals).unwrap();
79        assert_eq!(result.to_vec().unwrap(), [10, 200, 30, 400]);
80    }
81    
82    #[test]
83    fn test_if_else_mixed_values() {
84        let mask = Tensor::new(&[true, false, true, false]).unwrap();
85        
86        let true_vals = 5;  // 标量
87        let false_vals = Tensor::new(&[100, 200, 300, 400]).unwrap();
88        
89        let result = Tensor::if_else(&mask, true_vals, &false_vals).unwrap();
90        assert_eq!(result.to_vec().unwrap(), [5, 200, 5, 400]);
91    }
92    
93    #[test]
94    fn test_if_else_shape_mismatch() {
95        let mask = Tensor::new(&[true, false, true]).unwrap();
96        let true_vals = Tensor::new(&[1, 2, 3, 4]).unwrap();
97        let false_vals = 0;
98        
99        let result = Tensor::if_else(&mask, &true_vals, false_vals);
100        assert!(result.is_err());
101    }
102    
103    #[test]
104    fn test_if_else_all_true_or_all_false() {
105        let mask = Tensor::new(&[true, true, true]).unwrap();
106        let result = Tensor::if_else(&mask, 1, 0).unwrap();
107        assert_eq!(result.to_vec().unwrap(), [1, 1, 1]);
108    
109        let mask = Tensor::new(&[false, false, false]).unwrap();
110        let result = Tensor::if_else(&mask, 1, 0).unwrap();
111        assert_eq!(result.to_vec().unwrap(), [0, 0, 0]);
112    }
113
114    #[test]
115    fn test_if_else_2d_array_values() {
116        let mask = Tensor::new(&[[true, false, true], [false, true, false]]).unwrap();
117        let true_vals = Tensor::new(&[[10, 20, 30], [40, 50, 60]]).unwrap();
118        let false_vals = Tensor::new(&[[100, 200, 300], [400, 500, 600]]).unwrap();
119    
120        let result = Tensor::if_else(&mask, &true_vals, &false_vals).unwrap();
121        assert_eq!(result.to_vec().unwrap(), [10, 200, 30, 400, 50, 600]);
122    }
123    
124    #[test]
125    fn test_if_else_3d_mixed_values() {
126        let mask = Tensor::new(&[
127            [[true, false], [false, true]],
128            [[true, true], [false, false]]
129        ]).unwrap();
130        let true_val = 1;  // 标量
131        let false_vals = Tensor::new(&[
132            [[10, 20], [30, 40]],
133            [[50, 60], [70, 80]]
134        ]).unwrap();
135    
136        let result = Tensor::if_else(&mask, true_val, &false_vals).unwrap();
137        assert_eq!(result.to_vec().unwrap(), [1, 20, 30, 1, 1, 1, 70, 80]);
138    }
139    
140    #[test]
141    fn test_if_else() {
142        let scores = Tensor::new(&[
143            [45., 12., 34., 90.],
144            [31., 19., 84., 60.],
145            [55., 34., 44., 82.],
146            [85., 89., 54., 67.],
147        ]).unwrap();
148
149        // scores > 60 & scores < 85
150        let mask = scores.ge(60.).unwrap().and(&scores.le(85.).unwrap()).unwrap();
151
152        let if_elseed_scores = Tensor::if_else(&mask, &scores, -1.).unwrap();
153        println!("{}", if_elseed_scores);
154    }
155}