hugr_core/std_extensions/
logic.rs1use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::extension::{ConstFold, ConstFoldResult};
8use crate::ops::constant::ValueName;
9use crate::ops::Value;
10use crate::types::Signature;
11use crate::{
12 extension::{
13 prelude::bool_t,
14 simple_op::{try_from_name, MakeOpDef, MakeRegisteredOp, OpLoadError},
15 ExtensionId, OpDef, SignatureFunc,
16 },
17 ops,
18 types::type_param::TypeArg,
19 utils::sorted_consts,
20 Extension, IncomingPort,
21};
22use lazy_static::lazy_static;
23pub const FALSE_NAME: ValueName = ValueName::new_inline("FALSE");
25pub const TRUE_NAME: ValueName = ValueName::new_inline("TRUE");
27
28impl ConstFold for LogicOp {
29 fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
30 match self {
31 Self::And => {
32 let inps = read_inputs(consts)?;
33 let res = inps.iter().all(|x| *x);
34 (!res || inps.len() as u64 == 2)
36 .then_some(vec![(0.into(), ops::Value::from_bool(res))])
37 }
38 Self::Or => {
39 let inps = read_inputs(consts)?;
40 let res = inps.iter().any(|x| *x);
41 (res || inps.len() as u64 == 2)
43 .then_some(vec![(0.into(), ops::Value::from_bool(res))])
44 }
45 Self::Eq => {
46 let inps = read_inputs(consts)?;
47 let res = inps.iter().copied().reduce(|a, b| a == b)?;
48 (!res || inps.len() as u64 == 2)
50 .then_some(vec![(0.into(), ops::Value::from_bool(res))])
51 }
52 Self::Not => {
53 let inps = read_inputs(consts)?;
54 let res = inps.iter().all(|x| !*x);
55 (!res || inps.len() as u64 == 1)
56 .then_some(vec![(0.into(), ops::Value::from_bool(res))])
57 }
58 Self::Xor => {
59 let inps = read_inputs(consts)?;
60 let res = inps.iter().fold(false, |acc, x| acc ^ *x);
61 (inps.len() as u64 == 2).then_some(vec![(0.into(), ops::Value::from_bool(res))])
62 }
63 }
64 }
65}
66
67#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
69#[allow(missing_docs)]
70#[non_exhaustive]
71pub enum LogicOp {
72 And,
73 Or,
74 Eq,
75 Not,
76 Xor,
77}
78
79impl MakeOpDef for LogicOp {
80 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
81 match self {
82 LogicOp::And | LogicOp::Or | LogicOp::Eq | LogicOp::Xor => {
83 Signature::new(vec![bool_t(); 2], vec![bool_t()])
84 }
85 LogicOp::Not => Signature::new_endo(vec![bool_t()]),
86 }
87 .into()
88 }
89
90 fn extension_ref(&self) -> Weak<Extension> {
91 Arc::downgrade(&EXTENSION)
92 }
93
94 fn description(&self) -> String {
95 match self {
96 LogicOp::And => "logical 'and'",
97 LogicOp::Or => "logical 'or'",
98 LogicOp::Eq => "test if bools are equal",
99 LogicOp::Not => "logical 'not'",
100 LogicOp::Xor => "logical 'xor'",
101 }
102 .to_string()
103 }
104
105 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
106 try_from_name(op_def.name(), op_def.extension_id())
107 }
108
109 fn extension(&self) -> ExtensionId {
110 EXTENSION_ID.to_owned()
111 }
112
113 fn post_opdef(&self, def: &mut OpDef) {
114 def.set_constant_folder(*self);
115 }
116}
117
118pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("logic");
120pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
122
123fn extension() -> Arc<Extension> {
125 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
126 LogicOp::load_all_ops(extension, extension_ref).unwrap();
127
128 extension
129 .add_value(FALSE_NAME, ops::Value::false_val())
130 .unwrap();
131 extension
132 .add_value(TRUE_NAME, ops::Value::true_val())
133 .unwrap();
134 })
135}
136
137lazy_static! {
138 pub static ref EXTENSION: Arc<Extension> = extension();
140}
141
142impl MakeRegisteredOp for LogicOp {
143 fn extension_id(&self) -> ExtensionId {
144 EXTENSION_ID.to_owned()
145 }
146
147 fn extension_ref(&self) -> Weak<Extension> {
148 Arc::downgrade(&EXTENSION)
149 }
150}
151
152fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option<Vec<bool>> {
153 let true_val = ops::Value::true_val();
154 let false_val = ops::Value::false_val();
155 let inps: Option<Vec<bool>> = sorted_consts(consts)
156 .into_iter()
157 .map(|c| {
158 if c == &true_val {
159 Some(true)
160 } else if c == &false_val {
161 Some(false)
162 } else {
163 None
164 }
165 })
166 .collect();
167 let inps = inps?;
168 Some(inps)
169}
170
171#[cfg(test)]
172pub(crate) mod test {
173 use std::sync::Arc;
174
175 use super::{extension, LogicOp, FALSE_NAME, TRUE_NAME};
176 use crate::{
177 extension::{
178 prelude::bool_t,
179 simple_op::{MakeOpDef, MakeRegisteredOp},
180 },
181 ops::{NamedOp, Value},
182 Extension,
183 };
184
185 use rstest::rstest;
186 use strum::IntoEnumIterator;
187
188 #[test]
189 fn test_logic_extension() {
190 let r: Arc<Extension> = extension();
191 assert_eq!(r.name() as &str, "logic");
192 assert_eq!(r.operations().count(), 5);
193
194 for op in LogicOp::iter() {
195 assert_eq!(
196 LogicOp::from_def(r.get_op(&op.name()).unwrap(),).unwrap(),
197 op
198 );
199 }
200 }
201
202 #[test]
203 fn test_conversions() {
204 for o in LogicOp::iter() {
205 let ext_op = o.to_extension_op().unwrap();
206 assert_eq!(LogicOp::from_op(&ext_op).unwrap(), o);
207 }
208 }
209
210 #[test]
211 fn test_values() {
212 let r: Arc<Extension> = extension();
213 let false_val = r.get_value(&FALSE_NAME).unwrap();
214 let true_val = r.get_value(&TRUE_NAME).unwrap();
215
216 for v in [false_val, true_val] {
217 let simpl = v.typed_value().get_type();
218 assert_eq!(simpl, bool_t());
219 }
220 }
221
222 pub(crate) fn and_op() -> LogicOp {
224 LogicOp::And
225 }
226
227 pub(crate) fn or_op() -> LogicOp {
229 LogicOp::Or
230 }
231
232 #[rstest]
233 #[case(LogicOp::And, [true, true], true)]
234 #[case(LogicOp::And, [true, false], false)]
235 #[case(LogicOp::Or, [false, true], true)]
236 #[case(LogicOp::Or, [false, false], false)]
237 #[case(LogicOp::Eq, [true, false], false)]
238 #[case(LogicOp::Eq, [false, false], true)]
239 #[case(LogicOp::Not, [false], true)]
240 #[case(LogicOp::Not, [true], false)]
241 #[case(LogicOp::Xor, [true, false], true)]
242 #[case(LogicOp::Xor, [true, true], false)]
243 fn const_fold(
244 #[case] op: LogicOp,
245 #[case] ins: impl IntoIterator<Item = bool>,
246 #[case] out: bool,
247 ) {
248 use itertools::Itertools;
249
250 use crate::extension::ConstFold;
251 let in_vals = ins
252 .into_iter()
253 .enumerate()
254 .map(|(i, b)| (i.into(), Value::from_bool(b)))
255 .collect_vec();
256 assert_eq!(
257 Some(vec![(0.into(), Value::from_bool(out))]),
258 op.fold(&[(in_vals.len() as u64).into()], &in_vals)
259 );
260 }
261
262 #[rstest]
263 #[case(LogicOp::And, [Some(true), None], None)]
264 #[case(LogicOp::And, [Some(false), None], Some(false))]
265 #[case(LogicOp::Or, [None, Some(false)], None)]
266 #[case(LogicOp::Or, [None, Some(true)], Some(true))]
267 #[case(LogicOp::Eq, [None, Some(true)], None)]
268 #[case(LogicOp::Not, [None], None)]
269 #[case(LogicOp::Xor, [None, Some(true)], None)]
270 fn partial_const_fold(
271 #[case] op: LogicOp,
272 #[case] ins: impl IntoIterator<Item = Option<bool>>,
273 #[case] mb_out: Option<bool>,
274 ) {
275 use itertools::Itertools;
276
277 use crate::extension::ConstFold;
278 let in_vals0 = ins.into_iter().enumerate().collect_vec();
279 let num_args = in_vals0.len() as u64;
280 let in_vals = in_vals0
281 .into_iter()
282 .filter_map(|(i, mb_b)| mb_b.map(|b| (i.into(), Value::from_bool(b))))
283 .collect_vec();
284 assert_eq!(
285 mb_out.map(|out| vec![(0.into(), Value::from_bool(out))]),
286 op.fold(&[num_args.into()], &in_vals)
287 );
288 }
289}