runmat_runtime/
comparison.rs

1//! Comparison operations for language-compatible logic
2//!
3//! Implements comparison operators returning logical matrices/values.
4
5use runmat_builtins::Tensor;
6
7/// Element-wise greater than comparison
8pub fn matrix_gt(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
9    if a.rows() != b.rows() || a.cols() != b.cols() {
10        return Err(format!(
11            "Matrix dimensions must agree: {}x{} > {}x{}",
12            a.rows(),
13            a.cols(),
14            b.rows(),
15            b.cols()
16        ));
17    }
18
19    let data: Vec<f64> = a
20        .data
21        .iter()
22        .zip(b.data.iter())
23        .map(|(x, y)| if x > y { 1.0 } else { 0.0 })
24        .collect();
25
26    Tensor::new_2d(data, a.rows(), a.cols())
27}
28
29/// Element-wise greater than or equal comparison
30pub fn matrix_ge(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
31    if a.rows() != b.rows() || a.cols() != b.cols() {
32        return Err(format!(
33            "Matrix dimensions must agree: {}x{} >= {}x{}",
34            a.rows(),
35            a.cols(),
36            b.rows(),
37            b.cols()
38        ));
39    }
40
41    let data: Vec<f64> = a
42        .data
43        .iter()
44        .zip(b.data.iter())
45        .map(|(x, y)| if x >= y { 1.0 } else { 0.0 })
46        .collect();
47
48    Tensor::new_2d(data, a.rows(), a.cols())
49}
50
51/// Element-wise less than comparison
52pub fn matrix_lt(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
53    if a.rows() != b.rows() || a.cols() != b.cols() {
54        return Err(format!(
55            "Matrix dimensions must agree: {}x{} < {}x{}",
56            a.rows(),
57            a.cols(),
58            b.rows(),
59            b.cols()
60        ));
61    }
62
63    let data: Vec<f64> = a
64        .data
65        .iter()
66        .zip(b.data.iter())
67        .map(|(x, y)| if x < y { 1.0 } else { 0.0 })
68        .collect();
69
70    Tensor::new_2d(data, a.rows(), a.cols())
71}
72
73/// Element-wise less than or equal comparison
74pub fn matrix_le(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
75    if a.rows() != b.rows() || a.cols() != b.cols() {
76        return Err(format!(
77            "Matrix dimensions must agree: {}x{} <= {}x{}",
78            a.rows(),
79            a.cols(),
80            b.rows(),
81            b.cols()
82        ));
83    }
84
85    let data: Vec<f64> = a
86        .data
87        .iter()
88        .zip(b.data.iter())
89        .map(|(x, y)| if x <= y { 1.0 } else { 0.0 })
90        .collect();
91
92    Tensor::new_2d(data, a.rows(), a.cols())
93}
94
95/// Element-wise equality comparison
96pub fn matrix_eq(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
97    if a.rows() != b.rows() || a.cols() != b.cols() {
98        return Err(format!(
99            "Matrix dimensions must agree: {}x{} == {}x{}",
100            a.rows(),
101            a.cols(),
102            b.rows(),
103            b.cols()
104        ));
105    }
106
107    let data: Vec<f64> = a
108        .data
109        .iter()
110        .zip(b.data.iter())
111        .map(|(x, y)| {
112            if (x - y).abs() < f64::EPSILON {
113                1.0
114            } else {
115                0.0
116            }
117        })
118        .collect();
119
120    Tensor::new_2d(data, a.rows(), a.cols())
121}
122
123/// Element-wise inequality comparison
124pub fn matrix_ne(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
125    if a.rows() != b.rows() || a.cols() != b.cols() {
126        return Err(format!(
127            "Matrix dimensions must agree: {}x{} != {}x{}",
128            a.rows(),
129            a.cols(),
130            b.rows(),
131            b.cols()
132        ));
133    }
134
135    let data: Vec<f64> = a
136        .data
137        .iter()
138        .zip(b.data.iter())
139        .map(|(x, y)| {
140            if (x - y).abs() >= f64::EPSILON {
141                1.0
142            } else {
143                0.0
144            }
145        })
146        .collect();
147
148    Tensor::new_2d(data, a.rows(), a.cols())
149}