1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
(datatype Dim (Times Dim Dim) (NamedDim String) (Lit i64))
(rewrite (Times a (Times b c)) (Times (Times a b) c))
(rewrite (Times (Times a b) c) (Times a (Times b c)) )
(rewrite (Times (Lit i) (Lit j)) (Lit (* i j)))
(rewrite (Times a b) (Times b a))
(datatype MExpr
(MMul MExpr MExpr)
(Kron MExpr MExpr)
(NamedMat String)
(Id Dim)
; DSum
; HStack
; VStack
; Transpose
; Inverse
; Zero Math Math
; ScalarMul
)
; alternative encoding (type A) = (Matrix n m) may be more useful for "large story example"
(constructor nrows (MExpr) Dim)
(constructor ncols (MExpr) Dim)
(rewrite (nrows (Kron A B)) (Times (nrows A) (nrows B)))
(rewrite (ncols (Kron A B)) (Times (ncols A) (ncols B)))
(rewrite (nrows (MMul A B)) (nrows A))
(rewrite (ncols (MMul A B)) (ncols B))
(rewrite (nrows (Id n)) n)
(rewrite (ncols (Id n)) n)
(rewrite (MMul (Id n) A) A)
(rewrite (MMul A (Id n)) A)
(rewrite (MMul A (MMul B C)) (MMul (MMul A B) C))
(rewrite (MMul (MMul A B) C) (MMul A (MMul B C)))
(rewrite (Kron A (Kron B C)) (Kron (Kron A B) C))
(rewrite (Kron (Kron A B) C) (Kron A (Kron B C)))
(rewrite (Kron (MMul A C) (MMul B D)) (MMul (Kron A B) (Kron C D)))
(rewrite (MMul (Kron A B) (Kron C D))
(Kron (MMul A C) (MMul B D))
:when
((= (ncols A) (nrows C))
(= (ncols B) (nrows D)))
)
; demand
(rule ((= e (MMul A B)))
((ncols A)
(nrows A)
(ncols B)
(nrows B))
)
(rule ((= e (Kron A B)))
((ncols A)
(nrows A)
(ncols B)
(nrows B))
)
(let $n (NamedDim "n"))
(let $m (NamedDim "m"))
(let $p (NamedDim "p"))
(let $A (NamedMat "A"))
(let $B (NamedMat "B"))
(let $C (NamedMat "C"))
(union (nrows $A) $n)
(union (ncols $A) $n)
(union (nrows $B) $m)
(union (ncols $B) $m)
(union (nrows $C) $p)
(union (ncols $C) $p)
(let $ex1 (MMul (Kron (Id $n) $B) (Kron $A (Id $m))))
(let $rows1 (nrows $ex1))
(let $cols1 (ncols $ex1))
(run 20)
(check (= (nrows $B) $m))
(check (= (nrows (Kron (Id $n) $B)) (Times $n $m)))
(let $simple_ex1 (Kron $A $B))
(check (= $ex1 $simple_ex1))
(let $ex2 (MMul (Kron (Id $p) $C) (Kron $A (Id $m))))
(run 10)
(fail (check (= $ex2 (Kron $A $C))))