1use runmat_builtins::Tensor;
6use runmat_builtins::Value;
7use runmat_macros::runtime_builtin;
8
9pub fn matrix_gt(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
11 if a.rows() != b.rows() || a.cols() != b.cols() {
12 return Err(format!(
13 "Matrix dimensions must agree: {}x{} > {}x{}",
14 a.rows(),
15 a.cols(),
16 b.rows(),
17 b.cols()
18 ));
19 }
20
21 let data: Vec<f64> = a
22 .data
23 .iter()
24 .zip(b.data.iter())
25 .map(|(x, y)| if x > y { 1.0 } else { 0.0 })
26 .collect();
27
28 Tensor::new_2d(data, a.rows(), a.cols())
29}
30
31pub fn matrix_ge(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
33 if a.rows() != b.rows() || a.cols() != b.cols() {
34 return Err(format!(
35 "Matrix dimensions must agree: {}x{} >= {}x{}",
36 a.rows(),
37 a.cols(),
38 b.rows(),
39 b.cols()
40 ));
41 }
42
43 let data: Vec<f64> = a
44 .data
45 .iter()
46 .zip(b.data.iter())
47 .map(|(x, y)| if x >= y { 1.0 } else { 0.0 })
48 .collect();
49
50 Tensor::new_2d(data, a.rows(), a.cols())
51}
52
53pub fn matrix_lt(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
55 if a.rows() != b.rows() || a.cols() != b.cols() {
56 return Err(format!(
57 "Matrix dimensions must agree: {}x{} < {}x{}",
58 a.rows(),
59 a.cols(),
60 b.rows(),
61 b.cols()
62 ));
63 }
64
65 let data: Vec<f64> = a
66 .data
67 .iter()
68 .zip(b.data.iter())
69 .map(|(x, y)| if x < y { 1.0 } else { 0.0 })
70 .collect();
71
72 Tensor::new_2d(data, a.rows(), a.cols())
73}
74
75pub fn matrix_le(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
77 if a.rows() != b.rows() || a.cols() != b.cols() {
78 return Err(format!(
79 "Matrix dimensions must agree: {}x{} <= {}x{}",
80 a.rows(),
81 a.cols(),
82 b.rows(),
83 b.cols()
84 ));
85 }
86
87 let data: Vec<f64> = a
88 .data
89 .iter()
90 .zip(b.data.iter())
91 .map(|(x, y)| if x <= y { 1.0 } else { 0.0 })
92 .collect();
93
94 Tensor::new_2d(data, a.rows(), a.cols())
95}
96
97pub fn matrix_eq(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
99 if a.rows() != b.rows() || a.cols() != b.cols() {
100 return Err(format!(
101 "Matrix dimensions must agree: {}x{} == {}x{}",
102 a.rows(),
103 a.cols(),
104 b.rows(),
105 b.cols()
106 ));
107 }
108
109 let data: Vec<f64> = a
110 .data
111 .iter()
112 .zip(b.data.iter())
113 .map(|(x, y)| {
114 if (x - y).abs() < f64::EPSILON {
115 1.0
116 } else {
117 0.0
118 }
119 })
120 .collect();
121
122 Tensor::new_2d(data, a.rows(), a.cols())
123}
124
125pub fn matrix_ne(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
127 if a.rows() != b.rows() || a.cols() != b.cols() {
128 return Err(format!(
129 "Matrix dimensions must agree: {}x{} != {}x{}",
130 a.rows(),
131 a.cols(),
132 b.rows(),
133 b.cols()
134 ));
135 }
136
137 let data: Vec<f64> = a
138 .data
139 .iter()
140 .zip(b.data.iter())
141 .map(|(x, y)| {
142 if (x - y).abs() >= f64::EPSILON {
143 1.0
144 } else {
145 0.0
146 }
147 })
148 .collect();
149
150 Tensor::new_2d(data, a.rows(), a.cols())
151}
152
153#[runtime_builtin(name = "gt")]
155fn gt_builtin(a: f64, b: f64) -> Result<f64, String> {
156 Ok(if a > b { 1.0 } else { 0.0 })
157}
158
159#[runtime_builtin(name = "ge")]
160fn ge_builtin(a: f64, b: f64) -> Result<f64, String> {
161 Ok(if a >= b { 1.0 } else { 0.0 })
162}
163
164#[runtime_builtin(name = "lt")]
165fn lt_builtin(a: f64, b: f64) -> Result<f64, String> {
166 Ok(if a < b { 1.0 } else { 0.0 })
167}
168
169#[runtime_builtin(name = "le")]
170fn le_builtin(a: f64, b: f64) -> Result<f64, String> {
171 Ok(if a <= b { 1.0 } else { 0.0 })
172}
173
174#[runtime_builtin(name = "eq")]
175fn eq_builtin(a: Value, b: Value) -> Result<Value, String> {
176 match (a, b) {
177 (Value::HandleObject(ha), Value::HandleObject(hb)) => {
179 let pa = unsafe { ha.target.as_raw() } as usize;
180 let pb = unsafe { hb.target.as_raw() } as usize;
181 Ok(Value::Num(if pa == pb { 1.0 } else { 0.0 }))
182 }
183 (Value::HandleObject(ha), other) => {
184 let pb = match other {
185 Value::HandleObject(hb) => (unsafe { hb.target.as_raw() }) as usize,
186 _ => 0usize,
187 };
188 let pa = (unsafe { ha.target.as_raw() }) as usize;
189 Ok(Value::Num(if pa == pb && pb != 0 { 1.0 } else { 0.0 }))
190 }
191 (other, Value::HandleObject(hb)) => {
192 let pa = match other {
193 Value::HandleObject(ha) => (unsafe { ha.target.as_raw() }) as usize,
194 _ => 0usize,
195 };
196 let pb = (unsafe { hb.target.as_raw() }) as usize;
197 Ok(Value::Num(if pa == pb && pa != 0 { 1.0 } else { 0.0 }))
198 }
199 (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
201 Ok(Value::Num(((ar == br) && (ai == bi)) as i32 as f64))
202 }
203 (Value::Complex(ar, ai), Value::Num(bn)) => {
204 Ok(Value::Num(((ai == 0.0) && (ar == bn)) as i32 as f64))
205 }
206 (Value::Num(an), Value::Complex(br, bi)) => {
207 Ok(Value::Num(((bi == 0.0) && (br == an)) as i32 as f64))
208 }
209 (Value::CharArray(ca), Value::CharArray(cb)) => {
210 if ca.rows != cb.rows || ca.cols != cb.cols {
211 return Err("shape mismatch for char array comparison".to_string());
212 }
213 let out: Vec<f64> = ca
214 .data
215 .iter()
216 .zip(cb.data.iter())
217 .map(|(x, y)| if x == y { 1.0 } else { 0.0 })
218 .collect();
219 Ok(Value::Tensor(
220 Tensor::new(out, vec![ca.rows, cb.cols]).map_err(|e| format!("eq: {e}"))?,
221 ))
222 }
223 (Value::CharArray(ca), Value::String(s)) => {
224 let ss: String = ca.data.iter().collect();
225 Ok(Value::Num(if ss == s { 1.0 } else { 0.0 }))
226 }
227 (Value::String(s), Value::CharArray(ca)) => {
228 let ss: String = ca.data.iter().collect();
229 Ok(Value::Num(if s == ss { 1.0 } else { 0.0 }))
230 }
231 (Value::StringArray(sa), Value::StringArray(sb)) => {
232 if sa.shape != sb.shape {
233 return Err("shape mismatch for string array comparison".to_string());
234 }
235 let out: Vec<f64> = sa
236 .data
237 .iter()
238 .zip(sb.data.iter())
239 .map(|(x, y)| if x == y { 1.0 } else { 0.0 })
240 .collect();
241 Ok(Value::Tensor(
242 Tensor::new(out, sa.shape).map_err(|e| format!("eq: {e}"))?,
243 ))
244 }
245 (Value::StringArray(sa), Value::String(s)) => {
246 let out: Vec<f64> = sa
247 .data
248 .iter()
249 .map(|x| if x == &s { 1.0 } else { 0.0 })
250 .collect();
251 Ok(Value::Tensor(
252 Tensor::new(out, sa.shape).map_err(|e| format!("eq: {e}"))?,
253 ))
254 }
255 (Value::String(s), Value::StringArray(sa)) => {
256 let out: Vec<f64> = sa
257 .data
258 .iter()
259 .map(|x| if &s == x { 1.0 } else { 0.0 })
260 .collect();
261 Ok(Value::Tensor(
262 Tensor::new(out, sa.shape).map_err(|e| format!("eq: {e}"))?,
263 ))
264 }
265 (Value::String(a), Value::String(b)) => Ok(Value::Num(if a == b { 1.0 } else { 0.0 })),
266 (Value::Num(a), Value::Num(b)) => Ok(Value::Num(if (a - b).abs() < f64::EPSILON {
267 1.0
268 } else {
269 0.0
270 })),
271 (Value::Int(a), Value::Int(b)) => {
272 Ok(Value::Num(if a.to_i64() == b.to_i64() { 1.0 } else { 0.0 }))
273 }
274 (Value::Int(a), Value::Num(b)) => {
275 Ok(Value::Num(if (a.to_f64() - b).abs() < f64::EPSILON {
276 1.0
277 } else {
278 0.0
279 }))
280 }
281 (Value::Num(a), Value::Int(b)) => {
282 Ok(Value::Num(if (a - b.to_f64()).abs() < f64::EPSILON {
283 1.0
284 } else {
285 0.0
286 }))
287 }
288 (a, b) => {
289 let aa: f64 = (&a).try_into()?;
290 let bb: f64 = (&b).try_into()?;
291 Ok(Value::Num(if (aa - bb).abs() < f64::EPSILON {
292 1.0
293 } else {
294 0.0
295 }))
296 }
297 }
298}
299
300#[runtime_builtin(name = "ne")]
301fn ne_builtin(a: Value, b: Value) -> Result<Value, String> {
302 match (a, b) {
303 (Value::HandleObject(ha), Value::HandleObject(hb)) => {
305 let pa = unsafe { ha.target.as_raw() } as usize;
306 let pb = unsafe { hb.target.as_raw() } as usize;
307 Ok(Value::Num(if pa != pb { 1.0 } else { 0.0 }))
308 }
309 (Value::HandleObject(ha), other) => {
310 let pb = match other {
311 Value::HandleObject(hb) => (unsafe { hb.target.as_raw() }) as usize,
312 _ => 0usize,
313 };
314 let pa = (unsafe { ha.target.as_raw() }) as usize;
315 Ok(Value::Num(if pa != pb || pb == 0 { 1.0 } else { 0.0 }))
316 }
317 (other, Value::HandleObject(hb)) => {
318 let pa = match other {
319 Value::HandleObject(ha) => (unsafe { ha.target.as_raw() }) as usize,
320 _ => 0usize,
321 };
322 let pb = (unsafe { hb.target.as_raw() }) as usize;
323 Ok(Value::Num(if pa != pb || pa == 0 { 1.0 } else { 0.0 }))
324 }
325 (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
327 Ok(Value::Num(((ar != br) || (ai != bi)) as i32 as f64))
328 }
329 (Value::Complex(ar, ai), Value::Num(bn)) => {
330 Ok(Value::Num(((ai != 0.0) || (ar != bn)) as i32 as f64))
331 }
332 (Value::Num(an), Value::Complex(br, bi)) => {
333 Ok(Value::Num(((bi != 0.0) || (br != an)) as i32 as f64))
334 }
335 (Value::CharArray(ca), Value::CharArray(cb)) => {
336 if ca.rows != cb.rows || ca.cols != cb.cols {
337 return Err("shape mismatch for char array comparison".to_string());
338 }
339 let out: Vec<f64> = ca
340 .data
341 .iter()
342 .zip(cb.data.iter())
343 .map(|(x, y)| if x != y { 1.0 } else { 0.0 })
344 .collect();
345 Ok(Value::Tensor(
346 Tensor::new(out, vec![ca.rows, cb.cols]).map_err(|e| format!("ne: {e}"))?,
347 ))
348 }
349 (Value::CharArray(ca), Value::String(s)) => {
350 let ss: String = ca.data.iter().collect();
351 Ok(Value::Num(if ss != s { 1.0 } else { 0.0 }))
352 }
353 (Value::String(s), Value::CharArray(ca)) => {
354 let ss: String = ca.data.iter().collect();
355 Ok(Value::Num(if s != ss { 1.0 } else { 0.0 }))
356 }
357 (Value::StringArray(sa), Value::StringArray(sb)) => {
358 if sa.shape != sb.shape {
359 return Err("shape mismatch for string array comparison".to_string());
360 }
361 let out: Vec<f64> = sa
362 .data
363 .iter()
364 .zip(sb.data.iter())
365 .map(|(x, y)| if x != y { 1.0 } else { 0.0 })
366 .collect();
367 Ok(Value::Tensor(
368 Tensor::new(out, sa.shape).map_err(|e| format!("ne: {e}"))?,
369 ))
370 }
371 (Value::StringArray(sa), Value::String(s)) => {
372 let out: Vec<f64> = sa
373 .data
374 .iter()
375 .map(|x| if x != &s { 1.0 } else { 0.0 })
376 .collect();
377 Ok(Value::Tensor(
378 Tensor::new(out, sa.shape).map_err(|e| format!("ne: {e}"))?,
379 ))
380 }
381 (Value::String(s), Value::StringArray(sa)) => {
382 let out: Vec<f64> = sa
383 .data
384 .iter()
385 .map(|x| if &s != x { 1.0 } else { 0.0 })
386 .collect();
387 Ok(Value::Tensor(
388 Tensor::new(out, sa.shape).map_err(|e| format!("ne: {e}"))?,
389 ))
390 }
391 (Value::String(a), Value::String(b)) => Ok(Value::Num(if a != b { 1.0 } else { 0.0 })),
392 (Value::Num(a), Value::Num(b)) => Ok(Value::Num(if (a - b).abs() >= f64::EPSILON {
393 1.0
394 } else {
395 0.0
396 })),
397 (Value::Int(a), Value::Int(b)) => {
398 Ok(Value::Num(if a.to_i64() != b.to_i64() { 1.0 } else { 0.0 }))
399 }
400 (Value::Int(a), Value::Num(b)) => {
401 Ok(Value::Num(if (a.to_f64() - b).abs() >= f64::EPSILON {
402 1.0
403 } else {
404 0.0
405 }))
406 }
407 (Value::Num(a), Value::Int(b)) => {
408 Ok(Value::Num(if (a - b.to_f64()).abs() >= f64::EPSILON {
409 1.0
410 } else {
411 0.0
412 }))
413 }
414 (a, b) => {
415 let aa: f64 = (&a).try_into()?;
416 let bb: f64 = (&b).try_into()?;
417 Ok(Value::Num(if (aa - bb).abs() >= f64::EPSILON {
418 1.0
419 } else {
420 0.0
421 }))
422 }
423 }
424}