hugr_core/std_extensions/
logic.rs1use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::extension::{ConstFold, ConstFoldResult};
8use crate::ops::constant::ValueName;
9use crate::ops::{OpName, Value};
10use crate::types::Signature;
11use crate::{
12 Extension, IncomingPort,
13 extension::{
14 ExtensionId, OpDef, SignatureFunc,
15 prelude::bool_t,
16 simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError, try_from_name},
17 },
18 ops,
19 types::type_param::TypeArg,
20 utils::sorted_consts,
21};
22use lazy_static::lazy_static;
23pub const FALSE_NAME: ValueName = ValueName::new_inline("FALSE");
25pub const TRUE_NAME: ValueName = ValueName::new_inline("TRUE");
27
28impl ConstFold for LogicOp {
29 fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
30 match self {
31 Self::And => {
32 let inps = read_inputs(consts)?;
33 let res = inps.iter().all(|x| *x);
34 (!res || inps.len() as u64 == 2)
36 .then_some(vec![(0.into(), ops::Value::from_bool(res))])
37 }
38 Self::Or => {
39 let inps = read_inputs(consts)?;
40 let res = inps.iter().any(|x| *x);
41 (res || inps.len() as u64 == 2)
43 .then_some(vec![(0.into(), ops::Value::from_bool(res))])
44 }
45 Self::Eq => {
46 let inps = read_inputs(consts)?;
47 let res = inps.iter().copied().reduce(|a, b| a == b)?;
48 (!res || inps.len() as u64 == 2)
50 .then_some(vec![(0.into(), ops::Value::from_bool(res))])
51 }
52 Self::Not => {
53 let inps = read_inputs(consts)?;
54 let res = inps.iter().all(|x| !*x);
55 (!res || inps.len() as u64 == 1)
56 .then_some(vec![(0.into(), ops::Value::from_bool(res))])
57 }
58 Self::Xor => {
59 let inps = read_inputs(consts)?;
60 let res = inps.iter().fold(false, |acc, x| acc ^ *x);
61 (inps.len() as u64 == 2).then_some(vec![(0.into(), ops::Value::from_bool(res))])
62 }
63 }
64 }
65}
66
67#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
69#[allow(missing_docs)]
70#[non_exhaustive]
71pub enum LogicOp {
72 And,
73 Or,
74 Eq,
75 Not,
76 Xor,
77}
78
79impl MakeOpDef for LogicOp {
80 fn opdef_id(&self) -> OpName {
81 <&'static str>::from(self).into()
82 }
83
84 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
85 match self {
86 LogicOp::And | LogicOp::Or | LogicOp::Eq | LogicOp::Xor => {
87 Signature::new(vec![bool_t(); 2], vec![bool_t()])
88 }
89 LogicOp::Not => Signature::new_endo(vec![bool_t()]),
90 }
91 .into()
92 }
93
94 fn extension_ref(&self) -> Weak<Extension> {
95 Arc::downgrade(&EXTENSION)
96 }
97
98 fn description(&self) -> String {
99 match self {
100 LogicOp::And => "logical 'and'",
101 LogicOp::Or => "logical 'or'",
102 LogicOp::Eq => "test if bools are equal",
103 LogicOp::Not => "logical 'not'",
104 LogicOp::Xor => "logical 'xor'",
105 }
106 .to_string()
107 }
108
109 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
110 try_from_name(op_def.name(), op_def.extension_id())
111 }
112
113 fn extension(&self) -> ExtensionId {
114 EXTENSION_ID.clone()
115 }
116
117 fn post_opdef(&self, def: &mut OpDef) {
118 def.set_constant_folder(*self);
119 }
120}
121
122pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("logic");
124pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
126
127fn extension() -> Arc<Extension> {
129 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
130 LogicOp::load_all_ops(extension, extension_ref).unwrap();
131 })
132}
133
134lazy_static! {
135 pub static ref EXTENSION: Arc<Extension> = extension();
137}
138
139impl MakeRegisteredOp for LogicOp {
140 fn extension_id(&self) -> ExtensionId {
141 EXTENSION_ID.clone()
142 }
143
144 fn extension_ref(&self) -> Weak<Extension> {
145 Arc::downgrade(&EXTENSION)
146 }
147}
148
149fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option<Vec<bool>> {
150 let true_val = ops::Value::true_val();
151 let false_val = ops::Value::false_val();
152 let inps: Option<Vec<bool>> = sorted_consts(consts)
153 .into_iter()
154 .map(|c| {
155 if c == &true_val {
156 Some(true)
157 } else if c == &false_val {
158 Some(false)
159 } else {
160 None
161 }
162 })
163 .collect();
164 let inps = inps?;
165 Some(inps)
166}
167
168#[cfg(test)]
169pub(crate) mod test {
170 use std::sync::Arc;
171
172 use super::{LogicOp, extension};
173 use crate::{
174 Extension,
175 extension::simple_op::{MakeOpDef, MakeRegisteredOp},
176 ops::Value,
177 };
178
179 use rstest::rstest;
180 use strum::IntoEnumIterator;
181
182 #[test]
183 fn test_logic_extension() {
184 let r: Arc<Extension> = extension();
185 assert_eq!(r.name() as &str, "logic");
186 assert_eq!(r.operations().count(), 5);
187
188 for op in LogicOp::iter() {
189 assert_eq!(
190 LogicOp::from_def(r.get_op(op.into()).unwrap(),).unwrap(),
191 op
192 );
193 }
194 }
195
196 #[test]
197 fn test_conversions() {
198 for o in LogicOp::iter() {
199 let ext_op = o.to_extension_op().unwrap();
200 assert_eq!(LogicOp::from_op(&ext_op).unwrap(), o);
201 }
202 }
203
204 pub(crate) fn and_op() -> LogicOp {
206 LogicOp::And
207 }
208
209 pub(crate) fn or_op() -> LogicOp {
211 LogicOp::Or
212 }
213
214 #[rstest]
215 #[case(LogicOp::And, [true, true], true)]
216 #[case(LogicOp::And, [true, false], false)]
217 #[case(LogicOp::Or, [false, true], true)]
218 #[case(LogicOp::Or, [false, false], false)]
219 #[case(LogicOp::Eq, [true, false], false)]
220 #[case(LogicOp::Eq, [false, false], true)]
221 #[case(LogicOp::Not, [false], true)]
222 #[case(LogicOp::Not, [true], false)]
223 #[case(LogicOp::Xor, [true, false], true)]
224 #[case(LogicOp::Xor, [true, true], false)]
225 fn const_fold(
226 #[case] op: LogicOp,
227 #[case] ins: impl IntoIterator<Item = bool>,
228 #[case] out: bool,
229 ) {
230 use itertools::Itertools;
231
232 use crate::extension::ConstFold;
233 let in_vals = ins
234 .into_iter()
235 .enumerate()
236 .map(|(i, b)| (i.into(), Value::from_bool(b)))
237 .collect_vec();
238 assert_eq!(
239 Some(vec![(0.into(), Value::from_bool(out))]),
240 op.fold(&[(in_vals.len() as u64).into()], &in_vals)
241 );
242 }
243
244 #[rstest]
245 #[case(LogicOp::And, [Some(true), None], None)]
246 #[case(LogicOp::And, [Some(false), None], Some(false))]
247 #[case(LogicOp::Or, [None, Some(false)], None)]
248 #[case(LogicOp::Or, [None, Some(true)], Some(true))]
249 #[case(LogicOp::Eq, [None, Some(true)], None)]
250 #[case(LogicOp::Not, [None], None)]
251 #[case(LogicOp::Xor, [None, Some(true)], None)]
252 fn partial_const_fold(
253 #[case] op: LogicOp,
254 #[case] ins: impl IntoIterator<Item = Option<bool>>,
255 #[case] mb_out: Option<bool>,
256 ) {
257 use itertools::Itertools;
258
259 use crate::extension::ConstFold;
260 let in_vals0 = ins.into_iter().enumerate().collect_vec();
261 let num_args = in_vals0.len() as u64;
262 let in_vals = in_vals0
263 .into_iter()
264 .filter_map(|(i, mb_b)| mb_b.map(|b| (i.into(), Value::from_bool(b))))
265 .collect_vec();
266 assert_eq!(
267 mb_out.map(|out| vec![(0.into(), Value::from_bool(out))]),
268 op.fold(&[num_args.into()], &in_vals)
269 );
270 }
271}