runmat_runtime/
comparison.rs

1//! Comparison operations for language-compatible logic
2//!
3//! Implements comparison operators returning logical matrices/values.
4
5use runmat_builtins::Tensor;
6use runmat_builtins::Value;
7use runmat_macros::runtime_builtin;
8
9/// Element-wise greater than comparison
10pub fn matrix_gt(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
11    if a.rows() != b.rows() || a.cols() != b.cols() {
12        return Err(format!(
13            "Matrix dimensions must agree: {}x{} > {}x{}",
14            a.rows(),
15            a.cols(),
16            b.rows(),
17            b.cols()
18        ));
19    }
20
21    let data: Vec<f64> = a
22        .data
23        .iter()
24        .zip(b.data.iter())
25        .map(|(x, y)| if x > y { 1.0 } else { 0.0 })
26        .collect();
27
28    Tensor::new_2d(data, a.rows(), a.cols())
29}
30
31/// Element-wise greater than or equal comparison
32pub fn matrix_ge(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
33    if a.rows() != b.rows() || a.cols() != b.cols() {
34        return Err(format!(
35            "Matrix dimensions must agree: {}x{} >= {}x{}",
36            a.rows(),
37            a.cols(),
38            b.rows(),
39            b.cols()
40        ));
41    }
42
43    let data: Vec<f64> = a
44        .data
45        .iter()
46        .zip(b.data.iter())
47        .map(|(x, y)| if x >= y { 1.0 } else { 0.0 })
48        .collect();
49
50    Tensor::new_2d(data, a.rows(), a.cols())
51}
52
53/// Element-wise less than comparison
54pub fn matrix_lt(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
55    if a.rows() != b.rows() || a.cols() != b.cols() {
56        return Err(format!(
57            "Matrix dimensions must agree: {}x{} < {}x{}",
58            a.rows(),
59            a.cols(),
60            b.rows(),
61            b.cols()
62        ));
63    }
64
65    let data: Vec<f64> = a
66        .data
67        .iter()
68        .zip(b.data.iter())
69        .map(|(x, y)| if x < y { 1.0 } else { 0.0 })
70        .collect();
71
72    Tensor::new_2d(data, a.rows(), a.cols())
73}
74
75/// Element-wise less than or equal comparison
76pub fn matrix_le(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
77    if a.rows() != b.rows() || a.cols() != b.cols() {
78        return Err(format!(
79            "Matrix dimensions must agree: {}x{} <= {}x{}",
80            a.rows(),
81            a.cols(),
82            b.rows(),
83            b.cols()
84        ));
85    }
86
87    let data: Vec<f64> = a
88        .data
89        .iter()
90        .zip(b.data.iter())
91        .map(|(x, y)| if x <= y { 1.0 } else { 0.0 })
92        .collect();
93
94    Tensor::new_2d(data, a.rows(), a.cols())
95}
96
97/// Element-wise equality comparison
98pub fn matrix_eq(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
99    if a.rows() != b.rows() || a.cols() != b.cols() {
100        return Err(format!(
101            "Matrix dimensions must agree: {}x{} == {}x{}",
102            a.rows(),
103            a.cols(),
104            b.rows(),
105            b.cols()
106        ));
107    }
108
109    let data: Vec<f64> = a
110        .data
111        .iter()
112        .zip(b.data.iter())
113        .map(|(x, y)| {
114            if (x - y).abs() < f64::EPSILON {
115                1.0
116            } else {
117                0.0
118            }
119        })
120        .collect();
121
122    Tensor::new_2d(data, a.rows(), a.cols())
123}
124
125/// Element-wise inequality comparison
126pub fn matrix_ne(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
127    if a.rows() != b.rows() || a.cols() != b.cols() {
128        return Err(format!(
129            "Matrix dimensions must agree: {}x{} != {}x{}",
130            a.rows(),
131            a.cols(),
132            b.rows(),
133            b.cols()
134        ));
135    }
136
137    let data: Vec<f64> = a
138        .data
139        .iter()
140        .zip(b.data.iter())
141        .map(|(x, y)| {
142            if (x - y).abs() >= f64::EPSILON {
143                1.0
144            } else {
145                0.0
146            }
147        })
148        .collect();
149
150    Tensor::new_2d(data, a.rows(), a.cols())
151}
152
153// Built-in comparison functions
154#[runtime_builtin(name = "gt")]
155fn gt_builtin(a: f64, b: f64) -> Result<f64, String> {
156    Ok(if a > b { 1.0 } else { 0.0 })
157}
158
159#[runtime_builtin(name = "ge")]
160fn ge_builtin(a: f64, b: f64) -> Result<f64, String> {
161    Ok(if a >= b { 1.0 } else { 0.0 })
162}
163
164#[runtime_builtin(name = "lt")]
165fn lt_builtin(a: f64, b: f64) -> Result<f64, String> {
166    Ok(if a < b { 1.0 } else { 0.0 })
167}
168
169#[runtime_builtin(name = "le")]
170fn le_builtin(a: f64, b: f64) -> Result<f64, String> {
171    Ok(if a <= b { 1.0 } else { 0.0 })
172}
173
174#[runtime_builtin(name = "eq")]
175fn eq_builtin(a: Value, b: Value) -> Result<Value, String> {
176    match (a, b) {
177        // Handle identity semantics
178        (Value::HandleObject(ha), Value::HandleObject(hb)) => {
179            let pa = unsafe { ha.target.as_raw() } as usize;
180            let pb = unsafe { hb.target.as_raw() } as usize;
181            Ok(Value::Num(if pa == pb { 1.0 } else { 0.0 }))
182        }
183        (Value::HandleObject(ha), other) => {
184            let pb = match other {
185                Value::HandleObject(hb) => (unsafe { hb.target.as_raw() }) as usize,
186                _ => 0usize,
187            };
188            let pa = (unsafe { ha.target.as_raw() }) as usize;
189            Ok(Value::Num(if pa == pb && pb != 0 { 1.0 } else { 0.0 }))
190        }
191        (other, Value::HandleObject(hb)) => {
192            let pa = match other {
193                Value::HandleObject(ha) => (unsafe { ha.target.as_raw() }) as usize,
194                _ => 0usize,
195            };
196            let pb = (unsafe { hb.target.as_raw() }) as usize;
197            Ok(Value::Num(if pa == pb && pa != 0 { 1.0 } else { 0.0 }))
198        }
199        // Complex equality: element-wise on scalars; uses exact float compare for now
200        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
201            Ok(Value::Num(((ar == br) && (ai == bi)) as i32 as f64))
202        }
203        (Value::Complex(ar, ai), Value::Num(bn)) => {
204            Ok(Value::Num(((ai == 0.0) && (ar == bn)) as i32 as f64))
205        }
206        (Value::Num(an), Value::Complex(br, bi)) => {
207            Ok(Value::Num(((bi == 0.0) && (br == an)) as i32 as f64))
208        }
209        (Value::CharArray(ca), Value::CharArray(cb)) => {
210            if ca.rows != cb.rows || ca.cols != cb.cols {
211                return Err("shape mismatch for char array comparison".to_string());
212            }
213            let out: Vec<f64> = ca
214                .data
215                .iter()
216                .zip(cb.data.iter())
217                .map(|(x, y)| if x == y { 1.0 } else { 0.0 })
218                .collect();
219            Ok(Value::Tensor(
220                Tensor::new(out, vec![ca.rows, cb.cols]).map_err(|e| format!("eq: {e}"))?,
221            ))
222        }
223        (Value::CharArray(ca), Value::String(s)) => {
224            let ss: String = ca.data.iter().collect();
225            Ok(Value::Num(if ss == s { 1.0 } else { 0.0 }))
226        }
227        (Value::String(s), Value::CharArray(ca)) => {
228            let ss: String = ca.data.iter().collect();
229            Ok(Value::Num(if s == ss { 1.0 } else { 0.0 }))
230        }
231        (Value::StringArray(sa), Value::StringArray(sb)) => {
232            if sa.shape != sb.shape {
233                return Err("shape mismatch for string array comparison".to_string());
234            }
235            let out: Vec<f64> = sa
236                .data
237                .iter()
238                .zip(sb.data.iter())
239                .map(|(x, y)| if x == y { 1.0 } else { 0.0 })
240                .collect();
241            Ok(Value::Tensor(
242                Tensor::new(out, sa.shape).map_err(|e| format!("eq: {e}"))?,
243            ))
244        }
245        (Value::StringArray(sa), Value::String(s)) => {
246            let out: Vec<f64> = sa
247                .data
248                .iter()
249                .map(|x| if x == &s { 1.0 } else { 0.0 })
250                .collect();
251            Ok(Value::Tensor(
252                Tensor::new(out, sa.shape).map_err(|e| format!("eq: {e}"))?,
253            ))
254        }
255        (Value::String(s), Value::StringArray(sa)) => {
256            let out: Vec<f64> = sa
257                .data
258                .iter()
259                .map(|x| if &s == x { 1.0 } else { 0.0 })
260                .collect();
261            Ok(Value::Tensor(
262                Tensor::new(out, sa.shape).map_err(|e| format!("eq: {e}"))?,
263            ))
264        }
265        (Value::String(a), Value::String(b)) => Ok(Value::Num(if a == b { 1.0 } else { 0.0 })),
266        (Value::Num(a), Value::Num(b)) => Ok(Value::Num(if (a - b).abs() < f64::EPSILON {
267            1.0
268        } else {
269            0.0
270        })),
271        (Value::Int(a), Value::Int(b)) => {
272            Ok(Value::Num(if a.to_i64() == b.to_i64() { 1.0 } else { 0.0 }))
273        }
274        (Value::Int(a), Value::Num(b)) => {
275            Ok(Value::Num(if (a.to_f64() - b).abs() < f64::EPSILON {
276                1.0
277            } else {
278                0.0
279            }))
280        }
281        (Value::Num(a), Value::Int(b)) => {
282            Ok(Value::Num(if (a - b.to_f64()).abs() < f64::EPSILON {
283                1.0
284            } else {
285                0.0
286            }))
287        }
288        (a, b) => {
289            let aa: f64 = (&a).try_into()?;
290            let bb: f64 = (&b).try_into()?;
291            Ok(Value::Num(if (aa - bb).abs() < f64::EPSILON {
292                1.0
293            } else {
294                0.0
295            }))
296        }
297    }
298}
299
300#[runtime_builtin(name = "ne")]
301fn ne_builtin(a: Value, b: Value) -> Result<Value, String> {
302    match (a, b) {
303        // Handle identity semantics
304        (Value::HandleObject(ha), Value::HandleObject(hb)) => {
305            let pa = unsafe { ha.target.as_raw() } as usize;
306            let pb = unsafe { hb.target.as_raw() } as usize;
307            Ok(Value::Num(if pa != pb { 1.0 } else { 0.0 }))
308        }
309        (Value::HandleObject(ha), other) => {
310            let pb = match other {
311                Value::HandleObject(hb) => (unsafe { hb.target.as_raw() }) as usize,
312                _ => 0usize,
313            };
314            let pa = (unsafe { ha.target.as_raw() }) as usize;
315            Ok(Value::Num(if pa != pb || pb == 0 { 1.0 } else { 0.0 }))
316        }
317        (other, Value::HandleObject(hb)) => {
318            let pa = match other {
319                Value::HandleObject(ha) => (unsafe { ha.target.as_raw() }) as usize,
320                _ => 0usize,
321            };
322            let pb = (unsafe { hb.target.as_raw() }) as usize;
323            Ok(Value::Num(if pa != pb || pa == 0 { 1.0 } else { 0.0 }))
324        }
325        // Complex inequality
326        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
327            Ok(Value::Num(((ar != br) || (ai != bi)) as i32 as f64))
328        }
329        (Value::Complex(ar, ai), Value::Num(bn)) => {
330            Ok(Value::Num(((ai != 0.0) || (ar != bn)) as i32 as f64))
331        }
332        (Value::Num(an), Value::Complex(br, bi)) => {
333            Ok(Value::Num(((bi != 0.0) || (br != an)) as i32 as f64))
334        }
335        (Value::CharArray(ca), Value::CharArray(cb)) => {
336            if ca.rows != cb.rows || ca.cols != cb.cols {
337                return Err("shape mismatch for char array comparison".to_string());
338            }
339            let out: Vec<f64> = ca
340                .data
341                .iter()
342                .zip(cb.data.iter())
343                .map(|(x, y)| if x != y { 1.0 } else { 0.0 })
344                .collect();
345            Ok(Value::Tensor(
346                Tensor::new(out, vec![ca.rows, cb.cols]).map_err(|e| format!("ne: {e}"))?,
347            ))
348        }
349        (Value::CharArray(ca), Value::String(s)) => {
350            let ss: String = ca.data.iter().collect();
351            Ok(Value::Num(if ss != s { 1.0 } else { 0.0 }))
352        }
353        (Value::String(s), Value::CharArray(ca)) => {
354            let ss: String = ca.data.iter().collect();
355            Ok(Value::Num(if s != ss { 1.0 } else { 0.0 }))
356        }
357        (Value::StringArray(sa), Value::StringArray(sb)) => {
358            if sa.shape != sb.shape {
359                return Err("shape mismatch for string array comparison".to_string());
360            }
361            let out: Vec<f64> = sa
362                .data
363                .iter()
364                .zip(sb.data.iter())
365                .map(|(x, y)| if x != y { 1.0 } else { 0.0 })
366                .collect();
367            Ok(Value::Tensor(
368                Tensor::new(out, sa.shape).map_err(|e| format!("ne: {e}"))?,
369            ))
370        }
371        (Value::StringArray(sa), Value::String(s)) => {
372            let out: Vec<f64> = sa
373                .data
374                .iter()
375                .map(|x| if x != &s { 1.0 } else { 0.0 })
376                .collect();
377            Ok(Value::Tensor(
378                Tensor::new(out, sa.shape).map_err(|e| format!("ne: {e}"))?,
379            ))
380        }
381        (Value::String(s), Value::StringArray(sa)) => {
382            let out: Vec<f64> = sa
383                .data
384                .iter()
385                .map(|x| if &s != x { 1.0 } else { 0.0 })
386                .collect();
387            Ok(Value::Tensor(
388                Tensor::new(out, sa.shape).map_err(|e| format!("ne: {e}"))?,
389            ))
390        }
391        (Value::String(a), Value::String(b)) => Ok(Value::Num(if a != b { 1.0 } else { 0.0 })),
392        (Value::Num(a), Value::Num(b)) => Ok(Value::Num(if (a - b).abs() >= f64::EPSILON {
393            1.0
394        } else {
395            0.0
396        })),
397        (Value::Int(a), Value::Int(b)) => {
398            Ok(Value::Num(if a.to_i64() != b.to_i64() { 1.0 } else { 0.0 }))
399        }
400        (Value::Int(a), Value::Num(b)) => {
401            Ok(Value::Num(if (a.to_f64() - b).abs() >= f64::EPSILON {
402                1.0
403            } else {
404                0.0
405            }))
406        }
407        (Value::Num(a), Value::Int(b)) => {
408            Ok(Value::Num(if (a - b.to_f64()).abs() >= f64::EPSILON {
409                1.0
410            } else {
411                0.0
412            }))
413        }
414        (a, b) => {
415            let aa: f64 = (&a).try_into()?;
416            let bb: f64 = (&b).try_into()?;
417            Ok(Value::Num(if (aa - bb).abs() >= f64::EPSILON {
418                1.0
419            } else {
420                0.0
421            }))
422        }
423    }
424}