open_hypergraphs/lax/var/
operators.rs

1use crate::lax::hypergraph::Hyperedge;
2use crate::lax::open_hypergraph::OpenHypergraph;
3
4use super::var::*;
5
6use std::cell::RefCell;
7use std::ops::*;
8use std::rc::Rc;
9
10/// A general helper for constructing `m → n` maps
11pub fn operation<O: Clone, A: HasVar>(
12    // TODO: generalise to something which borrows a mutable OpenHypergraph?
13    builder: &Rc<RefCell<OpenHypergraph<O, A>>>,
14    vars: &[Var<O, A>],
15    result_types: Vec<O>, // types of output vars
16    op: A,
17) -> Vec<Var<O, A>> {
18    let mut nodes = Vec::with_capacity(vars.len());
19    for v in vars {
20        nodes.push(v.new_target());
21    }
22
23    let result_vars: Vec<Var<O, A>> = result_types
24        .into_iter()
25        .map(|t| Var::new(builder.clone(), t))
26        .collect();
27    let result_nodes = result_vars.iter().map(|v| v.new_source()).collect();
28
29    let mut term = builder.borrow_mut();
30    let _ = term.new_edge(
31        op,
32        Hyperedge {
33            sources: nodes,
34            targets: result_nodes,
35        },
36    );
37
38    result_vars
39}
40
41/// An `n → 1` operation, returning its sole target `Var`.
42pub fn fn_operation<O: Clone, A: HasVar>(
43    // TODO: generalise to something which borrows a mutable OpenHypergraph?
44    builder: &Rc<RefCell<OpenHypergraph<O, A>>>,
45    vars: &[Var<O, A>],
46    result_type: O,
47    op: A,
48) -> Var<O, A> {
49    let vs = operation(builder, vars, vec![result_type], op);
50    assert_eq!(vs.len(), 1);
51    vs.into_iter().next().unwrap()
52}
53
54/// Vars can be XORed when the underlying signature has an operation for 'xor'.
55pub trait HasBitXor<O, A> {
56    fn bitxor(lhs_type: O, rhs_type: O) -> (O, A);
57}
58
59impl<O: Clone, A: HasVar + HasBitXor<O, A>> BitXor for Var<O, A> {
60    type Output = Var<O, A>;
61
62    fn bitxor(self, rhs: Self) -> Self::Output {
63        // only difference between impls
64        let (result_label, op) = A::bitxor(self.label.clone(), rhs.label.clone());
65        fn_operation(&self.state.clone(), &[self, rhs], result_label, op)
66    }
67}
68
69// Macro to reduce boilerplate for binary operators
70macro_rules! define_binary_op {
71    ($trait_name:ident, $fn_name:ident, $has_trait_name:ident) => {
72        #[doc = r" Vars support this operator when the underlying signature has the appropriate operation."]
73        pub trait $has_trait_name<O, A> {
74            fn $fn_name(lhs_type: O, rhs_type: O) -> (O, A);
75        }
76
77        impl<O: Clone, A: HasVar + $has_trait_name<O, A>> $trait_name for Var<O, A> {
78            type Output = Var<O, A>;
79
80            fn $fn_name(self, rhs: Self) -> Self::Output {
81                let (result_label, op) = A::$fn_name(self.label.clone(), rhs.label.clone());
82                //binop(self, rhs, result_label, op)
83                fn_operation(&self.state.clone(), &[self, rhs], result_label, op)
84            }
85        }
86    };
87}
88
89// Macro to reduce boilerplate for unary operators
90macro_rules! define_unary_op {
91    ($trait_name:ident, $fn_name:ident, $has_trait_name:ident) => {
92        #[doc = r" Vars support this unary operator when the underlying signature has the appropriate operation."]
93        pub trait $has_trait_name<O, A> {
94            fn $fn_name(operand_type: O) -> (O, A);
95        }
96
97        impl<O: Clone, A: HasVar + $has_trait_name<O, A>> $trait_name for Var<O, A> {
98            type Output = Var<O, A>;
99
100            fn $fn_name(self) -> Self::Output {
101                let (result_label, op) = A::$fn_name(self.label.clone());
102                fn_operation(&self.state.clone(), &[self], result_label, op)
103            }
104        }
105    };
106}
107
108//define_binary_op!(BitXor, bitand, HasBitXor); // hand-written
109define_binary_op!(BitAnd, bitand, HasBitAnd);
110define_binary_op!(BitOr, bitor, HasBitOr);
111define_binary_op!(Shl, shl, HasShl);
112define_binary_op!(Shr, shr, HasShr);
113define_unary_op!(Not, not, HasNot);
114
115define_binary_op!(Add, add, HasAdd);
116define_binary_op!(Mul, mul, HasMul);
117define_binary_op!(Sub, sub, HasSub);
118define_binary_op!(Div, div, HasDiv);
119define_unary_op!(Neg, neg, HasNeg);