1use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8use crate::op::Op;
9use crate::types::{AxisType, BinaryOp, ConstValue};
10use crate::uop::UOp;
11
12impl UOp {
13 pub fn const_factor(&self) -> i64 {
18 match &self.op {
19 Op::Const(cv) => match &cv.0 {
20 ConstValue::Int(i) => *i,
21 ConstValue::UInt(u) => *u as i64,
22 _ => 1,
23 },
24 Op::VConst { values } => values
26 .iter()
27 .filter_map(|v| match v {
28 ConstValue::Int(i) => Some(*i),
29 ConstValue::UInt(u) => Some(*u as i64),
30 _ => None,
31 })
32 .map(|v| v.abs())
33 .reduce(gcd)
34 .unwrap_or(1),
35 Op::Binary(BinaryOp::Mul, a, b) => {
37 if let Op::Const(cv) = &a.op
38 && let ConstValue::Int(i) = cv.0
39 {
40 return i;
41 }
42 if let Op::Const(cv) = &b.op
43 && let ConstValue::Int(i) = cv.0
44 {
45 return i;
46 }
47 1
48 }
49 Op::Binary(BinaryOp::Add, a, b) => gcd(a.const_factor().abs(), b.const_factor().abs()),
50 _ => 1,
51 }
52 }
53
54 pub fn divides(self: &Arc<Self>, v: &Arc<Self>) -> Option<Arc<Self>> {
59 if let Op::Const(cv) = v.op()
60 && let ConstValue::Int(divisor) = cv.0
61 {
62 return self.divides_int(divisor);
63 }
64 None
65 }
66
67 pub fn divides_int(self: &Arc<Self>, v: i64) -> Option<Arc<Self>> {
72 if v == 1 {
73 return Some(Arc::clone(self));
74 }
75 if v == 0 {
76 return None;
77 }
78 match self.op() {
79 Op::Const(cv) => {
80 let ConstValue::Int(val) = cv.0 else { return None };
81 if val % v == 0 { Some(Self::const_(self.dtype(), ConstValue::Int(val / v))) } else { None }
82 }
83 Op::VConst { values } => {
85 let divided: Option<Vec<ConstValue>> = values
86 .iter()
87 .map(|val| match val {
88 ConstValue::Int(i) if i % v == 0 => Some(ConstValue::Int(i / v)),
89 _ => None,
90 })
91 .collect();
92 divided.map(|v| UOp::vconst(v, self.dtype().scalar_dtype()))
93 }
94 Op::Binary(BinaryOp::Add, a, b) => {
95 let d0 = a.divides_int(v)?;
96 let d1 = b.divides_int(v)?;
97 d0.try_add(&d1).ok()
98 }
99 Op::Binary(BinaryOp::Mul, a, b) => {
100 if let Some(d0) = a.divides_int(v) {
101 return d0.try_mul(b).ok();
102 }
103 if let Some(d1) = b.divides_int(v) {
104 return a.try_mul(&d1).ok();
105 }
106 None
107 }
108 _ => None,
109 }
110 }
111
112 pub fn divide_exact(self: &Arc<Self>, v: &Arc<Self>) -> Option<Arc<Self>> {
117 if Arc::ptr_eq(self, v) {
118 return Some(self.const_like(1i64));
119 }
120 if let Op::Const(cv) = v.op()
121 && let ConstValue::Int(d) = cv.0
122 {
123 return self.divides_int(d);
124 }
125 if let Op::Binary(BinaryOp::Add, a, b) = self.op() {
126 let d0 = a.divide_exact(v)?;
127 let d1 = b.divide_exact(v)?;
128 return d0.try_add(&d1).ok();
129 }
130 if let Op::Binary(BinaryOp::Mul, a, b) = self.op() {
131 if let Some(d) = a.divide_exact(v) {
132 return d.try_mul(b).ok();
133 }
134 if let Some(d) = b.divide_exact(v) {
135 return a.try_mul(&d).ok();
136 }
137 }
138 None
139 }
140
141 pub fn symbolic_gcd(uops: &[Arc<Self>]) -> Arc<Self> {
148 assert!(!uops.is_empty(), "symbolic_gcd requires at least one uop");
149
150 let decomp: Vec<(Arc<Self>, i64)> = uops
152 .iter()
153 .map(|u| {
154 let f = u.const_factor();
155 let term = if f == 1 || f == 0 {
156 Arc::clone(u)
157 } else {
158 u.divides_int(f).unwrap_or_else(|| u.const_like(1i64))
159 };
160 (term, f)
161 })
162 .collect();
163
164 let counters: Vec<HashMap<*const Self, (Arc<Self>, usize)>> = decomp
166 .iter()
167 .map(|(term, _)| {
168 let mut counter: HashMap<*const Self, (Arc<Self>, usize)> = HashMap::new();
169 for factor in term.split_uop(BinaryOp::Mul) {
170 let ptr = Arc::as_ptr(&factor);
171 counter.entry(ptr).and_modify(|(_, c)| *c += 1).or_insert((factor, 1));
172 }
173 counter
174 })
175 .collect();
176
177 let mut common = counters[0].clone();
179 for other in &counters[1..] {
180 common.retain(|ptr, (_, count)| {
181 if let Some((_, other_count)) = other.get(ptr) {
182 *count = (*count).min(*other_count);
183 true
184 } else {
185 false
186 }
187 });
188 }
189
190 let numeric = decomp.iter().map(|(_, f)| f.abs()).reduce(gcd).unwrap_or(1);
192
193 let mut result = uops[0].const_like(numeric);
195 for (factor, count) in common.values() {
196 if let Op::Const(cv) = factor.op()
198 && matches!(cv.0, ConstValue::Int(1))
199 {
200 continue;
201 }
202 for _ in 0..*count {
203 result = result.try_mul(factor).expect("symbolic_gcd: mul failed");
204 }
205 }
206
207 result
208 }
209
210 pub fn pop_const(self: &Arc<Self>, op: BinaryOp) -> (Arc<Self>, Option<ConstValue>) {
223 if let Op::Binary(self_op, a, b) = self.op()
224 && *self_op == op
225 {
226 if let Op::Const(cv) = b.op() {
228 return (a.clone(), Some(cv.0));
229 }
230 if op.is_commutative()
232 && let Op::Const(cv) = a.op()
233 {
234 return (b.clone(), Some(cv.0));
235 }
236 }
237
238 (self.clone(), None)
239 }
240
241 pub fn split_uop(self: &Arc<Self>, sep: BinaryOp) -> Vec<Arc<Self>> {
253 let mut result = Vec::new();
254 let mut stack = vec![self.clone()];
255
256 while let Some(node) = stack.pop() {
257 if let Op::Binary(op, a, b) = node.op()
258 && *op == sep
259 {
260 stack.push(b.clone());
262 stack.push(a.clone());
263 continue;
264 }
265 result.push(node);
266 }
267
268 result
269 }
270
271 pub fn backward_slice_ids(self: &Arc<Self>) -> &HashSet<u64> {
277 use crate::uop::cached_property::CachedProperty;
278 use crate::uop::properties::BackwardSliceProperty;
279 BackwardSliceProperty::get(self)
280 }
281
282 pub fn backward_slice(self: &Arc<Self>) -> Vec<Arc<Self>> {
287 let mut visited = HashSet::new();
288 let mut result = Vec::new();
289 let mut stack = vec![self.clone()];
290
291 while let Some(node) = stack.pop() {
292 let ptr = Arc::as_ptr(&node);
293
294 if visited.contains(&ptr) {
295 continue;
296 }
297
298 visited.insert(ptr);
299 result.push(node.clone());
300
301 node.op.map_child(|child| {
303 stack.push(child.clone());
304 });
305 }
306
307 result
308 }
309
310 pub fn divisible_by(self: &Arc<Self>, amount: usize) -> Option<usize> {
323 if let Op::Range { end, .. } = self.op() {
325 if let Op::Const(cv) = end.op()
327 && let ConstValue::Int(sz) = cv.0
328 && sz > 0
329 && (sz as usize).is_multiple_of(amount)
330 {
331 return Some((sz as usize) / amount);
332 }
333
334 let factor = end.const_factor();
336 if factor > 0 && (factor as usize).is_multiple_of(amount) {
337 return Some((factor as usize) / amount);
338 }
339 }
340
341 if let Op::Const(cv) = self.op()
343 && let ConstValue::Int(val) = cv.0
344 && val > 0
345 && (val as usize).is_multiple_of(amount)
346 {
347 return Some((val as usize) / amount);
348 }
349
350 None
351 }
352
353 pub fn with_axis_type(self: &Arc<Self>, new_type: AxisType) -> Arc<Self> {
370 if let Op::Range { end, axis_id, .. } = self.op() {
371 Self::range_axis(end.clone(), *axis_id, new_type)
372 } else {
373 panic!("with_axis_type() called on non-RANGE operation: {:?}", self.op);
374 }
375 }
376
377 pub fn get_idx(self: &Arc<Self>) -> Arc<Self> {
399 use crate::types::TernaryOp;
400
401 match self.op() {
402 Op::Ternary(TernaryOp::Where, _, true_val, false_val) if Self::is_invalid_marker(false_val) => {
403 true_val.clone()
405 }
406 _ => self.clone(),
407 }
408 }
409
410 pub fn get_valid(self: &Arc<Self>) -> Arc<Self> {
434 use crate::types::TernaryOp;
435 use morok_dtype::DType;
436
437 match self.op() {
438 Op::Ternary(TernaryOp::Where, cond, _, false_val) if Self::is_invalid_marker(false_val) => {
439 cond.clone()
441 }
442 Op::Invalid => {
443 Self::const_(DType::Bool, ConstValue::Bool(false))
445 }
446 _ => {
447 Self::const_(DType::Bool, ConstValue::Bool(true))
449 }
450 }
451 }
452
453 pub fn is_invalid_marker(uop: &Arc<Self>) -> bool {
462 match uop.op() {
463 Op::Invalid => true,
464 Op::Vectorize { elements } => {
465 !elements.is_empty() && elements.iter().all(|e| matches!(e.op(), Op::Invalid))
466 }
467 _ => false,
468 }
469 }
470
471 pub fn invalid_marker() -> Arc<Self> {
489 use morok_dtype::DType;
490
491 Self::new(Op::Invalid, DType::Index)
493 }
494
495 pub fn is_increasing(self: &Arc<Self>) -> bool {
525 match self.op() {
526 Op::Range { .. } | Op::Const(_) | Op::DefineVar { .. } => true,
528
529 Op::Binary(BinaryOp::Add, a, b) => a.is_increasing() && b.is_increasing(),
531
532 Op::Binary(BinaryOp::Mul | BinaryOp::Idiv, a, b) => {
534 if let Op::Const(cv) = b.op() {
535 matches!(cv.0, ConstValue::Int(n) if n >= 0) && a.is_increasing()
536 } else {
537 false
538 }
539 }
540
541 _ => false,
542 }
543 }
544}
545
546pub fn gcd(a: i64, b: i64) -> i64 {
549 let (mut a, mut b) = (a.abs(), b.abs());
550 while b != 0 {
551 let temp = b;
552 b = a % b;
553 a = temp;
554 }
555 a
556}
557
558#[allow(dead_code)] trait BinaryOpExt {
561 fn is_commutative(&self) -> bool;
562}
563
564impl BinaryOpExt for BinaryOp {
565 fn is_commutative(&self) -> bool {
566 matches!(
567 self,
568 BinaryOp::Add
569 | BinaryOp::Mul
570 | BinaryOp::And
571 | BinaryOp::Or
572 | BinaryOp::Xor
573 | BinaryOp::Max
574 | BinaryOp::Eq
575 | BinaryOp::Ne
576 )
577 }
578}
579
580#[cfg(test)]
581mod tests {
582 use super::*;
583 use morok_dtype::DType;
584
585 #[test]
586 fn test_const_factor_constant() {
587 let c = UOp::const_(DType::Int32, ConstValue::Int(6));
588 assert_eq!(c.const_factor(), 6);
589 }
590
591 #[test]
592 fn test_const_factor_multiplication() {
593 let x = UOp::var("x", DType::Int32, 0, 100);
594 let c = UOp::const_(DType::Int32, ConstValue::Int(6));
595 let mul = x.try_mul(&c).unwrap();
596 assert_eq!(mul.const_factor(), 6);
597 }
598
599 #[test]
600 fn test_const_factor_addition() {
601 let c1 = UOp::const_(DType::Int32, ConstValue::Int(6));
602 let c2 = UOp::const_(DType::Int32, ConstValue::Int(9));
603 let add = c1.try_add(&c2).unwrap();
604 assert_eq!(add.const_factor(), 3); }
606
607 #[test]
608 fn test_divides_constant_exact() {
609 let c = UOp::const_(DType::Int32, ConstValue::Int(12));
610 let divisor = UOp::const_(DType::Int32, ConstValue::Int(3));
611 let result = c.divides(&divisor);
612
613 assert!(result.is_some());
614 if let Some(r) = result {
615 if let Op::Const(cv) = r.op() {
616 assert_eq!(cv.0, ConstValue::Int(4));
617 } else {
618 panic!("Expected constant result");
619 }
620 }
621 }
622
623 #[test]
624 fn test_divides_constant_not_exact() {
625 let c = UOp::const_(DType::Int32, ConstValue::Int(10));
626 let divisor = UOp::const_(DType::Int32, ConstValue::Int(3));
627 let result = c.divides(&divisor);
628
629 assert!(result.is_none());
630 }
631
632 #[test]
633 fn test_pop_const_with_constant() {
634 let x = UOp::var("x", DType::Int32, 0, 100);
635 let c = UOp::const_(DType::Int32, ConstValue::Int(5));
636 let add = x.try_add(&c).unwrap();
637
638 let (rest, const_val) = add.pop_const(BinaryOp::Add);
639
640 assert!(Arc::ptr_eq(&rest, &x));
641 assert_eq!(const_val, Some(ConstValue::Int(5)));
642 }
643
644 #[test]
645 fn test_pop_const_without_constant() {
646 let x = UOp::var("x", DType::Int32, 0, 100);
647 let y = UOp::var("y", DType::Int32, 0, 100);
648 let add = x.try_add(&y).unwrap();
649
650 let (rest, const_val) = add.pop_const(BinaryOp::Add);
651
652 assert!(Arc::ptr_eq(&rest, &add));
653 assert_eq!(const_val, None);
654 }
655
656 #[test]
657 fn test_split_uop_chain() {
658 let x = UOp::var("x", DType::Int32, 0, 100);
659 let y = UOp::var("y", DType::Int32, 0, 100);
660 let z = UOp::var("z", DType::Int32, 0, 100);
661
662 let xy = x.try_add(&y).unwrap();
664 let xyz = xy.try_add(&z).unwrap();
665
666 let terms = xyz.split_uop(BinaryOp::Add);
667
668 assert_eq!(terms.len(), 3);
669 assert!(Arc::ptr_eq(&terms[0], &x));
670 assert!(Arc::ptr_eq(&terms[1], &y));
671 assert!(Arc::ptr_eq(&terms[2], &z));
672 }
673
674 #[test]
675 fn test_split_uop_single() {
676 let x = UOp::var("x", DType::Int32, 0, 100);
677 let terms = x.split_uop(BinaryOp::Add);
678
679 assert_eq!(terms.len(), 1);
680 assert!(Arc::ptr_eq(&terms[0], &x));
681 }
682
683 #[test]
684 fn test_gcd() {
685 assert_eq!(gcd(12, 8), 4);
686 assert_eq!(gcd(17, 19), 1);
687 assert_eq!(gcd(100, 50), 50);
688 assert_eq!(gcd(-12, 8), 4);
689 assert_eq!(gcd(12, -8), 4);
690 assert_eq!(gcd(-12, -8), 4);
691 }
692
693 #[test]
694 fn test_symbolic_gcd_numeric_only() {
695 let x = UOp::var("x", DType::Index, 0, 10);
697 let y = UOp::var("y", DType::Index, 0, 10);
698 let six = UOp::const_(DType::Index, ConstValue::Int(6));
699 let four = UOp::const_(DType::Index, ConstValue::Int(4));
700 let a = x.try_mul(&six).unwrap(); let b = y.try_mul(&four).unwrap(); let g = UOp::symbolic_gcd(&[a, b]);
703 if let Op::Const(cv) = g.op() {
704 assert_eq!(cv.0, ConstValue::Int(2));
705 } else {
706 panic!("Expected constant GCD, got: {}", g.tree());
707 }
708 }
709
710 #[test]
711 fn test_symbolic_gcd_with_common_factor() {
712 let x = UOp::var("x", DType::Index, 0, 10);
714 let six = UOp::const_(DType::Index, ConstValue::Int(6));
715 let four = UOp::const_(DType::Index, ConstValue::Int(4));
716 let a = x.try_mul(&six).unwrap(); let b = x.try_mul(&four).unwrap(); let g = UOp::symbolic_gcd(&[a, b]);
719 assert!(matches!(g.op(), Op::Binary(BinaryOp::Mul, _, _)), "Expected MUL, got: {}", g.tree());
721 }
722
723 #[test]
724 fn test_const_factor_mul_only_immediate() {
725 let x = UOp::var("x", DType::Index, 0, 10);
727 let y = UOp::var("y", DType::Index, 0, 10);
728 let six = UOp::const_(DType::Index, ConstValue::Int(6));
729 let four = UOp::const_(DType::Index, ConstValue::Int(4));
730 let a = x.try_mul(&six).unwrap(); let b = y.try_mul(&four).unwrap(); let ab = a.try_mul(&b).unwrap(); assert_eq!(ab.const_factor(), 1);
735 }
736
737 #[test]
738 fn test_const_factor_vconst() {
739 let vc = UOp::vconst(
740 vec![ConstValue::Int(6), ConstValue::Int(12), ConstValue::Int(18), ConstValue::Int(24)],
741 DType::Int64,
742 );
743 assert_eq!(vc.const_factor(), 6); }
745
746 #[test]
747 fn test_const_factor_vconst_no_common() {
748 let vc = UOp::vconst(vec![ConstValue::Int(7), ConstValue::Int(11)], DType::Int64);
749 assert_eq!(vc.const_factor(), 1); }
751
752 #[test]
753 fn test_divides_int_vconst() {
754 let vc = UOp::vconst(vec![ConstValue::Int(6), ConstValue::Int(12)], DType::Int64);
755 let result = vc.divides_int(3);
756 assert!(result.is_some());
757 if let Some(r) = result {
758 if let Op::VConst { values } = r.op() {
759 assert_eq!(values, &[ConstValue::Int(2), ConstValue::Int(4)]);
760 } else {
761 panic!("Expected VConst result");
762 }
763 }
764 }
765
766 #[test]
767 fn test_divides_int_vconst_not_divisible() {
768 let vc = UOp::vconst(
769 vec![
770 ConstValue::Int(6),
771 ConstValue::Int(7), ],
773 DType::Int64,
774 );
775 assert!(vc.divides_int(3).is_none());
776 }
777
778 #[test]
779 fn test_is_increasing_const() {
780 let c = UOp::const_(DType::Int32, ConstValue::Int(5));
781 assert!(c.is_increasing());
782
783 let neg = UOp::const_(DType::Int32, ConstValue::Int(-5));
784 assert!(neg.is_increasing()); }
786
787 #[test]
788 fn test_is_increasing_add() {
789 let a = UOp::const_(DType::Int32, ConstValue::Int(5));
790 let b = UOp::const_(DType::Int32, ConstValue::Int(3));
791 let sum = a.try_add(&b).unwrap();
792 assert!(sum.is_increasing());
793 }
794
795 #[test]
796 fn test_is_increasing_mul_positive_const() {
797 let x = UOp::var("x", DType::Int32, 0, 100);
798 let two = UOp::const_(DType::Int32, ConstValue::Int(2));
799 let scaled = x.try_mul(&two).unwrap();
800 assert!(scaled.is_increasing());
801 }
802
803 #[test]
804 fn test_is_increasing_mul_negative_const() {
805 let x = UOp::var("x", DType::Int32, 0, 100);
806 let neg = UOp::const_(DType::Int32, ConstValue::Int(-2));
807 let scaled = x.try_mul(&neg).unwrap();
808 assert!(!scaled.is_increasing()); }
810
811 #[test]
812 fn test_is_increasing_idiv_positive_const() {
813 let x = UOp::var("x", DType::Int32, 0, 100);
814 let two = UOp::const_(DType::Int32, ConstValue::Int(2));
815 let divided = x.idiv(&two);
816 assert!(divided.is_increasing());
817 }
818
819 #[test]
820 fn test_is_increasing_complex() {
821 let x = UOp::var("x", DType::Int32, 0, 100);
823 let five = UOp::const_(DType::Int32, ConstValue::Int(5));
824 let two = UOp::const_(DType::Int32, ConstValue::Int(2));
825 let sum = x.try_add(&five).unwrap();
826 let scaled = sum.try_mul(&two).unwrap();
827 assert!(scaled.is_increasing());
828 }
829}