1use crate::{AutogradMetaT, Error, Layout, Result, Storage, TensorOrScalar, WithDType};
2use super::Tensor;
3
4impl Tensor<bool> {
5 pub fn if_else<T: WithDType>(&self, true_val: impl Into<TensorOrScalar<T>>, false_val: impl Into<TensorOrScalar<T>>) -> Result<Tensor<T>> {
6 let true_val = true_val.into();
7 let false_val = false_val.into();
8
9 if let TensorOrScalar::Tensor(tensor) = &true_val && tensor.shape() != self.shape() {
10 Err(Error::ShapeMismatchSelect { mask: self.shape().clone(), who: "true_val", })?
11 }
12 if let TensorOrScalar::Tensor(tensor) = &false_val && tensor.shape() != self.shape() {
13 Err(Error::ShapeMismatchSelect { mask: self.shape().clone(), who: "false_val", })?
14 }
15
16 let (mut new_storage, tv) = match &true_val {
17 TensorOrScalar::Tensor(tensor) => (tensor.storage_read()?.copy(self.layout()), Some(tensor)),
18 TensorOrScalar::Scalar(v) => (Storage::full(*v, self.shape()), None),
19 };
20 let layout = Layout::contiguous(self.shape());
21
22 let fv = match &false_val {
23 TensorOrScalar::Tensor(false_val) => {
24 for ((result_index, condition), fv) in layout.storage_indices().zip(self.iter()?).zip(false_val.iter()?) {
25 if !condition {
26 new_storage.set_unchecked(result_index, fv);
27 }
28 }
29 Some(false_val)
30 }
31 TensorOrScalar::Scalar(fv) => {
32 for (result_index, condition) in layout.storage_indices().zip(self.iter()?) {
33 if !condition {
34 new_storage.set_unchecked(result_index, *fv);
35 }
36 }
37 None
38 }
39 };
40
41 let meta = T::AutogradMeta::on_ifelse_op(self, tv, fv);
42 Ok(Tensor::from_storage(new_storage, layout, meta))
43 }
44
45 pub fn true_count(&self) -> crate::Result<usize> {
46 self.iter().map(|i| i.filter(|v| *v).count())
47 }
48
49 pub fn false_count(&self) -> crate::Result<usize> {
50 self.iter().map(|i| i.filter(|v| !*v).count())
51 }
52}
53
54impl<T: WithDType> Tensor<T> {
55 pub fn masked_fill(&self, mask: &Tensor<bool>, value: impl Into<TensorOrScalar<T>>) -> Result<Tensor<T>> {
56 mask.if_else(value, self)
57 }
58}
59
60#[cfg(test)]
61mod test {
62 use crate::Tensor;
63
64 #[test]
65 fn test_if_else_scalar_values() {
66 let mask = Tensor::new(&[true, false, true, false]).unwrap();
67 let result = Tensor::if_else(&mask, 1, 0).unwrap();
68 assert_eq!(result.to_vec().unwrap(), [1, 0, 1, 0]);
69 }
70
71 #[test]
72 fn test_if_else_array_values() {
73 let mask = Tensor::new(&[true, false, true, false]).unwrap();
74
75 let true_vals = Tensor::new(&[10, 20, 30, 40]).unwrap();
76 let false_vals = Tensor::new(&[100, 200, 300, 400]).unwrap();
77
78 let result = mask.if_else(&true_vals, &false_vals).unwrap();
79 assert_eq!(result.to_vec().unwrap(), [10, 200, 30, 400]);
80 }
81
82 #[test]
83 fn test_if_else_mixed_values() {
84 let mask = Tensor::new(&[true, false, true, false]).unwrap();
85
86 let true_vals = 5; let false_vals = Tensor::new(&[100, 200, 300, 400]).unwrap();
88
89 let result = Tensor::if_else(&mask, true_vals, &false_vals).unwrap();
90 assert_eq!(result.to_vec().unwrap(), [5, 200, 5, 400]);
91 }
92
93 #[test]
94 fn test_if_else_shape_mismatch() {
95 let mask = Tensor::new(&[true, false, true]).unwrap();
96 let true_vals = Tensor::new(&[1, 2, 3, 4]).unwrap();
97 let false_vals = 0;
98
99 let result = Tensor::if_else(&mask, &true_vals, false_vals);
100 assert!(result.is_err());
101 }
102
103 #[test]
104 fn test_if_else_all_true_or_all_false() {
105 let mask = Tensor::new(&[true, true, true]).unwrap();
106 let result = Tensor::if_else(&mask, 1, 0).unwrap();
107 assert_eq!(result.to_vec().unwrap(), [1, 1, 1]);
108
109 let mask = Tensor::new(&[false, false, false]).unwrap();
110 let result = Tensor::if_else(&mask, 1, 0).unwrap();
111 assert_eq!(result.to_vec().unwrap(), [0, 0, 0]);
112 }
113
114 #[test]
115 fn test_if_else_2d_array_values() {
116 let mask = Tensor::new(&[[true, false, true], [false, true, false]]).unwrap();
117 let true_vals = Tensor::new(&[[10, 20, 30], [40, 50, 60]]).unwrap();
118 let false_vals = Tensor::new(&[[100, 200, 300], [400, 500, 600]]).unwrap();
119
120 let result = Tensor::if_else(&mask, &true_vals, &false_vals).unwrap();
121 assert_eq!(result.to_vec().unwrap(), [10, 200, 30, 400, 50, 600]);
122 }
123
124 #[test]
125 fn test_if_else_3d_mixed_values() {
126 let mask = Tensor::new(&[
127 [[true, false], [false, true]],
128 [[true, true], [false, false]]
129 ]).unwrap();
130 let true_val = 1; let false_vals = Tensor::new(&[
132 [[10, 20], [30, 40]],
133 [[50, 60], [70, 80]]
134 ]).unwrap();
135
136 let result = Tensor::if_else(&mask, true_val, &false_vals).unwrap();
137 assert_eq!(result.to_vec().unwrap(), [1, 20, 30, 1, 1, 1, 70, 80]);
138 }
139
140 #[test]
141 fn test_if_else() {
142 let scores = Tensor::new(&[
143 [45., 12., 34., 90.],
144 [31., 19., 84., 60.],
145 [55., 34., 44., 82.],
146 [85., 89., 54., 67.],
147 ]).unwrap();
148
149 let mask = scores.ge(60.).unwrap().and(&scores.le(85.).unwrap()).unwrap();
151
152 let if_elseed_scores = Tensor::if_else(&mask, &scores, -1.).unwrap();
153 println!("{}", if_elseed_scores);
154 }
155}