mech_set/operations/
symmetric_difference.rs

1use crate::*;
2
3use indexmap::set::IndexSet;
4use mech_core::set::MechSet;
5
6// Symmetric Difference ------------------------------------------------------------------------
7
8#[derive(Debug)]
9struct SetSymDifferenceFxn {
10  lhs: Ref<MechSet>,
11  rhs: Ref<MechSet>,
12  out: Ref<MechSet>,
13}
14
15impl MechFunctionFactory for SetSymDifferenceFxn {
16  fn new(args: FunctionArgs) -> MResult<Box<dyn MechFunction>> {
17    match args {
18      FunctionArgs::Binary(out, arg1, arg2) => {
19        let lhs: Ref<MechSet> = unsafe { arg1.as_unchecked() }.clone();
20        let rhs: Ref<MechSet> = unsafe { arg2.as_unchecked() }.clone();
21        let out: Ref<MechSet> = unsafe { out.as_unchecked() }.clone();
22        Ok(Box::new(SetSymDifferenceFxn { lhs, rhs, out }))
23      },
24      _ => Err(MechError2::new(IncorrectNumberOfArguments { expected: 2, found: args.len() }, None).with_compiler_loc()),
25    }
26  }
27}
28
29impl MechFunctionImpl for SetSymDifferenceFxn {
30  fn solve(&self) {
31    unsafe {
32      // Get mutable reference to the output set
33      let out_ptr: &mut MechSet = &mut *(self.out.as_mut_ptr());
34
35      // Get references to lhs and rhs sets
36      let lhs_ptr: &MechSet = &*(self.lhs.as_ptr());
37      let rhs_ptr: &MechSet = &*(self.rhs.as_ptr());
38
39      // Clear the output set first
40      out_ptr.set.clear();
41
42      // Compute (lhs \ rhs) ∪ (rhs \ lhs) into output
43      out_ptr.set = lhs_ptr.set.symmetric_difference(&(rhs_ptr.set)).cloned().collect();
44
45      // Update metadata
46      out_ptr.num_elements = out_ptr.set.len();
47      out_ptr.kind = if out_ptr.set.len() > 0 {
48        out_ptr.set.iter().next().unwrap().kind()
49      } else {
50        ValueKind::Empty
51      };
52    }
53  }
54  fn out(&self) -> Value { Value::Set(self.out.clone()) }
55  fn to_string(&self) -> String { format!("{:#?}", self) }
56}
57
58#[cfg(feature = "compiler")]
59impl MechFunctionCompiler for SetSymDifferenceFxn {
60  fn compile(&self, ctx: &mut CompileCtx) -> MResult<Register> {
61    let name = format!("SetSymDifferenceFxn");
62    compile_binop!(name, self.out, self.lhs, self.rhs, ctx, FeatureFlag::Custom(hash_str("set/symmetric-difference")));
63  }
64}
65
66register_descriptor! {
67  FunctionDescriptor {
68    name: "SetSymDifferenceFxn",
69    ptr: SetSymDifferenceFxn::new,
70  }
71}
72
73fn set_sym_difference_fxn(lhs: Value, rhs: Value) -> MResult<Box<dyn MechFunction>> {
74  match (lhs, rhs) {
75    (Value::Set(lhs), Value::Set(rhs)) => {
76      Ok(Box::new(SetSymDifferenceFxn {
77        lhs: lhs.clone(),
78        rhs: rhs.clone(),
79        out: Ref::new(MechSet::new(
80          lhs.borrow().kind.clone(),
81          lhs.borrow().num_elements + rhs.borrow().num_elements
82        ))
83      }))
84    },
85    x => Err(MechError2::new(
86      UnhandledFunctionArgumentKind2 { arg: (x.0.kind(), x.1.kind()), fxn_name: "set/symmetric-difference".to_string() },
87      None
88    ).with_compiler_loc()),
89  }
90}
91
92pub struct SetSymmetricDifference {}
93impl NativeFunctionCompiler for SetSymmetricDifference {
94  fn compile(&self, arguments: &Vec<Value>) -> MResult<Box<dyn MechFunction>> {
95    if arguments.len() != 2 {
96      return Err(MechError2::new(IncorrectNumberOfArguments { expected: 2, found: arguments.len() }, None).with_compiler_loc());
97    }
98    let lhs = arguments[0].clone();
99    let rhs = arguments[1].clone();
100    match set_sym_difference_fxn(lhs.clone(), rhs.clone()) {
101      Ok(fxn) => Ok(fxn),
102      Err(_) => {
103        match (lhs, rhs) {
104          (Value::MutableReference(lhs), Value::MutableReference(rhs)) => { set_sym_difference_fxn(lhs.borrow().clone(), rhs.borrow().clone()) },
105          (lhs, Value::MutableReference(rhs)) => { set_sym_difference_fxn(lhs.clone(), rhs.borrow().clone()) },
106          (Value::MutableReference(lhs), rhs) => { set_sym_difference_fxn(lhs.borrow().clone(), rhs.clone()) },
107          x => Err(MechError2::new(
108            UnhandledFunctionArgumentKind2 { arg: (x.0.kind(), x.1.kind()), fxn_name: "set/symmetric-difference".to_string() },
109            None
110          ).with_compiler_loc()),
111        }
112      }
113    }
114  }
115}
116
117register_descriptor! {
118  FunctionCompilerDescriptor {
119    name: "set/symmetric-difference",
120    ptr: &SetSymmetricDifference{},
121  }
122}